Warning: Due to various recent migrations, viewing non-HEAD refs may be broken.
/oidc.go (raw)
package main
import (
	"context"
	"crypto/rand"
	"crypto/rsa"
	"crypto/sha256"
	"crypto/x509"
	"encoding/base64"
	"encoding/json"
	"encoding/pem"
	"fmt"
	"math/big"
	"net/http"
	"sort"
	"strconv"
	"strings"
	"time"
	oauth2 "github.com/emersion/go-oauth2"
	"github.com/golang-jwt/jwt/v5"
)
type OIDCProvider struct {
	signingKeys []*oidcSigningKey
}
type oidcSigningKey struct {
	key       *SigningKey
	private   *rsa.PrivateKey
	publicJWK jwk
}
type jwk struct {
	Kty string `json:"kty"`
	Use string `json:"use,omitempty"`
	Alg string `json:"alg,omitempty"`
	Kid string `json:"kid,omitempty"`
	N   string `json:"n,omitempty"`
	E   string `json:"e,omitempty"`
}
type jwks struct {
	Keys []jwk `json:"keys"`
}
const idTokenTTL = 15 * time.Minute
func newOIDCProvider(ctx context.Context, db *DB) (*OIDCProvider, error) {
	signingRecords, err := db.FetchSigningKeys(ctx)
	if err == errNoDBRows {
		generated, genErr := generateSigningKey()
		if genErr != nil {
			return nil, genErr
		}
		if storeErr := db.StoreSigningKey(ctx, generated); storeErr != nil {
			return nil, fmt.Errorf("failed to persist signing key: %w", storeErr)
		}
		signingRecords = []SigningKey{*generated}
	} else if err != nil {
		return nil, fmt.Errorf("failed to fetch signing keys: %w", err)
	}
	signingKeys := make([]*oidcSigningKey, 0, len(signingRecords))
	for i := range signingRecords {
		material, convErr := toOIDCSigningKey(&signingRecords[i])
		if convErr != nil {
			return nil, convErr
		}
		signingKeys = append(signingKeys, material)
	}
	return &OIDCProvider{signingKeys: signingKeys}, nil
}
func generateSigningKey() (*SigningKey, error) {
	priv, err := rsa.GenerateKey(rand.Reader, 2048)
	if err != nil {
		return nil, fmt.Errorf("failed to generate signing key: %w", err)
	}
	pemBlock := pem.EncodeToMemory(&pem.Block{
		Type:  "RSA PRIVATE KEY",
		Bytes: x509.MarshalPKCS1PrivateKey(priv),
	})
	if pemBlock == nil {
		return nil, fmt.Errorf("failed to encode signing key")
	}
	kid, err := generateUID()
	if err != nil {
		return nil, fmt.Errorf("failed to generate signing key ID: %w", err)
	}
	return &SigningKey{
		KID:        kid,
		Algorithm:  "RS256",
		PrivateKey: pemBlock,
		CreatedAt:  time.Now(),
	}, nil
}
func toOIDCSigningKey(signing *SigningKey) (*oidcSigningKey, error) {
	block, _ := pem.Decode(signing.PrivateKey)
	if block == nil {
		return nil, fmt.Errorf("failed to decode signing key PEM")
	}
	priv, err := x509.ParsePKCS1PrivateKey(block.Bytes)
	if err != nil {
		return nil, fmt.Errorf("failed to parse signing key: %w", err)
	}
	jwk := jwk{
		Kty: "RSA",
		Use: "sig",
		Alg: signing.Algorithm,
		Kid: signing.KID,
		N:   base64.RawURLEncoding.EncodeToString(priv.N.Bytes()),
	}
	e := big.NewInt(int64(priv.E)).Bytes()
	jwk.E = base64.RawURLEncoding.EncodeToString(e)
	return &oidcSigningKey{
		key:       signing,
		private:   priv,
		publicJWK: jwk,
	}, nil
}
func (op *OIDCProvider) currentSigningKey() *oidcSigningKey {
	if len(op.signingKeys) == 0 {
		return nil
	}
	return op.signingKeys[0]
}
func (op *OIDCProvider) signingMethod() (*jwt.SigningMethodRSA, *oidcSigningKey, error) {
	key := op.currentSigningKey()
	if key == nil {
		return nil, nil, fmt.Errorf("no signing key configured")
	}
	switch key.key.Algorithm {
	case "RS256":
		return jwt.SigningMethodRS256, key, nil
	default:
		return nil, nil, fmt.Errorf("unsupported signing algorithm %q", key.key.Algorithm)
	}
}
func (op *OIDCProvider) MintIDToken(issuer string, client *Client, user *User, token *AccessToken, scopes []string, nonce string, accessToken string, authTime time.Time) (string, error) {
	method, signingKey, err := op.signingMethod()
	if err != nil {
		return "", err
	}
	now := time.Now()
	expiresAt := now.Add(idTokenTTL)
	if token.ExpiresAt.Before(expiresAt) {
		expiresAt = token.ExpiresAt
	}
	if expiresAt.Before(now) {
		expiresAt = now
	}
	claims := jwt.MapClaims{
		"iss": issuer,
		"sub": subjectForUser(user),
		"aud": client.ClientID,
		"exp": jwt.NewNumericDate(expiresAt),
		"iat": jwt.NewNumericDate(now),
	}
	if !authTime.IsZero() {
		claims["auth_time"] = jwt.NewNumericDate(authTime)
	}
	if nonce != "" {
		claims["nonce"] = nonce
	}
	if accessToken != "" {
		claims["at_hash"] = computeAtHash(accessToken)
	}
	if containsScope(scopes, scopeProfile) {
		displayName := user.Name
		if displayName == "" {
			displayName = user.Username
		}
		claims["preferred_username"] = user.Username
		claims["name"] = displayName
	}
	if containsScope(scopes, scopeEmail) && user.Email != "" {
		claims["email"] = user.Email
		claims["email_verified"] = false
	}
	tokenJWT := jwt.NewWithClaims(method, claims)
	tokenJWT.Header["kid"] = signingKey.key.KID
	return tokenJWT.SignedString(signingKey.private)
}
func (op *OIDCProvider) JWKS() jwks {
	keys := make([]jwk, 0, len(op.signingKeys))
	for _, key := range op.signingKeys {
		keys = append(keys, key.publicJWK)
	}
	return jwks{Keys: keys}
}
func subjectForUser(user *User) string {
	return strconv.FormatInt(int64(user.ID), 10)
}
func computeAtHash(accessToken string) string {
	sum := sha256.Sum256([]byte(accessToken))
	return base64.RawURLEncoding.EncodeToString(sum[:len(sum)/2])
}
func containsScope(scopes []string, scope string) bool {
	for _, s := range scopes {
		if strings.EqualFold(s, scope) {
			return true
		}
	}
	return false
}
func getOpenIDConfiguration(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	oidc := oidcProviderFromContext(ctx)
	issuer := getIssuer(req)
	currentKey := oidc.currentSigningKey()
	scopes := make([]string, 0, len(allowedScopes))
	for scope := range allowedScopes {
		scopes = append(scopes, scope)
	}
	sort.Strings(scopes)
	idTokenAlgs := []string{"RS256"}
	if currentKey != nil && currentKey.key.Algorithm != "" {
		idTokenAlgs = []string{currentKey.key.Algorithm}
	}
	config := map[string]interface{}{
		"issuer":                                         issuer,
		"authorization_endpoint":                         issuer + "/authorize",
		"token_endpoint":                                 issuer + "/token",
		"userinfo_endpoint":                              issuer + "/userinfo",
		"jwks_uri":                                       issuer + "/.well-known/jwks.json",
		"response_types_supported":                       []string{string(oauth2.ResponseTypeCode)},
		"response_modes_supported":                       []string{string(oauth2.ResponseModeQuery)},
		"grant_types_supported":                          []string{string(oauth2.GrantTypeAuthorizationCode), string(oauth2.GrantTypeRefreshToken)},
		"subject_types_supported":                        []string{"public"},
		"id_token_signing_alg_values_supported":          idTokenAlgs,
		"scopes_supported":                               scopes,
		"claims_supported":                               []string{"sub", "preferred_username", "name", "email", "email_verified"},
		"token_endpoint_auth_methods_supported":          []string{string(oauth2.AuthMethodNone), string(oauth2.AuthMethodClientSecretBasic)},
		"introspection_endpoint":                         issuer + "/introspect",
		"revocation_endpoint":                            issuer + "/revoke",
		"authorization_response_iss_parameter_supported": true,
		"claims_parameter_supported":                     false,
		"request_parameter_supported":                    false,
		"request_uri_parameter_supported":                false,
		"code_challenge_methods_supported":               []string{pkceMethodPlain, pkceMethodS256},
	}
	w.Header().Set("Content-Type", "application/json")
	json.NewEncoder(w).Encode(config)
}
func getOIDCJWKS(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	oidc := oidcProviderFromContext(ctx)
	w.Header().Set("Content-Type", "application/json")
	w.Header().Set("Cache-Control", "public, max-age=300")
	json.NewEncoder(w).Encode(oidc.JWKS())
}
func userInfo(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	db := dbFromContext(ctx)
	if req.Method != http.MethodGet && req.Method != http.MethodPost {
		w.Header().Set("Allow", "GET, POST")
		http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
		return
	}
	tokenValue, err := bearerTokenFromRequest(req)
	if err != nil {
		writeBearerError(w, http.StatusUnauthorized, "invalid_token", err.Error())
		return
	}
	tokenID, secret, err := UnmarshalSecret[*AccessToken](tokenValue)
	if err != nil {
		writeBearerError(w, http.StatusUnauthorized, "invalid_token", "Malformed access token")
		return
	}
	token, err := db.FetchAccessToken(ctx, tokenID)
	if err == errNoDBRows || (err == nil && !token.VerifySecret(secret)) {
		writeBearerError(w, http.StatusUnauthorized, "invalid_token", "Invalid access token")
		return
	} else if err != nil {
		httpError(w, fmt.Errorf("failed to fetch access token: %v", err))
		return
	}
	scopes := parseScopes(token.Scope)
	if !containsScope(scopes, scopeOpenID) {
		writeBearerError(w, http.StatusForbidden, "insufficient_scope", "Scope openid missing")
		return
	}
	user, err := db.FetchUser(ctx, token.User)
	if err != nil {
		httpError(w, fmt.Errorf("failed to fetch user: %v", err))
		return
	}
	resp := map[string]interface{}{
		"sub": subjectForUser(user),
	}
	if containsScope(scopes, scopeProfile) {
		displayName := user.Name
		if displayName == "" {
			displayName = user.Username
		}
		resp["preferred_username"] = user.Username
		resp["name"] = displayName
	}
	if containsScope(scopes, scopeEmail) && user.Email != "" {
		resp["email"] = user.Email
		resp["email_verified"] = false
	}
	w.Header().Set("Content-Type", "application/json")
	json.NewEncoder(w).Encode(resp)
}
func bearerTokenFromRequest(req *http.Request) (string, error) {
	authz := req.Header.Get("Authorization")
	if authz == "" {
		return "", fmt.Errorf("Authorization header missing")
	}
	if len(authz) < 7 || !strings.EqualFold(authz[:7], "Bearer ") {
		return "", fmt.Errorf("Unsupported authorization scheme")
	}
	token := strings.TrimSpace(authz[7:])
	if token == "" {
		return "", fmt.Errorf("Missing access token")
	}
	return token, nil
}
func writeBearerError(w http.ResponseWriter, status int, code, description string) {
	challenge := "Bearer"
	if code != "" {
		challenge += fmt.Sprintf(" error=\"%s\"", code)
	}
	if description != "" {
		if code == "" {
			challenge += " "
		} else {
			challenge += ", "
		}
		challenge += fmt.Sprintf("error_description=\"%s\"", description)
	}
	w.Header().Set("WWW-Authenticate", challenge)
	http.Error(w, http.StatusText(status), status)
}