Warning: Due to various recent migrations, viewing non-HEAD refs may be broken.
/oauth2.go (raw)
package main
import (
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"mime"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/emersion/go-oauth2"
)
const (
scopeOpenID = "openid"
scopeProfile = "profile"
scopeEmail = "email"
scopeOfflineAccess = "offline_access"
pkceMethodPlain = "plain"
pkceMethodS256 = "S256"
)
var allowedScopes = map[string]struct{}{
scopeOpenID: {},
scopeProfile: {},
scopeEmail: {},
scopeOfflineAccess: {},
}
type oidcTokenResponse struct {
AccessToken string `json:"access_token"`
TokenType oauth2.TokenType `json:"token_type"`
ExpiresIn int64 `json:"expires_in,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
Scope string `json:"scope,omitempty"`
IDToken string `json:"id_token,omitempty"`
}
func parseScopes(scope string) []string {
if scope == "" {
return nil
}
parts := strings.Fields(scope)
var scopes []string
seen := make(map[string]struct{}, len(parts))
for _, p := range parts {
if p == "" {
continue
}
p = strings.ToLower(p)
if _, ok := seen[p]; ok {
continue
}
seen[p] = struct{}{}
scopes = append(scopes, p)
}
return scopes
}
func normalizeScope(scope string) (string, []string) {
scopes := parseScopes(scope)
if len(scopes) == 0 {
return "", nil
}
return strings.Join(scopes, " "), scopes
}
func validateScopes(scopes []string) error {
for _, scope := range scopes {
if _, ok := allowedScopes[scope]; !ok {
return fmt.Errorf("unsupported scope %q", scope)
}
}
return nil
}
func normalizeCodeChallengeMethod(method string) (string, error) {
if method == "" {
return pkceMethodPlain, nil
}
switch {
case strings.EqualFold(method, pkceMethodPlain):
return pkceMethodPlain, nil
case strings.EqualFold(method, pkceMethodS256):
return pkceMethodS256, nil
default:
return "", fmt.Errorf("unsupported code_challenge_method")
}
}
func validateCodeVerifier(verifier string) error {
if verifier == "" {
return fmt.Errorf("missing code_verifier")
}
if len(verifier) < 43 || len(verifier) > 128 {
return fmt.Errorf("invalid code_verifier length")
}
for _, r := range verifier {
if !(r >= 'A' && r <= 'Z' || r >= 'a' && r <= 'z' || r >= '0' && r <= '9' || r == '-' || r == '.' || r == '_' || r == '~') {
return fmt.Errorf("invalid character in code_verifier")
}
}
return nil
}
func validateCodeChallenge(method, challenge string) error {
if challenge == "" {
return fmt.Errorf("missing code_challenge")
}
if err := validateCodeVerifier(challenge); err != nil {
return err
}
if method != pkceMethodPlain && method != pkceMethodS256 {
return fmt.Errorf("unsupported code_challenge_method")
}
return nil
}
func verifyCodeVerifier(method, challenge, verifier string) error {
if err := validateCodeVerifier(verifier); err != nil {
return err
}
switch method {
case "", pkceMethodPlain:
if challenge != verifier {
return fmt.Errorf("code_verifier mismatch")
}
case pkceMethodS256:
hash := sha256.Sum256([]byte(verifier))
expected := base64.RawURLEncoding.EncodeToString(hash[:])
if expected != challenge {
return fmt.Errorf("code_verifier mismatch")
}
default:
return fmt.Errorf("unsupported code_challenge_method")
}
return nil
}
func getOAuthServerMetadata(w http.ResponseWriter, req *http.Request) {
issuer := getIssuer(req)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(&oauth2.ServerMetadata{
Issuer: issuer,
AuthorizationEndpoint: issuer + "/authorize",
TokenEndpoint: issuer + "/token",
IntrospectionEndpoint: issuer + "/introspect",
RevocationEndpoint: issuer + "/revoke",
JWKSURI: issuer + "/.well-known/jwks.json",
ScopesSupported: []string{scopeOpenID, scopeProfile, scopeEmail, scopeOfflineAccess},
ResponseTypesSupported: []oauth2.ResponseType{oauth2.ResponseTypeCode},
ResponseModesSupported: []oauth2.ResponseMode{oauth2.ResponseModeQuery},
GrantTypesSupported: []oauth2.GrantType{oauth2.GrantTypeAuthorizationCode, oauth2.GrantTypeRefreshToken},
TokenEndpointAuthMethodsSupported: []oauth2.AuthMethod{oauth2.AuthMethodNone, oauth2.AuthMethodClientSecretBasic},
IntrospectionEndpointAuthMethodsSupported: []oauth2.AuthMethod{oauth2.AuthMethodNone, oauth2.AuthMethodClientSecretBasic},
RevocationEndpointAuthMethodsSupported: []oauth2.AuthMethod{oauth2.AuthMethodNone, oauth2.AuthMethodClientSecretBasic},
CodeChallengeMethodsSupported: []string{pkceMethodPlain, pkceMethodS256},
AuthorizationResponseIssParameterSupported: true,
})
}
func getIssuer(req *http.Request) string {
issuerURL := url.URL{
Scheme: "https",
Host: req.Host,
}
if !isForwardedHTTPS(req) && isLoopback(req) {
// TODO: add config option for allowed reverse proxy IPs
issuerURL.Scheme = "http"
}
return issuerURL.String()
}
func isLoopback(req *http.Request) bool {
host, _, _ := net.SplitHostPort(req.RemoteAddr)
ip := net.ParseIP(host)
return ip.IsLoopback()
}
func authorize(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
db := dbFromContext(ctx)
tpl := templateFromContext(ctx)
q := req.URL.Query()
respType := oauth2.ResponseType(q.Get("response_type"))
clientID := q.Get("client_id")
rawRedirectURI := q.Get("redirect_uri")
scope := q.Get("scope")
state := q.Get("state")
_, stateProvided := q["state"]
codeChallenge := q.Get("code_challenge")
codeChallengeMethod := q.Get("code_challenge_method")
var normalizedCodeChallengeMethod string
nonce := q.Get("nonce")
if clientID == "" {
http.Error(w, "Missing client ID", http.StatusBadRequest)
return
}
client, err := db.FetchClientByClientID(ctx, clientID)
if err == errNoDBRows {
http.Error(w, "Invalid client ID", http.StatusForbidden)
return
} else if err != nil {
httpError(w, fmt.Errorf("failed to fetch client: %v", err))
return
}
requiredPKCE, err := normalizeClientPKCERequirement(client.PKCERequirement)
if err != nil {
httpError(w, fmt.Errorf("invalid PKCE requirement configuration: %v", err))
return
}
var allowedRedirectURIs []*url.URL
for _, s := range strings.Split(client.RedirectURIs, "\n") {
if s == "" {
continue
}
u, err := url.Parse(s)
if err != nil {
httpError(w, fmt.Errorf("failed to parse client redirect URI"))
return
}
allowedRedirectURIs = append(allowedRedirectURIs, u)
}
var redirectURI *url.URL
if rawRedirectURI != "" {
redirectURI, err = url.Parse(rawRedirectURI)
if err != nil {
http.Error(w, "Invalid redirect URI", http.StatusBadRequest)
return
}
if !validateRedirectURI(redirectURI, allowedRedirectURIs) {
http.Error(w, "Invalid redirect URI", http.StatusBadRequest)
return
}
} else {
if len(allowedRedirectURIs) == 0 {
http.Error(w, "Missing redirect URI", http.StatusBadRequest)
return
}
redirectURI = allowedRedirectURIs[0]
}
if codeChallenge != "" {
method, err := normalizeCodeChallengeMethod(codeChallengeMethod)
if err != nil {
redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: err.Error(),
})
return
}
if err := validateCodeChallenge(method, codeChallenge); err != nil {
redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: err.Error(),
})
return
}
normalizedCodeChallengeMethod = method
}
if codeChallenge == "" && codeChallengeMethod != "" {
redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: "code_challenge_method without code_challenge",
})
return
}
switch requiredPKCE {
case pkceRequirementPlain:
if codeChallenge == "" {
redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: "PKCE is required",
})
return
}
if !allowPKCERequirement(normalizedCodeChallengeMethod, pkceRequirementPlain) {
redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: "PKCE method does not satisfy requirement",
})
return
}
case pkceRequirementS256:
if codeChallenge == "" {
redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: "PKCE (S256) is required",
})
return
}
if !allowPKCERequirement(normalizedCodeChallengeMethod, pkceRequirementS256) {
redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: "PKCE (S256) is required",
})
return
}
}
codeChallengeMethod = normalizedCodeChallengeMethod
if respType != oauth2.ResponseTypeCode {
redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
Code: oauth2.ErrorCodeUnsupportedResponseType,
})
return
}
normalizedScope, scopes := normalizeScope(scope)
if len(scopes) == 0 {
redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidScope,
Description: "Missing required openid scope",
})
return
}
if err := validateScopes(scopes); err != nil {
redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidScope,
Description: err.Error(),
})
return
}
if !containsScope(scopes, scopeOpenID) {
redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidScope,
Description: "Scope openid is required",
})
return
}
scope = normalizedScope
loginToken := loginTokenFromContext(ctx)
if loginToken == nil {
q := make(url.Values)
q.Set("redirect_uri", req.URL.String())
u := url.URL{
Path: "/login",
RawQuery: q.Encode(),
}
http.Redirect(w, req, u.String(), http.StatusFound)
return
}
_ = req.ParseForm()
if _, ok := req.PostForm["deny"]; ok {
redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
Code: oauth2.ErrorCodeAccessDenied,
})
return
}
if _, ok := req.PostForm["authorize"]; !ok {
data := struct {
TemplateBaseData
Client *Client
}{
Client: client,
}
tpl.MustExecuteTemplate(req.Context(), w, "authorize.html", &data)
return
}
authCode := AuthCode{
User: loginToken.User,
Client: client.ID,
Scope: scope,
RedirectURI: rawRedirectURI,
Nonce: nonce,
CodeChallenge: codeChallenge,
CodeChallengeMethod: codeChallengeMethod,
}
secret, err := authCode.Generate()
if err != nil {
httpError(w, fmt.Errorf("failed to generate authentication code: %v", err))
return
}
if err := db.CreateAuthCode(ctx, &authCode); err != nil {
httpError(w, fmt.Errorf("failed to create authentication code: %v", err))
return
}
code := MarshalSecret(authCode.ID, SecretKindAuthCode, secret)
values := make(url.Values)
values.Set("code", code)
if stateProvided {
values.Set("state", state)
}
redirectClient(w, req, redirectURI, values)
}
func exchangeToken(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
db := dbFromContext(ctx)
values, err := parseRequestBody(req)
if err != nil {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: err.Error(),
})
return
}
clientID := values.Get("client_id")
grantType := oauth2.GrantType(values.Get("grant_type"))
scope := values.Get("scope")
codeVerifier := values.Get("code_verifier")
authClientID, clientSecret, _ := req.BasicAuth()
if clientID == "" {
clientID = authClientID
} else if clientID != authClientID {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: "Client ID in request body doesn't match Authorization header field",
})
return
}
var client *Client
if clientID != "" {
client, err = db.FetchClientByClientID(ctx, clientID)
if err == errNoDBRows {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidClient,
Description: "Invalid client ID",
})
return
} else if err != nil {
oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
return
}
if !client.IsPublic() && !client.VerifySecret(clientSecret) {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeAccessDenied,
Description: "Invalid client secret",
})
return
}
}
var (
token *AccessToken
authorizationCode *AuthCode
currentClient *Client
nonceValue string
)
switch grantType {
case oauth2.GrantTypeAuthorizationCode:
if client == nil {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: "Missing client ID",
})
return
}
codeID, codeSecret, _ := UnmarshalSecret[*AuthCode](values.Get("code"))
authorizationCode, err = db.PopAuthCode(ctx, codeID)
if err == errNoDBRows || (err == nil && !authorizationCode.VerifySecret(codeSecret)) || authorizationCode.Client != client.ID {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeAccessDenied,
Description: "Invalid authorization code",
})
return
} else if err != nil {
oauthError(w, fmt.Errorf("failed to fetch authorization code: %v", err))
return
}
if scope != "" && scope != authorizationCode.Scope {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeAccessDenied,
Description: "Invalid scope",
})
return
}
if values.Get("redirect_uri") != authorizationCode.RedirectURI {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeAccessDenied,
Description: "Invalid redirect URI",
})
return
}
if authorizationCode.CodeChallenge != "" {
if codeVerifier == "" {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: "Missing code_verifier",
})
return
}
if err := verifyCodeVerifier(authorizationCode.CodeChallengeMethod, authorizationCode.CodeChallenge, codeVerifier); err != nil {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidGrant,
Description: "Invalid code_verifier",
})
return
}
} else if codeVerifier != "" {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: "Unexpected code_verifier",
})
return
}
token = NewAccessTokenFromAuthCode(authorizationCode)
currentClient = client
nonceValue = authorizationCode.Nonce
case oauth2.GrantTypeRefreshToken:
tokenID, refreshSecret, _ := UnmarshalSecret[*AccessToken](values.Get("refresh_token"))
token, err = db.FetchAccessToken(ctx, tokenID)
if err == errNoDBRows || (err == nil && !token.VerifyRefreshSecret(refreshSecret)) {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeAccessDenied,
Description: "Invalid refresh token",
})
return
} else if err != nil {
oauthError(w, fmt.Errorf("failed to fetch access token: %v", err))
return
}
if client != nil && client.ID != token.Client {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeAccessDenied,
Description: "Invalid refresh token",
})
return
}
currentClient, err = db.FetchClient(ctx, token.Client)
if err != nil {
oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
return
}
if !currentClient.IsPublic() && client == nil {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeAccessDenied,
Description: "Invalid client secret",
})
return
}
if scope != "" && scope != token.Scope {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeAccessDenied,
Description: "Invalid scope",
})
return
}
default:
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeUnsupportedGrantType,
Description: "Unsupported grant type",
})
}
secret, err := token.Generate(accessTokenExpiration)
if err != nil {
oauthError(w, err)
return
}
tokenScopes := parseScopes(token.Scope)
issueRefresh := containsScope(tokenScopes, scopeOfflineAccess)
var refreshSecret string
if issueRefresh {
refreshSecret, err = token.GenerateRefresh()
if err != nil {
oauthError(w, err)
return
}
} else {
token.RefreshHash = nil
token.RefreshExpiresAt = time.Time{}
}
if err := db.StoreAccessToken(ctx, token); err != nil {
oauthError(w, fmt.Errorf("failed to create access token: %v", err))
return
}
accessTokenValue := MarshalSecret(token.ID, SecretKindAccessToken, secret)
if token.AuthTime.IsZero() {
token.AuthTime = token.IssuedAt
}
var idToken string
if containsScope(tokenScopes, scopeOpenID) {
if currentClient == nil {
currentClient, err = db.FetchClient(ctx, token.Client)
if err != nil {
oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
return
}
}
user, err := db.FetchUser(ctx, token.User)
if err != nil {
oauthError(w, fmt.Errorf("failed to fetch user: %v", err))
return
}
issuer := getIssuer(req)
oidcProvider := oidcProviderFromContext(ctx)
idToken, err = oidcProvider.MintIDToken(issuer, currentClient, user, token, tokenScopes, nonceValue, accessTokenValue, token.AuthTime)
if err != nil {
oauthError(w, fmt.Errorf("failed to mint ID token: %v", err))
return
}
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Cache-Control", "no-store")
resp := oidcTokenResponse{
AccessToken: accessTokenValue,
TokenType: oauth2.TokenTypeBearer,
ExpiresIn: int64(time.Until(token.ExpiresAt).Seconds()),
Scope: token.Scope,
IDToken: idToken,
}
if issueRefresh {
refreshTokenValue := MarshalSecret(token.ID, SecretKindRefreshToken, refreshSecret)
resp.RefreshToken = refreshTokenValue
}
json.NewEncoder(w).Encode(&resp)
}
func introspectToken(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
db := dbFromContext(ctx)
values, err := parseRequestBody(req)
if err != nil {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: err.Error(),
})
return
}
client, err := maybeAuthenticateClient(w, req)
if err != nil {
oauthError(w, err)
return
}
tokenID, secret, _ := UnmarshalSecret[*AccessToken](values.Get("token"))
token, err := db.FetchAccessToken(ctx, tokenID)
if err == errNoDBRows || (err == nil && !token.VerifySecret(secret)) {
token = nil
} else if err != nil {
oauthError(w, fmt.Errorf("failed to fetch access token: %v", err))
return
}
var resp oauth2.IntrospectionResp
if token != nil {
if client == nil {
client, err = db.FetchClient(ctx, token.Client)
if err != nil {
oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
return
}
if !client.IsPublic() {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidClient,
Description: "Missing client ID and secret",
})
return
}
}
if client.ID != token.Client {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidClient,
Description: "Invalid client ID or secret",
})
return
}
user, err := db.FetchUser(ctx, token.User)
if err != nil {
oauthError(w, fmt.Errorf("failed to fetch user: %v", err))
return
}
resp.Active = true
resp.TokenType = oauth2.TokenTypeBearer
resp.ExpiresAt = token.ExpiresAt
resp.IssuedAt = token.IssuedAt
resp.ClientID = client.ClientID
resp.Username = user.Username
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(&resp)
}
func revokeToken(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
db := dbFromContext(ctx)
values, err := parseRequestBody(req)
if err != nil {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: err.Error(),
})
return
}
client, err := maybeAuthenticateClient(w, req)
if err != nil {
oauthError(w, err)
return
}
tokenID, secret, _ := UnmarshalSecret[*AccessToken](values.Get("token"))
token, err := db.FetchAccessToken(ctx, tokenID)
if err == errNoDBRows || (err == nil && !token.VerifySecret(secret)) {
return // ignore
} else if err != nil {
oauthError(w, fmt.Errorf("failed to fetch access token: %v", err))
return
}
if client == nil {
client, err = db.FetchClient(ctx, token.Client)
if err != nil {
oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
return
}
if !client.IsPublic() {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidClient,
Description: "Missing client ID and secret",
})
return
}
}
if client.ID != token.Client {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidClient,
Description: "Invalid client ID or secret",
})
return
}
if err := db.DeleteAccessToken(ctx, token.ID); err != nil {
oauthError(w, err)
return
}
}
func parseRequestBody(req *http.Request) (url.Values, error) {
ct := req.Header.Get("Content-Type")
if ct != "" {
mimeType, _, err := mime.ParseMediaType(req.Header.Get("Content-Type"))
if err != nil {
return nil, fmt.Errorf("malformed Content-Type header field")
} else if mimeType != "application/x-www-form-urlencoded" {
return nil, fmt.Errorf("unsupported request content type")
}
}
r := io.LimitReader(req.Body, 10<<20)
b, err := io.ReadAll(r)
if err != nil {
return nil, fmt.Errorf("failed to read request body: %v", err)
}
values, err := url.ParseQuery(string(b))
if err != nil {
return nil, fmt.Errorf("failed to parse request body: %v", err)
}
return values, nil
}
func oauthError(w http.ResponseWriter, err error) {
var oauthErr *oauth2.Error
if !errors.As(err, &oauthErr) {
oauthErr = &oauth2.Error{Code: oauth2.ErrorCodeServerError}
log.Print(err)
}
statusCode := http.StatusInternalServerError
switch oauthErr.Code {
case oauth2.ErrorCodeInvalidRequest, oauth2.ErrorCodeUnsupportedResponseType, oauth2.ErrorCodeInvalidScope, oauth2.ErrorCodeInvalidClient, oauth2.ErrorCodeInvalidGrant, oauth2.ErrorCodeUnsupportedGrantType:
statusCode = http.StatusBadRequest
case oauth2.ErrorCodeUnauthorizedClient, oauth2.ErrorCodeAccessDenied:
statusCode = http.StatusForbidden
case oauth2.ErrorCodeTemporarilyUnavailable:
statusCode = http.StatusServiceUnavailable
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
json.NewEncoder(w).Encode(oauthErr)
}
func redirectClient(w http.ResponseWriter, req *http.Request, redirectURI *url.URL, values url.Values) {
q := redirectURI.Query()
for k, v := range values {
q[k] = v
}
q.Set("iss", getIssuer(req))
u := *redirectURI
u.RawQuery = q.Encode()
http.Redirect(w, req, u.String(), http.StatusFound)
}
func redirectClientError(w http.ResponseWriter, req *http.Request, redirectURI *url.URL, state string, stateProvided bool, err error) {
var oauthErr *oauth2.Error
if !errors.As(err, &oauthErr) {
oauthErr = &oauth2.Error{Code: oauth2.ErrorCodeServerError}
log.Print(err)
}
values := make(url.Values)
values.Set("error", string(oauthErr.Code))
if oauthErr.Description != "" {
values.Set("error_description", oauthErr.Description)
}
if oauthErr.URI != "" {
values.Set("error_uri", oauthErr.URI)
}
if stateProvided {
values.Set("state", state)
}
redirectClient(w, req, redirectURI, values)
}
func validateRedirectURI(u *url.URL, allowedURIs []*url.URL) bool {
// Loopback interface, see RFC 8252 section 7.3
host, _, _ := net.SplitHostPort(u.Host)
ip := net.ParseIP(host)
if u.Scheme == "http" && ip.IsLoopback() {
uu := *u
uu.Host = "localhost"
u = &uu
}
for _, allowed := range allowedURIs {
if u.String() == allowed.String() {
return true
}
}
return false
}
func maybeAuthenticateClient(w http.ResponseWriter, req *http.Request) (*Client, error) {
ctx := req.Context()
db := dbFromContext(ctx)
clientID, clientSecret, ok := req.BasicAuth()
if !ok {
return nil, nil
}
client, err := db.FetchClientByClientID(ctx, clientID)
if err == errNoDBRows || (err == nil && !client.VerifySecret(clientSecret)) {
return nil, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidClient,
Description: "Invalid client ID or secret",
}
} else if err != nil {
return nil, fmt.Errorf("failed to fetch client: %v", err)
}
return client, nil
}