Warning: Due to various recent migrations, viewing non-HEAD refs may be broken.
/oauth2.go (raw)
package main
import (
	"crypto/sha256"
	"encoding/base64"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"log"
	"mime"
	"net"
	"net/http"
	"net/url"
	"strings"
	"time"
	"github.com/emersion/go-oauth2"
)
const (
	scopeOpenID        = "openid"
	scopeProfile       = "profile"
	scopeEmail         = "email"
	scopeOfflineAccess = "offline_access"
	pkceMethodPlain    = "plain"
	pkceMethodS256     = "S256"
)
var allowedScopes = map[string]struct{}{
	scopeOpenID:        {},
	scopeProfile:       {},
	scopeEmail:         {},
	scopeOfflineAccess: {},
}
type oidcTokenResponse struct {
	AccessToken  string           `json:"access_token"`
	TokenType    oauth2.TokenType `json:"token_type"`
	ExpiresIn    int64            `json:"expires_in,omitempty"`
	RefreshToken string           `json:"refresh_token,omitempty"`
	Scope        string           `json:"scope,omitempty"`
	IDToken      string           `json:"id_token,omitempty"`
}
func parseScopes(scope string) []string {
	if scope == "" {
		return nil
	}
	parts := strings.Fields(scope)
	var scopes []string
	seen := make(map[string]struct{}, len(parts))
	for _, p := range parts {
		if p == "" {
			continue
		}
		p = strings.ToLower(p)
		if _, ok := seen[p]; ok {
			continue
		}
		seen[p] = struct{}{}
		scopes = append(scopes, p)
	}
	return scopes
}
func normalizeScope(scope string) (string, []string) {
	scopes := parseScopes(scope)
	if len(scopes) == 0 {
		return "", nil
	}
	return strings.Join(scopes, " "), scopes
}
func validateScopes(scopes []string) error {
	for _, scope := range scopes {
		if _, ok := allowedScopes[scope]; !ok {
			return fmt.Errorf("unsupported scope %q", scope)
		}
	}
	return nil
}
func normalizeCodeChallengeMethod(method string) (string, error) {
	if method == "" {
		return pkceMethodPlain, nil
	}
	switch {
	case strings.EqualFold(method, pkceMethodPlain):
		return pkceMethodPlain, nil
	case strings.EqualFold(method, pkceMethodS256):
		return pkceMethodS256, nil
	default:
		return "", fmt.Errorf("unsupported code_challenge_method")
	}
}
func validateCodeVerifier(verifier string) error {
	if verifier == "" {
		return fmt.Errorf("missing code_verifier")
	}
	if len(verifier) < 43 || len(verifier) > 128 {
		return fmt.Errorf("invalid code_verifier length")
	}
	for _, r := range verifier {
		if !(r >= 'A' && r <= 'Z' || r >= 'a' && r <= 'z' || r >= '0' && r <= '9' || r == '-' || r == '.' || r == '_' || r == '~') {
			return fmt.Errorf("invalid character in code_verifier")
		}
	}
	return nil
}
func validateCodeChallenge(method, challenge string) error {
	if challenge == "" {
		return fmt.Errorf("missing code_challenge")
	}
	if err := validateCodeVerifier(challenge); err != nil {
		return err
	}
	if method != pkceMethodPlain && method != pkceMethodS256 {
		return fmt.Errorf("unsupported code_challenge_method")
	}
	return nil
}
func verifyCodeVerifier(method, challenge, verifier string) error {
	if err := validateCodeVerifier(verifier); err != nil {
		return err
	}
	switch method {
	case "", pkceMethodPlain:
		if challenge != verifier {
			return fmt.Errorf("code_verifier mismatch")
		}
	case pkceMethodS256:
		hash := sha256.Sum256([]byte(verifier))
		expected := base64.RawURLEncoding.EncodeToString(hash[:])
		if expected != challenge {
			return fmt.Errorf("code_verifier mismatch")
		}
	default:
		return fmt.Errorf("unsupported code_challenge_method")
	}
	return nil
}
func getOAuthServerMetadata(w http.ResponseWriter, req *http.Request) {
	issuer := getIssuer(req)
	w.Header().Set("Content-Type", "application/json")
	json.NewEncoder(w).Encode(&oauth2.ServerMetadata{
		Issuer:                            issuer,
		AuthorizationEndpoint:             issuer + "/authorize",
		TokenEndpoint:                     issuer + "/token",
		IntrospectionEndpoint:             issuer + "/introspect",
		RevocationEndpoint:                issuer + "/revoke",
		JWKSURI:                           issuer + "/.well-known/jwks.json",
		ScopesSupported:                   []string{scopeOpenID, scopeProfile, scopeEmail, scopeOfflineAccess},
		ResponseTypesSupported:            []oauth2.ResponseType{oauth2.ResponseTypeCode},
		ResponseModesSupported:            []oauth2.ResponseMode{oauth2.ResponseModeQuery},
		GrantTypesSupported:               []oauth2.GrantType{oauth2.GrantTypeAuthorizationCode, oauth2.GrantTypeRefreshToken},
		TokenEndpointAuthMethodsSupported: []oauth2.AuthMethod{oauth2.AuthMethodNone, oauth2.AuthMethodClientSecretBasic},
		IntrospectionEndpointAuthMethodsSupported:  []oauth2.AuthMethod{oauth2.AuthMethodNone, oauth2.AuthMethodClientSecretBasic},
		RevocationEndpointAuthMethodsSupported:     []oauth2.AuthMethod{oauth2.AuthMethodNone, oauth2.AuthMethodClientSecretBasic},
		CodeChallengeMethodsSupported:              []string{pkceMethodPlain, pkceMethodS256},
		AuthorizationResponseIssParameterSupported: true,
	})
}
func getIssuer(req *http.Request) string {
	issuerURL := url.URL{
		Scheme: "https",
		Host:   req.Host,
	}
	if !isForwardedHTTPS(req) && isLoopback(req) {
		// TODO: add config option for allowed reverse proxy IPs
		issuerURL.Scheme = "http"
	}
	return issuerURL.String()
}
func isLoopback(req *http.Request) bool {
	host, _, _ := net.SplitHostPort(req.RemoteAddr)
	ip := net.ParseIP(host)
	return ip.IsLoopback()
}
func authorize(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	db := dbFromContext(ctx)
	tpl := templateFromContext(ctx)
	q := req.URL.Query()
	respType := oauth2.ResponseType(q.Get("response_type"))
	clientID := q.Get("client_id")
	rawRedirectURI := q.Get("redirect_uri")
	scope := q.Get("scope")
	state := q.Get("state")
	_, stateProvided := q["state"]
	codeChallenge := q.Get("code_challenge")
	codeChallengeMethod := q.Get("code_challenge_method")
	var normalizedCodeChallengeMethod string
	nonce := q.Get("nonce")
	if clientID == "" {
		http.Error(w, "Missing client ID", http.StatusBadRequest)
		return
	}
	client, err := db.FetchClientByClientID(ctx, clientID)
	if err == errNoDBRows {
		http.Error(w, "Invalid client ID", http.StatusForbidden)
		return
	} else if err != nil {
		httpError(w, fmt.Errorf("failed to fetch client: %v", err))
		return
	}
	requiredPKCE, err := normalizeClientPKCERequirement(client.PKCERequirement)
	if err != nil {
		httpError(w, fmt.Errorf("invalid PKCE requirement configuration: %v", err))
		return
	}
	var allowedRedirectURIs []*url.URL
	for _, s := range strings.Split(client.RedirectURIs, "\n") {
		if s == "" {
			continue
		}
		u, err := url.Parse(s)
		if err != nil {
			httpError(w, fmt.Errorf("failed to parse client redirect URI"))
			return
		}
		allowedRedirectURIs = append(allowedRedirectURIs, u)
	}
	var redirectURI *url.URL
	if rawRedirectURI != "" {
		redirectURI, err = url.Parse(rawRedirectURI)
		if err != nil {
			http.Error(w, "Invalid redirect URI", http.StatusBadRequest)
			return
		}
		if !validateRedirectURI(redirectURI, allowedRedirectURIs) {
			http.Error(w, "Invalid redirect URI", http.StatusBadRequest)
			return
		}
	} else {
		if len(allowedRedirectURIs) == 0 {
			http.Error(w, "Missing redirect URI", http.StatusBadRequest)
			return
		}
		redirectURI = allowedRedirectURIs[0]
	}
	if codeChallenge != "" {
		method, err := normalizeCodeChallengeMethod(codeChallengeMethod)
		if err != nil {
			redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidRequest,
				Description: err.Error(),
			})
			return
		}
		if err := validateCodeChallenge(method, codeChallenge); err != nil {
			redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidRequest,
				Description: err.Error(),
			})
			return
		}
		normalizedCodeChallengeMethod = method
	}
	if codeChallenge == "" && codeChallengeMethod != "" {
		redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidRequest,
			Description: "code_challenge_method without code_challenge",
		})
		return
	}
	switch requiredPKCE {
	case pkceRequirementPlain:
		if codeChallenge == "" {
			redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidRequest,
				Description: "PKCE is required",
			})
			return
		}
		if !allowPKCERequirement(normalizedCodeChallengeMethod, pkceRequirementPlain) {
			redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidRequest,
				Description: "PKCE method does not satisfy requirement",
			})
			return
		}
	case pkceRequirementS256:
		if codeChallenge == "" {
			redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidRequest,
				Description: "PKCE (S256) is required",
			})
			return
		}
		if !allowPKCERequirement(normalizedCodeChallengeMethod, pkceRequirementS256) {
			redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidRequest,
				Description: "PKCE (S256) is required",
			})
			return
		}
	}
	codeChallengeMethod = normalizedCodeChallengeMethod
	if respType != oauth2.ResponseTypeCode {
		redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
			Code: oauth2.ErrorCodeUnsupportedResponseType,
		})
		return
	}
	normalizedScope, scopes := normalizeScope(scope)
	if len(scopes) == 0 {
		redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidScope,
			Description: "Missing required openid scope",
		})
		return
	}
	if err := validateScopes(scopes); err != nil {
		redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidScope,
			Description: err.Error(),
		})
		return
	}
	if !containsScope(scopes, scopeOpenID) {
		redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidScope,
			Description: "Scope openid is required",
		})
		return
	}
	scope = normalizedScope
	loginToken := loginTokenFromContext(ctx)
	if loginToken == nil {
		q := make(url.Values)
		q.Set("redirect_uri", req.URL.String())
		u := url.URL{
			Path:     "/login",
			RawQuery: q.Encode(),
		}
		http.Redirect(w, req, u.String(), http.StatusFound)
		return
	}
	_ = req.ParseForm()
	if _, ok := req.PostForm["deny"]; ok {
		redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
			Code: oauth2.ErrorCodeAccessDenied,
		})
		return
	}
	if _, ok := req.PostForm["authorize"]; !ok {
		data := struct {
			TemplateBaseData
			Client *Client
		}{
			Client: client,
		}
		tpl.MustExecuteTemplate(req.Context(), w, "authorize.html", &data)
		return
	}
	authCode := AuthCode{
		User:                loginToken.User,
		Client:              client.ID,
		Scope:               scope,
		RedirectURI:         rawRedirectURI,
		Nonce:               nonce,
		CodeChallenge:       codeChallenge,
		CodeChallengeMethod: codeChallengeMethod,
	}
	secret, err := authCode.Generate()
	if err != nil {
		httpError(w, fmt.Errorf("failed to generate authentication code: %v", err))
		return
	}
	if err := db.CreateAuthCode(ctx, &authCode); err != nil {
		httpError(w, fmt.Errorf("failed to create authentication code: %v", err))
		return
	}
	code := MarshalSecret(authCode.ID, SecretKindAuthCode, secret)
	values := make(url.Values)
	values.Set("code", code)
	if stateProvided {
		values.Set("state", state)
	}
	redirectClient(w, req, redirectURI, values)
}
func exchangeToken(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	db := dbFromContext(ctx)
	values, err := parseRequestBody(req)
	if err != nil {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidRequest,
			Description: err.Error(),
		})
		return
	}
	clientID := values.Get("client_id")
	grantType := oauth2.GrantType(values.Get("grant_type"))
	scope := values.Get("scope")
	codeVerifier := values.Get("code_verifier")
	authClientID, clientSecret, _ := req.BasicAuth()
	if clientID == "" {
		clientID = authClientID
	} else if clientID != authClientID {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidRequest,
			Description: "Client ID in request body doesn't match Authorization header field",
		})
		return
	}
	var client *Client
	if clientID != "" {
		client, err = db.FetchClientByClientID(ctx, clientID)
		if err == errNoDBRows {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidClient,
				Description: "Invalid client ID",
			})
			return
		} else if err != nil {
			oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
			return
		}
		if !client.IsPublic() && !client.VerifySecret(clientSecret) {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid client secret",
			})
			return
		}
	}
	var (
		token             *AccessToken
		authorizationCode *AuthCode
		currentClient     *Client
		nonceValue        string
	)
	switch grantType {
	case oauth2.GrantTypeAuthorizationCode:
		if client == nil {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidRequest,
				Description: "Missing client ID",
			})
			return
		}
		codeID, codeSecret, _ := UnmarshalSecret[*AuthCode](values.Get("code"))
		authorizationCode, err = db.PopAuthCode(ctx, codeID)
		if err == errNoDBRows || (err == nil && !authorizationCode.VerifySecret(codeSecret)) || authorizationCode.Client != client.ID {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid authorization code",
			})
			return
		} else if err != nil {
			oauthError(w, fmt.Errorf("failed to fetch authorization code: %v", err))
			return
		}
		if scope != "" && scope != authorizationCode.Scope {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid scope",
			})
			return
		}
		if values.Get("redirect_uri") != authorizationCode.RedirectURI {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid redirect URI",
			})
			return
		}
		if authorizationCode.CodeChallenge != "" {
			if codeVerifier == "" {
				oauthError(w, &oauth2.Error{
					Code:        oauth2.ErrorCodeInvalidRequest,
					Description: "Missing code_verifier",
				})
				return
			}
			if err := verifyCodeVerifier(authorizationCode.CodeChallengeMethod, authorizationCode.CodeChallenge, codeVerifier); err != nil {
				oauthError(w, &oauth2.Error{
					Code:        oauth2.ErrorCodeInvalidGrant,
					Description: "Invalid code_verifier",
				})
				return
			}
		} else if codeVerifier != "" {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidRequest,
				Description: "Unexpected code_verifier",
			})
			return
		}
		token = NewAccessTokenFromAuthCode(authorizationCode)
		currentClient = client
		nonceValue = authorizationCode.Nonce
	case oauth2.GrantTypeRefreshToken:
		tokenID, refreshSecret, _ := UnmarshalSecret[*AccessToken](values.Get("refresh_token"))
		token, err = db.FetchAccessToken(ctx, tokenID)
		if err == errNoDBRows || (err == nil && !token.VerifyRefreshSecret(refreshSecret)) {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid refresh token",
			})
			return
		} else if err != nil {
			oauthError(w, fmt.Errorf("failed to fetch access token: %v", err))
			return
		}
		if client != nil && client.ID != token.Client {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid refresh token",
			})
			return
		}
		currentClient, err = db.FetchClient(ctx, token.Client)
		if err != nil {
			oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
			return
		}
		if !currentClient.IsPublic() && client == nil {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid client secret",
			})
			return
		}
		if scope != "" && scope != token.Scope {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid scope",
			})
			return
		}
	default:
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeUnsupportedGrantType,
			Description: "Unsupported grant type",
		})
	}
	secret, err := token.Generate(accessTokenExpiration)
	if err != nil {
		oauthError(w, err)
		return
	}
	tokenScopes := parseScopes(token.Scope)
	issueRefresh := containsScope(tokenScopes, scopeOfflineAccess)
	var refreshSecret string
	if issueRefresh {
		refreshSecret, err = token.GenerateRefresh()
		if err != nil {
			oauthError(w, err)
			return
		}
	} else {
		token.RefreshHash = nil
		token.RefreshExpiresAt = time.Time{}
	}
	if err := db.StoreAccessToken(ctx, token); err != nil {
		oauthError(w, fmt.Errorf("failed to create access token: %v", err))
		return
	}
	accessTokenValue := MarshalSecret(token.ID, SecretKindAccessToken, secret)
	if token.AuthTime.IsZero() {
		token.AuthTime = token.IssuedAt
	}
	var idToken string
	if containsScope(tokenScopes, scopeOpenID) {
		if currentClient == nil {
			currentClient, err = db.FetchClient(ctx, token.Client)
			if err != nil {
				oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
				return
			}
		}
		user, err := db.FetchUser(ctx, token.User)
		if err != nil {
			oauthError(w, fmt.Errorf("failed to fetch user: %v", err))
			return
		}
		issuer := getIssuer(req)
		oidcProvider := oidcProviderFromContext(ctx)
		idToken, err = oidcProvider.MintIDToken(issuer, currentClient, user, token, tokenScopes, nonceValue, accessTokenValue, token.AuthTime)
		if err != nil {
			oauthError(w, fmt.Errorf("failed to mint ID token: %v", err))
			return
		}
	}
	w.Header().Set("Content-Type", "application/json")
	w.Header().Set("Cache-Control", "no-store")
	resp := oidcTokenResponse{
		AccessToken: accessTokenValue,
		TokenType:   oauth2.TokenTypeBearer,
		ExpiresIn:   int64(time.Until(token.ExpiresAt).Seconds()),
		Scope:       token.Scope,
		IDToken:     idToken,
	}
	if issueRefresh {
		refreshTokenValue := MarshalSecret(token.ID, SecretKindRefreshToken, refreshSecret)
		resp.RefreshToken = refreshTokenValue
	}
	json.NewEncoder(w).Encode(&resp)
}
func introspectToken(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	db := dbFromContext(ctx)
	values, err := parseRequestBody(req)
	if err != nil {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidRequest,
			Description: err.Error(),
		})
		return
	}
	client, err := maybeAuthenticateClient(w, req)
	if err != nil {
		oauthError(w, err)
		return
	}
	tokenID, secret, _ := UnmarshalSecret[*AccessToken](values.Get("token"))
	token, err := db.FetchAccessToken(ctx, tokenID)
	if err == errNoDBRows || (err == nil && !token.VerifySecret(secret)) {
		token = nil
	} else if err != nil {
		oauthError(w, fmt.Errorf("failed to fetch access token: %v", err))
		return
	}
	var resp oauth2.IntrospectionResp
	if token != nil {
		if client == nil {
			client, err = db.FetchClient(ctx, token.Client)
			if err != nil {
				oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
				return
			}
			if !client.IsPublic() {
				oauthError(w, &oauth2.Error{
					Code:        oauth2.ErrorCodeInvalidClient,
					Description: "Missing client ID and secret",
				})
				return
			}
		}
		if client.ID != token.Client {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidClient,
				Description: "Invalid client ID or secret",
			})
			return
		}
		user, err := db.FetchUser(ctx, token.User)
		if err != nil {
			oauthError(w, fmt.Errorf("failed to fetch user: %v", err))
			return
		}
		resp.Active = true
		resp.TokenType = oauth2.TokenTypeBearer
		resp.ExpiresAt = token.ExpiresAt
		resp.IssuedAt = token.IssuedAt
		resp.ClientID = client.ClientID
		resp.Username = user.Username
	}
	w.Header().Set("Content-Type", "application/json")
	json.NewEncoder(w).Encode(&resp)
}
func revokeToken(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	db := dbFromContext(ctx)
	values, err := parseRequestBody(req)
	if err != nil {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidRequest,
			Description: err.Error(),
		})
		return
	}
	client, err := maybeAuthenticateClient(w, req)
	if err != nil {
		oauthError(w, err)
		return
	}
	tokenID, secret, _ := UnmarshalSecret[*AccessToken](values.Get("token"))
	token, err := db.FetchAccessToken(ctx, tokenID)
	if err == errNoDBRows || (err == nil && !token.VerifySecret(secret)) {
		return // ignore
	} else if err != nil {
		oauthError(w, fmt.Errorf("failed to fetch access token: %v", err))
		return
	}
	if client == nil {
		client, err = db.FetchClient(ctx, token.Client)
		if err != nil {
			oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
			return
		}
		if !client.IsPublic() {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidClient,
				Description: "Missing client ID and secret",
			})
			return
		}
	}
	if client.ID != token.Client {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidClient,
			Description: "Invalid client ID or secret",
		})
		return
	}
	if err := db.DeleteAccessToken(ctx, token.ID); err != nil {
		oauthError(w, err)
		return
	}
}
func parseRequestBody(req *http.Request) (url.Values, error) {
	ct := req.Header.Get("Content-Type")
	if ct != "" {
		mimeType, _, err := mime.ParseMediaType(req.Header.Get("Content-Type"))
		if err != nil {
			return nil, fmt.Errorf("malformed Content-Type header field")
		} else if mimeType != "application/x-www-form-urlencoded" {
			return nil, fmt.Errorf("unsupported request content type")
		}
	}
	r := io.LimitReader(req.Body, 10<<20)
	b, err := io.ReadAll(r)
	if err != nil {
		return nil, fmt.Errorf("failed to read request body: %v", err)
	}
	values, err := url.ParseQuery(string(b))
	if err != nil {
		return nil, fmt.Errorf("failed to parse request body: %v", err)
	}
	return values, nil
}
func oauthError(w http.ResponseWriter, err error) {
	var oauthErr *oauth2.Error
	if !errors.As(err, &oauthErr) {
		oauthErr = &oauth2.Error{Code: oauth2.ErrorCodeServerError}
		log.Print(err)
	}
	statusCode := http.StatusInternalServerError
	switch oauthErr.Code {
	case oauth2.ErrorCodeInvalidRequest, oauth2.ErrorCodeUnsupportedResponseType, oauth2.ErrorCodeInvalidScope, oauth2.ErrorCodeInvalidClient, oauth2.ErrorCodeInvalidGrant, oauth2.ErrorCodeUnsupportedGrantType:
		statusCode = http.StatusBadRequest
	case oauth2.ErrorCodeUnauthorizedClient, oauth2.ErrorCodeAccessDenied:
		statusCode = http.StatusForbidden
	case oauth2.ErrorCodeTemporarilyUnavailable:
		statusCode = http.StatusServiceUnavailable
	}
	w.Header().Set("Content-Type", "application/json")
	w.WriteHeader(statusCode)
	json.NewEncoder(w).Encode(oauthErr)
}
func redirectClient(w http.ResponseWriter, req *http.Request, redirectURI *url.URL, values url.Values) {
	q := redirectURI.Query()
	for k, v := range values {
		q[k] = v
	}
	q.Set("iss", getIssuer(req))
	u := *redirectURI
	u.RawQuery = q.Encode()
	http.Redirect(w, req, u.String(), http.StatusFound)
}
func redirectClientError(w http.ResponseWriter, req *http.Request, redirectURI *url.URL, state string, stateProvided bool, err error) {
	var oauthErr *oauth2.Error
	if !errors.As(err, &oauthErr) {
		oauthErr = &oauth2.Error{Code: oauth2.ErrorCodeServerError}
		log.Print(err)
	}
	values := make(url.Values)
	values.Set("error", string(oauthErr.Code))
	if oauthErr.Description != "" {
		values.Set("error_description", oauthErr.Description)
	}
	if oauthErr.URI != "" {
		values.Set("error_uri", oauthErr.URI)
	}
	if stateProvided {
		values.Set("state", state)
	}
	redirectClient(w, req, redirectURI, values)
}
func validateRedirectURI(u *url.URL, allowedURIs []*url.URL) bool {
	// Loopback interface, see RFC 8252 section 7.3
	host, _, _ := net.SplitHostPort(u.Host)
	ip := net.ParseIP(host)
	if u.Scheme == "http" && ip.IsLoopback() {
		uu := *u
		uu.Host = "localhost"
		u = &uu
	}
	for _, allowed := range allowedURIs {
		if u.String() == allowed.String() {
			return true
		}
	}
	return false
}
func maybeAuthenticateClient(w http.ResponseWriter, req *http.Request) (*Client, error) {
	ctx := req.Context()
	db := dbFromContext(ctx)
	clientID, clientSecret, ok := req.BasicAuth()
	if !ok {
		return nil, nil
	}
	client, err := db.FetchClientByClientID(ctx, clientID)
	if err == errNoDBRows || (err == nil && !client.VerifySecret(clientSecret)) {
		return nil, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidClient,
			Description: "Invalid client ID or secret",
		}
	} else if err != nil {
		return nil, fmt.Errorf("failed to fetch client: %v", err)
	}
	return client, nil
}