Warning: Due to various recent migrations, viewing non-HEAD refs may be broken.
/oidc.go (raw)
package main
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"math/big"
"net/http"
"sort"
"strconv"
"strings"
"time"
oauth2 "github.com/emersion/go-oauth2"
"github.com/golang-jwt/jwt/v5"
)
type OIDCProvider struct {
signingKeys []*oidcSigningKey
}
type oidcSigningKey struct {
key *SigningKey
private *rsa.PrivateKey
publicJWK jwk
}
type jwk struct {
Kty string `json:"kty"`
Use string `json:"use,omitempty"`
Alg string `json:"alg,omitempty"`
Kid string `json:"kid,omitempty"`
N string `json:"n,omitempty"`
E string `json:"e,omitempty"`
}
type jwks struct {
Keys []jwk `json:"keys"`
}
const idTokenTTL = 15 * time.Minute
func newOIDCProvider(ctx context.Context, db *DB) (*OIDCProvider, error) {
signingRecords, err := db.FetchSigningKeys(ctx)
if err == errNoDBRows {
generated, genErr := generateSigningKey()
if genErr != nil {
return nil, genErr
}
if storeErr := db.StoreSigningKey(ctx, generated); storeErr != nil {
return nil, fmt.Errorf("failed to persist signing key: %w", storeErr)
}
signingRecords = []SigningKey{*generated}
} else if err != nil {
return nil, fmt.Errorf("failed to fetch signing keys: %w", err)
}
signingKeys := make([]*oidcSigningKey, 0, len(signingRecords))
for i := range signingRecords {
material, convErr := toOIDCSigningKey(&signingRecords[i])
if convErr != nil {
return nil, convErr
}
signingKeys = append(signingKeys, material)
}
return &OIDCProvider{signingKeys: signingKeys}, nil
}
func generateSigningKey() (*SigningKey, error) {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, fmt.Errorf("failed to generate signing key: %w", err)
}
pemBlock := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(priv),
})
if pemBlock == nil {
return nil, fmt.Errorf("failed to encode signing key")
}
kid, err := generateUID()
if err != nil {
return nil, fmt.Errorf("failed to generate signing key ID: %w", err)
}
return &SigningKey{
KID: kid,
Algorithm: "RS256",
PrivateKey: pemBlock,
CreatedAt: time.Now(),
}, nil
}
func toOIDCSigningKey(signing *SigningKey) (*oidcSigningKey, error) {
block, _ := pem.Decode(signing.PrivateKey)
if block == nil {
return nil, fmt.Errorf("failed to decode signing key PEM")
}
priv, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse signing key: %w", err)
}
jwk := jwk{
Kty: "RSA",
Use: "sig",
Alg: signing.Algorithm,
Kid: signing.KID,
N: base64.RawURLEncoding.EncodeToString(priv.N.Bytes()),
}
e := big.NewInt(int64(priv.E)).Bytes()
jwk.E = base64.RawURLEncoding.EncodeToString(e)
return &oidcSigningKey{
key: signing,
private: priv,
publicJWK: jwk,
}, nil
}
func (op *OIDCProvider) currentSigningKey() *oidcSigningKey {
if len(op.signingKeys) == 0 {
return nil
}
return op.signingKeys[0]
}
func (op *OIDCProvider) signingMethod() (*jwt.SigningMethodRSA, *oidcSigningKey, error) {
key := op.currentSigningKey()
if key == nil {
return nil, nil, fmt.Errorf("no signing key configured")
}
switch key.key.Algorithm {
case "RS256":
return jwt.SigningMethodRS256, key, nil
default:
return nil, nil, fmt.Errorf("unsupported signing algorithm %q", key.key.Algorithm)
}
}
func (op *OIDCProvider) MintIDToken(issuer string, client *Client, user *User, token *AccessToken, scopes []string, nonce string, accessToken string, authTime time.Time) (string, error) {
method, signingKey, err := op.signingMethod()
if err != nil {
return "", err
}
now := time.Now()
expiresAt := now.Add(idTokenTTL)
if token.ExpiresAt.Before(expiresAt) {
expiresAt = token.ExpiresAt
}
if expiresAt.Before(now) {
expiresAt = now
}
claims := jwt.MapClaims{
"iss": issuer,
"sub": subjectForUser(user),
"aud": client.ClientID,
"exp": jwt.NewNumericDate(expiresAt),
"iat": jwt.NewNumericDate(now),
}
if !authTime.IsZero() {
claims["auth_time"] = jwt.NewNumericDate(authTime)
}
if nonce != "" {
claims["nonce"] = nonce
}
if accessToken != "" {
claims["at_hash"] = computeAtHash(accessToken)
}
if containsScope(scopes, scopeProfile) {
displayName := user.Name
if displayName == "" {
displayName = user.Username
}
claims["preferred_username"] = user.Username
claims["name"] = displayName
}
if containsScope(scopes, scopeEmail) && user.Email != "" {
claims["email"] = user.Email
claims["email_verified"] = false
}
tokenJWT := jwt.NewWithClaims(method, claims)
tokenJWT.Header["kid"] = signingKey.key.KID
return tokenJWT.SignedString(signingKey.private)
}
func (op *OIDCProvider) JWKS() jwks {
keys := make([]jwk, 0, len(op.signingKeys))
for _, key := range op.signingKeys {
keys = append(keys, key.publicJWK)
}
return jwks{Keys: keys}
}
func subjectForUser(user *User) string {
return strconv.FormatInt(int64(user.ID), 10)
}
func computeAtHash(accessToken string) string {
sum := sha256.Sum256([]byte(accessToken))
return base64.RawURLEncoding.EncodeToString(sum[:len(sum)/2])
}
func containsScope(scopes []string, scope string) bool {
for _, s := range scopes {
if strings.EqualFold(s, scope) {
return true
}
}
return false
}
func getOpenIDConfiguration(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
oidc := oidcProviderFromContext(ctx)
issuer := getIssuer(req)
currentKey := oidc.currentSigningKey()
scopes := make([]string, 0, len(allowedScopes))
for scope := range allowedScopes {
scopes = append(scopes, scope)
}
sort.Strings(scopes)
idTokenAlgs := []string{"RS256"}
if currentKey != nil && currentKey.key.Algorithm != "" {
idTokenAlgs = []string{currentKey.key.Algorithm}
}
config := map[string]interface{}{
"issuer": issuer,
"authorization_endpoint": issuer + "/authorize",
"token_endpoint": issuer + "/token",
"userinfo_endpoint": issuer + "/userinfo",
"jwks_uri": issuer + "/.well-known/jwks.json",
"response_types_supported": []string{string(oauth2.ResponseTypeCode)},
"response_modes_supported": []string{string(oauth2.ResponseModeQuery)},
"grant_types_supported": []string{string(oauth2.GrantTypeAuthorizationCode), string(oauth2.GrantTypeRefreshToken)},
"subject_types_supported": []string{"public"},
"id_token_signing_alg_values_supported": idTokenAlgs,
"scopes_supported": scopes,
"claims_supported": []string{"sub", "preferred_username", "name", "email", "email_verified"},
"token_endpoint_auth_methods_supported": []string{string(oauth2.AuthMethodNone), string(oauth2.AuthMethodClientSecretBasic)},
"introspection_endpoint": issuer + "/introspect",
"revocation_endpoint": issuer + "/revoke",
"authorization_response_iss_parameter_supported": true,
"claims_parameter_supported": false,
"request_parameter_supported": false,
"request_uri_parameter_supported": false,
"code_challenge_methods_supported": []string{pkceMethodPlain, pkceMethodS256},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(config)
}
func getOIDCJWKS(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
oidc := oidcProviderFromContext(ctx)
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Cache-Control", "public, max-age=300")
json.NewEncoder(w).Encode(oidc.JWKS())
}
func userInfo(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
db := dbFromContext(ctx)
if req.Method != http.MethodGet && req.Method != http.MethodPost {
w.Header().Set("Allow", "GET, POST")
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
return
}
tokenValue, err := bearerTokenFromRequest(req)
if err != nil {
writeBearerError(w, http.StatusUnauthorized, "invalid_token", err.Error())
return
}
tokenID, secret, err := UnmarshalSecret[*AccessToken](tokenValue)
if err != nil {
writeBearerError(w, http.StatusUnauthorized, "invalid_token", "Malformed access token")
return
}
token, err := db.FetchAccessToken(ctx, tokenID)
if err == errNoDBRows || (err == nil && !token.VerifySecret(secret)) {
writeBearerError(w, http.StatusUnauthorized, "invalid_token", "Invalid access token")
return
} else if err != nil {
httpError(w, fmt.Errorf("failed to fetch access token: %v", err))
return
}
scopes := parseScopes(token.Scope)
if !containsScope(scopes, scopeOpenID) {
writeBearerError(w, http.StatusForbidden, "insufficient_scope", "Scope openid missing")
return
}
user, err := db.FetchUser(ctx, token.User)
if err != nil {
httpError(w, fmt.Errorf("failed to fetch user: %v", err))
return
}
resp := map[string]interface{}{
"sub": subjectForUser(user),
}
if containsScope(scopes, scopeProfile) {
displayName := user.Name
if displayName == "" {
displayName = user.Username
}
resp["preferred_username"] = user.Username
resp["name"] = displayName
}
if containsScope(scopes, scopeEmail) && user.Email != "" {
resp["email"] = user.Email
resp["email_verified"] = false
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
func bearerTokenFromRequest(req *http.Request) (string, error) {
authz := req.Header.Get("Authorization")
if authz == "" {
return "", fmt.Errorf("Authorization header missing")
}
if len(authz) < 7 || !strings.EqualFold(authz[:7], "Bearer ") {
return "", fmt.Errorf("Unsupported authorization scheme")
}
token := strings.TrimSpace(authz[7:])
if token == "" {
return "", fmt.Errorf("Missing access token")
}
return token, nil
}
func writeBearerError(w http.ResponseWriter, status int, code, description string) {
challenge := "Bearer"
if code != "" {
challenge += fmt.Sprintf(" error=\"%s\"", code)
}
if description != "" {
if code == "" {
challenge += " "
} else {
challenge += ", "
}
challenge += fmt.Sprintf("error_description=\"%s\"", description)
}
w.Header().Set("WWW-Authenticate", challenge)
http.Error(w, http.StatusText(status), status)
}