From 476a35955f666ccb1d01fdacb0a52b80115bf214 Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Mon, 26 Feb 2024 13:21:50 +0100 Subject: [PATCH] Add support for refresh tokens Closes: https://todo.sr.ht/~emersion/sinwon/21 --- db.go | 26 +++++++++++++++++++++----- entity.go | 108 +++++++++++++++++++++++++++++++++-------------------- middleware.go | 2 +- oauth2.go | 166 +++++++++++++++++++++++++++++++++++------------------ schema.sql | 2 ++ user.go | 2 +- diff --git a/db.go b/db.go index f1dbab190c309899e6d590b034659e2fda2f8369..cf405ab7dc70c155036c51b41847655c74b5c73d 100644 --- a/db.go +++ b/db.go @@ -15,6 +15,11 @@ var schema string var migrations = []string{ "", // migration #0 is reserved for schema initialization + ` + ALTER TABLE AccessToken ADD COLUMN refresh_hash BLOB; + ALTER TABLE AccessToken ADD COLUMN refresh_expires_at datetime; + CREATE UNIQUE INDEX access_token_refresh_hash ON AccessToken(refresh_hash); + `, } var errNoDBRows = sql.ErrNoRows @@ -210,7 +215,7 @@ rows, err := db.db.QueryContext(ctx, ` SELECT id, client_id, client_name, client_uri, token.expires_at FROM Client, ( - SELECT client, MAX(expires_at) as expires_at + SELECT client, MAX(COALESCE(refresh_expires_at, expires_at)) as expires_at FROM AccessToken WHERE user = ? GROUP BY client @@ -255,10 +260,21 @@ err = scanRow(&token, rows) return &token, err } -func (db *DB) CreateAccessToken(ctx context.Context, token *AccessToken) error { +func (db *DB) StoreAccessToken(ctx context.Context, token *AccessToken) error { return db.db.QueryRowContext(ctx, ` - INSERT INTO AccessToken(hash, user, client, scope, issued_at, expires_at) - VALUES (:hash, :user, :client, :scope, :issued_at, :expires_at) + INSERT INTO AccessToken(id, hash, user, client, scope, issued_at, + expires_at, refresh_hash, refresh_expires_at) + VALUES (:id, :hash, :user, :client, :scope, :issued_at, :expires_at, + :refresh_hash, :refresh_expires_at) + ON CONFLICT(id) DO UPDATE SET + hash = :hash, + user = :user, + client = :client, + scope = :scope, + issued_at = :issued_at, + expires_at = :expires_at, + refresh_hash = :refresh_hash, + refresh_expires_at = :refresh_expires_at RETURNING id `, entityArgs(token)...).Scan(&token.ID) } @@ -301,7 +317,7 @@ func (db *DB) Maintain(ctx context.Context) error { _, err := db.db.ExecContext(ctx, ` DELETE FROM AccessToken - WHERE timediff('now', expires_at) > 0 + WHERE timediff('now', COALESCE(refresh_expires_at, expires_at)) > 0 `) if err != nil { return err diff --git a/entity.go b/entity.go index 27ba61d7d53092d631c6d0c1b561f62fadcc2674..eceb92320ca9d2b89d4e4d28a0d5beb69072f4b0 100644 --- a/entity.go +++ b/entity.go @@ -8,6 +8,7 @@ "database/sql" "database/sql/driver" "encoding/base64" "fmt" + "reflect" "strconv" "strings" "time" @@ -16,8 +17,9 @@ "golang.org/x/crypto/bcrypt" ) const ( - accessTokenExpiration = 30 * 24 * time.Hour - authCodeExpiration = 10 * time.Minute + accessTokenExpiration = 30 * 24 * time.Hour + refreshTokenExpiration = 2 * accessTokenExpiration + authCodeExpiration = 10 * time.Minute ) type entity interface { @@ -67,32 +69,37 @@ return int64(id), nil } } -type nullString string +type nullValue struct { + ptr interface{} +} var ( - _ sql.Scanner = (*nullString)(nil) - _ driver.Valuer = (*nullString)(nil) + _ sql.Scanner = nullValue{nil} + _ driver.Valuer = nullValue{nil} ) -func (ptr *nullString) Scan(v interface{}) error { +func (nv nullValue) Scan(v interface{}) error { + out := reflect.ValueOf(nv.ptr).Elem() if v == nil { - *ptr = "" + out.SetZero() return nil } - s, ok := v.(string) - if !ok { - return fmt.Errorf("cannot scan nullStringPtr from %T", v) + + rv := reflect.ValueOf(v) + if rv.Type() != out.Type() { + return fmt.Errorf("cannot scan %v into %v", rv.Type(), out.Type()) } - *ptr = nullString(s) + + out.Set(rv) return nil } -func (ptr *nullString) Value() (driver.Value, error) { - if *ptr == "" { +func (nv nullValue) Value() (driver.Value, error) { + in := reflect.ValueOf(nv.ptr).Elem() + if in.IsZero() { return nil, nil - } else { - return string(*ptr), nil } + return in.Interface(), nil } type User struct { @@ -106,7 +113,7 @@ func (user *User) columns() map[string]interface{} { return map[string]interface{}{ "id": &user.ID, "username": &user.Username, - "password_hash": (*nullString)(&user.PasswordHash), + "password_hash": nullValue{&user.PasswordHash}, "admin": &user.Admin, } } @@ -164,9 +171,9 @@ "id": &client.ID, "client_id": &client.ClientID, "client_secret_hash": &client.ClientSecretHash, "owner": &client.Owner, - "redirect_uris": (*nullString)(&client.RedirectURIs), - "client_name": (*nullString)(&client.ClientName), - "client_uri": (*nullString)(&client.ClientURI), + "redirect_uris": nullValue{&client.RedirectURIs}, + "client_name": nullValue{&client.ClientName}, + "client_uri": nullValue{&client.ClientURI}, } } @@ -186,6 +193,9 @@ Client ID[*Client] Scope string IssuedAt time.Time ExpiresAt time.Time + + RefreshHash []byte + RefreshExpiresAt time.Time } func (token *AccessToken) Generate(expiration time.Duration) (secret string, err error) { @@ -199,25 +209,35 @@ token.ExpiresAt = time.Now().Add(expiration) return secret, nil } -func NewAccessTokenFromAuthCode(authCode *AuthCode) (token *AccessToken, secret string, err error) { - token = &AccessToken{ +func (token *AccessToken) GenerateRefresh() (secret string, err error) { + secret, hash, err := generateSecret() + if err != nil { + return "", fmt.Errorf("failed to generate refresh token secret: %v", err) + } + token.RefreshHash = hash + token.RefreshExpiresAt = time.Now().Add(refreshTokenExpiration) + return secret, nil +} + +func NewAccessTokenFromAuthCode(authCode *AuthCode) *AccessToken { + return &AccessToken{ User: authCode.User, Client: authCode.Client, Scope: authCode.Scope, } - secret, err = token.Generate(accessTokenExpiration) - return token, secret, err } func (token *AccessToken) columns() map[string]interface{} { return map[string]interface{}{ - "id": &token.ID, - "hash": &token.Hash, - "user": &token.User, - "client": &token.Client, - "scope": (*nullString)(&token.Scope), - "issued_at": &token.IssuedAt, - "expires_at": &token.ExpiresAt, + "id": &token.ID, + "hash": &token.Hash, + "user": &token.User, + "client": &token.Client, + "scope": nullValue{&token.Scope}, + "issued_at": &token.IssuedAt, + "expires_at": &token.ExpiresAt, + "refresh_hash": &token.RefreshHash, + "refresh_expires_at": nullValue{&token.RefreshExpiresAt}, } } @@ -225,6 +245,10 @@ func (token *AccessToken) VerifySecret(secret string) bool { return verifyHash(token.Hash, secret) && verifyExpiration(token.ExpiresAt) } +func (token *AccessToken) VerifyRefreshSecret(secret string) bool { + return verifyHash(token.RefreshHash, secret) && verifyExpiration(token.RefreshExpiresAt) +} + type AuthorizedClient struct { Client Client ExpiresAt time.Time @@ -257,8 +281,8 @@ "hash": &code.Hash, "created_at": &code.CreatedAt, "user": &code.User, "client": &code.Client, - "scope": (*nullString)(&code.Scope), - "redirect_uri": (*nullString)(&code.RedirectURI), + "scope": nullValue{&code.Scope}, + "redirect_uri": nullValue{&code.RedirectURI}, } } @@ -269,8 +293,9 @@ type SecretKind byte const ( - SecretKindAccessToken = SecretKind('a') - SecretKindAuthCode = SecretKind('c') + SecretKindAccessToken = SecretKind('a') + SecretKindRefreshToken = SecretKind('r') + SecretKindAuthCode = SecretKind('c') ) func UnmarshalSecret[T entity](s string) (id ID[T], secret string, err error) { @@ -281,7 +306,7 @@ return 0, "", fmt.Errorf("malformed secret") } switch SecretKind(kind[0]) { - case SecretKindAccessToken: + case SecretKindAccessToken, SecretKindRefreshToken: _, ok = interface{}(id).(ID[*AccessToken]) case SecretKindAuthCode: _, ok = interface{}(id).(ID[*AuthCode]) @@ -294,19 +319,20 @@ id, err = ParseID[T](idStr) return id, secret, err } -func MarshalSecret[T entity](id ID[T], secret string) string { +func MarshalSecret[T entity](id ID[T], kind SecretKind, secret string) string { if id == 0 { panic("cannot marshal zero ID") } - var kind SecretKind + var ok bool switch interface{}(id).(type) { case ID[*AccessToken]: - kind = SecretKindAccessToken + ok = kind == SecretKindAccessToken || kind == SecretKindRefreshToken case ID[*AuthCode]: - kind = SecretKindAuthCode - default: - panic(fmt.Sprintf("unsupported secret kind for ID type %T", id)) + ok = kind == SecretKindAuthCode + } + if !ok { + panic(fmt.Sprintf("unsupported secret kind %q for ID type %T", string(kind), id)) } return fmt.Sprintf("%v.%v.%v", string(kind), int64(id), secret) diff --git a/middleware.go b/middleware.go index 09810b6e71ec1878a8173a68ba709cfd18f2f3f6..50cc2abbb1b5fb15c10d3ffac55a83a40f922b34 100644 --- a/middleware.go +++ b/middleware.go @@ -49,7 +49,7 @@ func setLoginTokenCookie(w http.ResponseWriter, req *http.Request, token *AccessToken, secret string) { http.SetCookie(w, &http.Cookie{ Name: loginCookieName, - Value: MarshalSecret(token.ID, secret), + Value: MarshalSecret(token.ID, SecretKindAccessToken, secret), HttpOnly: true, SameSite: http.SameSiteStrictMode, Secure: isForwardedHTTPS(req), diff --git a/oauth2.go b/oauth2.go index 20e8d25484f675bb1924834394884c225de308d7..058eaf4c3bc11b0042f855f6f3d77fe7f7506e02 100644 --- a/oauth2.go +++ b/oauth2.go @@ -174,7 +174,7 @@ httpError(w, fmt.Errorf("failed to create authentication code: %v", err)) return } - code := MarshalSecret(authCode.ID, secret) + code := MarshalSecret(authCode.ID, SecretKindAuthCode, secret) values := make(url.Values) values.Set("code", code) @@ -200,16 +200,9 @@ clientID := values.Get("client_id") grantType := oauth2.GrantType(values.Get("grant_type")) scope := values.Get("scope") - redirectURI := values.Get("redirect_uri") authClientID, clientSecret, _ := req.BasicAuth() - if clientID == "" && authClientID == "" { - oauthError(w, &oauth2.Error{ - Code: oauth2.ErrorCodeInvalidRequest, - Description: "Missing client ID", - }) - return - } else if clientID == "" { + if clientID == "" { clientID = authClientID } else if clientID != authClientID { oauthError(w, &oauth2.Error{ @@ -219,20 +212,21 @@ }) return } - client, err := db.FetchClientByClientID(ctx, clientID) - if err == errNoDBRows { - oauthError(w, &oauth2.Error{ - Code: oauth2.ErrorCodeInvalidClient, - Description: "Invalid client ID", - }) - return - } else if err != nil { - oauthError(w, fmt.Errorf("failed to fetch client: %v", err)) - return - } + var client *Client + if clientID != "" { + client, err = db.FetchClientByClientID(ctx, clientID) + if err == errNoDBRows { + oauthError(w, &oauth2.Error{ + Code: oauth2.ErrorCodeInvalidClient, + Description: "Invalid client ID", + }) + return + } else if err != nil { + oauthError(w, fmt.Errorf("failed to fetch client: %v", err)) + return + } - if !client.IsPublic() { - if !client.VerifySecret(clientSecret) { + if !client.IsPublic() && !client.VerifySecret(clientSecret) { oauthError(w, &oauth2.Error{ Code: oauth2.ErrorCodeAccessDenied, Description: "Invalid client secret", @@ -241,49 +235,108 @@ return } } - if grantType != oauth2.GrantTypeAuthorizationCode { + var token *AccessToken + switch grantType { + case oauth2.GrantTypeAuthorizationCode: + if client == nil { + oauthError(w, &oauth2.Error{ + Code: oauth2.ErrorCodeInvalidRequest, + Description: "Missing client ID", + }) + return + } + + codeID, codeSecret, _ := UnmarshalSecret[*AuthCode](values.Get("code")) + authCode, err := db.PopAuthCode(ctx, codeID) + if err == errNoDBRows || (err == nil && !authCode.VerifySecret(codeSecret)) || authCode.Client != client.ID { + oauthError(w, &oauth2.Error{ + Code: oauth2.ErrorCodeAccessDenied, + Description: "Invalid authorization code", + }) + return + } else if err != nil { + oauthError(w, fmt.Errorf("failed to fetch authorization code: %v", err)) + return + } + + if scope != authCode.Scope { + oauthError(w, &oauth2.Error{ + Code: oauth2.ErrorCodeAccessDenied, + Description: "Invalid scope", + }) + return + } + if values.Get("redirect_uri") != authCode.RedirectURI { + oauthError(w, &oauth2.Error{ + Code: oauth2.ErrorCodeAccessDenied, + Description: "Invalid redirect URI", + }) + return + } + + token = NewAccessTokenFromAuthCode(authCode) + case oauth2.GrantTypeRefreshToken: + tokenID, refreshSecret, _ := UnmarshalSecret[*AccessToken](values.Get("refresh_token")) + token, err = db.FetchAccessToken(ctx, tokenID) + if err == errNoDBRows || (err == nil && !token.VerifyRefreshSecret(refreshSecret)) { + oauthError(w, &oauth2.Error{ + Code: oauth2.ErrorCodeAccessDenied, + Description: "Invalid refresh token", + }) + return + } else if err != nil { + oauthError(w, fmt.Errorf("failed to fetch access token: %v", err)) + return + } + + if client != nil && client.ID != token.Client { + oauthError(w, &oauth2.Error{ + Code: oauth2.ErrorCodeAccessDenied, + Description: "Invalid refresh token", + }) + return + } + + tokenClient, err := db.FetchClient(ctx, token.Client) + if err != nil { + oauthError(w, fmt.Errorf("failed to fetch client: %v", err)) + return + } + + if !tokenClient.IsPublic() && client == nil { + oauthError(w, &oauth2.Error{ + Code: oauth2.ErrorCodeAccessDenied, + Description: "Invalid client secret", + }) + return + } + + if scope != token.Scope { + oauthError(w, &oauth2.Error{ + Code: oauth2.ErrorCodeAccessDenied, + Description: "Invalid scope", + }) + return + } + default: oauthError(w, &oauth2.Error{ Code: oauth2.ErrorCodeUnsupportedGrantType, Description: "Unsupported grant type", }) - return } - codeID, codeSecret, _ := UnmarshalSecret[*AuthCode](values.Get("code")) - authCode, err := db.PopAuthCode(ctx, codeID) - if err == errNoDBRows || (err == nil && !authCode.VerifySecret(codeSecret)) || authCode.Client != client.ID { - oauthError(w, &oauth2.Error{ - Code: oauth2.ErrorCodeAccessDenied, - Description: "Invalid authorization code", - }) - return - } else if err != nil { - oauthError(w, fmt.Errorf("failed to fetch authorization code: %v", err)) + secret, err := token.Generate(accessTokenExpiration) + if err != nil { + oauthError(w, err) return } - - if scope != authCode.Scope { - oauthError(w, &oauth2.Error{ - Code: oauth2.ErrorCodeAccessDenied, - Description: "Invalid scope", - }) - return - } - if redirectURI != authCode.RedirectURI { - oauthError(w, &oauth2.Error{ - Code: oauth2.ErrorCodeAccessDenied, - Description: "Invalid redirect URI", - }) - return - } - - token, secret, err := NewAccessTokenFromAuthCode(authCode) + refreshSecret, err := token.GenerateRefresh() if err != nil { oauthError(w, err) return } - if err := db.CreateAccessToken(ctx, token); err != nil { + if err := db.StoreAccessToken(ctx, token); err != nil { oauthError(w, fmt.Errorf("failed to create access token: %v", err)) return } @@ -291,10 +344,11 @@ w.Header().Set("Content-Type", "application/json") w.Header().Set("Cache-Control", "no-store") json.NewEncoder(w).Encode(&oauth2.TokenResp{ - AccessToken: MarshalSecret(token.ID, secret), - TokenType: oauth2.TokenTypeBearer, - ExpiresIn: time.Until(token.ExpiresAt), - Scope: strings.Split(token.Scope, " "), + AccessToken: MarshalSecret(token.ID, SecretKindAccessToken, secret), + TokenType: oauth2.TokenTypeBearer, + ExpiresIn: time.Until(token.ExpiresAt), + Scope: strings.Split(token.Scope, " "), + RefreshToken: MarshalSecret(token.ID, SecretKindRefreshToken, refreshSecret), }) } diff --git a/schema.sql b/schema.sql index 86a892f57152029daf1ab4dfde6ae56971ba85fe..ee9302d709d4d11384157284779f90b51fae4dd1 100644 --- a/schema.sql +++ b/schema.sql @@ -24,6 +24,8 @@ client INTEGER, scope TEXT, issued_at datetime NOT NULL, expires_at datetime NOT NULL, + refresh_hash BLOB UNIQUE, + refresh_expires_at datetime, FOREIGN KEY(user) REFERENCES User(id), FOREIGN KEY(client) REFERENCES Client(id) ); diff --git a/user.go b/user.go index e91e64e58363ce84356717d8b18eb6b40b0162e5..d74b76d6058798bc44bcfd172adfe10bedff0e7d 100644 --- a/user.go +++ b/user.go @@ -127,7 +127,7 @@ if err != nil { httpError(w, fmt.Errorf("failed to generate access token: %v", err)) return } - if err := db.CreateAccessToken(ctx, &token); err != nil { + if err := db.StoreAccessToken(ctx, &token); err != nil { httpError(w, fmt.Errorf("failed to create access token: %v", err)) return } -- 2.48.1