Warning: Due to various recent migrations, viewing non-HEAD refs may be broken.
/entity.go (raw)
package main
import (
	"crypto/rand"
	"crypto/sha512"
	"crypto/subtle"
	"database/sql"
	"database/sql/driver"
	"encoding/base64"
	"fmt"
	"reflect"
	"strconv"
	"strings"
	"time"
	"golang.org/x/crypto/bcrypt"
)
const (
	accessTokenExpiration  = 30 * 24 * time.Hour
	refreshTokenExpiration = 2 * accessTokenExpiration
	authCodeExpiration     = 10 * time.Minute
)
type entity interface {
	columns() map[string]interface{}
}
var (
	_ entity = (*User)(nil)
	_ entity = (*Client)(nil)
	_ entity = (*AccessToken)(nil)
	_ entity = (*AuthCode)(nil)
	_ entity = (*SigningKey)(nil)
)
type ID[T entity] int64
var (
	_ sql.Scanner   = (*ID[*User])(nil)
	_ driver.Valuer = ID[*User](0)
)
func ParseID[T entity](s string) (ID[T], error) {
	u, _ := strconv.ParseUint(s, 10, 63)
	if u == 0 {
		return 0, fmt.Errorf("invalid ID")
	}
	return ID[T](u), nil
}
func (ptr *ID[T]) Scan(v interface{}) error {
	if v == nil {
		*ptr = 0
		return nil
	}
	id, ok := v.(int64)
	if !ok {
		return fmt.Errorf("cannot scan ID from %T", v)
	}
	*ptr = ID[T](id)
	return nil
}
func (id ID[T]) Value() (driver.Value, error) {
	if id == 0 {
		return nil, nil
	} else {
		return int64(id), nil
	}
}
type nullValue struct {
	ptr interface{}
}
var (
	_ sql.Scanner   = nullValue{nil}
	_ driver.Valuer = nullValue{nil}
)
func (nv nullValue) Scan(v interface{}) error {
	out := reflect.ValueOf(nv.ptr).Elem()
	if v == nil {
		out.SetZero()
		return nil
	}
	rv := reflect.ValueOf(v)
	if rv.Type() != out.Type() {
		return fmt.Errorf("cannot scan %v into %v", rv.Type(), out.Type())
	}
	out.Set(rv)
	return nil
}
func (nv nullValue) Value() (driver.Value, error) {
	in := reflect.ValueOf(nv.ptr).Elem()
	if in.IsZero() {
		return nil, nil
	}
	return in.Interface(), nil
}
type User struct {
	ID           ID[*User]
	Username     string
	Name         string
	Email        string
	PasswordHash string
	Admin        bool
}
func (user *User) columns() map[string]interface{} {
	return map[string]interface{}{
		"id":            &user.ID,
		"username":      &user.Username,
		"name":          nullValue{&user.Name},
		"email":         nullValue{&user.Email},
		"password_hash": nullValue{&user.PasswordHash},
		"admin":         &user.Admin,
	}
}
func (user *User) VerifyPassword(password string) error {
	return bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password))
}
func (user *User) SetPassword(password string) error {
	hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
	if err != nil {
		return err
	}
	user.PasswordHash = string(hash)
	return nil
}
func (user *User) PasswordNeedsRehash() bool {
	cost, _ := bcrypt.Cost([]byte(user.PasswordHash))
	return cost != bcrypt.DefaultCost
}
type Client struct {
	ID               ID[*Client]
	ClientID         string
	ClientSecretHash []byte
	Owner            ID[*User]
	RedirectURIs     string
	ClientName       string
	ClientURI        string
	PKCERequirement  string
}
func (client *Client) Generate(isPublic bool) (secret string, err error) {
	id, err := generateUID()
	if err != nil {
		return "", fmt.Errorf("failed to generate client ID: %v", err)
	}
	client.ClientID = id
	if !isPublic {
		var hash []byte
		secret, hash, err = generateSecret()
		if err != nil {
			return "", fmt.Errorf("failed to generate client secret: %v", err)
		}
		client.ClientSecretHash = hash
	}
	return secret, nil
}
func (client *Client) columns() map[string]interface{} {
	return map[string]interface{}{
		"id":                 &client.ID,
		"client_id":          &client.ClientID,
		"client_secret_hash": &client.ClientSecretHash,
		"owner":              &client.Owner,
		"redirect_uris":      nullValue{&client.RedirectURIs},
		"client_name":        nullValue{&client.ClientName},
		"client_uri":         nullValue{&client.ClientURI},
		"pkce_requirement":   nullValue{&client.PKCERequirement},
	}
}
func (client *Client) VerifySecret(secret string) bool {
	return verifyHash(client.ClientSecretHash, secret)
}
func (client *Client) IsPublic() bool {
	return client.ClientSecretHash == nil
}
type AccessToken struct {
	ID        ID[*AccessToken]
	Hash      []byte
	User      ID[*User]
	Client    ID[*Client]
	Scope     string
	IssuedAt  time.Time
	ExpiresAt time.Time
	AuthTime  time.Time
	RefreshHash      []byte
	RefreshExpiresAt time.Time
}
func (token *AccessToken) Generate(expiration time.Duration) (secret string, err error) {
	secret, hash, err := generateSecret()
	if err != nil {
		return "", fmt.Errorf("failed to generate access token secret: %v", err)
	}
	token.Hash = hash
	token.IssuedAt = time.Now()
	token.ExpiresAt = time.Now().Add(expiration)
	return secret, nil
}
func (token *AccessToken) GenerateRefresh() (secret string, err error) {
	secret, hash, err := generateSecret()
	if err != nil {
		return "", fmt.Errorf("failed to generate refresh token secret: %v", err)
	}
	token.RefreshHash = hash
	token.RefreshExpiresAt = time.Now().Add(refreshTokenExpiration)
	return secret, nil
}
func NewAccessTokenFromAuthCode(authCode *AuthCode) *AccessToken {
	return &AccessToken{
		User:     authCode.User,
		Client:   authCode.Client,
		Scope:    authCode.Scope,
		AuthTime: authCode.CreatedAt,
	}
}
func (token *AccessToken) columns() map[string]interface{} {
	return map[string]interface{}{
		"id":                 &token.ID,
		"hash":               &token.Hash,
		"user":               &token.User,
		"client":             &token.Client,
		"scope":              nullValue{&token.Scope},
		"issued_at":          &token.IssuedAt,
		"expires_at":         &token.ExpiresAt,
		"auth_time":          nullValue{&token.AuthTime},
		"refresh_hash":       &token.RefreshHash,
		"refresh_expires_at": nullValue{&token.RefreshExpiresAt},
	}
}
func (token *AccessToken) VerifySecret(secret string) bool {
	return verifyHash(token.Hash, secret) && verifyExpiration(token.ExpiresAt)
}
func (token *AccessToken) VerifyRefreshSecret(secret string) bool {
	return verifyHash(token.RefreshHash, secret) && verifyExpiration(token.RefreshExpiresAt)
}
type AuthorizedClient struct {
	Client    Client
	ExpiresAt time.Time
}
type AuthCode struct {
	ID                  ID[*AuthCode]
	Hash                []byte
	CreatedAt           time.Time
	User                ID[*User]
	Client              ID[*Client]
	Scope               string
	RedirectURI         string
	Nonce               string
	CodeChallenge       string
	CodeChallengeMethod string
}
func (code *AuthCode) Generate() (secret string, err error) {
	secret, hash, err := generateSecret()
	if err != nil {
		return "", fmt.Errorf("failed to generate authentication code secret: %v", err)
	}
	code.Hash = hash
	code.CreatedAt = time.Now()
	return secret, nil
}
func (code *AuthCode) columns() map[string]interface{} {
	return map[string]interface{}{
		"id":                    &code.ID,
		"hash":                  &code.Hash,
		"created_at":            &code.CreatedAt,
		"user":                  &code.User,
		"client":                &code.Client,
		"scope":                 nullValue{&code.Scope},
		"redirect_uri":          nullValue{&code.RedirectURI},
		"nonce":                 nullValue{&code.Nonce},
		"code_challenge":        nullValue{&code.CodeChallenge},
		"code_challenge_method": nullValue{&code.CodeChallengeMethod},
	}
}
func (code *AuthCode) VerifySecret(secret string) bool {
	return verifyHash(code.Hash, secret) && verifyExpiration(code.CreatedAt.Add(authCodeExpiration))
}
type SecretKind byte
const (
	SecretKindAccessToken  = SecretKind('a')
	SecretKindRefreshToken = SecretKind('r')
	SecretKindAuthCode     = SecretKind('c')
)
type SigningKey struct {
	ID         ID[*SigningKey]
	KID        string
	Algorithm  string
	PrivateKey []byte
	CreatedAt  time.Time
}
func (key *SigningKey) columns() map[string]interface{} {
	return map[string]interface{}{
		"id":          &key.ID,
		"kid":         &key.KID,
		"algorithm":   &key.Algorithm,
		"private_key": &key.PrivateKey,
		"created_at":  &key.CreatedAt,
	}
}
func UnmarshalSecret[T entity](s string) (id ID[T], secret string, err error) {
	kind, s, _ := strings.Cut(s, ".")
	idStr, secret, ok := strings.Cut(s, ".")
	if !ok || len(kind) != 1 {
		return 0, "", fmt.Errorf("malformed secret")
	}
	switch SecretKind(kind[0]) {
	case SecretKindAccessToken, SecretKindRefreshToken:
		_, ok = interface{}(id).(ID[*AccessToken])
	case SecretKindAuthCode:
		_, ok = interface{}(id).(ID[*AuthCode])
	}
	if !ok {
		return 0, "", fmt.Errorf("invalid secret kind %q", kind)
	}
	id, err = ParseID[T](idStr)
	return id, secret, err
}
func MarshalSecret[T entity](id ID[T], kind SecretKind, secret string) string {
	if id == 0 {
		panic("cannot marshal zero ID")
	}
	var ok bool
	switch interface{}(id).(type) {
	case ID[*AccessToken]:
		ok = kind == SecretKindAccessToken || kind == SecretKindRefreshToken
	case ID[*AuthCode]:
		ok = kind == SecretKindAuthCode
	}
	if !ok {
		panic(fmt.Sprintf("unsupported secret kind %q for ID type %T", string(kind), id))
	}
	return fmt.Sprintf("%v.%v.%v", string(kind), int64(id), secret)
}
func generateUID() (string, error) {
	b := make([]byte, 32)
	if _, err := rand.Read(b); err != nil {
		return "", err
	}
	return base64.RawURLEncoding.EncodeToString(b), nil
}
func generateSecret() (secret string, hash []byte, err error) {
	b := make([]byte, 32)
	if _, err := rand.Read(b); err != nil {
		return "", nil, err
	}
	secret = base64.RawURLEncoding.EncodeToString(b)
	h := sha512.Sum512(b)
	return secret, h[:], nil
}
func verifyHash(hash []byte, secret string) bool {
	b, _ := base64.RawURLEncoding.DecodeString(secret)
	h := sha512.Sum512(b)
	return subtle.ConstantTimeCompare(hash, h[:]) == 1
}
func verifyExpiration(t time.Time) bool {
	return time.Now().Before(t)
}