Lindenii Project Forge
Login

server

Vireo IdP server

Hi… I am well aware that this diff view is very suboptimal. It will be fixed when the refactored server comes along!

Commit info
ID
476a35955f666ccb1d01fdacb0a52b80115bf214
Author
Author date
Mon, 26 Feb 2024 13:21:50 +0100
Committer
Committer date
Mon, 26 Feb 2024 13:21:50 +0100
Actions
Add support for refresh tokens

Closes: https://todo.sr.ht/~emersion/sinwon/21
package main

import (
	"context"
	"database/sql"
	_ "embed"
	"fmt"
	"time"

	"github.com/mattn/go-sqlite3"
)

//go:embed schema.sql
var schema string

var migrations = []string{
	"", // migration #0 is reserved for schema initialization
	`
		ALTER TABLE AccessToken ADD COLUMN refresh_hash BLOB;
		ALTER TABLE AccessToken ADD COLUMN refresh_expires_at datetime;
		CREATE UNIQUE INDEX access_token_refresh_hash ON AccessToken(refresh_hash);
	`,
}

var errNoDBRows = sql.ErrNoRows

type DB struct {
	db *sql.DB
}

func openDB(filename string) (*DB, error) {
	sqlDB, err := sql.Open("sqlite3", filename)
	if err != nil {
		return nil, err
	}

	db := &DB{sqlDB}
	if err := db.init(context.TODO()); err != nil {
		db.Close()
		return nil, err
	}

	return db, nil
}

func (db *DB) init(ctx context.Context) error {
	version, err := db.upgrade(ctx)
	if err != nil {
		return err
	}

	if version > 0 {
		return nil
	}

	// TODO: drop this
	defaultUser := User{Username: "root", Admin: true}
	if err := defaultUser.SetPassword("root"); err != nil {
		return err
	}
	return db.StoreUser(ctx, &defaultUser)
}

func (db *DB) upgrade(ctx context.Context) (version int, err error) {
	if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
		return 0, fmt.Errorf("failed to query schema version: %v", err)
	}

	if version == len(migrations) {
		return version, nil
	} else if version > len(migrations) {
		return version, fmt.Errorf("sinwon (version %d) older than schema (version %d)", len(migrations), version)
	}

	tx, err := db.db.Begin()
	if err != nil {
		return version, err
	}
	defer tx.Rollback()

	if version == 0 {
		if _, err := tx.Exec(schema); err != nil {
			return version, fmt.Errorf("failed to initialize schema: %v", err)
		}
	} else {
		for i := version; i < len(migrations); i++ {
			if _, err := tx.Exec(migrations[i]); err != nil {
				return version, fmt.Errorf("failed to execute migration #%v: %v", i, err)
			}
		}
	}

	// For some reason prepared statements don't work here
	_, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(migrations)))
	if err != nil {
		return version, fmt.Errorf("failed to bump schema version: %v", err)
	}

	return version, tx.Commit()
}

func (db *DB) Close() error {
	return db.db.Close()
}

func (db *DB) FetchUser(ctx context.Context, id ID[*User]) (*User, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM User WHERE id = ?", id)
	if err != nil {
		return nil, err
	}
	var user User
	err = scanRow(&user, rows)
	return &user, err
}

func (db *DB) FetchUserByUsername(ctx context.Context, username string) (*User, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM User WHERE username = ?", username)
	if err != nil {
		return nil, err
	}
	var user User
	err = scanRow(&user, rows)
	return &user, err
}

func (db *DB) StoreUser(ctx context.Context, user *User) error {
	return db.db.QueryRowContext(ctx, `
		INSERT INTO User(id, username, password_hash, admin)
		VALUES (:id, :username, :password_hash, :admin)
		ON CONFLICT(id) DO UPDATE SET
			username = :username,
			password_hash = :password_hash,
			admin = :admin
		RETURNING id
	`, entityArgs(user)...).Scan(&user.ID)
}

func (db *DB) ListUsers(ctx context.Context) ([]User, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM User")
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var l []User
	for rows.Next() {
		var user User
		if err := scan(&user, rows); err != nil {
			return nil, err
		}
		l = append(l, user)
	}

	return l, rows.Close()
}

func (db *DB) FetchClient(ctx context.Context, id ID[*Client]) (*Client, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM Client WHERE id = ?", id)
	if err != nil {
		return nil, err
	}
	var client Client
	err = scanRow(&client, rows)
	return &client, err
}

func (db *DB) FetchClientByClientID(ctx context.Context, clientID string) (*Client, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM Client WHERE client_id = ?", clientID)
	if err != nil {
		return nil, err
	}
	var client Client
	err = scanRow(&client, rows)
	return &client, err
}

func (db *DB) StoreClient(ctx context.Context, client *Client) error {
	return db.db.QueryRowContext(ctx, `
		INSERT INTO Client(id, client_id, client_secret_hash, owner,
			redirect_uris, client_name, client_uri)
		VALUES (:id, :client_id, :client_secret_hash, :owner,
			:redirect_uris, :client_name, :client_uri)
		ON CONFLICT(id) DO UPDATE SET
			client_id = :client_id,
			client_secret_hash = :client_secret_hash,
			owner = :owner,
			redirect_uris = :redirect_uris,
			client_name = :client_name,
			client_uri = :client_uri
		RETURNING id
	`, entityArgs(client)...).Scan(&client.ID)
}

func (db *DB) ListClients(ctx context.Context, owner ID[*User]) ([]Client, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM Client WHERE owner IS ?", owner)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var l []Client
	for rows.Next() {
		var client Client
		if err := scan(&client, rows); err != nil {
			return nil, err
		}
		l = append(l, client)
	}

	return l, rows.Close()
}

func (db *DB) ListAuthorizedClients(ctx context.Context, user ID[*User]) ([]AuthorizedClient, error) {
	rows, err := db.db.QueryContext(ctx, `
		SELECT id, client_id, client_name, client_uri, token.expires_at
		FROM Client,
		(
			SELECT client, MAX(expires_at) as expires_at
			SELECT client, MAX(COALESCE(refresh_expires_at, expires_at)) as expires_at
			FROM AccessToken
			WHERE user = ?
			GROUP BY client
		) AS token
		WHERE Client.id = token.client
	`, user)
	if err != nil {
		return nil, err
	}

	var l []AuthorizedClient
	for rows.Next() {
		var authClient AuthorizedClient
		columns := authClient.Client.columns()
		var expiresAt string
		err := rows.Scan(columns["id"], columns["client_id"], columns["client_name"], columns["client_uri"], &expiresAt)
		if err != nil {
			return nil, err
		}
		authClient.ExpiresAt, err = time.Parse(sqlite3.SQLiteTimestampFormats[0], expiresAt)
		if err != nil {
			return nil, err
		}
		l = append(l, authClient)
	}

	return l, rows.Close()
}

