Lindenii Project Forge
Login

server

Vireo IdP server

Hi… I am well aware that this diff view is very suboptimal. It will be fixed when the refactored server comes along!

Commit info
ID
9a40194633ff36b6de968ba5a8829245bcbfd6bc
Author
Author date
Mon, 19 Feb 2024 16:33:25 +0100
Committer
Committer date
Mon, 19 Feb 2024 16:33:25 +0100
Actions
Drop outdated TODO
package main

import (
	"crypto/rand"
	"crypto/sha512"
	"crypto/subtle"
	"database/sql"
	"database/sql/driver"
	"encoding/base64"
	"fmt"
	"strconv"
	"strings"
	"time"

	"golang.org/x/crypto/bcrypt"
)

const (
	accessTokenExpiration = 30 * 24 * time.Hour
	authCodeExpiration    = 10 * time.Minute
)

type entity interface {
	columns() map[string]interface{}
}

var (
	_ entity = (*User)(nil)
	_ entity = (*Client)(nil)
	_ entity = (*AccessToken)(nil)
	_ entity = (*AuthCode)(nil)
)

type ID[T entity] int64

var (
	_ sql.Scanner   = (*ID[*User])(nil)
	_ driver.Valuer = ID[*User](0)
)

func ParseID[T entity](s string) (ID[T], error) {
	u, _ := strconv.ParseUint(s, 10, 63)
	if u == 0 {
		return 0, fmt.Errorf("invalid ID")
	}
	return ID[T](u), nil
}

func (ptr *ID[T]) Scan(v interface{}) error {
	if v == nil {
		*ptr = 0
		return nil
	}
	id, ok := v.(int64)
	if !ok {
		return fmt.Errorf("cannot scan ID from %T", v)
	}
	*ptr = ID[T](id)
	return nil
}

func (id ID[T]) Value() (driver.Value, error) {
	if id == 0 {
		return nil, nil
	} else {
		return int64(id), nil
	}
}

type nullString string

var (
	_ sql.Scanner   = (*nullString)(nil)
	_ driver.Valuer = (*nullString)(nil)
)

func (ptr *nullString) Scan(v interface{}) error {
	if v == nil {
		*ptr = ""
		return nil
	}
	s, ok := v.(string)
	if !ok {
		return fmt.Errorf("cannot scan nullStringPtr from %T", v)
	}
	*ptr = nullString(s)
	return nil
}

func (ptr *nullString) Value() (driver.Value, error) {
	if *ptr == "" {
		return nil, nil
	} else {
		return string(*ptr), nil
	}
}

type User struct {
	ID           ID[*User]
	Username     string
	PasswordHash string
	Admin        bool
}

func (user *User) columns() map[string]interface{} {
	return map[string]interface{}{
		"id":            &user.ID,
		"username":      &user.Username,
		"password_hash": (*nullString)(&user.PasswordHash),
		"admin":         &user.Admin,
	}
}

func (user *User) VerifyPassword(password string) error {
	// TODO: upgrade hash
	return bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password))
}

func (user *User) SetPassword(password string) error {
	hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
	if err != nil {
		return err
	}
	user.PasswordHash = string(hash)
	return nil
}

func (user *User) PasswordNeedsRehash() bool {
	cost, _ := bcrypt.Cost([]byte(user.PasswordHash))
	return cost != bcrypt.DefaultCost
}

type Client struct {
	ID               ID[*Client]
	ClientID         string
	ClientSecretHash []byte
	Owner            ID[*User]
	RedirectURIs     string
	ClientName       string
	ClientURI        string
}

func (client *Client) Generate(isPublic bool) (secret string, err error) {
	id, err := generateUID()
	if err != nil {
		return "", fmt.Errorf("failed to generate client ID: %v", err)
	}
	client.ClientID = id

	if !isPublic {
		var hash []byte
		secret, hash, err = generateSecret()
		if err != nil {
			return "", fmt.Errorf("failed to generate client secret: %v", err)
		}
		client.ClientSecretHash = hash
	}

	return secret, nil
}

func (client *Client) columns() map[string]interface{} {
	return map[string]interface{}{
		"id":                 &client.ID,
		"client_id":          &client.ClientID,
		"client_secret_hash": &client.ClientSecretHash,
		"owner":              &client.Owner,
		"redirect_uris":      (*nullString)(&client.RedirectURIs),
		"client_name":        (*nullString)(&client.ClientName),
		"client_uri":         (*nullString)(&client.ClientURI),
	}
}

func (client *Client) VerifySecret(secret string) bool {
	return verifyHash(client.ClientSecretHash, secret)
}

func (client *Client) IsPublic() bool {
	return client.ClientSecretHash == nil
}

type AccessToken struct {
	ID        ID[*AccessToken]
	Hash      []byte
	User      ID[*User]
	Client    ID[*Client]
	Scope     string
	IssuedAt  time.Time
	ExpiresAt time.Time
}

