From e09d49912f3ad9ebb12db862a25722bc1729819a Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Mon, 19 Feb 2024 14:49:54 +0100 Subject: [PATCH] Add token revocation Closes: https://todo.sr.ht/~emersion/sinwon/6 --- db.go | 5 +++++ main.go | 1 + oauth2.go | 99 ++++++++++++++++++++++++++++++++++++++++++++++------- diff --git a/db.go b/db.go index 2c3fa884a5717ee25cd85251e11843a465cac05b..bc159f78daad22db6c26fd5db240afb775ae7f6e 100644 --- a/db.go +++ b/db.go @@ -222,6 +222,11 @@ RETURNING id `, entityArgs(token)...).Scan(&token.ID) } +func (db *DB) DeleteAccessToken(ctx context.Context, id ID[*AccessToken]) error { + _, err := db.db.ExecContext(ctx, "DELETE FROM AccessToken WHERE id = ?", id) + return err +} + func (db *DB) RevokeAccessTokens(ctx context.Context, clientID ID[*Client], userID ID[*User]) error { _, err := db.db.ExecContext(ctx, ` DELETE FROM AccessToken diff --git a/main.go b/main.go index e1cc459efcb9a1ead2d603025ca915148ea6a7ab..11997a10f89c21321ecb48d55d88b07fe0852bd7 100644 --- a/main.go +++ b/main.go @@ -62,6 +62,7 @@ mux.Get("/.well-known/oauth-authorization-server", getOAuthServerMetadata) mux.HandleFunc("/authorize", authorize) mux.Post("/token", exchangeToken) mux.Post("/introspect", introspectToken) + mux.Post("/revoke", revokeToken) go maintainDBLoop(db) diff --git a/oauth2.go b/oauth2.go index f6d4ff9ac83f5c31f8dedc2ba607ffebaa1dbc4e..cb427641c5a856bc1cf4487498f7a67fbe4f94ea 100644 --- a/oauth2.go +++ b/oauth2.go @@ -33,11 +33,13 @@ Issuer: issuer, AuthorizationEndpoint: issuer + "/authorize", TokenEndpoint: issuer + "/token", IntrospectionEndpoint: issuer + "/introspect", + RevocationEndpoint: issuer + "/revoke", ResponseTypesSupported: []oauth2.ResponseType{oauth2.ResponseTypeCode}, ResponseModesSupported: []oauth2.ResponseMode{oauth2.ResponseModeQuery}, GrantTypesSupported: []oauth2.GrantType{oauth2.GrantTypeAuthorizationCode}, TokenEndpointAuthMethodsSupported: []oauth2.AuthMethod{oauth2.AuthMethodClientSecretBasic}, IntrospectionEndpointAuthMethodsSupported: []oauth2.AuthMethod{oauth2.AuthMethodClientSecretBasic}, + RevocationEndpointAuthMethodsSupported: []oauth2.AuthMethod{oauth2.AuthMethodClientSecretBasic}, }) } @@ -293,19 +295,10 @@ }) return } - var client *Client - if clientID, clientSecret, ok := req.BasicAuth(); ok { - client, err = db.FetchClientByClientID(ctx, clientID) - if err == errNoDBRows || (err == nil && !client.VerifySecret(clientSecret)) { - oauthError(w, &oauth2.Error{ - Code: oauth2.ErrorCodeInvalidClient, - Description: "Invalid client ID or secret", - }) - return - } else if err != nil { - oauthError(w, fmt.Errorf("failed to fetch client: %v", err)) - return - } + client, err := maybeAuthenticateClient(w, req) + if err != nil { + oauthError(w, err) + return } tokenID, secret, _ := UnmarshalSecret[*AccessToken](values.Get("token")) @@ -361,6 +354,64 @@ w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(&resp) } +func revokeToken(w http.ResponseWriter, req *http.Request) { + ctx := req.Context() + db := dbFromContext(ctx) + + values, err := parseRequestBody(req) + if err != nil { + oauthError(w, &oauth2.Error{ + Code: oauth2.ErrorCodeInvalidRequest, + Description: err.Error(), + }) + return + } + + client, err := maybeAuthenticateClient(w, req) + if err != nil { + oauthError(w, err) + return + } + + tokenID, secret, _ := UnmarshalSecret[*AccessToken](values.Get("token")) + token, err := db.FetchAccessToken(ctx, tokenID) + if err == errNoDBRows || (err == nil && !token.VerifySecret(secret)) { + return // ignore + } else if err != nil { + oauthError(w, fmt.Errorf("failed to fetch access token: %v", err)) + return + } + + if client == nil { + client, err = db.FetchClient(ctx, token.Client) + if err != nil { + oauthError(w, fmt.Errorf("failed to fetch client: %v", err)) + return + } + + if !client.IsPublic() { + oauthError(w, &oauth2.Error{ + Code: oauth2.ErrorCodeInvalidClient, + Description: "Missing client ID and secret", + }) + return + } + } + + if client.ID != token.Client { + oauthError(w, &oauth2.Error{ + Code: oauth2.ErrorCodeInvalidClient, + Description: "Invalid client ID or secret", + }) + return + } + + if err := db.DeleteAccessToken(ctx, token.ID); err != nil { + oauthError(w, err) + return + } +} + func parseRequestBody(req *http.Request) (url.Values, error) { ct := req.Header.Get("Content-Type") if ct != "" { @@ -459,3 +510,25 @@ } return false } + +func maybeAuthenticateClient(w http.ResponseWriter, req *http.Request) (*Client, error) { + ctx := req.Context() + db := dbFromContext(ctx) + + clientID, clientSecret, ok := req.BasicAuth() + if !ok { + return nil, nil + } + + client, err := db.FetchClientByClientID(ctx, clientID) + if err == errNoDBRows || (err == nil && !client.VerifySecret(clientSecret)) { + return nil, &oauth2.Error{ + Code: oauth2.ErrorCodeInvalidClient, + Description: "Invalid client ID or secret", + } + } else if err != nil { + return nil, fmt.Errorf("failed to fetch client: %v", err) + } + + return client, nil +} -- 2.48.1