func (db *DB) DeleteClient(ctx context.Context, id ID[*Client]) error {
	_, err := db.db.ExecContext(ctx, "DELETE FROM Client WHERE id = ?", id)
	return err
}

func (db *DB) FetchAccessToken(ctx context.Context, id ID[*AccessToken]) (*AccessToken, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM AccessToken WHERE id = ?", id)
	if err != nil {
		return nil, err
	}
	var token AccessToken
	err = scanRow(&token, rows)
	return &token, err
}

func (db *DB) CreateAccessToken(ctx context.Context, token *AccessToken) error {
func (db *DB) StoreAccessToken(ctx context.Context, token *AccessToken) error {
	return db.db.QueryRowContext(ctx, `
		INSERT INTO AccessToken(hash, user, client, scope, issued_at, expires_at)
		VALUES (:hash, :user, :client, :scope, :issued_at, :expires_at)
		INSERT INTO AccessToken(id, hash, user, client, scope, issued_at,
			expires_at, refresh_hash, refresh_expires_at)
		VALUES (:id, :hash, :user, :client, :scope, :issued_at, :expires_at,
			:refresh_hash, :refresh_expires_at)
		ON CONFLICT(id) DO UPDATE SET
			hash = :hash,
			user = :user,
			client = :client,
			scope = :scope,
			issued_at = :issued_at,
			expires_at = :expires_at,
			refresh_hash = :refresh_hash,
			refresh_expires_at = :refresh_expires_at
		RETURNING id
	`, entityArgs(token)...).Scan(&token.ID)
}

func (db *DB) DeleteAccessToken(ctx context.Context, id ID[*AccessToken]) error {
	_, err := db.db.ExecContext(ctx, "DELETE FROM AccessToken WHERE id = ?", id)
	return err
}

func (db *DB) RevokeAccessTokens(ctx context.Context, clientID ID[*Client], userID ID[*User]) error {
	_, err := db.db.ExecContext(ctx, `
		DELETE FROM AccessToken
		WHERE client = ? AND user = ?
	`, clientID, userID)
	return err
}

func (db *DB) CreateAuthCode(ctx context.Context, code *AuthCode) error {
	return db.db.QueryRowContext(ctx, `
		INSERT INTO AuthCode(hash, created_at, user, client, scope, redirect_uri)
		VALUES (:hash, :created_at, :user, :client, :scope, :redirect_uri)
		RETURNING id
	`, entityArgs(code)...).Scan(&code.ID)
}

func (db *DB) PopAuthCode(ctx context.Context, id ID[*AuthCode]) (*AuthCode, error) {
	rows, err := db.db.QueryContext(ctx, `
		DELETE FROM AuthCode
		WHERE id = ?
		RETURNING *
	`, id)
	if err != nil {
		return nil, err
	}
	var authCode AuthCode
	err = scanRow(&authCode, rows)
	return &authCode, err
}

func (db *DB) Maintain(ctx context.Context) error {
	_, err := db.db.ExecContext(ctx, `
		DELETE FROM AccessToken
		WHERE timediff('now', expires_at) > 0
		WHERE timediff('now', COALESCE(refresh_expires_at, expires_at)) > 0
	`)
	if err != nil {
		return err
	}

	_, err = db.db.ExecContext(ctx, `
		DELETE FROM AuthCode
		WHERE timediff(?, created_at) > 0
	`, time.Now().Add(-authCodeExpiration))
	if err != nil {
		return err
	}

	return nil
}

func scan(e entity, rows *sql.Rows) error {
	columns := e.columns()

	keys, err := rows.Columns()
	if err != nil {
		panic(err)
	}
	out := make([]interface{}, len(keys))
	for i, k := range keys {
		v, ok := columns[k]
		if !ok {
			panic(fmt.Errorf("unknown column %q", k))
		}
		out[i] = v
	}

	return rows.Scan(out...)
}

func scanRow(e entity, rows *sql.Rows) error {
	if !rows.Next() {
		return sql.ErrNoRows
	}
	if err := scan(e, rows); err != nil {
		return err
	}
	return rows.Close()
}

func entityArgs(e entity) []interface{} {
	columns := e.columns()

	l := make([]interface{}, 0, len(columns))
	for k, v := range columns {
		l = append(l, sql.Named(k, v))
	}

	return l
}
package main

import (
	"crypto/rand"
	"crypto/sha512"
	"crypto/subtle"
	"database/sql"
	"database/sql/driver"
	"encoding/base64"
	"fmt"
	"reflect"
	"strconv"
	"strings"
	"time"

	"golang.org/x/crypto/bcrypt"
)

const (
	accessTokenExpiration = 30 * 24 * time.Hour
	authCodeExpiration    = 10 * time.Minute
	accessTokenExpiration  = 30 * 24 * time.Hour
	refreshTokenExpiration = 2 * accessTokenExpiration
	authCodeExpiration     = 10 * time.Minute
)

type entity interface {
	columns() map[string]interface{}
}

var (
	_ entity = (*User)(nil)
	_ entity = (*Client)(nil)
	_ entity = (*AccessToken)(nil)
	_ entity = (*AuthCode)(nil)
)

type ID[T entity] int64

var (
	_ sql.Scanner   = (*ID[*User])(nil)
	_ driver.Valuer = ID[*User](0)
)

func ParseID[T entity](s string) (ID[T], error) {
	u, _ := strconv.ParseUint(s, 10, 63)
	if u == 0 {
		return 0, fmt.Errorf("invalid ID")
	}
	return ID[T](u), nil
}

func (ptr *ID[T]) Scan(v interface{}) error {
	if v == nil {
		*ptr = 0
		return nil
	}
	id, ok := v.(int64)
	if !ok {
		return fmt.Errorf("cannot scan ID from %T", v)
	}
	*ptr = ID[T](id)
	return nil
}

func (id ID[T]) Value() (driver.Value, error) {
	if id == 0 {
		return nil, nil
	} else {
		return int64(id), nil
	}
}

type nullString string
type nullValue struct {
	ptr interface{}
}

var (
	_ sql.Scanner   = (*nullString)(nil)
	_ driver.Valuer = (*nullString)(nil)
	_ sql.Scanner   = nullValue{nil}
	_ driver.Valuer = nullValue{nil}
)

func (ptr *nullString) Scan(v interface{}) error {
func (nv nullValue) Scan(v interface{}) error {
	out := reflect.ValueOf(nv.ptr).Elem()
	if v == nil {
		*ptr = ""
		out.SetZero()
		return nil
	}
	s, ok := v.(string)
	if !ok {
		return fmt.Errorf("cannot scan nullStringPtr from %T", v)

	rv := reflect.ValueOf(v)
	if rv.Type() != out.Type() {
		return fmt.Errorf("cannot scan %v into %v", rv.Type(), out.Type())
	}
	*ptr = nullString(s)

	out.Set(rv)
	return nil
}

func (ptr *nullString) Value() (driver.Value, error) {
	if *ptr == "" {
func (nv nullValue) Value() (driver.Value, error) {
	in := reflect.ValueOf(nv.ptr).Elem()
	if in.IsZero() {
		return nil, nil
	} else {
		return string(*ptr), nil
	}
	return in.Interface(), nil
}

type User struct {
	ID           ID[*User]
	Username     string
	PasswordHash string
	Admin        bool
}

func (user *User) columns() map[string]interface{} {
	return map[string]interface{}{
		"id":            &user.ID,
		"username":      &user.Username,
		"password_hash": (*nullString)(&user.PasswordHash),
		"password_hash": nullValue{&user.PasswordHash},
		"admin":         &user.Admin,
	}
}

func (user *User) VerifyPassword(password string) error {
	return bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password))
}

func (user *User) SetPassword(password string) error {
	hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
	if err != nil {
		return err
	}
	user.PasswordHash = string(hash)
	return nil
}

func (user *User) PasswordNeedsRehash() bool {
	cost, _ := bcrypt.Cost([]byte(user.PasswordHash))
	return cost != bcrypt.DefaultCost
}

type Client struct {
	ID               ID[*Client]
	ClientID         string
	ClientSecretHash []byte
	Owner            ID[*User]
	RedirectURIs     string
	ClientName       string
	ClientURI        string
}

func (client *Client) Generate(isPublic bool) (secret string, err error) {
	id, err := generateUID()
	if err != nil {
		return "", fmt.Errorf("failed to generate client ID: %v", err)
	}
	client.ClientID = id

	if !isPublic {
		var hash []byte
		secret, hash, err = generateSecret()
		if err != nil {
			return "", fmt.Errorf("failed to generate client secret: %v", err)
		}
		client.ClientSecretHash = hash
	}

	return secret, nil
}

func (client *Client) columns() map[string]interface{} {
	return map[string]interface{}{
		"id":                 &client.ID,
		"client_id":          &client.ClientID,
		"client_secret_hash": &client.ClientSecretHash,
		"owner":              &client.Owner,
		"redirect_uris":      (*nullString)(&client.RedirectURIs),
		"client_name":        (*nullString)(&client.ClientName),
		"client_uri":         (*nullString)(&client.ClientURI),
		"redirect_uris":      nullValue{&client.RedirectURIs},
		"client_name":        nullValue{&client.ClientName},
		"client_uri":         nullValue{&client.ClientURI},
	}
}

func (client *Client) VerifySecret(secret string) bool {
	return verifyHash(client.ClientSecretHash, secret)
}

func (client *Client) IsPublic() bool {
	return client.ClientSecretHash == nil
}

type AccessToken struct {
	ID        ID[*AccessToken]
	Hash      []byte
	User      ID[*User]
	Client    ID[*Client]
	Scope     string
	IssuedAt  time.Time
	ExpiresAt time.Time

	RefreshHash      []byte
	RefreshExpiresAt time.Time
}

func (token *AccessToken) Generate(expiration time.Duration) (secret string, err error) {
	secret, hash, err := generateSecret()
	if err != nil {
		return "", fmt.Errorf("failed to generate access token secret: %v", err)
	}
	token.Hash = hash
	token.IssuedAt = time.Now()
	token.ExpiresAt = time.Now().Add(expiration)
	return secret, nil
}

func NewAccessTokenFromAuthCode(authCode *AuthCode) (token *AccessToken, secret string, err error) {
	token = &AccessToken{
func (token *AccessToken) GenerateRefresh() (secret string, err error) {
	secret, hash, err := generateSecret()
	if err != nil {
		return "", fmt.Errorf("failed to generate refresh token secret: %v", err)
	}
	token.RefreshHash = hash
	token.RefreshExpiresAt = time.Now().Add(refreshTokenExpiration)
	return secret, nil
}

func NewAccessTokenFromAuthCode(authCode *AuthCode) *AccessToken {
	return &AccessToken{
		User:   authCode.User,
		Client: authCode.Client,
		Scope:  authCode.Scope,
	}
	secret, err = token.Generate(accessTokenExpiration)
	return token, secret, err
}

func (token *AccessToken) columns() map[string]interface{} {
	return map[string]interface{}{
		"id":         &token.ID,
		"hash":       &token.Hash,
		"user":       &token.User,
		"client":     &token.Client,
		"scope":      (*nullString)(&token.Scope),
		"issued_at":  &token.IssuedAt,
		"expires_at": &token.ExpiresAt,
		"id":                 &token.ID,
		"hash":               &token.Hash,
		"user":               &token.User,
		"client":             &token.Client,
		"scope":              nullValue{&token.Scope},
		"issued_at":          &token.IssuedAt,
		"expires_at":         &token.ExpiresAt,
		"refresh_hash":       &token.RefreshHash,
		"refresh_expires_at": nullValue{&token.RefreshExpiresAt},
	}
}

func (token *AccessToken) VerifySecret(secret string) bool {
	return verifyHash(token.Hash, secret) && verifyExpiration(token.ExpiresAt)
}

func (token *AccessToken) VerifyRefreshSecret(secret string) bool {
	return verifyHash(token.RefreshHash, secret) && verifyExpiration(token.RefreshExpiresAt)
}

type AuthorizedClient struct {
	Client    Client
	ExpiresAt time.Time
}

type AuthCode struct {
	ID          ID[*AuthCode]
	Hash        []byte
	CreatedAt   time.Time
	User        ID[*User]
	Client      ID[*Client]
	Scope       string
	RedirectURI string
}

func (code *AuthCode) Generate() (secret string, err error) {
	secret, hash, err := generateSecret()
	if err != nil {
		return "", fmt.Errorf("failed to generate authentication code secret: %v", err)
	}
	code.Hash = hash
	code.CreatedAt = time.Now()
	return secret, nil
}

func (code *AuthCode) columns() map[string]interface{} {
	return map[string]interface{}{
		"id":           &code.ID,
		"hash":         &code.Hash,
		"created_at":   &code.CreatedAt,
		"user":         &code.User,
		"client":       &code.Client,
		"scope":        (*nullString)(&code.Scope),
		"redirect_uri": (*nullString)(&code.RedirectURI),
		"scope":        nullValue{&code.Scope},
		"redirect_uri": nullValue{&code.RedirectURI},
	}
}

func (code *AuthCode) VerifySecret(secret string) bool {
	return verifyHash(code.Hash, secret) && verifyExpiration(code.CreatedAt.Add(authCodeExpiration))
}

type SecretKind byte

const (
	SecretKindAccessToken = SecretKind('a')
	SecretKindAuthCode    = SecretKind('c')
	SecretKindAccessToken  = SecretKind('a')
	SecretKindRefreshToken = SecretKind('r')
	SecretKindAuthCode     = SecretKind('c')
)

func UnmarshalSecret[T entity](s string) (id ID[T], secret string, err error) {
	kind, s, _ := strings.Cut(s, ".")
	idStr, secret, ok := strings.Cut(s, ".")
	if !ok || len(kind) != 1 {
		return 0, "", fmt.Errorf("malformed secret")
	}

	switch SecretKind(kind[0]) {
	case SecretKindAccessToken:
	case SecretKindAccessToken, SecretKindRefreshToken:
		_, ok = interface{}(id).(ID[*AccessToken])
	case SecretKindAuthCode:
		_, ok = interface{}(id).(ID[*AuthCode])
	}
	if !ok {
		return 0, "", fmt.Errorf("invalid secret kind %q", kind)
	}

	id, err = ParseID[T](idStr)
	return id, secret, err
}

func MarshalSecret[T entity](id ID[T], secret string) string {
func MarshalSecret[T entity](id ID[T], kind SecretKind, secret string) string {
	if id == 0 {
		panic("cannot marshal zero ID")
	}

	var kind SecretKind
	var ok bool
	switch interface{}(id).(type) {
	case ID[*AccessToken]:
		kind = SecretKindAccessToken
		ok = kind == SecretKindAccessToken || kind == SecretKindRefreshToken
	case ID[*AuthCode]:
		kind = SecretKindAuthCode
	default:
		panic(fmt.Sprintf("unsupported secret kind for ID type %T", id))
		ok = kind == SecretKindAuthCode
	}
	if !ok {
		panic(fmt.Sprintf("unsupported secret kind %q for ID type %T", string(kind), id))
	}

	return fmt.Sprintf("%v.%v.%v", string(kind), int64(id), secret)
}

func generateUID() (string, error) {
	b := make([]byte, 32)
	if _, err := rand.Read(b); err != nil {
		return "", err
	}
	return base64.RawURLEncoding.EncodeToString(b), nil
}

func generateSecret() (secret string, hash []byte, err error) {
	b := make([]byte, 32)
	if _, err := rand.Read(b); err != nil {
		return "", nil, err
	}
	secret = base64.RawURLEncoding.EncodeToString(b)
	h := sha512.Sum512(b)
	return secret, h[:], nil
}

func verifyHash(hash []byte, secret string) bool {
	b, _ := base64.RawURLEncoding.DecodeString(secret)
	h := sha512.Sum512(b)
	return subtle.ConstantTimeCompare(hash, h[:]) == 1
}

func verifyExpiration(t time.Time) bool {
	return time.Now().Before(t)
}
package main

import (
	"context"
	"fmt"
	"html/template"
	"io"
	"io/fs"
	"mime"
	"net/http"
)

const (
	loginCookieName    = "sinwon-login"
	internalTokenScope = "_sinwon"
)

type contextKey string

const (
	contextKeyDB         = "db"
	contextKeyTemplate   = "template"
	contextKeyLoginToken = "login-token"
)

func dbFromContext(ctx context.Context) *DB {
	return ctx.Value(contextKeyDB).(*DB)
}

func templateFromContext(ctx context.Context) *Template {
	return ctx.Value(contextKeyTemplate).(*Template)
}

func loginTokenFromContext(ctx context.Context) *AccessToken {
	v := ctx.Value(contextKeyLoginToken)
	if v == nil {
		return nil
	}
	return v.(*AccessToken)
}

func newBaseContext(db *DB, tpl *Template) context.Context {
	ctx := context.Background()
	ctx = context.WithValue(ctx, contextKeyDB, db)
	ctx = context.WithValue(ctx, contextKeyTemplate, tpl)
	return ctx
}

func setLoginTokenCookie(w http.ResponseWriter, req *http.Request, token *AccessToken, secret string) {
	http.SetCookie(w, &http.Cookie{
		Name:     loginCookieName,
		Value:    MarshalSecret(token.ID, secret),
		Value:    MarshalSecret(token.ID, SecretKindAccessToken, secret),
		HttpOnly: true,
		SameSite: http.SameSiteStrictMode,
		Secure:   isForwardedHTTPS(req),
	})
}

func unsetLoginTokenCookie(w http.ResponseWriter, req *http.Request) {
	http.SetCookie(w, &http.Cookie{
		Name:     loginCookieName,
		HttpOnly: true,
		SameSite: http.SameSiteStrictMode,
		Secure:   isForwardedHTTPS(req),
		MaxAge:   -1,
	})
}

func isForwardedHTTPS(req *http.Request) bool {
	if forwarded := req.Header.Get("Forwarded"); forwarded != "" {
		_, params, _ := mime.ParseMediaType("_; " + forwarded)
		return params["proto"] == "https"
	}
	if forwardedProto := req.Header.Get("X-Forwarded-Proto"); forwardedProto != "" {
		return forwardedProto == "https"
	}
	return false
}

func loginTokenMiddleware(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
		cookie, _ := req.Cookie(loginCookieName)
		if cookie == nil {
			next.ServeHTTP(w, req)
			return
		}

		ctx := req.Context()
		db := dbFromContext(ctx)
		tokenID, tokenSecret, _ := UnmarshalSecret[*AccessToken](cookie.Value)
		token, err := db.FetchAccessToken(ctx, tokenID)
		if err == errNoDBRows || (err == nil && !token.VerifySecret(tokenSecret)) {
			unsetLoginTokenCookie(w, req)
			next.ServeHTTP(w, req)
			return
		} else if err != nil {
			httpError(w, fmt.Errorf("failed to fetch access token: %v", err))
			return
		}

		if token.Scope != internalTokenScope {
			http.Error(w, "Invalid login token scope", http.StatusForbidden)
			return
		}
		if token.User == 0 {
			panic("login token with zero user ID")
		}

		ctx = context.WithValue(ctx, contextKeyLoginToken, token)
		req = req.WithContext(ctx)
		next.ServeHTTP(w, req)
	})
}

type TemplateBaseData struct {
	ServerName string
}

func (data *TemplateBaseData) Base() *TemplateBaseData {
	return data
}

type TemplateData interface {
	Base() *TemplateBaseData
}

type Template struct {
	tpl      *template.Template
	baseData *TemplateBaseData
}

func loadTemplate(fs fs.FS, pattern string, baseData *TemplateBaseData) (*Template, error) {
	tpl, err := template.ParseFS(fs, pattern)
	if err != nil {
		return nil, err
	}
	return &Template{tpl: tpl, baseData: baseData}, nil
}

func (tpl *Template) MustExecuteTemplate(w io.Writer, filename string, data TemplateData) {
	if data == nil {
		data = tpl.baseData
	} else {
		*data.Base() = *tpl.baseData
	}
	if err := tpl.tpl.ExecuteTemplate(w, filename, data); err != nil {
		panic(err)
	}
}
package main

import (
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"log"
	"mime"
	"net"
	"net/http"
	"net/url"
	"strings"
	"time"

	"git.sr.ht/~emersion/go-oauth2"
)

func getOAuthServerMetadata(w http.ResponseWriter, req *http.Request) {
	issuer := getIssuer(req)

	w.Header().Set("Content-Type", "application/json")
	json.NewEncoder(w).Encode(&oauth2.ServerMetadata{
		Issuer:                                     issuer,
		AuthorizationEndpoint:                      issuer + "/authorize",
		TokenEndpoint:                              issuer + "/token",
		IntrospectionEndpoint:                      issuer + "/introspect",
		RevocationEndpoint:                         issuer + "/revoke",
		ResponseTypesSupported:                     []oauth2.ResponseType{oauth2.ResponseTypeCode},
		ResponseModesSupported:                     []oauth2.ResponseMode{oauth2.ResponseModeQuery},
		GrantTypesSupported:                        []oauth2.GrantType{oauth2.GrantTypeAuthorizationCode},
		TokenEndpointAuthMethodsSupported:          []oauth2.AuthMethod{oauth2.AuthMethodNone, oauth2.AuthMethodClientSecretBasic},
		IntrospectionEndpointAuthMethodsSupported:  []oauth2.AuthMethod{oauth2.AuthMethodNone, oauth2.AuthMethodClientSecretBasic},
		RevocationEndpointAuthMethodsSupported:     []oauth2.AuthMethod{oauth2.AuthMethodNone, oauth2.AuthMethodClientSecretBasic},
		AuthorizationResponseIssParameterSupported: true,
	})
}

func getIssuer(req *http.Request) string {
	issuerURL := url.URL{
		Scheme: "https",
		Host:   req.Host,
	}
	if !isForwardedHTTPS(req) && isLoopback(req) {
		// TODO: add config option for allowed reverse proxy IPs
		issuerURL.Scheme = "http"
	}
	return issuerURL.String()
}

func isLoopback(req *http.Request) bool {
	host, _, _ := net.SplitHostPort(req.RemoteAddr)
	ip := net.ParseIP(host)
	return ip.IsLoopback()
}

func authorize(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	db := dbFromContext(ctx)
	tpl := templateFromContext(ctx)

	q := req.URL.Query()
	respType := oauth2.ResponseType(q.Get("response_type"))
	clientID := q.Get("client_id")
	rawRedirectURI := q.Get("redirect_uri")
	scope := q.Get("scope")
	state := q.Get("state")

	if clientID == "" {
		http.Error(w, "Missing client ID", http.StatusBadRequest)
		return
	}

	client, err := db.FetchClientByClientID(ctx, clientID)
	if err == errNoDBRows {
		http.Error(w, "Invalid client ID", http.StatusForbidden)
		return
	} else if err != nil {
		httpError(w, fmt.Errorf("failed to fetch client: %v", err))
		return
	}

	var allowedRedirectURIs []*url.URL
	for _, s := range strings.Split(client.RedirectURIs, "\n") {
		if s == "" {
			continue
		}
		u, err := url.Parse(s)
		if err != nil {
			httpError(w, fmt.Errorf("failed to parse client redirect URI"))
			return
		}
		allowedRedirectURIs = append(allowedRedirectURIs, u)
	}

	var redirectURI *url.URL
	if rawRedirectURI != "" {
		redirectURI, err = url.Parse(rawRedirectURI)
		if err != nil {
			http.Error(w, "Invalid redirect URI", http.StatusBadRequest)
			return
		}
		if !validateRedirectURI(redirectURI, allowedRedirectURIs) {
			http.Error(w, "Invalid redirect URI", http.StatusBadRequest)
			return
		}
	} else {
		if len(allowedRedirectURIs) == 0 {
			http.Error(w, "Missing redirect URI", http.StatusBadRequest)
			return
		}
		redirectURI = allowedRedirectURIs[0]
	}

	if respType != oauth2.ResponseTypeCode {
		redirectClientError(w, req, redirectURI, state, &oauth2.Error{
			Code: oauth2.ErrorCodeUnsupportedResponseType,
		})
		return
	}

	// TODO: add support for scope
	if scope != "" {
		redirectClientError(w, req, redirectURI, state, &oauth2.Error{
			Code: oauth2.ErrorCodeInvalidScope,
		})
		return
	}

	loginToken := loginTokenFromContext(ctx)
	if loginToken == nil {
		q := make(url.Values)
		q.Set("redirect_uri", req.URL.String())
		u := url.URL{
			Path:     "/login",
			RawQuery: q.Encode(),
		}
		http.Redirect(w, req, u.String(), http.StatusFound)
		return
	}

	_ = req.ParseForm()
	if _, ok := req.PostForm["deny"]; ok {
		redirectClientError(w, req, redirectURI, state, &oauth2.Error{
			Code: oauth2.ErrorCodeAccessDenied,
		})
		return
	}
	if _, ok := req.PostForm["authorize"]; !ok {
		data := struct {
			TemplateBaseData
			Client *Client
		}{
			Client: client,
		}
		tpl.MustExecuteTemplate(w, "authorize.html", &data)
		return
	}

	authCode := AuthCode{
		User:        loginToken.User,
		Client:      client.ID,
		Scope:       scope,
		RedirectURI: rawRedirectURI,
	}
	secret, err := authCode.Generate()
	if err != nil {
		httpError(w, fmt.Errorf("failed to generate authentication code: %v", err))
		return
	}

	if err := db.CreateAuthCode(ctx, &authCode); err != nil {
		httpError(w, fmt.Errorf("failed to create authentication code: %v", err))
		return
	}

	code := MarshalSecret(authCode.ID, secret)
	code := MarshalSecret(authCode.ID, SecretKindAuthCode, secret)

	values := make(url.Values)
	values.Set("code", code)
	if state != "" {
		values.Set("state", state)
	}
	redirectClient(w, req, redirectURI, values)
}

func exchangeToken(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	db := dbFromContext(ctx)

	values, err := parseRequestBody(req)
	if err != nil {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidRequest,
			Description: err.Error(),
		})
		return
	}

	clientID := values.Get("client_id")
	grantType := oauth2.GrantType(values.Get("grant_type"))
	scope := values.Get("scope")
	redirectURI := values.Get("redirect_uri")

	authClientID, clientSecret, _ := req.BasicAuth()
	if clientID == "" && authClientID == "" {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidRequest,
			Description: "Missing client ID",
		})
		return
	} else if clientID == "" {
	if clientID == "" {
		clientID = authClientID
	} else if clientID != authClientID {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidRequest,
			Description: "Client ID in request body doesn't match Authorization header field",
		})
		return
	}

	client, err := db.FetchClientByClientID(ctx, clientID)
	if err == errNoDBRows {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidClient,
			Description: "Invalid client ID",
		})
		return
	} else if err != nil {
		oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
		return
	}
	var client *Client
	if clientID != "" {
		client, err = db.FetchClientByClientID(ctx, clientID)
		if err == errNoDBRows {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidClient,
				Description: "Invalid client ID",
			})
			return
		} else if err != nil {
			oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
			return
		}

	if !client.IsPublic() {
		if !client.VerifySecret(clientSecret) {
		if !client.IsPublic() && !client.VerifySecret(clientSecret) {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid client secret",
			})
			return
		}
	}

	if grantType != oauth2.GrantTypeAuthorizationCode {
	var token *AccessToken
	switch grantType {
	case oauth2.GrantTypeAuthorizationCode:
		if client == nil {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidRequest,
				Description: "Missing client ID",
			})
			return
		}

		codeID, codeSecret, _ := UnmarshalSecret[*AuthCode](values.Get("code"))
		authCode, err := db.PopAuthCode(ctx, codeID)
		if err == errNoDBRows || (err == nil && !authCode.VerifySecret(codeSecret)) || authCode.Client != client.ID {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid authorization code",
			})
			return
		} else if err != nil {
			oauthError(w, fmt.Errorf("failed to fetch authorization code: %v", err))
			return
		}

		if scope != authCode.Scope {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid scope",
			})
			return
		}
		if values.Get("redirect_uri") != authCode.RedirectURI {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid redirect URI",
			})
			return
		}

		token = NewAccessTokenFromAuthCode(authCode)
	case oauth2.GrantTypeRefreshToken:
		tokenID, refreshSecret, _ := UnmarshalSecret[*AccessToken](values.Get("refresh_token"))
		token, err = db.FetchAccessToken(ctx, tokenID)
		if err == errNoDBRows || (err == nil && !token.VerifyRefreshSecret(refreshSecret)) {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid refresh token",
			})
			return
		} else if err != nil {
			oauthError(w, fmt.Errorf("failed to fetch access token: %v", err))
			return
		}

		if client != nil && client.ID != token.Client {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid refresh token",
			})
			return
		}

		tokenClient, err := db.FetchClient(ctx, token.Client)
		if err != nil {
			oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
			return
		}

		if !tokenClient.IsPublic() && client == nil {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid client secret",
			})
			return
		}

		if scope != token.Scope {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid scope",
			})
			return
		}
	default:
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeUnsupportedGrantType,
			Description: "Unsupported grant type",
		})
		return
	}

	codeID, codeSecret, _ := UnmarshalSecret[*AuthCode](values.Get("code"))
	authCode, err := db.PopAuthCode(ctx, codeID)
	if err == errNoDBRows || (err == nil && !authCode.VerifySecret(codeSecret)) || authCode.Client != client.ID {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeAccessDenied,
			Description: "Invalid authorization code",
		})
		return
	} else if err != nil {
		oauthError(w, fmt.Errorf("failed to fetch authorization code: %v", err))
	secret, err := token.Generate(accessTokenExpiration)
	if err != nil {
		oauthError(w, err)
		return
	}

	if scope != authCode.Scope {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeAccessDenied,
			Description: "Invalid scope",
		})
		return
	}
	if redirectURI != authCode.RedirectURI {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeAccessDenied,
			Description: "Invalid redirect URI",
		})
		return
	}

	token, secret, err := NewAccessTokenFromAuthCode(authCode)
	refreshSecret, err := token.GenerateRefresh()
	if err != nil {
		oauthError(w, err)
		return
	}

	if err := db.CreateAccessToken(ctx, token); err != nil {
	if err := db.StoreAccessToken(ctx, token); err != nil {
		oauthError(w, fmt.Errorf("failed to create access token: %v", err))
		return
	}

	w.Header().Set("Content-Type", "application/json")
	w.Header().Set("Cache-Control", "no-store")
	json.NewEncoder(w).Encode(&oauth2.TokenResp{
		AccessToken: MarshalSecret(token.ID, secret),
		TokenType:   oauth2.TokenTypeBearer,
		ExpiresIn:   time.Until(token.ExpiresAt),
		Scope:       strings.Split(token.Scope, " "),
		AccessToken:  MarshalSecret(token.ID, SecretKindAccessToken, secret),
		TokenType:    oauth2.TokenTypeBearer,
		ExpiresIn:    time.Until(token.ExpiresAt),
		Scope:        strings.Split(token.Scope, " "),
		RefreshToken: MarshalSecret(token.ID, SecretKindRefreshToken, refreshSecret),
	})
}

