From 69434566ea5a8a0d39802a62cc6b254ea8bd667c Mon Sep 17 00:00:00 2001 From: Runxi Yu Date: Thu, 25 Sep 2025 19:41:57 +0800 Subject: [PATCH] Implement basic OIDC and some fixes --- client.go | 25 ++++++++++++++++++++----- csrf.go | 106 +++++++++++++++++++++++++++++++++++++++++++++++++++++ db.go | 96 +++++++++++++++++++++++++++++++++++++++++++++++------ entity.go | 68 ++++++++++++++++++++++++++++++++++++++++-------------- go.mod | 1 + go.sum | 8 ++------ main.go | 12 ++++++++++-- middleware.go | 25 +++++++++++++++++++------ oauth2.go | 384 +++++++++++++++++++++++++++++++++++++++++++++++------ oidc.go | 373 +++++++++++++++++++++++++++++++++++++++++++++++++++++ pkce.go | 40 ++++++++++++++++++++++++++++++++++++++++ schema.sql | 21 +++++++++++++++++++-- static/style.css | 8 ++++---- template/authorize.html | 1 + template/head.html | 6 +++--- template/index.html | 12 +++++++++--- template/login.html | 1 + template/manage-client.html | 10 ++++++++++ template/manage-user.html | 9 +++++++++ user.go | 18 ++++++++++++------ diff --git a/client.go b/client.go index 63a914e76ef0393db9d2f8d4f8f7ccec55c568df..6d1c72ce6caad48fcbeef9220c4d0b4e03abc709 100644 --- a/client.go +++ b/client.go @@ -44,6 +44,12 @@ return } } + if normalized, err := normalizeClientPKCERequirement(client.PKCERequirement); err == nil { + client.PKCERequirement = normalized + } else { + client.PKCERequirement = pkceRequirementNone + } + if req.Method != http.MethodPost { data := struct { TemplateBaseData @@ -51,7 +57,7 @@ Client *Client }{ Client: client, } - tpl.MustExecuteTemplate(w, "manage-client.html", &data) + tpl.MustExecuteTemplate(req.Context(), w, "manage-client.html", &data) return } @@ -79,6 +85,17 @@ client.ClientName = req.PostFormValue("client_name") client.ClientURI = req.PostFormValue("client_uri") client.RedirectURIs = req.PostFormValue("redirect_uris") + pkceRequirement := req.PostFormValue("pkce_requirement") + if !isPublic { + pkceRequirement = pkceRequirementNone + } + normalizedRequirement, err := normalizeClientPKCERequirement(pkceRequirement) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + client.PKCERequirement = normalizedRequirement + if err := validateAllowedRedirectURIs(client.RedirectURIs); err != nil { // TODO: nicer error message http.Error(w, err.Error(), http.StatusBadRequest) @@ -113,7 +130,7 @@ }{ ClientID: client.ClientID, ClientSecret: clientSecret, } - tpl.MustExecuteTemplate(w, "client-secret.html", &data) + tpl.MustExecuteTemplate(req.Context(), w, "client-secret.html", &data) } func validateAllowedRedirectURIs(rawRedirectURIs string) error { @@ -130,9 +147,7 @@ switch u.Scheme { case "https": // ok case "http": - if u.Host != "localhost" { - return fmt.Errorf("Only http://localhost is allowed for insecure HTTP URIs") - } + // insecure but let's just trust the admin default: if !strings.Contains(u.Scheme, ".") { return fmt.Errorf("Only private-use URIs referring to domain names are allowed") diff --git a/csrf.go b/csrf.go new file mode 100644 index 0000000000000000000000000000000000000000..538484504a78aca42a99a5f8c1dee95bed5bab06 --- /dev/null +++ b/csrf.go @@ -0,0 +1,106 @@ +package main + +import ( + "context" + "crypto/subtle" + "net/http" + "strings" +) + +const ( + csrfCookieName = "vireo-csrf" + csrfFormField = "_csrf" +) + +func csrfTokenFromContext(ctx context.Context) string { + if ctx == nil { + return "" + } + if token, ok := ctx.Value(contextKeyCSRFToken).(string); ok { + return token + } + return "" +} + +func ensureCSRFToken(w http.ResponseWriter, req *http.Request) (string, error) { + if cookie, err := req.Cookie(csrfCookieName); err == nil && cookie.Value != "" { + return cookie.Value, nil + } + + token, err := generateUID() + if err != nil { + return "", err + } + + http.SetCookie(w, &http.Cookie{ + Name: csrfCookieName, + Value: token, + Path: "/", + HttpOnly: true, + SameSite: http.SameSiteStrictMode, + Secure: isForwardedHTTPS(req), + }) + + return token, nil +} + +func csrfMiddleware(next http.Handler) http.Handler { + cop := http.NewCrossOriginProtection() + cop.SetDenyHandler(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + http.Error(w, "Forbidden", http.StatusForbidden) + })) + + handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if isCSRFBypass(req) || req.Header.Get("Authorization") != "" { + next.ServeHTTP(w, req) + return + } + token, err := ensureCSRFToken(w, req) + if err != nil { + httpError(w, err) + return + } + + ctx := context.WithValue(req.Context(), contextKeyCSRFToken, token) + req = req.WithContext(ctx) + + switch req.Method { + case http.MethodGet, http.MethodHead, http.MethodOptions: + next.ServeHTTP(w, req) + return + } + + if err := req.ParseForm(); err != nil { + http.Error(w, "Bad request", http.StatusBadRequest) + return + } + + formToken := req.PostFormValue(csrfFormField) + if formToken == "" { + formToken = req.Header.Get("X-CSRF-Token") + } + if formToken == "" || subtle.ConstantTimeCompare([]byte(formToken), []byte(token)) != 1 { + http.Error(w, "Invalid CSRF token", http.StatusForbidden) + return + } + + next.ServeHTTP(w, req) + }) + + return cop.Handler(handler) +} + +func isCSRFBypass(req *http.Request) bool { + path := req.URL.Path + switch { + case path == "/token", + path == "/introspect", + path == "/revoke", + path == "/userinfo", + strings.HasPrefix(path, "/static/"), + strings.HasPrefix(path, "/.well-known/"), + path == "/favicon.ico": + return true + } + return false +} diff --git a/db.go b/db.go index 186f709b73904de64ba14652cd60a9bd9c6a1cd3..d42aaa8c22ce0a2b3f8667b2c748eaf8cd70aa8e 100644 --- a/db.go +++ b/db.go @@ -20,6 +20,42 @@ ALTER TABLE AccessToken ADD COLUMN refresh_hash BLOB; ALTER TABLE AccessToken ADD COLUMN refresh_expires_at datetime; CREATE UNIQUE INDEX access_token_refresh_hash ON AccessToken(refresh_hash); `, + ` + ALTER TABLE AuthCode ADD COLUMN nonce TEXT; + CREATE TABLE IF NOT EXISTS SigningKey ( + id INTEGER PRIMARY KEY, + kid TEXT NOT NULL UNIQUE, + algorithm TEXT NOT NULL, + private_key BLOB NOT NULL, + created_at datetime NOT NULL + ); + `, + ` + ALTER TABLE AccessToken ADD COLUMN auth_time datetime; + ALTER TABLE SigningKey RENAME TO SigningKey_old; + CREATE TABLE SigningKey ( + id INTEGER PRIMARY KEY, + kid TEXT NOT NULL UNIQUE, + algorithm TEXT NOT NULL, + private_key BLOB NOT NULL, + created_at datetime NOT NULL + ); + INSERT INTO SigningKey(id, kid, algorithm, private_key, created_at) + SELECT id, kid, algorithm, private_key, created_at FROM SigningKey_old; + DROP TABLE SigningKey_old; + CREATE INDEX IF NOT EXISTS signing_key_created_at ON SigningKey(created_at); + `, + ` + ALTER TABLE User ADD COLUMN email TEXT; + `, + ` + ALTER TABLE User ADD COLUMN name TEXT; + `, + ` + ALTER TABLE AuthCode ADD COLUMN code_challenge TEXT; + ALTER TABLE AuthCode ADD COLUMN code_challenge_method TEXT; + ALTER TABLE Client ADD COLUMN pkce_requirement TEXT; + `, } var errNoDBRows = sql.ErrNoRows @@ -55,7 +91,7 @@ return nil } // TODO: drop this - defaultUser := User{Username: "root", Admin: true} + defaultUser := User{Username: "root", Name: "Root User", Email: "root@example.invalid", Admin: true} if err := defaultUser.SetPassword("root"); err != nil { return err } @@ -126,10 +162,12 @@ } func (db *DB) StoreUser(ctx context.Context, user *User) error { return db.db.QueryRowContext(ctx, ` - INSERT INTO User(id, username, password_hash, admin) - VALUES (:id, :username, :password_hash, :admin) + INSERT INTO User(id, username, name, email, password_hash, admin) + VALUES (:id, :username, :name, :email, :password_hash, :admin) ON CONFLICT(id) DO UPDATE SET username = :username, + name = :name, + email = :email, password_hash = :password_hash, admin = :admin RETURNING id @@ -178,16 +216,17 @@ func (db *DB) StoreClient(ctx context.Context, client *Client) error { return db.db.QueryRowContext(ctx, ` INSERT INTO Client(id, client_id, client_secret_hash, owner, - redirect_uris, client_name, client_uri) + redirect_uris, client_name, client_uri, pkce_requirement) VALUES (:id, :client_id, :client_secret_hash, :owner, - :redirect_uris, :client_name, :client_uri) + :redirect_uris, :client_name, :client_uri, :pkce_requirement) ON CONFLICT(id) DO UPDATE SET client_id = :client_id, client_secret_hash = :client_secret_hash, owner = :owner, redirect_uris = :redirect_uris, client_name = :client_name, - client_uri = :client_uri + client_uri = :client_uri, + pkce_requirement = :pkce_requirement RETURNING id `, entityArgs(client)...).Scan(&client.ID) } @@ -264,9 +303,9 @@ func (db *DB) StoreAccessToken(ctx context.Context, token *AccessToken) error { return db.db.QueryRowContext(ctx, ` INSERT INTO AccessToken(id, hash, user, client, scope, issued_at, - expires_at, refresh_hash, refresh_expires_at) + expires_at, auth_time, refresh_hash, refresh_expires_at) VALUES (:id, :hash, :user, :client, :scope, :issued_at, :expires_at, - :refresh_hash, :refresh_expires_at) + :auth_time, :refresh_hash, :refresh_expires_at) ON CONFLICT(id) DO UPDATE SET hash = :hash, user = :user, @@ -274,6 +313,7 @@ client = :client, scope = :scope, issued_at = :issued_at, expires_at = :expires_at, + auth_time = :auth_time, refresh_hash = :refresh_hash, refresh_expires_at = :refresh_expires_at RETURNING id @@ -313,8 +353,8 @@ } func (db *DB) CreateAuthCode(ctx context.Context, code *AuthCode) error { return db.db.QueryRowContext(ctx, ` - INSERT INTO AuthCode(hash, created_at, user, client, scope, redirect_uri) - VALUES (:hash, :created_at, :user, :client, :scope, :redirect_uri) + INSERT INTO AuthCode(hash, created_at, user, client, scope, redirect_uri, nonce, code_challenge, code_challenge_method) + VALUES (:hash, :created_at, :user, :client, :scope, :redirect_uri, :nonce, :code_challenge, :code_challenge_method) RETURNING id `, entityArgs(code)...).Scan(&code.ID) } @@ -331,6 +371,42 @@ } var authCode AuthCode err = scanRow(&authCode, rows) return &authCode, err +} + +func (db *DB) FetchSigningKeys(ctx context.Context) ([]SigningKey, error) { + rows, err := db.db.QueryContext(ctx, ` + SELECT * FROM SigningKey + ORDER BY created_at DESC + `) + if err != nil { + return nil, err + } + defer rows.Close() + + var keys []SigningKey + for rows.Next() { + var key SigningKey + if err := scan(&key, rows); err != nil { + return nil, err + } + keys = append(keys, key) + } + + if err := rows.Err(); err != nil { + return nil, err + } + if len(keys) == 0 { + return nil, errNoDBRows + } + return keys, nil +} + +func (db *DB) StoreSigningKey(ctx context.Context, key *SigningKey) error { + return db.db.QueryRowContext(ctx, ` + INSERT INTO SigningKey(kid, algorithm, private_key, created_at) + VALUES (:kid, :algorithm, :private_key, :created_at) + RETURNING id + `, sql.Named("kid", key.KID), sql.Named("algorithm", key.Algorithm), sql.Named("private_key", key.PrivateKey), sql.Named("created_at", key.CreatedAt)).Scan(&key.ID) } func (db *DB) Maintain(ctx context.Context) error { diff --git a/entity.go b/entity.go index eceb92320ca9d2b89d4e4d28a0d5beb69072f4b0..4edc14a929046db58aa6984ad8992a7df196ea43 100644 --- a/entity.go +++ b/entity.go @@ -31,6 +31,7 @@ _ entity = (*User)(nil) _ entity = (*Client)(nil) _ entity = (*AccessToken)(nil) _ entity = (*AuthCode)(nil) + _ entity = (*SigningKey)(nil) ) type ID[T entity] int64 @@ -105,6 +106,8 @@ type User struct { ID ID[*User] Username string + Name string + Email string PasswordHash string Admin bool } @@ -113,6 +116,8 @@ func (user *User) columns() map[string]interface{} { return map[string]interface{}{ "id": &user.ID, "username": &user.Username, + "name": nullValue{&user.Name}, + "email": nullValue{&user.Email}, "password_hash": nullValue{&user.PasswordHash}, "admin": &user.Admin, } @@ -144,6 +149,7 @@ Owner ID[*User] RedirectURIs string ClientName string ClientURI string + PKCERequirement string } func (client *Client) Generate(isPublic bool) (secret string, err error) { @@ -174,6 +180,7 @@ "owner": &client.Owner, "redirect_uris": nullValue{&client.RedirectURIs}, "client_name": nullValue{&client.ClientName}, "client_uri": nullValue{&client.ClientURI}, + "pkce_requirement": nullValue{&client.PKCERequirement}, } } @@ -193,6 +200,7 @@ Client ID[*Client] Scope string IssuedAt time.Time ExpiresAt time.Time + AuthTime time.Time RefreshHash []byte RefreshExpiresAt time.Time @@ -221,9 +229,10 @@ } func NewAccessTokenFromAuthCode(authCode *AuthCode) *AccessToken { return &AccessToken{ - User: authCode.User, - Client: authCode.Client, - Scope: authCode.Scope, + User: authCode.User, + Client: authCode.Client, + Scope: authCode.Scope, + AuthTime: authCode.CreatedAt, } } @@ -236,6 +245,7 @@ "client": &token.Client, "scope": nullValue{&token.Scope}, "issued_at": &token.IssuedAt, "expires_at": &token.ExpiresAt, + "auth_time": nullValue{&token.AuthTime}, "refresh_hash": &token.RefreshHash, "refresh_expires_at": nullValue{&token.RefreshExpiresAt}, } @@ -255,13 +265,16 @@ ExpiresAt time.Time } type AuthCode struct { - ID ID[*AuthCode] - Hash []byte - CreatedAt time.Time - User ID[*User] - Client ID[*Client] - Scope string - RedirectURI string + ID ID[*AuthCode] + Hash []byte + CreatedAt time.Time + User ID[*User] + Client ID[*Client] + Scope string + RedirectURI string + Nonce string + CodeChallenge string + CodeChallengeMethod string } func (code *AuthCode) Generate() (secret string, err error) { @@ -276,13 +289,16 @@ } func (code *AuthCode) columns() map[string]interface{} { return map[string]interface{}{ - "id": &code.ID, - "hash": &code.Hash, - "created_at": &code.CreatedAt, - "user": &code.User, - "client": &code.Client, - "scope": nullValue{&code.Scope}, - "redirect_uri": nullValue{&code.RedirectURI}, + "id": &code.ID, + "hash": &code.Hash, + "created_at": &code.CreatedAt, + "user": &code.User, + "client": &code.Client, + "scope": nullValue{&code.Scope}, + "redirect_uri": nullValue{&code.RedirectURI}, + "nonce": nullValue{&code.Nonce}, + "code_challenge": nullValue{&code.CodeChallenge}, + "code_challenge_method": nullValue{&code.CodeChallengeMethod}, } } @@ -297,6 +313,24 @@ SecretKindAccessToken = SecretKind('a') SecretKindRefreshToken = SecretKind('r') SecretKindAuthCode = SecretKind('c') ) + +type SigningKey struct { + ID ID[*SigningKey] + KID string + Algorithm string + PrivateKey []byte + CreatedAt time.Time +} + +func (key *SigningKey) columns() map[string]interface{} { + return map[string]interface{}{ + "id": &key.ID, + "kid": &key.KID, + "algorithm": &key.Algorithm, + "private_key": &key.PrivateKey, + "created_at": &key.CreatedAt, + } +} func UnmarshalSecret[T entity](s string) (id ID[T], secret string, err error) { kind, s, _ := strings.Cut(s, ".") diff --git a/go.mod b/go.mod index 6d54256cb8f5dc70ee73ef044dc0835ad24dd4e5..995cb68508c65a99e25856df5bc6fd1b25007551 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( codeberg.org/emersion/go-scfg v0.1.0 github.com/emersion/go-oauth2 v0.0.0-20250228145955-eaead4148315 github.com/go-chi/chi/v5 v5.2.3 + github.com/golang-jwt/jwt/v5 v5.2.1 github.com/mattn/go-sqlite3 v1.14.32 golang.org/x/crypto v0.42.0 ) diff --git a/go.sum b/go.sum index 3b0c6e5fc5d6e7caca72a293a9f8cb6ede24fa96..30e4827711921e981d2e0bab25d419ff99b2e1c9 100644 --- a/go.sum +++ b/go.sum @@ -4,15 +4,11 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/emersion/go-oauth2 v0.0.0-20250228145955-eaead4148315 h1:sXzwA8yItbg3ji0UuTLkuO4NKPqQJjC035hPoZI40h8= github.com/emersion/go-oauth2 v0.0.0-20250228145955-eaead4148315/go.mod h1:pSj8CBn/jb+ynRxt/ESIJisazza/Sh2DjwUn31l2tI0= -github.com/go-chi/chi/v5 v5.2.1 h1:KOIHODQj58PmL80G2Eak4WdvUzjSJSm0vG72crDCqb8= -github.com/go-chi/chi/v5 v5.2.1/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE= github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= -github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= -github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= -golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI= golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8= diff --git a/main.go b/main.go index ee309209629eba3ae13ae71a4661c74fd89d8238..c9b938fbb4c571105430a3a6b3a4116e90441468 100644 --- a/main.go +++ b/main.go @@ -50,6 +50,11 @@ if err != nil { log.Fatalf("Failed to load template: %v", err) } + oidcProvider, err := newOIDCProvider(context.Background(), db) + if err != nil { + log.Fatalf("Failed to initialize OpenID Connect provider: %v", err) + } + mux := chi.NewRouter() mux.Handle("/static/*", http.FileServer(http.FS(staticFS))) mux.Get("/", index) @@ -61,18 +66,21 @@ mux.Post("/client/{id}/revoke", revokeClient) mux.HandleFunc("/user/new", manageUser) mux.HandleFunc("/user/{id}", manageUser) mux.Get("/.well-known/oauth-authorization-server", getOAuthServerMetadata) + mux.Get("/.well-known/openid-configuration", getOpenIDConfiguration) + mux.Get("/.well-known/jwks.json", getOIDCJWKS) mux.HandleFunc("/authorize", authorize) mux.Post("/token", exchangeToken) mux.Post("/introspect", introspectToken) mux.Post("/revoke", revokeToken) + mux.HandleFunc("/userinfo", userInfo) go maintainDBLoop(db) server := http.Server{ Addr: listenAddr, - Handler: loginTokenMiddleware(mux), + Handler: csrfMiddleware(loginTokenMiddleware(mux)), BaseContext: func(net.Listener) context.Context { - return newBaseContext(db, tpl) + return newBaseContext(db, tpl, oidcProvider) }, } log.Printf("OAuth server listening on %v", server.Addr) diff --git a/middleware.go b/middleware.go index 3f6132095db3707811d56785f845d1e20bdf1d9f..ca550ffad78caf32ee9391db7a07cd27083c70f9 100644 --- a/middleware.go +++ b/middleware.go @@ -21,6 +21,8 @@ const ( contextKeyDB = "db" contextKeyTemplate = "template" contextKeyLoginToken = "login-token" + contextKeyOIDC = "oidc" + contextKeyCSRFToken = "csrf-token" ) func dbFromContext(ctx context.Context) *DB { @@ -31,6 +33,10 @@ func templateFromContext(ctx context.Context) *Template { return ctx.Value(contextKeyTemplate).(*Template) } +func oidcProviderFromContext(ctx context.Context) *OIDCProvider { + return ctx.Value(contextKeyOIDC).(*OIDCProvider) +} + func loginTokenFromContext(ctx context.Context) *AccessToken { v := ctx.Value(contextKeyLoginToken) if v == nil { @@ -39,10 +45,11 @@ } return v.(*AccessToken) } -func newBaseContext(db *DB, tpl *Template) context.Context { +func newBaseContext(db *DB, tpl *Template, oidc *OIDCProvider) context.Context { ctx := context.Background() ctx = context.WithValue(ctx, contextKeyDB, db) ctx = context.WithValue(ctx, contextKeyTemplate, tpl) + ctx = context.WithValue(ctx, contextKeyOIDC, oidc) return ctx } @@ -51,7 +58,7 @@ http.SetCookie(w, &http.Cookie{ Name: loginCookieName, Value: MarshalSecret(token.ID, SecretKindAccessToken, secret), HttpOnly: true, - SameSite: http.SameSiteStrictMode, + SameSite: http.SameSiteLaxMode, Secure: isForwardedHTTPS(req), }) } @@ -60,7 +67,7 @@ func unsetLoginTokenCookie(w http.ResponseWriter, req *http.Request) { http.SetCookie(w, &http.Cookie{ Name: loginCookieName, HttpOnly: true, - SameSite: http.SameSiteStrictMode, + SameSite: http.SameSiteLaxMode, Secure: isForwardedHTTPS(req), MaxAge: -1, }) @@ -114,6 +121,7 @@ } type TemplateBaseData struct { ServerName string + CSRFToken string } func (data *TemplateBaseData) Base() *TemplateBaseData { @@ -137,11 +145,16 @@ } return &Template{tpl: tpl, baseData: baseData}, nil } -func (tpl *Template) MustExecuteTemplate(w io.Writer, filename string, data TemplateData) { +func (tpl *Template) MustExecuteTemplate(ctx context.Context, w io.Writer, filename string, data TemplateData) { + baseCopy := *tpl.baseData + if token := csrfTokenFromContext(ctx); token != "" { + baseCopy.CSRFToken = token + } if data == nil { - data = tpl.baseData + base := baseCopy + data = &base } else { - *data.Base() = *tpl.baseData + *data.Base() = baseCopy } if err := tpl.tpl.ExecuteTemplate(w, filename, data); err != nil { panic(err) diff --git a/oauth2.go b/oauth2.go index 4654df2e785a464559ae4729b711219d2430e058..cd4fdcbde59bbf91548db913b330e0df97c1ecb1 100644 --- a/oauth2.go +++ b/oauth2.go @@ -1,6 +1,8 @@ package main import ( + "crypto/sha256" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -16,22 +18,151 @@ "github.com/emersion/go-oauth2" ) +const ( + scopeOpenID = "openid" + scopeProfile = "profile" + scopeEmail = "email" + scopeOfflineAccess = "offline_access" + pkceMethodPlain = "plain" + pkceMethodS256 = "S256" +) + +var allowedScopes = map[string]struct{}{ + scopeOpenID: {}, + scopeProfile: {}, + scopeEmail: {}, + scopeOfflineAccess: {}, +} + +type oidcTokenResponse struct { + AccessToken string `json:"access_token"` + TokenType oauth2.TokenType `json:"token_type"` + ExpiresIn int64 `json:"expires_in,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` + Scope string `json:"scope,omitempty"` + IDToken string `json:"id_token,omitempty"` +} + +func parseScopes(scope string) []string { + if scope == "" { + return nil + } + parts := strings.Fields(scope) + var scopes []string + seen := make(map[string]struct{}, len(parts)) + for _, p := range parts { + if p == "" { + continue + } + p = strings.ToLower(p) + if _, ok := seen[p]; ok { + continue + } + seen[p] = struct{}{} + scopes = append(scopes, p) + } + return scopes +} + +func normalizeScope(scope string) (string, []string) { + scopes := parseScopes(scope) + if len(scopes) == 0 { + return "", nil + } + return strings.Join(scopes, " "), scopes +} + +func validateScopes(scopes []string) error { + for _, scope := range scopes { + if _, ok := allowedScopes[scope]; !ok { + return fmt.Errorf("unsupported scope %q", scope) + } + } + return nil +} + +func normalizeCodeChallengeMethod(method string) (string, error) { + if method == "" { + return pkceMethodPlain, nil + } + switch { + case strings.EqualFold(method, pkceMethodPlain): + return pkceMethodPlain, nil + case strings.EqualFold(method, pkceMethodS256): + return pkceMethodS256, nil + default: + return "", fmt.Errorf("unsupported code_challenge_method") + } +} + +func validateCodeVerifier(verifier string) error { + if verifier == "" { + return fmt.Errorf("missing code_verifier") + } + if len(verifier) < 43 || len(verifier) > 128 { + return fmt.Errorf("invalid code_verifier length") + } + for _, r := range verifier { + if !(r >= 'A' && r <= 'Z' || r >= 'a' && r <= 'z' || r >= '0' && r <= '9' || r == '-' || r == '.' || r == '_' || r == '~') { + return fmt.Errorf("invalid character in code_verifier") + } + } + return nil +} + +func validateCodeChallenge(method, challenge string) error { + if challenge == "" { + return fmt.Errorf("missing code_challenge") + } + if err := validateCodeVerifier(challenge); err != nil { + return err + } + if method != pkceMethodPlain && method != pkceMethodS256 { + return fmt.Errorf("unsupported code_challenge_method") + } + return nil +} + +func verifyCodeVerifier(method, challenge, verifier string) error { + if err := validateCodeVerifier(verifier); err != nil { + return err + } + switch method { + case "", pkceMethodPlain: + if challenge != verifier { + return fmt.Errorf("code_verifier mismatch") + } + case pkceMethodS256: + hash := sha256.Sum256([]byte(verifier)) + expected := base64.RawURLEncoding.EncodeToString(hash[:]) + if expected != challenge { + return fmt.Errorf("code_verifier mismatch") + } + default: + return fmt.Errorf("unsupported code_challenge_method") + } + return nil +} + func getOAuthServerMetadata(w http.ResponseWriter, req *http.Request) { issuer := getIssuer(req) w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(&oauth2.ServerMetadata{ - Issuer: issuer, - AuthorizationEndpoint: issuer + "/authorize", - TokenEndpoint: issuer + "/token", - IntrospectionEndpoint: issuer + "/introspect", - RevocationEndpoint: issuer + "/revoke", - ResponseTypesSupported: []oauth2.ResponseType{oauth2.ResponseTypeCode}, - ResponseModesSupported: []oauth2.ResponseMode{oauth2.ResponseModeQuery}, - GrantTypesSupported: []oauth2.GrantType{oauth2.GrantTypeAuthorizationCode}, - TokenEndpointAuthMethodsSupported: []oauth2.AuthMethod{oauth2.AuthMethodNone, oauth2.AuthMethodClientSecretBasic}, + Issuer: issuer, + AuthorizationEndpoint: issuer + "/authorize", + TokenEndpoint: issuer + "/token", + IntrospectionEndpoint: issuer + "/introspect", + RevocationEndpoint: issuer + "/revoke", + JWKSURI: issuer + "/.well-known/jwks.json", + ScopesSupported: []string{scopeOpenID, scopeProfile, scopeEmail, scopeOfflineAccess}, + ResponseTypesSupported: []oauth2.ResponseType{oauth2.ResponseTypeCode}, + ResponseModesSupported: []oauth2.ResponseMode{oauth2.ResponseModeQuery}, + GrantTypesSupported: []oauth2.GrantType{oauth2.GrantTypeAuthorizationCode, oauth2.GrantTypeRefreshToken}, + TokenEndpointAuthMethodsSupported: []oauth2.AuthMethod{oauth2.AuthMethodNone, oauth2.AuthMethodClientSecretBasic}, IntrospectionEndpointAuthMethodsSupported: []oauth2.AuthMethod{oauth2.AuthMethodNone, oauth2.AuthMethodClientSecretBasic}, RevocationEndpointAuthMethodsSupported: []oauth2.AuthMethod{oauth2.AuthMethodNone, oauth2.AuthMethodClientSecretBasic}, + CodeChallengeMethodsSupported: []string{pkceMethodPlain, pkceMethodS256}, AuthorizationResponseIssParameterSupported: true, }) } @@ -65,6 +196,12 @@ clientID := q.Get("client_id") rawRedirectURI := q.Get("redirect_uri") scope := q.Get("scope") state := q.Get("state") + _, stateProvided := q["state"] + codeChallenge := q.Get("code_challenge") + codeChallengeMethod := q.Get("code_challenge_method") + + var normalizedCodeChallengeMethod string + nonce := q.Get("nonce") if clientID == "" { http.Error(w, "Missing client ID", http.StatusBadRequest) @@ -80,6 +217,12 @@ httpError(w, fmt.Errorf("failed to fetch client: %v", err)) return } + requiredPKCE, err := normalizeClientPKCERequirement(client.PKCERequirement) + if err != nil { + httpError(w, fmt.Errorf("invalid PKCE requirement configuration: %v", err)) + return + } + var allowedRedirectURIs []*url.URL for _, s := range strings.Split(client.RedirectURIs, "\n") { if s == "" { @@ -112,20 +255,98 @@ } redirectURI = allowedRedirectURIs[0] } + if codeChallenge != "" { + method, err := normalizeCodeChallengeMethod(codeChallengeMethod) + if err != nil { + redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{ + Code: oauth2.ErrorCodeInvalidRequest, + Description: err.Error(), + }) + return + } + if err := validateCodeChallenge(method, codeChallenge); err != nil { + redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{ + Code: oauth2.ErrorCodeInvalidRequest, + Description: err.Error(), + }) + return + } + normalizedCodeChallengeMethod = method + } + + if codeChallenge == "" && codeChallengeMethod != "" { + redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{ + Code: oauth2.ErrorCodeInvalidRequest, + Description: "code_challenge_method without code_challenge", + }) + return + } + + switch requiredPKCE { + case pkceRequirementPlain: + if codeChallenge == "" { + redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{ + Code: oauth2.ErrorCodeInvalidRequest, + Description: "PKCE is required", + }) + return + } + if !allowPKCERequirement(normalizedCodeChallengeMethod, pkceRequirementPlain) { + redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{ + Code: oauth2.ErrorCodeInvalidRequest, + Description: "PKCE method does not satisfy requirement", + }) + return + } + case pkceRequirementS256: + if codeChallenge == "" { + redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{ + Code: oauth2.ErrorCodeInvalidRequest, + Description: "PKCE (S256) is required", + }) + return + } + if !allowPKCERequirement(normalizedCodeChallengeMethod, pkceRequirementS256) { + redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{ + Code: oauth2.ErrorCodeInvalidRequest, + Description: "PKCE (S256) is required", + }) + return + } + } + + codeChallengeMethod = normalizedCodeChallengeMethod + if respType != oauth2.ResponseTypeCode { - redirectClientError(w, req, redirectURI, state, &oauth2.Error{ + redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{ Code: oauth2.ErrorCodeUnsupportedResponseType, }) return } - // TODO: add support for scope - if scope != "" { - redirectClientError(w, req, redirectURI, state, &oauth2.Error{ - Code: oauth2.ErrorCodeInvalidScope, + normalizedScope, scopes := normalizeScope(scope) + if len(scopes) == 0 { + redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{ + Code: oauth2.ErrorCodeInvalidScope, + Description: "Missing required openid scope", }) return } + if err := validateScopes(scopes); err != nil { + redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{ + Code: oauth2.ErrorCodeInvalidScope, + Description: err.Error(), + }) + return + } + if !containsScope(scopes, scopeOpenID) { + redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{ + Code: oauth2.ErrorCodeInvalidScope, + Description: "Scope openid is required", + }) + return + } + scope = normalizedScope loginToken := loginTokenFromContext(ctx) if loginToken == nil { @@ -141,7 +362,7 @@ } _ = req.ParseForm() if _, ok := req.PostForm["deny"]; ok { - redirectClientError(w, req, redirectURI, state, &oauth2.Error{ + redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{ Code: oauth2.ErrorCodeAccessDenied, }) return @@ -153,15 +374,18 @@ Client *Client }{ Client: client, } - tpl.MustExecuteTemplate(w, "authorize.html", &data) + tpl.MustExecuteTemplate(req.Context(), w, "authorize.html", &data) return } authCode := AuthCode{ - User: loginToken.User, - Client: client.ID, - Scope: scope, - RedirectURI: rawRedirectURI, + User: loginToken.User, + Client: client.ID, + Scope: scope, + RedirectURI: rawRedirectURI, + Nonce: nonce, + CodeChallenge: codeChallenge, + CodeChallengeMethod: codeChallengeMethod, } secret, err := authCode.Generate() if err != nil { @@ -178,7 +402,7 @@ code := MarshalSecret(authCode.ID, SecretKindAuthCode, secret) values := make(url.Values) values.Set("code", code) - if state != "" { + if stateProvided { values.Set("state", state) } redirectClient(w, req, redirectURI, values) @@ -200,6 +424,7 @@ clientID := values.Get("client_id") grantType := oauth2.GrantType(values.Get("grant_type")) scope := values.Get("scope") + codeVerifier := values.Get("code_verifier") authClientID, clientSecret, _ := req.BasicAuth() if clientID == "" { @@ -235,7 +460,13 @@ return } } - var token *AccessToken + var ( + token *AccessToken + authorizationCode *AuthCode + currentClient *Client + nonceValue string + ) + switch grantType { case oauth2.GrantTypeAuthorizationCode: if client == nil { @@ -247,8 +478,8 @@ return } codeID, codeSecret, _ := UnmarshalSecret[*AuthCode](values.Get("code")) - authCode, err := db.PopAuthCode(ctx, codeID) - if err == errNoDBRows || (err == nil && !authCode.VerifySecret(codeSecret)) || authCode.Client != client.ID { + authorizationCode, err = db.PopAuthCode(ctx, codeID) + if err == errNoDBRows || (err == nil && !authorizationCode.VerifySecret(codeSecret)) || authorizationCode.Client != client.ID { oauthError(w, &oauth2.Error{ Code: oauth2.ErrorCodeAccessDenied, Description: "Invalid authorization code", @@ -259,22 +490,46 @@ oauthError(w, fmt.Errorf("failed to fetch authorization code: %v", err)) return } - if scope != authCode.Scope { + if scope != "" && scope != authorizationCode.Scope { oauthError(w, &oauth2.Error{ Code: oauth2.ErrorCodeAccessDenied, Description: "Invalid scope", }) return } - if values.Get("redirect_uri") != authCode.RedirectURI { + if values.Get("redirect_uri") != authorizationCode.RedirectURI { oauthError(w, &oauth2.Error{ Code: oauth2.ErrorCodeAccessDenied, Description: "Invalid redirect URI", }) return } + if authorizationCode.CodeChallenge != "" { + if codeVerifier == "" { + oauthError(w, &oauth2.Error{ + Code: oauth2.ErrorCodeInvalidRequest, + Description: "Missing code_verifier", + }) + return + } + if err := verifyCodeVerifier(authorizationCode.CodeChallengeMethod, authorizationCode.CodeChallenge, codeVerifier); err != nil { + oauthError(w, &oauth2.Error{ + Code: oauth2.ErrorCodeInvalidGrant, + Description: "Invalid code_verifier", + }) + return + } + } else if codeVerifier != "" { + oauthError(w, &oauth2.Error{ + Code: oauth2.ErrorCodeInvalidRequest, + Description: "Unexpected code_verifier", + }) + return + } - token = NewAccessTokenFromAuthCode(authCode) + token = NewAccessTokenFromAuthCode(authorizationCode) + currentClient = client + nonceValue = authorizationCode.Nonce case oauth2.GrantTypeRefreshToken: tokenID, refreshSecret, _ := UnmarshalSecret[*AccessToken](values.Get("refresh_token")) token, err = db.FetchAccessToken(ctx, tokenID) @@ -297,13 +552,13 @@ }) return } - tokenClient, err := db.FetchClient(ctx, token.Client) + currentClient, err = db.FetchClient(ctx, token.Client) if err != nil { oauthError(w, fmt.Errorf("failed to fetch client: %v", err)) return } - if !tokenClient.IsPublic() && client == nil { + if !currentClient.IsPublic() && client == nil { oauthError(w, &oauth2.Error{ Code: oauth2.ErrorCodeAccessDenied, Description: "Invalid client secret", @@ -311,7 +566,7 @@ }) return } - if scope != token.Scope { + if scope != "" && scope != token.Scope { oauthError(w, &oauth2.Error{ Code: oauth2.ErrorCodeAccessDenied, Description: "Invalid scope", @@ -330,10 +585,19 @@ if err != nil { oauthError(w, err) return } - refreshSecret, err := token.GenerateRefresh() - if err != nil { - oauthError(w, err) - return + + tokenScopes := parseScopes(token.Scope) + issueRefresh := containsScope(tokenScopes, scopeOfflineAccess) + var refreshSecret string + if issueRefresh { + refreshSecret, err = token.GenerateRefresh() + if err != nil { + oauthError(w, err) + return + } + } else { + token.RefreshHash = nil + token.RefreshExpiresAt = time.Time{} } if err := db.StoreAccessToken(ctx, token); err != nil { @@ -341,15 +605,49 @@ oauthError(w, fmt.Errorf("failed to create access token: %v", err)) return } + accessTokenValue := MarshalSecret(token.ID, SecretKindAccessToken, secret) + if token.AuthTime.IsZero() { + token.AuthTime = token.IssuedAt + } + + var idToken string + if containsScope(tokenScopes, scopeOpenID) { + if currentClient == nil { + currentClient, err = db.FetchClient(ctx, token.Client) + if err != nil { + oauthError(w, fmt.Errorf("failed to fetch client: %v", err)) + return + } + } + user, err := db.FetchUser(ctx, token.User) + if err != nil { + oauthError(w, fmt.Errorf("failed to fetch user: %v", err)) + return + } + + issuer := getIssuer(req) + oidcProvider := oidcProviderFromContext(ctx) + idToken, err = oidcProvider.MintIDToken(issuer, currentClient, user, token, tokenScopes, nonceValue, accessTokenValue, token.AuthTime) + if err != nil { + oauthError(w, fmt.Errorf("failed to mint ID token: %v", err)) + return + } + } + w.Header().Set("Content-Type", "application/json") w.Header().Set("Cache-Control", "no-store") - json.NewEncoder(w).Encode(&oauth2.TokenResp{ - AccessToken: MarshalSecret(token.ID, SecretKindAccessToken, secret), - TokenType: oauth2.TokenTypeBearer, - ExpiresIn: time.Until(token.ExpiresAt), - Scope: strings.Split(token.Scope, " "), - RefreshToken: MarshalSecret(token.ID, SecretKindRefreshToken, refreshSecret), - }) + resp := oidcTokenResponse{ + AccessToken: accessTokenValue, + TokenType: oauth2.TokenTypeBearer, + ExpiresIn: int64(time.Until(token.ExpiresAt).Seconds()), + Scope: token.Scope, + IDToken: idToken, + } + if issueRefresh { + refreshTokenValue := MarshalSecret(token.ID, SecretKindRefreshToken, refreshSecret) + resp.RefreshToken = refreshTokenValue + } + json.NewEncoder(w).Encode(&resp) } func introspectToken(w http.ResponseWriter, req *http.Request) { @@ -542,7 +840,7 @@ http.Redirect(w, req, u.String(), http.StatusFound) } -func redirectClientError(w http.ResponseWriter, req *http.Request, redirectURI *url.URL, state string, err error) { +func redirectClientError(w http.ResponseWriter, req *http.Request, redirectURI *url.URL, state string, stateProvided bool, err error) { var oauthErr *oauth2.Error if !errors.As(err, &oauthErr) { oauthErr = &oauth2.Error{Code: oauth2.ErrorCodeServerError} @@ -557,7 +855,7 @@ } if oauthErr.URI != "" { values.Set("error_uri", oauthErr.URI) } - if state != "" { + if stateProvided { values.Set("state", state) } redirectClient(w, req, redirectURI, values) diff --git a/oidc.go b/oidc.go new file mode 100644 index 0000000000000000000000000000000000000000..3dfbe9eeb09ece37b41b67b38d6c28244e1d7ee1 --- /dev/null +++ b/oidc.go @@ -0,0 +1,373 @@ +package main + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/base64" + "encoding/json" + "encoding/pem" + "fmt" + "math/big" + "net/http" + "sort" + "strconv" + "strings" + "time" + + oauth2 "github.com/emersion/go-oauth2" + "github.com/golang-jwt/jwt/v5" +) + +type OIDCProvider struct { + signingKeys []*oidcSigningKey +} + +type oidcSigningKey struct { + key *SigningKey + private *rsa.PrivateKey + publicJWK jwk +} + +type jwk struct { + Kty string `json:"kty"` + Use string `json:"use,omitempty"` + Alg string `json:"alg,omitempty"` + Kid string `json:"kid,omitempty"` + N string `json:"n,omitempty"` + E string `json:"e,omitempty"` +} + +type jwks struct { + Keys []jwk `json:"keys"` +} + +const idTokenTTL = 15 * time.Minute + +func newOIDCProvider(ctx context.Context, db *DB) (*OIDCProvider, error) { + signingRecords, err := db.FetchSigningKeys(ctx) + if err == errNoDBRows { + generated, genErr := generateSigningKey() + if genErr != nil { + return nil, genErr + } + if storeErr := db.StoreSigningKey(ctx, generated); storeErr != nil { + return nil, fmt.Errorf("failed to persist signing key: %w", storeErr) + } + signingRecords = []SigningKey{*generated} + } else if err != nil { + return nil, fmt.Errorf("failed to fetch signing keys: %w", err) + } + + signingKeys := make([]*oidcSigningKey, 0, len(signingRecords)) + for i := range signingRecords { + material, convErr := toOIDCSigningKey(&signingRecords[i]) + if convErr != nil { + return nil, convErr + } + signingKeys = append(signingKeys, material) + } + + return &OIDCProvider{signingKeys: signingKeys}, nil +} + +func generateSigningKey() (*SigningKey, error) { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, fmt.Errorf("failed to generate signing key: %w", err) + } + + pemBlock := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(priv), + }) + if pemBlock == nil { + return nil, fmt.Errorf("failed to encode signing key") + } + + kid, err := generateUID() + if err != nil { + return nil, fmt.Errorf("failed to generate signing key ID: %w", err) + } + return &SigningKey{ + KID: kid, + Algorithm: "RS256", + PrivateKey: pemBlock, + CreatedAt: time.Now(), + }, nil +} + +func toOIDCSigningKey(signing *SigningKey) (*oidcSigningKey, error) { + block, _ := pem.Decode(signing.PrivateKey) + if block == nil { + return nil, fmt.Errorf("failed to decode signing key PEM") + } + priv, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse signing key: %w", err) + } + + jwk := jwk{ + Kty: "RSA", + Use: "sig", + Alg: signing.Algorithm, + Kid: signing.KID, + N: base64.RawURLEncoding.EncodeToString(priv.N.Bytes()), + } + + e := big.NewInt(int64(priv.E)).Bytes() + jwk.E = base64.RawURLEncoding.EncodeToString(e) + + return &oidcSigningKey{ + key: signing, + private: priv, + publicJWK: jwk, + }, nil +} + +func (op *OIDCProvider) currentSigningKey() *oidcSigningKey { + if len(op.signingKeys) == 0 { + return nil + } + return op.signingKeys[0] +} + +func (op *OIDCProvider) signingMethod() (*jwt.SigningMethodRSA, *oidcSigningKey, error) { + key := op.currentSigningKey() + if key == nil { + return nil, nil, fmt.Errorf("no signing key configured") + } + + switch key.key.Algorithm { + case "RS256": + return jwt.SigningMethodRS256, key, nil + default: + return nil, nil, fmt.Errorf("unsupported signing algorithm %q", key.key.Algorithm) + } +} + +func (op *OIDCProvider) MintIDToken(issuer string, client *Client, user *User, token *AccessToken, scopes []string, nonce string, accessToken string, authTime time.Time) (string, error) { + method, signingKey, err := op.signingMethod() + if err != nil { + return "", err + } + + now := time.Now() + expiresAt := now.Add(idTokenTTL) + if token.ExpiresAt.Before(expiresAt) { + expiresAt = token.ExpiresAt + } + if expiresAt.Before(now) { + expiresAt = now + } + + claims := jwt.MapClaims{ + "iss": issuer, + "sub": subjectForUser(user), + "aud": client.ClientID, + "exp": jwt.NewNumericDate(expiresAt), + "iat": jwt.NewNumericDate(now), + } + if !authTime.IsZero() { + claims["auth_time"] = jwt.NewNumericDate(authTime) + } + if nonce != "" { + claims["nonce"] = nonce + } + if accessToken != "" { + claims["at_hash"] = computeAtHash(accessToken) + } + if containsScope(scopes, scopeProfile) { + displayName := user.Name + if displayName == "" { + displayName = user.Username + } + claims["preferred_username"] = user.Username + claims["name"] = displayName + } + if containsScope(scopes, scopeEmail) && user.Email != "" { + claims["email"] = user.Email + claims["email_verified"] = false + } + + tokenJWT := jwt.NewWithClaims(method, claims) + tokenJWT.Header["kid"] = signingKey.key.KID + + return tokenJWT.SignedString(signingKey.private) +} + +func (op *OIDCProvider) JWKS() jwks { + keys := make([]jwk, 0, len(op.signingKeys)) + for _, key := range op.signingKeys { + keys = append(keys, key.publicJWK) + } + return jwks{Keys: keys} +} + +func subjectForUser(user *User) string { + return strconv.FormatInt(int64(user.ID), 10) +} + +func computeAtHash(accessToken string) string { + sum := sha256.Sum256([]byte(accessToken)) + return base64.RawURLEncoding.EncodeToString(sum[:len(sum)/2]) +} + +func containsScope(scopes []string, scope string) bool { + for _, s := range scopes { + if strings.EqualFold(s, scope) { + return true + } + } + return false +} + +func getOpenIDConfiguration(w http.ResponseWriter, req *http.Request) { + ctx := req.Context() + oidc := oidcProviderFromContext(ctx) + issuer := getIssuer(req) + currentKey := oidc.currentSigningKey() + + scopes := make([]string, 0, len(allowedScopes)) + for scope := range allowedScopes { + scopes = append(scopes, scope) + } + sort.Strings(scopes) + + idTokenAlgs := []string{"RS256"} + if currentKey != nil && currentKey.key.Algorithm != "" { + idTokenAlgs = []string{currentKey.key.Algorithm} + } + + config := map[string]interface{}{ + "issuer": issuer, + "authorization_endpoint": issuer + "/authorize", + "token_endpoint": issuer + "/token", + "userinfo_endpoint": issuer + "/userinfo", + "jwks_uri": issuer + "/.well-known/jwks.json", + "response_types_supported": []string{string(oauth2.ResponseTypeCode)}, + "response_modes_supported": []string{string(oauth2.ResponseModeQuery)}, + "grant_types_supported": []string{string(oauth2.GrantTypeAuthorizationCode), string(oauth2.GrantTypeRefreshToken)}, + "subject_types_supported": []string{"public"}, + "id_token_signing_alg_values_supported": idTokenAlgs, + "scopes_supported": scopes, + "claims_supported": []string{"sub", "preferred_username", "name", "email", "email_verified"}, + "token_endpoint_auth_methods_supported": []string{string(oauth2.AuthMethodNone), string(oauth2.AuthMethodClientSecretBasic)}, + "introspection_endpoint": issuer + "/introspect", + "revocation_endpoint": issuer + "/revoke", + "authorization_response_iss_parameter_supported": true, + "claims_parameter_supported": false, + "request_parameter_supported": false, + "request_uri_parameter_supported": false, + "code_challenge_methods_supported": []string{pkceMethodPlain, pkceMethodS256}, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(config) +} + +func getOIDCJWKS(w http.ResponseWriter, req *http.Request) { + ctx := req.Context() + oidc := oidcProviderFromContext(ctx) + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Cache-Control", "public, max-age=300") + json.NewEncoder(w).Encode(oidc.JWKS()) +} + +func userInfo(w http.ResponseWriter, req *http.Request) { + ctx := req.Context() + db := dbFromContext(ctx) + + if req.Method != http.MethodGet && req.Method != http.MethodPost { + w.Header().Set("Allow", "GET, POST") + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + return + } + + tokenValue, err := bearerTokenFromRequest(req) + if err != nil { + writeBearerError(w, http.StatusUnauthorized, "invalid_token", err.Error()) + return + } + + tokenID, secret, err := UnmarshalSecret[*AccessToken](tokenValue) + if err != nil { + writeBearerError(w, http.StatusUnauthorized, "invalid_token", "Malformed access token") + return + } + + token, err := db.FetchAccessToken(ctx, tokenID) + if err == errNoDBRows || (err == nil && !token.VerifySecret(secret)) { + writeBearerError(w, http.StatusUnauthorized, "invalid_token", "Invalid access token") + return + } else if err != nil { + httpError(w, fmt.Errorf("failed to fetch access token: %v", err)) + return + } + + scopes := parseScopes(token.Scope) + if !containsScope(scopes, scopeOpenID) { + writeBearerError(w, http.StatusForbidden, "insufficient_scope", "Scope openid missing") + return + } + + user, err := db.FetchUser(ctx, token.User) + if err != nil { + httpError(w, fmt.Errorf("failed to fetch user: %v", err)) + return + } + + resp := map[string]interface{}{ + "sub": subjectForUser(user), + } + if containsScope(scopes, scopeProfile) { + displayName := user.Name + if displayName == "" { + displayName = user.Username + } + resp["preferred_username"] = user.Username + resp["name"] = displayName + } + if containsScope(scopes, scopeEmail) && user.Email != "" { + resp["email"] = user.Email + resp["email_verified"] = false + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) +} + +func bearerTokenFromRequest(req *http.Request) (string, error) { + authz := req.Header.Get("Authorization") + if authz == "" { + return "", fmt.Errorf("Authorization header missing") + } + if len(authz) < 7 || !strings.EqualFold(authz[:7], "Bearer ") { + return "", fmt.Errorf("Unsupported authorization scheme") + } + token := strings.TrimSpace(authz[7:]) + if token == "" { + return "", fmt.Errorf("Missing access token") + } + return token, nil +} + +func writeBearerError(w http.ResponseWriter, status int, code, description string) { + challenge := "Bearer" + if code != "" { + challenge += fmt.Sprintf(" error=\"%s\"", code) + } + if description != "" { + if code == "" { + challenge += " " + } else { + challenge += ", " + } + challenge += fmt.Sprintf("error_description=\"%s\"", description) + } + w.Header().Set("WWW-Authenticate", challenge) + http.Error(w, http.StatusText(status), status) +} diff --git a/pkce.go b/pkce.go new file mode 100644 index 0000000000000000000000000000000000000000..a04486f8df385de65ce5f6dd4043f2f4821dabff --- /dev/null +++ b/pkce.go @@ -0,0 +1,40 @@ +package main + +import ( + "fmt" + "strings" +) + +const ( + pkceRequirementNone = "" + pkceRequirementPlain = pkceMethodPlain + pkceRequirementS256 = pkceMethodS256 +) + +func normalizeClientPKCERequirement(value string) (string, error) { + switch strings.ToUpper(strings.TrimSpace(value)) { + case "", "NONE": + return pkceRequirementNone, nil + case strings.ToUpper(pkceRequirementPlain): + return pkceRequirementPlain, nil + case pkceRequirementS256: + return pkceRequirementS256, nil + default: + return "", fmt.Errorf("invalid PKCE requirement") + } +} + +func allowPKCERequirement(method, requirement string) bool { + requirement = strings.ToUpper(requirement) + method = strings.ToUpper(method) + switch requirement { + case "", "NONE": + return true + case strings.ToUpper(pkceRequirementPlain): + return method == strings.ToUpper(pkceMethodPlain) || method == strings.ToUpper(pkceMethodS256) + case pkceRequirementS256: + return method == strings.ToUpper(pkceMethodS256) + default: + return false + } +} diff --git a/schema.sql b/schema.sql index 9ebc93035d8c818bf60d07295bffbd50931d82d7..92307df080034579691633de40ec67a82e0fe0c0 100644 --- a/schema.sql +++ b/schema.sql @@ -1,6 +1,8 @@ CREATE TABLE User ( id INTEGER PRIMARY KEY, username TEXT NOT NULL UNIQUE, + name TEXT, + email TEXT, password_hash TEXT, admin INTEGER NOT NULL DEFAULT 0 ); @@ -12,7 +14,8 @@ client_secret_hash BLOB, owner INTEGER REFERENCES User(id) ON DELETE CASCADE, redirect_uris TEXT, client_name TEXT, - client_uri TEXT + client_uri TEXT, + pkce_requirement TEXT ); CREATE TABLE AccessToken ( @@ -23,6 +26,7 @@ client INTEGER REFERENCES Client(id) ON DELETE CASCADE, scope TEXT, issued_at datetime NOT NULL, expires_at datetime NOT NULL, + auth_time datetime, refresh_hash BLOB UNIQUE, refresh_expires_at datetime ); @@ -34,5 +38,18 @@ created_at datetime NOT NULL, user INTEGER NOT NULL REFERENCES User(id) ON DELETE CASCADE, client INTEGER NOT NULL REFERENCES Client(id) ON DELETE CASCADE, redirect_uri TEXT, - scope TEXT + scope TEXT, + nonce TEXT, + code_challenge TEXT, + code_challenge_method TEXT +); + +CREATE TABLE SigningKey ( + id INTEGER PRIMARY KEY, + kid TEXT NOT NULL UNIQUE, + algorithm TEXT NOT NULL, + private_key BLOB NOT NULL, + created_at datetime NOT NULL ); + +CREATE INDEX signing_key_created_at ON SigningKey(created_at); diff --git a/static/style.css b/static/style.css index 209f408813aff10ec4295ee3b550263a64c897aa..b14a861e763fe813d68b232f299d57039edb62d0 100644 --- a/static/style.css +++ b/static/style.css @@ -59,14 +59,14 @@ background-color: rgb(0, 150, 0); border-color: rgb(0, 150, 0); } -input[type="text"], input[type="password"], input[type="url"], textarea { +input[type="text"], input[type="email"], input[type="password"], input[type="url"], textarea { border: 1px solid rgb(208, 210, 215); border-radius: 4px; padding: 6px; margin: 4px 0; color: #444; } -input[type="text"]:focus, input[type="password"]:focus, input[type="url"]:focus, textarea:focus { +input[type="email"]:focus, input[type="text"]:focus, input[type="password"]:focus, input[type="url"]:focus, textarea:focus { outline: none; border-color: rgb(0, 128, 0); } @@ -75,7 +75,7 @@ label { display: block; margin: 15px 0; } -label input[type="text"], label input[type="password"], label input[type="url"] { +label input[type="email"], label input[type="text"], label input[type="password"], label input[type="url"] { display: block; width: 100%; max-width: 350px; @@ -118,7 +118,7 @@ button:hover { background-color: rgba(255, 255, 255, 0.02); } - input[type="text"], input[type="password"], input[type="url"], textarea { + input[type="email"], input[type="text"], input[type="password"], input[type="url"], textarea { background-color: rgba(255, 255, 255, 0.05); color: inherit; } diff --git a/template/authorize.html b/template/authorize.html index 3146dc9877f1e2e44a5b2e59e1fa8458d5cc33f0..a8ebfe916591bb4c5a6fbba0bae7031ed74fafce 100644 --- a/template/authorize.html +++ b/template/authorize.html @@ -21,6 +21,7 @@ ?

+
diff --git a/template/head.html b/template/head.html index a36f23450c00bee71bba1c80d9883e155ef8025f..702cff01cd395108fe807766facca2b14714d700 100644 --- a/template/head.html +++ b/template/head.html @@ -1,9 +1,9 @@ - + {{ .ServerName }} - - + + diff --git a/template/index.html b/template/index.html index ddd00bbec5810994b5910c58f2bf17a9e4948d63..48dd2a577615fc1d82c3a48341ce4e9f02427f54 100644 --- a/template/index.html +++ b/template/index.html @@ -6,6 +6,7 @@

Welcome, {{ .Me.Username }}!

+
@@ -38,9 +39,10 @@ {{ end }} {{ .ExpiresAt }} -
- -
+
+ + +
{{ end }} @@ -82,11 +84,15 @@ + + {{ range .Users }} + +
UsernameNameEmail Role
{{ .Username }}{{ .Name }}{{ .Email }} {{ if .Admin }} Administrator diff --git a/template/login.html b/template/login.html index babcf43987c51756d117469441d3b5e712e5e911..2ccc52ecf2cd6206b7b7c6ac8ff9aa06912a9d5a 100644 --- a/template/login.html +++ b/template/login.html @@ -5,6 +5,7 @@

{{ .ServerName }}

+