Warning: Due to various recent migrations, viewing non-HEAD refs may be broken.
/entity.go (raw)
package main
import (
"crypto/rand"
"crypto/sha512"
"crypto/subtle"
"database/sql"
"database/sql/driver"
"encoding/base64"
"fmt"
"reflect"
"strconv"
"strings"
"time"
"golang.org/x/crypto/bcrypt"
)
const (
accessTokenExpiration = 30 * 24 * time.Hour
refreshTokenExpiration = 2 * accessTokenExpiration
authCodeExpiration = 10 * time.Minute
)
type entity interface {
columns() map[string]interface{}
}
var (
_ entity = (*User)(nil)
_ entity = (*Client)(nil)
_ entity = (*AccessToken)(nil)
_ entity = (*AuthCode)(nil)
_ entity = (*SigningKey)(nil)
)
type ID[T entity] int64
var (
_ sql.Scanner = (*ID[*User])(nil)
_ driver.Valuer = ID[*User](0)
)
func ParseID[T entity](s string) (ID[T], error) {
u, _ := strconv.ParseUint(s, 10, 63)
if u == 0 {
return 0, fmt.Errorf("invalid ID")
}
return ID[T](u), nil
}
func (ptr *ID[T]) Scan(v interface{}) error {
if v == nil {
*ptr = 0
return nil
}
id, ok := v.(int64)
if !ok {
return fmt.Errorf("cannot scan ID from %T", v)
}
*ptr = ID[T](id)
return nil
}
func (id ID[T]) Value() (driver.Value, error) {
if id == 0 {
return nil, nil
} else {
return int64(id), nil
}
}
type nullValue struct {
ptr interface{}
}
var (
_ sql.Scanner = nullValue{nil}
_ driver.Valuer = nullValue{nil}
)
func (nv nullValue) Scan(v interface{}) error {
out := reflect.ValueOf(nv.ptr).Elem()
if v == nil {
out.SetZero()
return nil
}
rv := reflect.ValueOf(v)
if rv.Type() != out.Type() {
return fmt.Errorf("cannot scan %v into %v", rv.Type(), out.Type())
}
out.Set(rv)
return nil
}
func (nv nullValue) Value() (driver.Value, error) {
in := reflect.ValueOf(nv.ptr).Elem()
if in.IsZero() {
return nil, nil
}
return in.Interface(), nil
}
type User struct {
ID ID[*User]
Username string
Name string
Email string
PasswordHash string
Admin bool
}
func (user *User) columns() map[string]interface{} {
return map[string]interface{}{
"id": &user.ID,
"username": &user.Username,
"name": nullValue{&user.Name},
"email": nullValue{&user.Email},
"password_hash": nullValue{&user.PasswordHash},
"admin": &user.Admin,
}
}
func (user *User) VerifyPassword(password string) error {
return bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password))
}
func (user *User) SetPassword(password string) error {
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return err
}
user.PasswordHash = string(hash)
return nil
}
func (user *User) PasswordNeedsRehash() bool {
cost, _ := bcrypt.Cost([]byte(user.PasswordHash))
return cost != bcrypt.DefaultCost
}
type Client struct {
ID ID[*Client]
ClientID string
ClientSecretHash []byte
Owner ID[*User]
RedirectURIs string
ClientName string
ClientURI string
PKCERequirement string
}
func (client *Client) Generate(isPublic bool) (secret string, err error) {
id, err := generateUID()
if err != nil {
return "", fmt.Errorf("failed to generate client ID: %v", err)
}
client.ClientID = id
if !isPublic {
var hash []byte
secret, hash, err = generateSecret()
if err != nil {
return "", fmt.Errorf("failed to generate client secret: %v", err)
}
client.ClientSecretHash = hash
}
return secret, nil
}
func (client *Client) columns() map[string]interface{} {
return map[string]interface{}{
"id": &client.ID,
"client_id": &client.ClientID,
"client_secret_hash": &client.ClientSecretHash,
"owner": &client.Owner,
"redirect_uris": nullValue{&client.RedirectURIs},
"client_name": nullValue{&client.ClientName},
"client_uri": nullValue{&client.ClientURI},
"pkce_requirement": nullValue{&client.PKCERequirement},
}
}
func (client *Client) VerifySecret(secret string) bool {
return verifyHash(client.ClientSecretHash, secret)
}
func (client *Client) IsPublic() bool {
return client.ClientSecretHash == nil
}
type AccessToken struct {
ID ID[*AccessToken]
Hash []byte
User ID[*User]
Client ID[*Client]
Scope string
IssuedAt time.Time
ExpiresAt time.Time
AuthTime time.Time
RefreshHash []byte
RefreshExpiresAt time.Time
}
func (token *AccessToken) Generate(expiration time.Duration) (secret string, err error) {
secret, hash, err := generateSecret()
if err != nil {
return "", fmt.Errorf("failed to generate access token secret: %v", err)
}
token.Hash = hash
token.IssuedAt = time.Now()
token.ExpiresAt = time.Now().Add(expiration)
return secret, nil
}
func (token *AccessToken) GenerateRefresh() (secret string, err error) {
secret, hash, err := generateSecret()
if err != nil {
return "", fmt.Errorf("failed to generate refresh token secret: %v", err)
}
token.RefreshHash = hash
token.RefreshExpiresAt = time.Now().Add(refreshTokenExpiration)
return secret, nil
}
func NewAccessTokenFromAuthCode(authCode *AuthCode) *AccessToken {
return &AccessToken{
User: authCode.User,
Client: authCode.Client,
Scope: authCode.Scope,
AuthTime: authCode.CreatedAt,
}
}
func (token *AccessToken) columns() map[string]interface{} {
return map[string]interface{}{
"id": &token.ID,
"hash": &token.Hash,
"user": &token.User,
"client": &token.Client,
"scope": nullValue{&token.Scope},
"issued_at": &token.IssuedAt,
"expires_at": &token.ExpiresAt,
"auth_time": nullValue{&token.AuthTime},
"refresh_hash": &token.RefreshHash,
"refresh_expires_at": nullValue{&token.RefreshExpiresAt},
}
}
func (token *AccessToken) VerifySecret(secret string) bool {
return verifyHash(token.Hash, secret) && verifyExpiration(token.ExpiresAt)
}
func (token *AccessToken) VerifyRefreshSecret(secret string) bool {
return verifyHash(token.RefreshHash, secret) && verifyExpiration(token.RefreshExpiresAt)
}
type AuthorizedClient struct {
Client Client
ExpiresAt time.Time
}
type AuthCode struct {
ID ID[*AuthCode]
Hash []byte
CreatedAt time.Time
User ID[*User]
Client ID[*Client]
Scope string
RedirectURI string
Nonce string
CodeChallenge string
CodeChallengeMethod string
}
func (code *AuthCode) Generate() (secret string, err error) {
secret, hash, err := generateSecret()
if err != nil {
return "", fmt.Errorf("failed to generate authentication code secret: %v", err)
}
code.Hash = hash
code.CreatedAt = time.Now()
return secret, nil
}
func (code *AuthCode) columns() map[string]interface{} {
return map[string]interface{}{
"id": &code.ID,
"hash": &code.Hash,
"created_at": &code.CreatedAt,
"user": &code.User,
"client": &code.Client,
"scope": nullValue{&code.Scope},
"redirect_uri": nullValue{&code.RedirectURI},
"nonce": nullValue{&code.Nonce},
"code_challenge": nullValue{&code.CodeChallenge},
"code_challenge_method": nullValue{&code.CodeChallengeMethod},
}
}
func (code *AuthCode) VerifySecret(secret string) bool {
return verifyHash(code.Hash, secret) && verifyExpiration(code.CreatedAt.Add(authCodeExpiration))
}
type SecretKind byte
const (
SecretKindAccessToken = SecretKind('a')
SecretKindRefreshToken = SecretKind('r')
SecretKindAuthCode = SecretKind('c')
)
type SigningKey struct {
ID ID[*SigningKey]
KID string
Algorithm string
PrivateKey []byte
CreatedAt time.Time
}
func (key *SigningKey) columns() map[string]interface{} {
return map[string]interface{}{
"id": &key.ID,
"kid": &key.KID,
"algorithm": &key.Algorithm,
"private_key": &key.PrivateKey,
"created_at": &key.CreatedAt,
}
}
func UnmarshalSecret[T entity](s string) (id ID[T], secret string, err error) {
kind, s, _ := strings.Cut(s, ".")
idStr, secret, ok := strings.Cut(s, ".")
if !ok || len(kind) != 1 {
return 0, "", fmt.Errorf("malformed secret")
}
switch SecretKind(kind[0]) {
case SecretKindAccessToken, SecretKindRefreshToken:
_, ok = interface{}(id).(ID[*AccessToken])
case SecretKindAuthCode:
_, ok = interface{}(id).(ID[*AuthCode])
}
if !ok {
return 0, "", fmt.Errorf("invalid secret kind %q", kind)
}
id, err = ParseID[T](idStr)
return id, secret, err
}
func MarshalSecret[T entity](id ID[T], kind SecretKind, secret string) string {
if id == 0 {
panic("cannot marshal zero ID")
}
var ok bool
switch interface{}(id).(type) {
case ID[*AccessToken]:
ok = kind == SecretKindAccessToken || kind == SecretKindRefreshToken
case ID[*AuthCode]:
ok = kind == SecretKindAuthCode
}
if !ok {
panic(fmt.Sprintf("unsupported secret kind %q for ID type %T", string(kind), id))
}
return fmt.Sprintf("%v.%v.%v", string(kind), int64(id), secret)
}
func generateUID() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
func generateSecret() (secret string, hash []byte, err error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", nil, err
}
secret = base64.RawURLEncoding.EncodeToString(b)
h := sha512.Sum512(b)
return secret, h[:], nil
}
func verifyHash(hash []byte, secret string) bool {
b, _ := base64.RawURLEncoding.DecodeString(secret)
h := sha512.Sum512(b)
return subtle.ConstantTimeCompare(hash, h[:]) == 1
}
func verifyExpiration(t time.Time) bool {
return time.Now().Before(t)
}