func introspectToken(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	db := dbFromContext(ctx)

	values, err := parseRequestBody(req)
	if err != nil {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidRequest,
			Description: err.Error(),
		})
		return
	}

	client, err := maybeAuthenticateClient(w, req)
	if err != nil {
		oauthError(w, err)
		return
	}

	tokenID, secret, _ := UnmarshalSecret[*AccessToken](values.Get("token"))
	token, err := db.FetchAccessToken(ctx, tokenID)
	if err == errNoDBRows || (err == nil && !token.VerifySecret(secret)) {
		token = nil
	} else if err != nil {
		oauthError(w, fmt.Errorf("failed to fetch access token: %v", err))
		return
	}

	var resp oauth2.IntrospectionResp
	if token != nil {
		if client == nil {
			client, err = db.FetchClient(ctx, token.Client)
			if err != nil {
				oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
				return
			}

			if !client.IsPublic() {
				oauthError(w, &oauth2.Error{
					Code:        oauth2.ErrorCodeInvalidClient,
					Description: "Missing client ID and secret",
				})
				return
			}
		}

		if client.ID != token.Client {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidClient,
				Description: "Invalid client ID or secret",
			})
			return
		}

		user, err := db.FetchUser(ctx, token.User)
		if err != nil {
			oauthError(w, fmt.Errorf("failed to fetch user: %v", err))
			return
		}

		resp.Active = true
		resp.TokenType = oauth2.TokenTypeBearer
		resp.ExpiresAt = token.ExpiresAt
		resp.IssuedAt = token.IssuedAt
		resp.ClientID = client.ClientID
		resp.Username = user.Username
	}

	w.Header().Set("Content-Type", "application/json")
	json.NewEncoder(w).Encode(&resp)
}