func (token *AccessToken) Generate(expiration time.Duration) (secret string, err error) {
	secret, hash, err := generateSecret()
	if err != nil {
		return "", fmt.Errorf("failed to generate access token secret: %v", err)
	}
	token.Hash = hash
	token.IssuedAt = time.Now()
	token.ExpiresAt = time.Now().Add(expiration)
	return secret, nil
}

func NewAccessTokenFromAuthCode(authCode *AuthCode) (token *AccessToken, secret string, err error) {
	token = &AccessToken{
		User:   authCode.User,
		Client: authCode.Client,
		Scope:  authCode.Scope,
	}
	secret, err = token.Generate(accessTokenExpiration)
	return token, secret, err
}

func (token *AccessToken) columns() map[string]interface{} {
	return map[string]interface{}{
		"id":         &token.ID,
		"hash":       &token.Hash,
		"user":       &token.User,
		"client":     &token.Client,
		"scope":      (*nullString)(&token.Scope),
		"issued_at":  &token.IssuedAt,
		"expires_at": &token.ExpiresAt,
	}
}

func (token *AccessToken) VerifySecret(secret string) bool {
	return verifyHash(token.Hash, secret) && verifyExpiration(token.ExpiresAt)
}

type AuthorizedClient struct {
	Client    Client
	ExpiresAt time.Time
}

type AuthCode struct {
	ID        ID[*AuthCode]
	Hash      []byte
	CreatedAt time.Time
	User      ID[*User]
	Client    ID[*Client]
	Scope     string
}

func NewAuthCode(user ID[*User], client ID[*Client], scope string) (code *AuthCode, secret string, err error) {
	secret, hash, err := generateSecret()
	if err != nil {
		return nil, "", fmt.Errorf("failed to generate authentication code secret: %v", err)
	}
	code = &AuthCode{
		Hash:      hash,
		CreatedAt: time.Now(),
		User:      user,
		Client:    client,
		Scope:     scope,
	}
	return code, secret, nil
}

func (code *AuthCode) columns() map[string]interface{} {
	return map[string]interface{}{
		"id":         &code.ID,
		"hash":       &code.Hash,
		"created_at": &code.CreatedAt,
		"user":       &code.User,
		"client":     &code.Client,
		"scope":      (*nullString)(&code.Scope),
	}
}

func (code *AuthCode) VerifySecret(secret string) bool {
	return verifyHash(code.Hash, secret) && verifyExpiration(code.CreatedAt.Add(authCodeExpiration))
}

type SecretKind byte

const (
	SecretKindAccessToken = SecretKind('a')
	SecretKindAuthCode    = SecretKind('c')
)

func UnmarshalSecret[T entity](s string) (id ID[T], secret string, err error) {
	kind, s, _ := strings.Cut(s, ".")
	idStr, secret, ok := strings.Cut(s, ".")
	if !ok || len(kind) != 1 {
		return 0, "", fmt.Errorf("malformed secret")
	}

	switch SecretKind(kind[0]) {
	case SecretKindAccessToken:
		_, ok = interface{}(id).(ID[*AccessToken])
	case SecretKindAuthCode:
		_, ok = interface{}(id).(ID[*AuthCode])
	}
	if !ok {
		return 0, "", fmt.Errorf("invalid secret kind %q", kind)
	}

	id, err = ParseID[T](idStr)
	return id, secret, err
}

func MarshalSecret[T entity](id ID[T], secret string) string {
	if id == 0 {
		panic("cannot marshal zero ID")
	}

	var kind SecretKind
	switch interface{}(id).(type) {
	case ID[*AccessToken]:
		kind = SecretKindAccessToken
	case ID[*AuthCode]:
		kind = SecretKindAuthCode
	default:
		panic(fmt.Sprintf("unsupported secret kind for ID type %T", id))
	}

	return fmt.Sprintf("%v.%v.%v", string(kind), int64(id), secret)
}

func generateUID() (string, error) {
	b := make([]byte, 32)
	if _, err := rand.Read(b); err != nil {
		return "", err
	}
	return base64.RawURLEncoding.EncodeToString(b), nil
}

func generateSecret() (secret string, hash []byte, err error) {
	b := make([]byte, 32)
	if _, err := rand.Read(b); err != nil {
		return "", nil, err
	}
	secret = base64.RawURLEncoding.EncodeToString(b)
	h := sha512.Sum512(b)
	return secret, h[:], nil
}

func verifyHash(hash []byte, secret string) bool {
	b, _ := base64.RawURLEncoding.DecodeString(secret)
	h := sha512.Sum512(b)
	return subtle.ConstantTimeCompare(hash, h[:]) == 1
}

func verifyExpiration(t time.Time) bool {
	return time.Now().Before(t)
}