func revokeToken(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	db := dbFromContext(ctx)

	values, err := parseRequestBody(req)
	if err != nil {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidRequest,
			Description: err.Error(),
		})
		return
	}

	client, err := maybeAuthenticateClient(w, req)
	if err != nil {
		oauthError(w, err)
		return
	}

	tokenID, secret, _ := UnmarshalSecret[*AccessToken](values.Get("token"))
	token, err := db.FetchAccessToken(ctx, tokenID)
	if err == errNoDBRows || (err == nil && !token.VerifySecret(secret)) {
		return // ignore
	} else if err != nil {
		oauthError(w, fmt.Errorf("failed to fetch access token: %v", err))
		return
	}

	if client == nil {
		client, err = db.FetchClient(ctx, token.Client)
		if err != nil {
			oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
			return
		}

		if !client.IsPublic() {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidClient,
				Description: "Missing client ID and secret",
			})
			return
		}
	}

	if client.ID != token.Client {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidClient,
			Description: "Invalid client ID or secret",
		})
		return
	}

	if err := db.DeleteAccessToken(ctx, token.ID); err != nil {
		oauthError(w, err)
		return
	}
}

func parseRequestBody(req *http.Request) (url.Values, error) {
	ct := req.Header.Get("Content-Type")
	if ct != "" {
		mimeType, _, err := mime.ParseMediaType(req.Header.Get("Content-Type"))
		if err != nil {
			return nil, fmt.Errorf("malformed Content-Type header field")
		} else if mimeType != "application/x-www-form-urlencoded" {
			return nil, fmt.Errorf("unsupported request content type")
		}
	}

	r := io.LimitReader(req.Body, 10<<20)
	b, err := io.ReadAll(r)
	if err != nil {
		return nil, fmt.Errorf("failed to read request body: %v", err)
	}

	values, err := url.ParseQuery(string(b))
	if err != nil {
		return nil, fmt.Errorf("failed to parse request body: %v", err)
	}

	return values, nil
}

func oauthError(w http.ResponseWriter, err error) {
	var oauthErr *oauth2.Error
	if !errors.As(err, &oauthErr) {
		oauthErr = &oauth2.Error{Code: oauth2.ErrorCodeServerError}
		log.Print(err)
	}

	statusCode := http.StatusInternalServerError
	switch oauthErr.Code {
	case oauth2.ErrorCodeInvalidRequest, oauth2.ErrorCodeUnsupportedResponseType, oauth2.ErrorCodeInvalidScope, oauth2.ErrorCodeInvalidClient, oauth2.ErrorCodeInvalidGrant, oauth2.ErrorCodeUnsupportedGrantType:
		statusCode = http.StatusBadRequest
	case oauth2.ErrorCodeUnauthorizedClient, oauth2.ErrorCodeAccessDenied:
		statusCode = http.StatusForbidden
	case oauth2.ErrorCodeTemporarilyUnavailable:
		statusCode = http.StatusServiceUnavailable
	}

	w.Header().Set("Content-Type", "application/json")
	w.WriteHeader(statusCode)
	json.NewEncoder(w).Encode(oauthErr)
}

func redirectClient(w http.ResponseWriter, req *http.Request, redirectURI *url.URL, values url.Values) {
	q := redirectURI.Query()
	for k, v := range values {
		q[k] = v
	}
	q.Set("iss", getIssuer(req))

	u := *redirectURI
	u.RawQuery = q.Encode()

	http.Redirect(w, req, u.String(), http.StatusFound)
}

func redirectClientError(w http.ResponseWriter, req *http.Request, redirectURI *url.URL, state string, err error) {
	var oauthErr *oauth2.Error
	if !errors.As(err, &oauthErr) {
		oauthErr = &oauth2.Error{Code: oauth2.ErrorCodeServerError}
		log.Print(err)
	}

	values := make(url.Values)
	values.Set("error", string(oauthErr.Code))
	if oauthErr.Description != "" {
		values.Set("error_description", oauthErr.Description)
	}
	if oauthErr.URI != "" {
		values.Set("error_uri", oauthErr.URI)
	}
	if state != "" {
		values.Set("state", state)
	}
	redirectClient(w, req, redirectURI, values)
}

func validateRedirectURI(u *url.URL, allowedURIs []*url.URL) bool {
	// Loopback interface, see RFC 8252 section 7.3
	host, _, _ := net.SplitHostPort(u.Host)
	ip := net.ParseIP(host)
	if u.Scheme == "http" && ip.IsLoopback() {
		uu := *u
		uu.Host = "localhost"
		u = &uu
	}

	for _, allowed := range allowedURIs {
		if u.String() == allowed.String() {
			return true
		}
	}

	return false
}

func maybeAuthenticateClient(w http.ResponseWriter, req *http.Request) (*Client, error) {
	ctx := req.Context()
	db := dbFromContext(ctx)

	clientID, clientSecret, ok := req.BasicAuth()
	if !ok {
		return nil, nil
	}

	client, err := db.FetchClientByClientID(ctx, clientID)
	if err == errNoDBRows || (err == nil && !client.VerifySecret(clientSecret)) {
		return nil, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidClient,
			Description: "Invalid client ID or secret",
		}
	} else if err != nil {
		return nil, fmt.Errorf("failed to fetch client: %v", err)
	}

	return client, nil
}
CREATE TABLE User (
	id INTEGER PRIMARY KEY,
	username TEXT NOT NULL UNIQUE,
	password_hash TEXT,
	admin INTEGER NOT NULL DEFAULT 0
);

CREATE TABLE Client (
	id INTEGER PRIMARY KEY,
	client_id TEXT NOT NULL UNIQUE,
	client_secret_hash BLOB,
	owner INTEGER,
	redirect_uris TEXT,
	client_name TEXT,
	client_uri TEXT,
	FOREIGN KEY(owner) REFERENCES User(id)
);

CREATE TABLE AccessToken (
	id INTEGER PRIMARY KEY,
	hash BLOB NOT NULL UNIQUE,
	user INTEGER NOT NULL,
	client INTEGER,
	scope TEXT,
	issued_at datetime NOT NULL,
	expires_at datetime NOT NULL,
	refresh_hash BLOB UNIQUE,
	refresh_expires_at datetime,
	FOREIGN KEY(user) REFERENCES User(id),
	FOREIGN KEY(client) REFERENCES Client(id)
);

CREATE TABLE AuthCode (
	id INTEGER PRIMARY KEY,
	hash BLOB NOT NULL UNIQUE,
	created_at datetime NOT NULL,
	user INTEGER NOT NULL,
	client INTEGER NOT NULL,
	redirect_uri TEXT,
	scope TEXT,
	FOREIGN KEY(user) REFERENCES User(id),
	FOREIGN KEY(client) REFERENCES Client(id)
);
package main

import (
	"fmt"
	"log"
	"net/http"
	"net/url"
	"time"

	"github.com/go-chi/chi/v5"
)

func index(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	db := dbFromContext(ctx)
	tpl := templateFromContext(ctx)

	loginToken := loginTokenFromContext(ctx)
	if loginToken == nil {
		http.Redirect(w, req, "/login", http.StatusFound)
		return
	}

	me, err := db.FetchUser(ctx, loginToken.User)
	if err != nil {
		httpError(w, err)
		return
	}

	authorizedClients, err := db.ListAuthorizedClients(ctx, loginToken.User)
	if err != nil {
		httpError(w, err)
		return
	}

	clients, err := db.ListClients(ctx, loginToken.User)
	if err != nil {
		httpError(w, err)
		return
	}

	var users []User
	if me.Admin {
		users, err = db.ListUsers(ctx)
		if err != nil {
			httpError(w, err)
			return
		}
	}

	data := struct {
		TemplateBaseData
		Me                *User
		AuthorizedClients []AuthorizedClient
		Clients           []Client
		Users             []User
	}{
		Me:                me,
		AuthorizedClients: authorizedClients,
		Clients:           clients,
		Users:             users,
	}
	tpl.MustExecuteTemplate(w, "index.html", &data)
}

func login(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	db := dbFromContext(ctx)
	tpl := templateFromContext(ctx)

	q := req.URL.Query()
	rawRedirectURI := q.Get("redirect_uri")
	if rawRedirectURI == "" {
		rawRedirectURI = "/"
	}

	redirectURI, err := url.Parse(rawRedirectURI)
	if err != nil || redirectURI.Scheme != "" || redirectURI.Opaque != "" || redirectURI.User != nil || redirectURI.Host != "" {
		http.Error(w, "Invalid redirect URI", http.StatusBadRequest)
		return
	}

	if loginTokenFromContext(ctx) != nil {
		http.Redirect(w, req, redirectURI.String(), http.StatusFound)
		return
	}

	username := req.PostFormValue("username")
	password := req.PostFormValue("password")
	if username == "" {
		tpl.MustExecuteTemplate(w, "login.html", nil)
		return
	}

	user, err := db.FetchUserByUsername(ctx, username)
	if err != nil && err != errNoDBRows {
		httpError(w, fmt.Errorf("failed to fetch user: %v", err))
		return
	}
	if err == nil {
		err = user.VerifyPassword(password)
	}
	if err != nil {
		log.Printf("login failed for user %q: %v", username, err)
		// TODO: show error message
		tpl.MustExecuteTemplate(w, "login.html", nil)
		return
	}

	if user.PasswordNeedsRehash() {
		if err := user.SetPassword(password); err != nil {
			httpError(w, fmt.Errorf("failed to rehash password: %v", err))
			return
		}
		if err := db.StoreUser(ctx, user); err != nil {
			httpError(w, fmt.Errorf("failed to store user: %v", err))
			return
		}
	}

	token := AccessToken{
		User:  user.ID,
		Scope: internalTokenScope,
	}
	secret, err := token.Generate(4 * time.Hour)
	if err != nil {
		httpError(w, fmt.Errorf("failed to generate access token: %v", err))
		return
	}
	if err := db.CreateAccessToken(ctx, &token); err != nil {
	if err := db.StoreAccessToken(ctx, &token); err != nil {
		httpError(w, fmt.Errorf("failed to create access token: %v", err))
		return
	}

	setLoginTokenCookie(w, req, &token, secret)
	http.Redirect(w, req, redirectURI.String(), http.StatusFound)
}

func logout(w http.ResponseWriter, req *http.Request) {
	unsetLoginTokenCookie(w, req)
	http.Redirect(w, req, "/login", http.StatusFound)
}

func manageUser(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	db := dbFromContext(ctx)
	tpl := templateFromContext(ctx)

	user := new(User)
	if idStr := chi.URLParam(req, "id"); idStr != "" {
		id, err := ParseID[*User](idStr)
		if err != nil {
			http.Error(w, err.Error(), http.StatusBadRequest)
			return
		}

		user, err = db.FetchUser(ctx, id)
		if err != nil {
			httpError(w, err)
			return
		}
	}

	loginToken := loginTokenFromContext(ctx)
	if loginToken == nil {
		http.Redirect(w, req, "/login", http.StatusFound)
		return
	}

	me, err := db.FetchUser(ctx, loginToken.User)
	if err != nil {
		httpError(w, err)
		return
	} else if loginToken.User != user.ID && !me.Admin {
		http.Error(w, "Access denied", http.StatusForbidden)
		return
	}

	username := req.PostFormValue("username")
	password := req.PostFormValue("password")
	admin := req.PostFormValue("admin") == "on"
	if username == "" {
		data := struct {
			TemplateBaseData
			User *User
			Me   *User
		}{
			User: user,
			Me:   me,
		}
		tpl.MustExecuteTemplate(w, "manage-user.html", &data)
		return
	}

	user.Username = username
	if me.Admin && user.ID != me.ID {
		user.Admin = admin
	}
	if password != "" {
		if err := user.SetPassword(password); err != nil {
			httpError(w, err)
			return
		}
	}

	if err := db.StoreUser(ctx, user); err != nil {
		httpError(w, err)
		return
	}

	http.Redirect(w, req, "/", http.StatusFound)
}