Lindenii Project Forge
Login

server

Vireo IdP server

Hi… I am well aware that this diff view is very suboptimal. It will be fixed when the refactored server comes along!

Commit info
ID
69434566ea5a8a0d39802a62cc6b254ea8bd667c
Author
Runxi Yu <me@runxiyu.org>
Author date
Thu, 25 Sep 2025 19:41:57 +0800
Committer
Runxi Yu <me@runxiyu.org>
Committer date
Thu, 25 Sep 2025 20:21:33 +0800
Actions
Implement basic OIDC and some fixes
package main

import (
	"fmt"
	"net/http"
	"net/url"
	"strings"

	"github.com/go-chi/chi/v5"
)

func manageClient(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	db := dbFromContext(ctx)
	tpl := templateFromContext(ctx)

	loginToken := loginTokenFromContext(ctx)
	if loginToken == nil {
		http.Redirect(w, req, "/login", http.StatusFound)
		return
	}

	me, err := db.FetchUser(ctx, loginToken.User)
	if err != nil {
		httpError(w, err)
		return
	} else if !me.Admin {
		http.Error(w, "Access denied", http.StatusForbidden)
		return
	}

	client := &Client{Owner: loginToken.User}
	if idStr := chi.URLParam(req, "id"); idStr != "" {
		id, err := ParseID[*Client](idStr)
		if err != nil {
			http.Error(w, err.Error(), http.StatusBadRequest)
			return
		}

		client, err = db.FetchClient(ctx, id)
		if err != nil {
			httpError(w, err)
			return
		}
	}

	if normalized, err := normalizeClientPKCERequirement(client.PKCERequirement); err == nil {
		client.PKCERequirement = normalized
	} else {
		client.PKCERequirement = pkceRequirementNone
	}

	if req.Method != http.MethodPost {
		data := struct {
			TemplateBaseData
			Client *Client
		}{
			Client: client,
		}
		tpl.MustExecuteTemplate(w, "manage-client.html", &data)
		tpl.MustExecuteTemplate(req.Context(), w, "manage-client.html", &data)
		return
	}

	_ = req.ParseForm()
	if _, ok := req.PostForm["delete"]; ok {
		if err := db.DeleteClient(ctx, client.ID); err != nil {
			httpError(w, err)
			return
		}
		http.Redirect(w, req, "/", http.StatusFound)
		return
	}

	_, rotate := req.PostForm["rotate"]

	var isPublic bool
	if client.ID != 0 {
		isPublic = client.IsPublic()
	} else {
		isPublic = req.PostFormValue("client_type") == "public"
	}

	if !rotate {
		client.ClientName = req.PostFormValue("client_name")
		client.ClientURI = req.PostFormValue("client_uri")
		client.RedirectURIs = req.PostFormValue("redirect_uris")

		pkceRequirement := req.PostFormValue("pkce_requirement")
		if !isPublic {
			pkceRequirement = pkceRequirementNone
		}
		normalizedRequirement, err := normalizeClientPKCERequirement(pkceRequirement)
		if err != nil {
			http.Error(w, err.Error(), http.StatusBadRequest)
			return
		}
		client.PKCERequirement = normalizedRequirement

		if err := validateAllowedRedirectURIs(client.RedirectURIs); err != nil {
			// TODO: nicer error message
			http.Error(w, err.Error(), http.StatusBadRequest)
			return
		}
	}

	var clientSecret string
	if client.ID == 0 || rotate {
		clientSecret, err = client.Generate(isPublic)
		if err != nil {
			httpError(w, err)
			return
		}
	}

	if err := db.StoreClient(ctx, client); err != nil {
		httpError(w, err)
		return
	}

	if clientSecret == "" {
		http.Redirect(w, req, "/", http.StatusFound)
		return
	}

	data := struct {
		TemplateBaseData
		ClientID     string
		ClientSecret string
	}{
		ClientID:     client.ClientID,
		ClientSecret: clientSecret,
	}
	tpl.MustExecuteTemplate(w, "client-secret.html", &data)
	tpl.MustExecuteTemplate(req.Context(), w, "client-secret.html", &data)
}

func validateAllowedRedirectURIs(rawRedirectURIs string) error {
	for _, s := range strings.Split(rawRedirectURIs, "\n") {
		if s == "" {
			continue
		}
		u, err := url.Parse(s)
		if err != nil {
			// TODO: nicer error message
			return fmt.Errorf("Invalid redirect URI %q: %v", s, err)
		}
		switch u.Scheme {
		case "https":
			// ok
		case "http":
			if u.Host != "localhost" {
				return fmt.Errorf("Only http://localhost is allowed for insecure HTTP URIs")
			}
			// insecure but let's just trust the admin
		default:
			if !strings.Contains(u.Scheme, ".") {
				return fmt.Errorf("Only private-use URIs referring to domain names are allowed")
			}
		}
	}
	return nil
}

func revokeClient(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	db := dbFromContext(ctx)

	id, err := ParseID[*Client](chi.URLParam(req, "id"))
	if err != nil {
		http.Error(w, err.Error(), http.StatusBadRequest)
		return
	}

	loginToken := loginTokenFromContext(ctx)
	if loginToken == nil {
		http.Redirect(w, req, "/login", http.StatusFound)
		return
	}

	if err := db.RevokeClientUser(ctx, id, loginToken.User); err != nil {
		httpError(w, err)
		return
	}

	http.Redirect(w, req, "/", http.StatusFound)
}
package main

import (
	"context"
	"crypto/subtle"
	"net/http"
	"strings"
)

const (
	csrfCookieName = "vireo-csrf"
	csrfFormField  = "_csrf"
)

func csrfTokenFromContext(ctx context.Context) string {
	if ctx == nil {
		return ""
	}
	if token, ok := ctx.Value(contextKeyCSRFToken).(string); ok {
		return token
	}
	return ""
}

func ensureCSRFToken(w http.ResponseWriter, req *http.Request) (string, error) {
	if cookie, err := req.Cookie(csrfCookieName); err == nil && cookie.Value != "" {
		return cookie.Value, nil
	}

	token, err := generateUID()
	if err != nil {
		return "", err
	}

	http.SetCookie(w, &http.Cookie{
		Name:     csrfCookieName,
		Value:    token,
		Path:     "/",
		HttpOnly: true,
		SameSite: http.SameSiteStrictMode,
		Secure:   isForwardedHTTPS(req),
	})

	return token, nil
}

func csrfMiddleware(next http.Handler) http.Handler {
	cop := http.NewCrossOriginProtection()
	cop.SetDenyHandler(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
		http.Error(w, "Forbidden", http.StatusForbidden)
	}))

	handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
		if isCSRFBypass(req) || req.Header.Get("Authorization") != "" {
			next.ServeHTTP(w, req)
			return
		}
		token, err := ensureCSRFToken(w, req)
		if err != nil {
			httpError(w, err)
			return
		}

		ctx := context.WithValue(req.Context(), contextKeyCSRFToken, token)
		req = req.WithContext(ctx)

		switch req.Method {
		case http.MethodGet, http.MethodHead, http.MethodOptions:
			next.ServeHTTP(w, req)
			return
		}

		if err := req.ParseForm(); err != nil {
			http.Error(w, "Bad request", http.StatusBadRequest)
			return
		}

		formToken := req.PostFormValue(csrfFormField)
		if formToken == "" {
			formToken = req.Header.Get("X-CSRF-Token")
		}
		if formToken == "" || subtle.ConstantTimeCompare([]byte(formToken), []byte(token)) != 1 {
			http.Error(w, "Invalid CSRF token", http.StatusForbidden)
			return
		}

		next.ServeHTTP(w, req)
	})

	return cop.Handler(handler)
}

func isCSRFBypass(req *http.Request) bool {
	path := req.URL.Path
	switch {
	case path == "/token",
		path == "/introspect",
		path == "/revoke",
		path == "/userinfo",
		strings.HasPrefix(path, "/static/"),
		strings.HasPrefix(path, "/.well-known/"),
		path == "/favicon.ico":
		return true
	}
	return false
}
package main

import (
	"context"
	"database/sql"
	_ "embed"
	"fmt"
	"time"

	"github.com/mattn/go-sqlite3"
)

//go:embed schema.sql
var schema string

var migrations = []string{
	"", // migration #0 is reserved for schema initialization
	`
		ALTER TABLE AccessToken ADD COLUMN refresh_hash BLOB;
		ALTER TABLE AccessToken ADD COLUMN refresh_expires_at datetime;
		CREATE UNIQUE INDEX access_token_refresh_hash ON AccessToken(refresh_hash);
	`,
	`
		ALTER TABLE AuthCode ADD COLUMN nonce TEXT;
		CREATE TABLE IF NOT EXISTS SigningKey (
			id INTEGER PRIMARY KEY,
			kid TEXT NOT NULL UNIQUE,
			algorithm TEXT NOT NULL,
			private_key BLOB NOT NULL,
			created_at datetime NOT NULL
		);
	`,
	`
		ALTER TABLE AccessToken ADD COLUMN auth_time datetime;
		ALTER TABLE SigningKey RENAME TO SigningKey_old;
		CREATE TABLE SigningKey (
			id INTEGER PRIMARY KEY,
			kid TEXT NOT NULL UNIQUE,
			algorithm TEXT NOT NULL,
			private_key BLOB NOT NULL,
			created_at datetime NOT NULL
		);
		INSERT INTO SigningKey(id, kid, algorithm, private_key, created_at)
			SELECT id, kid, algorithm, private_key, created_at FROM SigningKey_old;
		DROP TABLE SigningKey_old;
		CREATE INDEX IF NOT EXISTS signing_key_created_at ON SigningKey(created_at);
	`,
	`
		ALTER TABLE User ADD COLUMN email TEXT;
	`,
	`
		ALTER TABLE User ADD COLUMN name TEXT;
	`,
	`
		ALTER TABLE AuthCode ADD COLUMN code_challenge TEXT;
		ALTER TABLE AuthCode ADD COLUMN code_challenge_method TEXT;
		ALTER TABLE Client ADD COLUMN pkce_requirement TEXT;
	`,
}

var errNoDBRows = sql.ErrNoRows

type DB struct {
	db *sql.DB
}

func openDB(filename string) (*DB, error) {
	sqlDB, err := sql.Open("sqlite3", filename+"?cache=shared&_foreign_keys=1")
	if err != nil {
		return nil, err
	}
	sqlDB.SetMaxOpenConns(1)

	db := &DB{sqlDB}
	if err := db.init(context.TODO()); err != nil {
		db.Close()
		return nil, err
	}

	return db, nil
}

func (db *DB) init(ctx context.Context) error {
	version, err := db.upgrade(ctx)
	if err != nil {
		return err
	}

	if version > 0 {
		return nil
	}

	// TODO: drop this
	defaultUser := User{Username: "root", Admin: true}
	defaultUser := User{Username: "root", Name: "Root User", Email: "root@example.invalid", Admin: true}
	if err := defaultUser.SetPassword("root"); err != nil {
		return err
	}
	return db.StoreUser(ctx, &defaultUser)
}

func (db *DB) upgrade(ctx context.Context) (version int, err error) {
	if err := db.db.QueryRowContext(ctx, "PRAGMA user_version").Scan(&version); err != nil {
		return 0, fmt.Errorf("failed to query schema version: %v", err)
	}

	if version == len(migrations) {
		return version, nil
	} else if version > len(migrations) {
		return version, fmt.Errorf("vireo (version %d) older than schema (version %d)", len(migrations), version)
	}

	tx, err := db.db.BeginTx(ctx, nil)
	if err != nil {
		return version, err
	}
	defer tx.Rollback()

	if version == 0 {
		if _, err := tx.ExecContext(ctx, schema); err != nil {
			return version, fmt.Errorf("failed to initialize schema: %v", err)
		}
	} else {
		for i := version; i < len(migrations); i++ {
			if _, err := tx.ExecContext(ctx, migrations[i]); err != nil {
				return version, fmt.Errorf("failed to execute migration #%v: %v", i, err)
			}
		}
	}

	// For some reason prepared statements don't work here
	_, err = tx.ExecContext(ctx, fmt.Sprintf("PRAGMA user_version = %d", len(migrations)))
	if err != nil {
		return version, fmt.Errorf("failed to bump schema version: %v", err)
	}

	return version, tx.Commit()
}

func (db *DB) Close() error {
	return db.db.Close()
}

func (db *DB) FetchUser(ctx context.Context, id ID[*User]) (*User, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM User WHERE id = ?", id)
	if err != nil {
		return nil, err
	}
	var user User
	err = scanRow(&user, rows)
	return &user, err
}

func (db *DB) FetchUserByUsername(ctx context.Context, username string) (*User, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM User WHERE username = ?", username)
	if err != nil {
		return nil, err
	}
	var user User
	err = scanRow(&user, rows)
	return &user, err
}

func (db *DB) StoreUser(ctx context.Context, user *User) error {
	return db.db.QueryRowContext(ctx, `
		INSERT INTO User(id, username, password_hash, admin)
		VALUES (:id, :username, :password_hash, :admin)
		INSERT INTO User(id, username, name, email, password_hash, admin)
		VALUES (:id, :username, :name, :email, :password_hash, :admin)
		ON CONFLICT(id) DO UPDATE SET
			username = :username,
			name = :name,
			email = :email,
			password_hash = :password_hash,
			admin = :admin
		RETURNING id
	`, entityArgs(user)...).Scan(&user.ID)
}

func (db *DB) ListUsers(ctx context.Context) ([]User, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM User")
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var l []User
	for rows.Next() {
		var user User
		if err := scan(&user, rows); err != nil {
			return nil, err
		}
		l = append(l, user)
	}

	return l, rows.Close()
}

func (db *DB) FetchClient(ctx context.Context, id ID[*Client]) (*Client, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM Client WHERE id = ?", id)
	if err != nil {
		return nil, err
	}
	var client Client
	err = scanRow(&client, rows)
	return &client, err
}

func (db *DB) FetchClientByClientID(ctx context.Context, clientID string) (*Client, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM Client WHERE client_id = ?", clientID)
	if err != nil {
		return nil, err
	}
	var client Client
	err = scanRow(&client, rows)
	return &client, err
}

func (db *DB) StoreClient(ctx context.Context, client *Client) error {
	return db.db.QueryRowContext(ctx, `
		INSERT INTO Client(id, client_id, client_secret_hash, owner,
			redirect_uris, client_name, client_uri)
			redirect_uris, client_name, client_uri, pkce_requirement)
		VALUES (:id, :client_id, :client_secret_hash, :owner,
			:redirect_uris, :client_name, :client_uri)
			:redirect_uris, :client_name, :client_uri, :pkce_requirement)
		ON CONFLICT(id) DO UPDATE SET
			client_id = :client_id,
			client_secret_hash = :client_secret_hash,
			owner = :owner,
			redirect_uris = :redirect_uris,
			client_name = :client_name,
			client_uri = :client_uri
			client_uri = :client_uri,
			pkce_requirement = :pkce_requirement
		RETURNING id
	`, entityArgs(client)...).Scan(&client.ID)
}

func (db *DB) ListClients(ctx context.Context, owner ID[*User]) ([]Client, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM Client WHERE owner IS ?", owner)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var l []Client
	for rows.Next() {
		var client Client
		if err := scan(&client, rows); err != nil {
			return nil, err
		}
		l = append(l, client)
	}

	return l, rows.Close()
}

func (db *DB) ListAuthorizedClients(ctx context.Context, user ID[*User]) ([]AuthorizedClient, error) {
	rows, err := db.db.QueryContext(ctx, `
		SELECT id, client_id, client_name, client_uri, token.expires_at
		FROM Client,
		(
			SELECT client, MAX(COALESCE(refresh_expires_at, expires_at)) as expires_at
			FROM AccessToken
			WHERE user = ?
			GROUP BY client
		) AS token
		WHERE Client.id = token.client
	`, user)
	if err != nil {
		return nil, err
	}

	var l []AuthorizedClient
	for rows.Next() {
		var authClient AuthorizedClient
		columns := authClient.Client.columns()
		var expiresAt string
		err := rows.Scan(columns["id"], columns["client_id"], columns["client_name"], columns["client_uri"], &expiresAt)
		if err != nil {
			return nil, err
		}
		authClient.ExpiresAt, err = time.Parse(sqlite3.SQLiteTimestampFormats[0], expiresAt)
		if err != nil {
			return nil, err
		}
		l = append(l, authClient)
	}

	return l, rows.Close()
}

func (db *DB) DeleteClient(ctx context.Context, id ID[*Client]) error {
	_, err := db.db.ExecContext(ctx, "DELETE FROM Client WHERE id = ?", id)
	return err
}

func (db *DB) FetchAccessToken(ctx context.Context, id ID[*AccessToken]) (*AccessToken, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM AccessToken WHERE id = ?", id)
	if err != nil {
		return nil, err
	}
	var token AccessToken
	err = scanRow(&token, rows)
	return &token, err
}

func (db *DB) StoreAccessToken(ctx context.Context, token *AccessToken) error {
	return db.db.QueryRowContext(ctx, `
		INSERT INTO AccessToken(id, hash, user, client, scope, issued_at,
			expires_at, refresh_hash, refresh_expires_at)
			expires_at, auth_time, refresh_hash, refresh_expires_at)
		VALUES (:id, :hash, :user, :client, :scope, :issued_at, :expires_at,
			:refresh_hash, :refresh_expires_at)
			:auth_time, :refresh_hash, :refresh_expires_at)
		ON CONFLICT(id) DO UPDATE SET
			hash = :hash,
			user = :user,
			client = :client,
			scope = :scope,
			issued_at = :issued_at,
			expires_at = :expires_at,
			auth_time = :auth_time,
			refresh_hash = :refresh_hash,
			refresh_expires_at = :refresh_expires_at
		RETURNING id
	`, entityArgs(token)...).Scan(&token.ID)
}

func (db *DB) DeleteAccessToken(ctx context.Context, id ID[*AccessToken]) error {
	_, err := db.db.ExecContext(ctx, "DELETE FROM AccessToken WHERE id = ?", id)
	return err
}

func (db *DB) RevokeClientUser(ctx context.Context, clientID ID[*Client], userID ID[*User]) error {
	tx, err := db.db.BeginTx(ctx, nil)
	if err != nil {
		return err
	}
	defer tx.Rollback()

	_, err = tx.ExecContext(ctx, `
		DELETE FROM AccessToken
		WHERE client = ? AND user = ?
	`, clientID, userID)
	if err != nil {
		return err
	}

	_, err = tx.ExecContext(ctx, `
		DELETE FROM AuthCode
		WHERE client = ? AND user = ?
	`, clientID, userID)
	if err != nil {
		return err
	}

	return tx.Commit()
}

func (db *DB) CreateAuthCode(ctx context.Context, code *AuthCode) error {
	return db.db.QueryRowContext(ctx, `
		INSERT INTO AuthCode(hash, created_at, user, client, scope, redirect_uri)
		VALUES (:hash, :created_at, :user, :client, :scope, :redirect_uri)
		INSERT INTO AuthCode(hash, created_at, user, client, scope, redirect_uri, nonce, code_challenge, code_challenge_method)
		VALUES (:hash, :created_at, :user, :client, :scope, :redirect_uri, :nonce, :code_challenge, :code_challenge_method)
		RETURNING id
	`, entityArgs(code)...).Scan(&code.ID)
}

func (db *DB) PopAuthCode(ctx context.Context, id ID[*AuthCode]) (*AuthCode, error) {
	rows, err := db.db.QueryContext(ctx, `
		DELETE FROM AuthCode
		WHERE id = ?
		RETURNING *
	`, id)
	if err != nil {
		return nil, err
	}
	var authCode AuthCode
	err = scanRow(&authCode, rows)
	return &authCode, err
}

func (db *DB) FetchSigningKeys(ctx context.Context) ([]SigningKey, error) {
	rows, err := db.db.QueryContext(ctx, `
		SELECT * FROM SigningKey
		ORDER BY created_at DESC
	`)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var keys []SigningKey
	for rows.Next() {
		var key SigningKey
		if err := scan(&key, rows); err != nil {
			return nil, err
		}
		keys = append(keys, key)
	}

	if err := rows.Err(); err != nil {
		return nil, err
	}
	if len(keys) == 0 {
		return nil, errNoDBRows
	}
	return keys, nil
}

func (db *DB) StoreSigningKey(ctx context.Context, key *SigningKey) error {
	return db.db.QueryRowContext(ctx, `
		INSERT INTO SigningKey(kid, algorithm, private_key, created_at)
		VALUES (:kid, :algorithm, :private_key, :created_at)
		RETURNING id
	`, sql.Named("kid", key.KID), sql.Named("algorithm", key.Algorithm), sql.Named("private_key", key.PrivateKey), sql.Named("created_at", key.CreatedAt)).Scan(&key.ID)
}

func (db *DB) Maintain(ctx context.Context) error {
	_, err := db.db.ExecContext(ctx, `
		DELETE FROM AccessToken
		WHERE timediff('now', COALESCE(refresh_expires_at, expires_at)) > 0
	`)
	if err != nil {
		return err
	}

	_, err = db.db.ExecContext(ctx, `
		DELETE FROM AuthCode
		WHERE timediff(?, created_at) > 0
	`, time.Now().Add(-authCodeExpiration))
	if err != nil {
		return err
	}

	return nil
}

func scan(e entity, rows *sql.Rows) error {
	columns := e.columns()

	keys, err := rows.Columns()
	if err != nil {
		panic(err)
	}
	out := make([]interface{}, len(keys))
	for i, k := range keys {
		v, ok := columns[k]
		if !ok {
			panic(fmt.Errorf("unknown column %q", k))
		}
		out[i] = v
	}

	return rows.Scan(out...)
}

func scanRow(e entity, rows *sql.Rows) error {
	if !rows.Next() {
		return sql.ErrNoRows
	}
	if err := scan(e, rows); err != nil {
		return err
	}
	return rows.Close()
}

func entityArgs(e entity) []interface{} {
	columns := e.columns()

	l := make([]interface{}, 0, len(columns))
	for k, v := range columns {
		l = append(l, sql.Named(k, v))
	}

	return l
}
package main

import (
	"crypto/rand"
	"crypto/sha512"
	"crypto/subtle"
	"database/sql"
	"database/sql/driver"
	"encoding/base64"
	"fmt"
	"reflect"
	"strconv"
	"strings"
	"time"

	"golang.org/x/crypto/bcrypt"
)

const (
	accessTokenExpiration  = 30 * 24 * time.Hour
	refreshTokenExpiration = 2 * accessTokenExpiration
	authCodeExpiration     = 10 * time.Minute
)

type entity interface {
	columns() map[string]interface{}
}

var (
	_ entity = (*User)(nil)
	_ entity = (*Client)(nil)
	_ entity = (*AccessToken)(nil)
	_ entity = (*AuthCode)(nil)
	_ entity = (*SigningKey)(nil)
)

type ID[T entity] int64

var (
	_ sql.Scanner   = (*ID[*User])(nil)
	_ driver.Valuer = ID[*User](0)
)

func ParseID[T entity](s string) (ID[T], error) {
	u, _ := strconv.ParseUint(s, 10, 63)
	if u == 0 {
		return 0, fmt.Errorf("invalid ID")
	}
	return ID[T](u), nil
}

func (ptr *ID[T]) Scan(v interface{}) error {
	if v == nil {
		*ptr = 0
		return nil
	}
	id, ok := v.(int64)
	if !ok {
		return fmt.Errorf("cannot scan ID from %T", v)
	}
	*ptr = ID[T](id)
	return nil
}

func (id ID[T]) Value() (driver.Value, error) {
	if id == 0 {
		return nil, nil
	} else {
		return int64(id), nil
	}
}

type nullValue struct {
	ptr interface{}
}

var (
	_ sql.Scanner   = nullValue{nil}
	_ driver.Valuer = nullValue{nil}
)

func (nv nullValue) Scan(v interface{}) error {
	out := reflect.ValueOf(nv.ptr).Elem()
	if v == nil {
		out.SetZero()
		return nil
	}

	rv := reflect.ValueOf(v)
	if rv.Type() != out.Type() {
		return fmt.Errorf("cannot scan %v into %v", rv.Type(), out.Type())
	}

	out.Set(rv)
	return nil
}

func (nv nullValue) Value() (driver.Value, error) {
	in := reflect.ValueOf(nv.ptr).Elem()
	if in.IsZero() {
		return nil, nil
	}
	return in.Interface(), nil
}

type User struct {
	ID           ID[*User]
	Username     string
	Name         string
	Email        string
	PasswordHash string
	Admin        bool
}

func (user *User) columns() map[string]interface{} {
	return map[string]interface{}{
		"id":            &user.ID,
		"username":      &user.Username,
		"name":          nullValue{&user.Name},
		"email":         nullValue{&user.Email},
		"password_hash": nullValue{&user.PasswordHash},
		"admin":         &user.Admin,
	}
}

func (user *User) VerifyPassword(password string) error {
	return bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password))
}

func (user *User) SetPassword(password string) error {
	hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
	if err != nil {
		return err
	}
	user.PasswordHash = string(hash)
	return nil
}

func (user *User) PasswordNeedsRehash() bool {
	cost, _ := bcrypt.Cost([]byte(user.PasswordHash))
	return cost != bcrypt.DefaultCost
}

type Client struct {
	ID               ID[*Client]
	ClientID         string
	ClientSecretHash []byte
	Owner            ID[*User]
	RedirectURIs     string
	ClientName       string
	ClientURI        string
	PKCERequirement  string
}

func (client *Client) Generate(isPublic bool) (secret string, err error) {
	id, err := generateUID()
	if err != nil {
		return "", fmt.Errorf("failed to generate client ID: %v", err)
	}
	client.ClientID = id

	if !isPublic {
		var hash []byte
		secret, hash, err = generateSecret()
		if err != nil {
			return "", fmt.Errorf("failed to generate client secret: %v", err)
		}
		client.ClientSecretHash = hash
	}

	return secret, nil
}

func (client *Client) columns() map[string]interface{} {
	return map[string]interface{}{
		"id":                 &client.ID,
		"client_id":          &client.ClientID,
		"client_secret_hash": &client.ClientSecretHash,
		"owner":              &client.Owner,
		"redirect_uris":      nullValue{&client.RedirectURIs},
		"client_name":        nullValue{&client.ClientName},
		"client_uri":         nullValue{&client.ClientURI},
		"pkce_requirement":   nullValue{&client.PKCERequirement},
	}
}

func (client *Client) VerifySecret(secret string) bool {
	return verifyHash(client.ClientSecretHash, secret)
}

func (client *Client) IsPublic() bool {
	return client.ClientSecretHash == nil
}

type AccessToken struct {
	ID        ID[*AccessToken]
	Hash      []byte
	User      ID[*User]
	Client    ID[*Client]
	Scope     string
	IssuedAt  time.Time
	ExpiresAt time.Time
	AuthTime  time.Time

	RefreshHash      []byte
	RefreshExpiresAt time.Time
}

func (token *AccessToken) Generate(expiration time.Duration) (secret string, err error) {
	secret, hash, err := generateSecret()
	if err != nil {
		return "", fmt.Errorf("failed to generate access token secret: %v", err)
	}
	token.Hash = hash
	token.IssuedAt = time.Now()
	token.ExpiresAt = time.Now().Add(expiration)
	return secret, nil
}

func (token *AccessToken) GenerateRefresh() (secret string, err error) {
	secret, hash, err := generateSecret()
	if err != nil {
		return "", fmt.Errorf("failed to generate refresh token secret: %v", err)
	}
	token.RefreshHash = hash
	token.RefreshExpiresAt = time.Now().Add(refreshTokenExpiration)
	return secret, nil
}

func NewAccessTokenFromAuthCode(authCode *AuthCode) *AccessToken {
	return &AccessToken{
		User:   authCode.User,
		Client: authCode.Client,
		Scope:  authCode.Scope,
		User:     authCode.User,
		Client:   authCode.Client,
		Scope:    authCode.Scope,
		AuthTime: authCode.CreatedAt,
	}
}

func (token *AccessToken) columns() map[string]interface{} {
	return map[string]interface{}{
		"id":                 &token.ID,
		"hash":               &token.Hash,
		"user":               &token.User,
		"client":             &token.Client,
		"scope":              nullValue{&token.Scope},
		"issued_at":          &token.IssuedAt,
		"expires_at":         &token.ExpiresAt,
		"auth_time":          nullValue{&token.AuthTime},
		"refresh_hash":       &token.RefreshHash,
		"refresh_expires_at": nullValue{&token.RefreshExpiresAt},
	}
}

func (token *AccessToken) VerifySecret(secret string) bool {
	return verifyHash(token.Hash, secret) && verifyExpiration(token.ExpiresAt)
}

func (token *AccessToken) VerifyRefreshSecret(secret string) bool {
	return verifyHash(token.RefreshHash, secret) && verifyExpiration(token.RefreshExpiresAt)
}

type AuthorizedClient struct {
	Client    Client
	ExpiresAt time.Time
}

type AuthCode struct {
	ID          ID[*AuthCode]
	Hash        []byte
	CreatedAt   time.Time
	User        ID[*User]
	Client      ID[*Client]
	Scope       string
	RedirectURI string
	ID                  ID[*AuthCode]
	Hash                []byte
	CreatedAt           time.Time
	User                ID[*User]
	Client              ID[*Client]
	Scope               string
	RedirectURI         string
	Nonce               string
	CodeChallenge       string
	CodeChallengeMethod string
}

func (code *AuthCode) Generate() (secret string, err error) {
	secret, hash, err := generateSecret()
	if err != nil {
		return "", fmt.Errorf("failed to generate authentication code secret: %v", err)
	}
	code.Hash = hash
	code.CreatedAt = time.Now()
	return secret, nil
}

func (code *AuthCode) columns() map[string]interface{} {
	return map[string]interface{}{
		"id":           &code.ID,
		"hash":         &code.Hash,
		"created_at":   &code.CreatedAt,
		"user":         &code.User,
		"client":       &code.Client,
		"scope":        nullValue{&code.Scope},
		"redirect_uri": nullValue{&code.RedirectURI},
		"id":                    &code.ID,
		"hash":                  &code.Hash,
		"created_at":            &code.CreatedAt,
		"user":                  &code.User,
		"client":                &code.Client,
		"scope":                 nullValue{&code.Scope},
		"redirect_uri":          nullValue{&code.RedirectURI},
		"nonce":                 nullValue{&code.Nonce},
		"code_challenge":        nullValue{&code.CodeChallenge},
		"code_challenge_method": nullValue{&code.CodeChallengeMethod},
	}
}

func (code *AuthCode) VerifySecret(secret string) bool {
	return verifyHash(code.Hash, secret) && verifyExpiration(code.CreatedAt.Add(authCodeExpiration))
}

type SecretKind byte

const (
	SecretKindAccessToken  = SecretKind('a')
	SecretKindRefreshToken = SecretKind('r')
	SecretKindAuthCode     = SecretKind('c')
)

type SigningKey struct {
	ID         ID[*SigningKey]
	KID        string
	Algorithm  string
	PrivateKey []byte
	CreatedAt  time.Time
}

func (key *SigningKey) columns() map[string]interface{} {
	return map[string]interface{}{
		"id":          &key.ID,
		"kid":         &key.KID,
		"algorithm":   &key.Algorithm,
		"private_key": &key.PrivateKey,
		"created_at":  &key.CreatedAt,
	}
}

func UnmarshalSecret[T entity](s string) (id ID[T], secret string, err error) {
	kind, s, _ := strings.Cut(s, ".")
	idStr, secret, ok := strings.Cut(s, ".")
	if !ok || len(kind) != 1 {
		return 0, "", fmt.Errorf("malformed secret")
	}

	switch SecretKind(kind[0]) {
	case SecretKindAccessToken, SecretKindRefreshToken:
		_, ok = interface{}(id).(ID[*AccessToken])
	case SecretKindAuthCode:
		_, ok = interface{}(id).(ID[*AuthCode])
	}
	if !ok {
		return 0, "", fmt.Errorf("invalid secret kind %q", kind)
	}

	id, err = ParseID[T](idStr)
	return id, secret, err
}

func MarshalSecret[T entity](id ID[T], kind SecretKind, secret string) string {
	if id == 0 {
		panic("cannot marshal zero ID")
	}

	var ok bool
	switch interface{}(id).(type) {
	case ID[*AccessToken]:
		ok = kind == SecretKindAccessToken || kind == SecretKindRefreshToken
	case ID[*AuthCode]:
		ok = kind == SecretKindAuthCode
	}
	if !ok {
		panic(fmt.Sprintf("unsupported secret kind %q for ID type %T", string(kind), id))
	}

	return fmt.Sprintf("%v.%v.%v", string(kind), int64(id), secret)
}

func generateUID() (string, error) {
	b := make([]byte, 32)
	if _, err := rand.Read(b); err != nil {
		return "", err
	}
	return base64.RawURLEncoding.EncodeToString(b), nil
}

func generateSecret() (secret string, hash []byte, err error) {
	b := make([]byte, 32)
	if _, err := rand.Read(b); err != nil {
		return "", nil, err
	}
	secret = base64.RawURLEncoding.EncodeToString(b)
	h := sha512.Sum512(b)
	return secret, h[:], nil
}

func verifyHash(hash []byte, secret string) bool {
	b, _ := base64.RawURLEncoding.DecodeString(secret)
	h := sha512.Sum512(b)
	return subtle.ConstantTimeCompare(hash, h[:]) == 1
}

func verifyExpiration(t time.Time) bool {
	return time.Now().Before(t)
}
module vireo

go 1.24.0

require (
	codeberg.org/emersion/go-scfg v0.1.0
	github.com/emersion/go-oauth2 v0.0.0-20250228145955-eaead4148315
	github.com/go-chi/chi/v5 v5.2.3
	github.com/golang-jwt/jwt/v5 v5.2.1
	github.com/mattn/go-sqlite3 v1.14.32
	golang.org/x/crypto v0.42.0
)
codeberg.org/emersion/go-scfg v0.1.0 h1:6dnGU0ZI4gX+O5rMjwhoaySItzHG710eXL5TIQKl+uM=
codeberg.org/emersion/go-scfg v0.1.0/go.mod h1:0nooW1ufBB4SlJEdTtiVN9Or+bnNM1icOkQ6Tbrq6O0=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/emersion/go-oauth2 v0.0.0-20250228145955-eaead4148315 h1:sXzwA8yItbg3ji0UuTLkuO4NKPqQJjC035hPoZI40h8=
github.com/emersion/go-oauth2 v0.0.0-20250228145955-eaead4148315/go.mod h1:pSj8CBn/jb+ynRxt/ESIJisazza/Sh2DjwUn31l2tI0=
github.com/go-chi/chi/v5 v5.2.1 h1:KOIHODQj58PmL80G2Eak4WdvUzjSJSm0vG72crDCqb8=
github.com/go-chi/chi/v5 v5.2.1/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops=
github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE=
github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops=
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
package main

import (
	"context"
	"embed"
	"flag"
	"log"
	"net"
	"net/http"
	"time"

	"github.com/go-chi/chi/v5"
)

var (
	//go:embed template
	templateFS embed.FS
	//go:embed static
	staticFS embed.FS
)

func main() {
	var configFilename string
	flag.StringVar(&configFilename, "config", "/etc/lindenii/vireo/config", "Configuration filename")
	flag.Parse()

	cfg, err := loadConfig(configFilename)
	if err != nil {
		log.Fatalf("Failed to load config file: %v", err)
	}

	listenAddr := cfg.Listen
	if cfg.Database == "" {
		log.Fatalf("Missing database configuration")
	}

	db, err := openDB(cfg.Database)
	if err != nil {
		log.Fatalf("Failed to open DB: %v", err)
	}

	tplBaseData := &TemplateBaseData{
		ServerName: cfg.ServerName,
	}
	if tplBaseData.ServerName == "" {
		tplBaseData.ServerName = "vireo"
	}
	tpl, err := loadTemplate(templateFS, "template/*.html", tplBaseData)
	if err != nil {
		log.Fatalf("Failed to load template: %v", err)
	}

	oidcProvider, err := newOIDCProvider(context.Background(), db)
	if err != nil {
		log.Fatalf("Failed to initialize OpenID Connect provider: %v", err)
	}

	mux := chi.NewRouter()
	mux.Handle("/static/*", http.FileServer(http.FS(staticFS)))
	mux.Get("/", index)
	mux.HandleFunc("/login", login)
	mux.Post("/logout", logout)
	mux.HandleFunc("/client/new", manageClient)
	mux.HandleFunc("/client/{id}", manageClient)
	mux.Post("/client/{id}/revoke", revokeClient)
	mux.HandleFunc("/user/new", manageUser)
	mux.HandleFunc("/user/{id}", manageUser)
	mux.Get("/.well-known/oauth-authorization-server", getOAuthServerMetadata)
	mux.Get("/.well-known/openid-configuration", getOpenIDConfiguration)
	mux.Get("/.well-known/jwks.json", getOIDCJWKS)
	mux.HandleFunc("/authorize", authorize)
	mux.Post("/token", exchangeToken)
	mux.Post("/introspect", introspectToken)
	mux.Post("/revoke", revokeToken)
	mux.HandleFunc("/userinfo", userInfo)

	go maintainDBLoop(db)

	server := http.Server{
		Addr:    listenAddr,
		Handler: loginTokenMiddleware(mux),
		Handler: csrfMiddleware(loginTokenMiddleware(mux)),
		BaseContext: func(net.Listener) context.Context {
			return newBaseContext(db, tpl)
			return newBaseContext(db, tpl, oidcProvider)
		},
	}
	log.Printf("OAuth server listening on %v", server.Addr)
	if err := server.ListenAndServe(); err != nil {
		log.Fatalf("Failed to listen and serve: %v", err)
	}
}

func httpError(w http.ResponseWriter, err error) {
	log.Print(err)
	http.Error(w, "Internal server error", http.StatusInternalServerError)
}

func maintainDBLoop(db *DB) {
	ticker := time.NewTicker(15 * time.Minute)
	defer ticker.Stop()

	for range ticker.C {
		ctx := context.Background()
		ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
		if err := db.Maintain(ctx); err != nil {
			log.Printf("Failed to perform database maintenance: %v", err)
		}
		cancel()
	}
}
package main

import (
	"context"
	"fmt"
	"html/template"
	"io"
	"io/fs"
	"mime"
	"net/http"
)

const (
	loginCookieName    = "vireo-login"
	internalTokenScope = "_vireo"
)

type contextKey string

const (
	contextKeyDB         = "db"
	contextKeyTemplate   = "template"
	contextKeyLoginToken = "login-token"
	contextKeyOIDC       = "oidc"
	contextKeyCSRFToken  = "csrf-token"
)

func dbFromContext(ctx context.Context) *DB {
	return ctx.Value(contextKeyDB).(*DB)
}

func templateFromContext(ctx context.Context) *Template {
	return ctx.Value(contextKeyTemplate).(*Template)
}

func oidcProviderFromContext(ctx context.Context) *OIDCProvider {
	return ctx.Value(contextKeyOIDC).(*OIDCProvider)
}

func loginTokenFromContext(ctx context.Context) *AccessToken {
	v := ctx.Value(contextKeyLoginToken)
	if v == nil {
		return nil
	}
	return v.(*AccessToken)
}

func newBaseContext(db *DB, tpl *Template) context.Context {
func newBaseContext(db *DB, tpl *Template, oidc *OIDCProvider) context.Context {
	ctx := context.Background()
	ctx = context.WithValue(ctx, contextKeyDB, db)
	ctx = context.WithValue(ctx, contextKeyTemplate, tpl)
	ctx = context.WithValue(ctx, contextKeyOIDC, oidc)
	return ctx
}

func setLoginTokenCookie(w http.ResponseWriter, req *http.Request, token *AccessToken, secret string) {
	http.SetCookie(w, &http.Cookie{
		Name:     loginCookieName,
		Value:    MarshalSecret(token.ID, SecretKindAccessToken, secret),
		HttpOnly: true,
		SameSite: http.SameSiteStrictMode,
		SameSite: http.SameSiteLaxMode,
		Secure:   isForwardedHTTPS(req),
	})
}

func unsetLoginTokenCookie(w http.ResponseWriter, req *http.Request) {
	http.SetCookie(w, &http.Cookie{
		Name:     loginCookieName,
		HttpOnly: true,
		SameSite: http.SameSiteStrictMode,
		SameSite: http.SameSiteLaxMode,
		Secure:   isForwardedHTTPS(req),
		MaxAge:   -1,
	})
}

func isForwardedHTTPS(req *http.Request) bool {
	if forwarded := req.Header.Get("Forwarded"); forwarded != "" {
		_, params, _ := mime.ParseMediaType("_; " + forwarded)
		return params["proto"] == "https"
	}
	if forwardedProto := req.Header.Get("X-Forwarded-Proto"); forwardedProto != "" {
		return forwardedProto == "https"
	}
	return false
}

func loginTokenMiddleware(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
		cookie, _ := req.Cookie(loginCookieName)
		if cookie == nil {
			next.ServeHTTP(w, req)
			return
		}

		ctx := req.Context()
		db := dbFromContext(ctx)
		tokenID, tokenSecret, _ := UnmarshalSecret[*AccessToken](cookie.Value)
		token, err := db.FetchAccessToken(ctx, tokenID)
		if err == errNoDBRows || (err == nil && !token.VerifySecret(tokenSecret)) {
			unsetLoginTokenCookie(w, req)
			next.ServeHTTP(w, req)
			return
		} else if err != nil {
			httpError(w, fmt.Errorf("failed to fetch access token: %v", err))
			return
		}

		if token.Scope != internalTokenScope {
			http.Error(w, "Invalid login token scope", http.StatusForbidden)
			return
		}
		if token.User == 0 {
			panic("login token with zero user ID")
		}

		ctx = context.WithValue(ctx, contextKeyLoginToken, token)
		req = req.WithContext(ctx)
		next.ServeHTTP(w, req)
	})
}

type TemplateBaseData struct {
	ServerName string
	CSRFToken  string
}

func (data *TemplateBaseData) Base() *TemplateBaseData {
	return data
}

type TemplateData interface {
	Base() *TemplateBaseData
}

type Template struct {
	tpl      *template.Template
	baseData *TemplateBaseData
}

func loadTemplate(fs fs.FS, pattern string, baseData *TemplateBaseData) (*Template, error) {
	tpl, err := template.ParseFS(fs, pattern)
	if err != nil {
		return nil, err
	}
	return &Template{tpl: tpl, baseData: baseData}, nil
}

func (tpl *Template) MustExecuteTemplate(w io.Writer, filename string, data TemplateData) {
func (tpl *Template) MustExecuteTemplate(ctx context.Context, w io.Writer, filename string, data TemplateData) {
	baseCopy := *tpl.baseData
	if token := csrfTokenFromContext(ctx); token != "" {
		baseCopy.CSRFToken = token
	}
	if data == nil {
		data = tpl.baseData
		base := baseCopy
		data = &base
	} else {
		*data.Base() = *tpl.baseData
		*data.Base() = baseCopy
	}
	if err := tpl.tpl.ExecuteTemplate(w, filename, data); err != nil {
		panic(err)
	}
}
package main

import (
	"crypto/sha256"
	"encoding/base64"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"log"
	"mime"
	"net"
	"net/http"
	"net/url"
	"strings"
	"time"

	"github.com/emersion/go-oauth2"
)

const (
	scopeOpenID        = "openid"
	scopeProfile       = "profile"
	scopeEmail         = "email"
	scopeOfflineAccess = "offline_access"
	pkceMethodPlain    = "plain"
	pkceMethodS256     = "S256"
)

var allowedScopes = map[string]struct{}{
	scopeOpenID:        {},
	scopeProfile:       {},
	scopeEmail:         {},
	scopeOfflineAccess: {},
}

type oidcTokenResponse struct {
	AccessToken  string           `json:"access_token"`
	TokenType    oauth2.TokenType `json:"token_type"`
	ExpiresIn    int64            `json:"expires_in,omitempty"`
	RefreshToken string           `json:"refresh_token,omitempty"`
	Scope        string           `json:"scope,omitempty"`
	IDToken      string           `json:"id_token,omitempty"`
}

func parseScopes(scope string) []string {
	if scope == "" {
		return nil
	}
	parts := strings.Fields(scope)
	var scopes []string
	seen := make(map[string]struct{}, len(parts))
	for _, p := range parts {
		if p == "" {
			continue
		}
		p = strings.ToLower(p)
		if _, ok := seen[p]; ok {
			continue
		}
		seen[p] = struct{}{}
		scopes = append(scopes, p)
	}
	return scopes
}

func normalizeScope(scope string) (string, []string) {
	scopes := parseScopes(scope)
	if len(scopes) == 0 {
		return "", nil
	}
	return strings.Join(scopes, " "), scopes
}

func validateScopes(scopes []string) error {
	for _, scope := range scopes {
		if _, ok := allowedScopes[scope]; !ok {
			return fmt.Errorf("unsupported scope %q", scope)
		}
	}
	return nil
}

func normalizeCodeChallengeMethod(method string) (string, error) {
	if method == "" {
		return pkceMethodPlain, nil
	}
	switch {
	case strings.EqualFold(method, pkceMethodPlain):
		return pkceMethodPlain, nil
	case strings.EqualFold(method, pkceMethodS256):
		return pkceMethodS256, nil
	default:
		return "", fmt.Errorf("unsupported code_challenge_method")
	}
}

func validateCodeVerifier(verifier string) error {
	if verifier == "" {
		return fmt.Errorf("missing code_verifier")
	}
	if len(verifier) < 43 || len(verifier) > 128 {
		return fmt.Errorf("invalid code_verifier length")
	}
	for _, r := range verifier {
		if !(r >= 'A' && r <= 'Z' || r >= 'a' && r <= 'z' || r >= '0' && r <= '9' || r == '-' || r == '.' || r == '_' || r == '~') {
			return fmt.Errorf("invalid character in code_verifier")
		}
	}
	return nil
}

func validateCodeChallenge(method, challenge string) error {
	if challenge == "" {
		return fmt.Errorf("missing code_challenge")
	}
	if err := validateCodeVerifier(challenge); err != nil {
		return err
	}
	if method != pkceMethodPlain && method != pkceMethodS256 {
		return fmt.Errorf("unsupported code_challenge_method")
	}
	return nil
}

func verifyCodeVerifier(method, challenge, verifier string) error {
	if err := validateCodeVerifier(verifier); err != nil {
		return err
	}
	switch method {
	case "", pkceMethodPlain:
		if challenge != verifier {
			return fmt.Errorf("code_verifier mismatch")
		}
	case pkceMethodS256:
		hash := sha256.Sum256([]byte(verifier))
		expected := base64.RawURLEncoding.EncodeToString(hash[:])
		if expected != challenge {
			return fmt.Errorf("code_verifier mismatch")
		}
	default:
		return fmt.Errorf("unsupported code_challenge_method")
	}
	return nil
}

func getOAuthServerMetadata(w http.ResponseWriter, req *http.Request) {
	issuer := getIssuer(req)

	w.Header().Set("Content-Type", "application/json")
	json.NewEncoder(w).Encode(&oauth2.ServerMetadata{
		Issuer:                                     issuer,
		AuthorizationEndpoint:                      issuer + "/authorize",
		TokenEndpoint:                              issuer + "/token",
		IntrospectionEndpoint:                      issuer + "/introspect",
		RevocationEndpoint:                         issuer + "/revoke",
		ResponseTypesSupported:                     []oauth2.ResponseType{oauth2.ResponseTypeCode},
		ResponseModesSupported:                     []oauth2.ResponseMode{oauth2.ResponseModeQuery},
		GrantTypesSupported:                        []oauth2.GrantType{oauth2.GrantTypeAuthorizationCode},
		TokenEndpointAuthMethodsSupported:          []oauth2.AuthMethod{oauth2.AuthMethodNone, oauth2.AuthMethodClientSecretBasic},
		Issuer:                            issuer,
		AuthorizationEndpoint:             issuer + "/authorize",
		TokenEndpoint:                     issuer + "/token",
		IntrospectionEndpoint:             issuer + "/introspect",
		RevocationEndpoint:                issuer + "/revoke",
		JWKSURI:                           issuer + "/.well-known/jwks.json",
		ScopesSupported:                   []string{scopeOpenID, scopeProfile, scopeEmail, scopeOfflineAccess},
		ResponseTypesSupported:            []oauth2.ResponseType{oauth2.ResponseTypeCode},
		ResponseModesSupported:            []oauth2.ResponseMode{oauth2.ResponseModeQuery},
		GrantTypesSupported:               []oauth2.GrantType{oauth2.GrantTypeAuthorizationCode, oauth2.GrantTypeRefreshToken},
		TokenEndpointAuthMethodsSupported: []oauth2.AuthMethod{oauth2.AuthMethodNone, oauth2.AuthMethodClientSecretBasic},
		IntrospectionEndpointAuthMethodsSupported:  []oauth2.AuthMethod{oauth2.AuthMethodNone, oauth2.AuthMethodClientSecretBasic},
		RevocationEndpointAuthMethodsSupported:     []oauth2.AuthMethod{oauth2.AuthMethodNone, oauth2.AuthMethodClientSecretBasic},
		CodeChallengeMethodsSupported:              []string{pkceMethodPlain, pkceMethodS256},
		AuthorizationResponseIssParameterSupported: true,
	})
}

func getIssuer(req *http.Request) string {
	issuerURL := url.URL{
		Scheme: "https",
		Host:   req.Host,
	}
	if !isForwardedHTTPS(req) && isLoopback(req) {
		// TODO: add config option for allowed reverse proxy IPs
		issuerURL.Scheme = "http"
	}
	return issuerURL.String()
}

func isLoopback(req *http.Request) bool {
	host, _, _ := net.SplitHostPort(req.RemoteAddr)
	ip := net.ParseIP(host)
	return ip.IsLoopback()
}

func authorize(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	db := dbFromContext(ctx)
	tpl := templateFromContext(ctx)

	q := req.URL.Query()
	respType := oauth2.ResponseType(q.Get("response_type"))
	clientID := q.Get("client_id")
	rawRedirectURI := q.Get("redirect_uri")
	scope := q.Get("scope")
	state := q.Get("state")
	_, stateProvided := q["state"]
	codeChallenge := q.Get("code_challenge")
	codeChallengeMethod := q.Get("code_challenge_method")

	var normalizedCodeChallengeMethod string
	nonce := q.Get("nonce")

	if clientID == "" {
		http.Error(w, "Missing client ID", http.StatusBadRequest)
		return
	}

	client, err := db.FetchClientByClientID(ctx, clientID)
	if err == errNoDBRows {
		http.Error(w, "Invalid client ID", http.StatusForbidden)
		return
	} else if err != nil {
		httpError(w, fmt.Errorf("failed to fetch client: %v", err))
		return
	}

	requiredPKCE, err := normalizeClientPKCERequirement(client.PKCERequirement)
	if err != nil {
		httpError(w, fmt.Errorf("invalid PKCE requirement configuration: %v", err))
		return
	}

	var allowedRedirectURIs []*url.URL
	for _, s := range strings.Split(client.RedirectURIs, "\n") {
		if s == "" {
			continue
		}
		u, err := url.Parse(s)
		if err != nil {
			httpError(w, fmt.Errorf("failed to parse client redirect URI"))
			return
		}
		allowedRedirectURIs = append(allowedRedirectURIs, u)
	}

	var redirectURI *url.URL
	if rawRedirectURI != "" {
		redirectURI, err = url.Parse(rawRedirectURI)
		if err != nil {
			http.Error(w, "Invalid redirect URI", http.StatusBadRequest)
			return
		}
		if !validateRedirectURI(redirectURI, allowedRedirectURIs) {
			http.Error(w, "Invalid redirect URI", http.StatusBadRequest)
			return
		}
	} else {
		if len(allowedRedirectURIs) == 0 {
			http.Error(w, "Missing redirect URI", http.StatusBadRequest)
			return
		}
		redirectURI = allowedRedirectURIs[0]
	}

	if codeChallenge != "" {
		method, err := normalizeCodeChallengeMethod(codeChallengeMethod)
		if err != nil {
			redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidRequest,
				Description: err.Error(),
			})
			return
		}
		if err := validateCodeChallenge(method, codeChallenge); err != nil {
			redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidRequest,
				Description: err.Error(),
			})
			return
		}
		normalizedCodeChallengeMethod = method
	}

	if codeChallenge == "" && codeChallengeMethod != "" {
		redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidRequest,
			Description: "code_challenge_method without code_challenge",
		})
		return
	}

	switch requiredPKCE {
	case pkceRequirementPlain:
		if codeChallenge == "" {
			redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidRequest,
				Description: "PKCE is required",
			})
			return
		}
		if !allowPKCERequirement(normalizedCodeChallengeMethod, pkceRequirementPlain) {
			redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidRequest,
				Description: "PKCE method does not satisfy requirement",
			})
			return
		}
	case pkceRequirementS256:
		if codeChallenge == "" {
			redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidRequest,
				Description: "PKCE (S256) is required",
			})
			return
		}
		if !allowPKCERequirement(normalizedCodeChallengeMethod, pkceRequirementS256) {
			redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidRequest,
				Description: "PKCE (S256) is required",
			})
			return
		}
	}

	codeChallengeMethod = normalizedCodeChallengeMethod

	if respType != oauth2.ResponseTypeCode {
		redirectClientError(w, req, redirectURI, state, &oauth2.Error{
		redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
			Code: oauth2.ErrorCodeUnsupportedResponseType,
		})
		return
	}

	// TODO: add support for scope
	if scope != "" {
		redirectClientError(w, req, redirectURI, state, &oauth2.Error{
			Code: oauth2.ErrorCodeInvalidScope,
	normalizedScope, scopes := normalizeScope(scope)
	if len(scopes) == 0 {
		redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidScope,
			Description: "Missing required openid scope",
		})
		return
	}
	if err := validateScopes(scopes); err != nil {
		redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidScope,
			Description: err.Error(),
		})
		return
	}
	if !containsScope(scopes, scopeOpenID) {
		redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidScope,
			Description: "Scope openid is required",
		})
		return
	}
	scope = normalizedScope

	loginToken := loginTokenFromContext(ctx)
	if loginToken == nil {
		q := make(url.Values)
		q.Set("redirect_uri", req.URL.String())
		u := url.URL{
			Path:     "/login",
			RawQuery: q.Encode(),
		}
		http.Redirect(w, req, u.String(), http.StatusFound)
		return
	}

	_ = req.ParseForm()
	if _, ok := req.PostForm["deny"]; ok {
		redirectClientError(w, req, redirectURI, state, &oauth2.Error{
		redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
			Code: oauth2.ErrorCodeAccessDenied,
		})
		return
	}
	if _, ok := req.PostForm["authorize"]; !ok {
		data := struct {
			TemplateBaseData
			Client *Client
		}{
			Client: client,
		}
		tpl.MustExecuteTemplate(w, "authorize.html", &data)
		tpl.MustExecuteTemplate(req.Context(), w, "authorize.html", &data)
		return
	}

	authCode := AuthCode{
		User:        loginToken.User,
		Client:      client.ID,
		Scope:       scope,
		RedirectURI: rawRedirectURI,
		User:                loginToken.User,
		Client:              client.ID,
		Scope:               scope,
		RedirectURI:         rawRedirectURI,
		Nonce:               nonce,
		CodeChallenge:       codeChallenge,
		CodeChallengeMethod: codeChallengeMethod,
	}
	secret, err := authCode.Generate()
	if err != nil {
		httpError(w, fmt.Errorf("failed to generate authentication code: %v", err))
		return
	}

	if err := db.CreateAuthCode(ctx, &authCode); err != nil {
		httpError(w, fmt.Errorf("failed to create authentication code: %v", err))
		return
	}

	code := MarshalSecret(authCode.ID, SecretKindAuthCode, secret)

	values := make(url.Values)
	values.Set("code", code)
	if state != "" {
	if stateProvided {
		values.Set("state", state)
	}
	redirectClient(w, req, redirectURI, values)
}

func exchangeToken(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	db := dbFromContext(ctx)

	values, err := parseRequestBody(req)
	if err != nil {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidRequest,
			Description: err.Error(),
		})
		return
	}

	clientID := values.Get("client_id")
	grantType := oauth2.GrantType(values.Get("grant_type"))
	scope := values.Get("scope")
	codeVerifier := values.Get("code_verifier")

	authClientID, clientSecret, _ := req.BasicAuth()
	if clientID == "" {
		clientID = authClientID
	} else if clientID != authClientID {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidRequest,
			Description: "Client ID in request body doesn't match Authorization header field",
		})
		return
	}

	var client *Client
	if clientID != "" {
		client, err = db.FetchClientByClientID(ctx, clientID)
		if err == errNoDBRows {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidClient,
				Description: "Invalid client ID",
			})
			return
		} else if err != nil {
			oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
			return
		}

		if !client.IsPublic() && !client.VerifySecret(clientSecret) {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid client secret",
			})
			return
		}
	}

	var token *AccessToken
	var (
		token             *AccessToken
		authorizationCode *AuthCode
		currentClient     *Client
		nonceValue        string
	)

	switch grantType {
	case oauth2.GrantTypeAuthorizationCode:
		if client == nil {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidRequest,
				Description: "Missing client ID",
			})
			return
		}

		codeID, codeSecret, _ := UnmarshalSecret[*AuthCode](values.Get("code"))
		authCode, err := db.PopAuthCode(ctx, codeID)
		if err == errNoDBRows || (err == nil && !authCode.VerifySecret(codeSecret)) || authCode.Client != client.ID {
		authorizationCode, err = db.PopAuthCode(ctx, codeID)
		if err == errNoDBRows || (err == nil && !authorizationCode.VerifySecret(codeSecret)) || authorizationCode.Client != client.ID {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid authorization code",
			})
			return
		} else if err != nil {
			oauthError(w, fmt.Errorf("failed to fetch authorization code: %v", err))
			return
		}

		if scope != authCode.Scope {
		if scope != "" && scope != authorizationCode.Scope {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid scope",
			})
			return
		}
		if values.Get("redirect_uri") != authCode.RedirectURI {
		if values.Get("redirect_uri") != authorizationCode.RedirectURI {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid redirect URI",
			})
			return
		}
		if authorizationCode.CodeChallenge != "" {
			if codeVerifier == "" {
				oauthError(w, &oauth2.Error{
					Code:        oauth2.ErrorCodeInvalidRequest,
					Description: "Missing code_verifier",
				})
				return
			}
			if err := verifyCodeVerifier(authorizationCode.CodeChallengeMethod, authorizationCode.CodeChallenge, codeVerifier); err != nil {
				oauthError(w, &oauth2.Error{
					Code:        oauth2.ErrorCodeInvalidGrant,
					Description: "Invalid code_verifier",
				})
				return
			}
		} else if codeVerifier != "" {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidRequest,
				Description: "Unexpected code_verifier",
			})
			return
		}

		token = NewAccessTokenFromAuthCode(authCode)
		token = NewAccessTokenFromAuthCode(authorizationCode)
		currentClient = client
		nonceValue = authorizationCode.Nonce
	case oauth2.GrantTypeRefreshToken:
		tokenID, refreshSecret, _ := UnmarshalSecret[*AccessToken](values.Get("refresh_token"))
		token, err = db.FetchAccessToken(ctx, tokenID)
		if err == errNoDBRows || (err == nil && !token.VerifyRefreshSecret(refreshSecret)) {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid refresh token",
			})
			return
		} else if err != nil {
			oauthError(w, fmt.Errorf("failed to fetch access token: %v", err))
			return
		}

		if client != nil && client.ID != token.Client {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid refresh token",
			})
			return
		}

		tokenClient, err := db.FetchClient(ctx, token.Client)
		currentClient, err = db.FetchClient(ctx, token.Client)
		if err != nil {
			oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
			return
		}

		if !tokenClient.IsPublic() && client == nil {
		if !currentClient.IsPublic() && client == nil {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid client secret",
			})
			return
		}

		if scope != token.Scope {
		if scope != "" && scope != token.Scope {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid scope",
			})
			return
		}
	default:
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeUnsupportedGrantType,
			Description: "Unsupported grant type",
		})
	}

	secret, err := token.Generate(accessTokenExpiration)
	if err != nil {
		oauthError(w, err)
		return
	}
	refreshSecret, err := token.GenerateRefresh()
	if err != nil {
		oauthError(w, err)
		return

	tokenScopes := parseScopes(token.Scope)
	issueRefresh := containsScope(tokenScopes, scopeOfflineAccess)
	var refreshSecret string
	if issueRefresh {
		refreshSecret, err = token.GenerateRefresh()
		if err != nil {
			oauthError(w, err)
			return
		}
	} else {
		token.RefreshHash = nil
		token.RefreshExpiresAt = time.Time{}
	}

	if err := db.StoreAccessToken(ctx, token); err != nil {
		oauthError(w, fmt.Errorf("failed to create access token: %v", err))
		return
	}

	accessTokenValue := MarshalSecret(token.ID, SecretKindAccessToken, secret)
	if token.AuthTime.IsZero() {
		token.AuthTime = token.IssuedAt
	}

	var idToken string
	if containsScope(tokenScopes, scopeOpenID) {
		if currentClient == nil {
			currentClient, err = db.FetchClient(ctx, token.Client)
			if err != nil {
				oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
				return
			}
		}
		user, err := db.FetchUser(ctx, token.User)
		if err != nil {
			oauthError(w, fmt.Errorf("failed to fetch user: %v", err))
			return
		}

		issuer := getIssuer(req)
		oidcProvider := oidcProviderFromContext(ctx)
		idToken, err = oidcProvider.MintIDToken(issuer, currentClient, user, token, tokenScopes, nonceValue, accessTokenValue, token.AuthTime)
		if err != nil {
			oauthError(w, fmt.Errorf("failed to mint ID token: %v", err))
			return
		}
	}

	w.Header().Set("Content-Type", "application/json")
	w.Header().Set("Cache-Control", "no-store")
	json.NewEncoder(w).Encode(&oauth2.TokenResp{
		AccessToken:  MarshalSecret(token.ID, SecretKindAccessToken, secret),
		TokenType:    oauth2.TokenTypeBearer,
		ExpiresIn:    time.Until(token.ExpiresAt),
		Scope:        strings.Split(token.Scope, " "),
		RefreshToken: MarshalSecret(token.ID, SecretKindRefreshToken, refreshSecret),
	})
	resp := oidcTokenResponse{
		AccessToken: accessTokenValue,
		TokenType:   oauth2.TokenTypeBearer,
		ExpiresIn:   int64(time.Until(token.ExpiresAt).Seconds()),
		Scope:       token.Scope,
		IDToken:     idToken,
	}
	if issueRefresh {
		refreshTokenValue := MarshalSecret(token.ID, SecretKindRefreshToken, refreshSecret)
		resp.RefreshToken = refreshTokenValue
	}
	json.NewEncoder(w).Encode(&resp)
}

func introspectToken(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	db := dbFromContext(ctx)

	values, err := parseRequestBody(req)
	if err != nil {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidRequest,
			Description: err.Error(),
		})
		return
	}

	client, err := maybeAuthenticateClient(w, req)
	if err != nil {
		oauthError(w, err)
		return
	}

	tokenID, secret, _ := UnmarshalSecret[*AccessToken](values.Get("token"))
	token, err := db.FetchAccessToken(ctx, tokenID)
	if err == errNoDBRows || (err == nil && !token.VerifySecret(secret)) {
		token = nil
	} else if err != nil {
		oauthError(w, fmt.Errorf("failed to fetch access token: %v", err))
		return
	}

	var resp oauth2.IntrospectionResp
	if token != nil {
		if client == nil {
			client, err = db.FetchClient(ctx, token.Client)
			if err != nil {
				oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
				return
			}

			if !client.IsPublic() {
				oauthError(w, &oauth2.Error{
					Code:        oauth2.ErrorCodeInvalidClient,
					Description: "Missing client ID and secret",
				})
				return
			}
		}

		if client.ID != token.Client {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidClient,
				Description: "Invalid client ID or secret",
			})
			return
		}

		user, err := db.FetchUser(ctx, token.User)
		if err != nil {
			oauthError(w, fmt.Errorf("failed to fetch user: %v", err))
			return
		}

		resp.Active = true
		resp.TokenType = oauth2.TokenTypeBearer
		resp.ExpiresAt = token.ExpiresAt
		resp.IssuedAt = token.IssuedAt
		resp.ClientID = client.ClientID
		resp.Username = user.Username
	}

	w.Header().Set("Content-Type", "application/json")
	json.NewEncoder(w).Encode(&resp)
}

func revokeToken(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	db := dbFromContext(ctx)

	values, err := parseRequestBody(req)
	if err != nil {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidRequest,
			Description: err.Error(),
		})
		return
	}

	client, err := maybeAuthenticateClient(w, req)
	if err != nil {
		oauthError(w, err)
		return
	}

	tokenID, secret, _ := UnmarshalSecret[*AccessToken](values.Get("token"))
	token, err := db.FetchAccessToken(ctx, tokenID)
	if err == errNoDBRows || (err == nil && !token.VerifySecret(secret)) {
		return // ignore
	} else if err != nil {
		oauthError(w, fmt.Errorf("failed to fetch access token: %v", err))
		return
	}

	if client == nil {
		client, err = db.FetchClient(ctx, token.Client)
		if err != nil {
			oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
			return
		}

		if !client.IsPublic() {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidClient,
				Description: "Missing client ID and secret",
			})
			return
		}
	}

	if client.ID != token.Client {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidClient,
			Description: "Invalid client ID or secret",
		})
		return
	}

	if err := db.DeleteAccessToken(ctx, token.ID); err != nil {
		oauthError(w, err)
		return
	}
}

func parseRequestBody(req *http.Request) (url.Values, error) {
	ct := req.Header.Get("Content-Type")
	if ct != "" {
		mimeType, _, err := mime.ParseMediaType(req.Header.Get("Content-Type"))
		if err != nil {
			return nil, fmt.Errorf("malformed Content-Type header field")
		} else if mimeType != "application/x-www-form-urlencoded" {
			return nil, fmt.Errorf("unsupported request content type")
		}
	}

	r := io.LimitReader(req.Body, 10<<20)
	b, err := io.ReadAll(r)
	if err != nil {
		return nil, fmt.Errorf("failed to read request body: %v", err)
	}

	values, err := url.ParseQuery(string(b))
	if err != nil {
		return nil, fmt.Errorf("failed to parse request body: %v", err)
	}

	return values, nil
}

func oauthError(w http.ResponseWriter, err error) {
	var oauthErr *oauth2.Error
	if !errors.As(err, &oauthErr) {
		oauthErr = &oauth2.Error{Code: oauth2.ErrorCodeServerError}
		log.Print(err)
	}

	statusCode := http.StatusInternalServerError
	switch oauthErr.Code {
	case oauth2.ErrorCodeInvalidRequest, oauth2.ErrorCodeUnsupportedResponseType, oauth2.ErrorCodeInvalidScope, oauth2.ErrorCodeInvalidClient, oauth2.ErrorCodeInvalidGrant, oauth2.ErrorCodeUnsupportedGrantType:
		statusCode = http.StatusBadRequest
	case oauth2.ErrorCodeUnauthorizedClient, oauth2.ErrorCodeAccessDenied:
		statusCode = http.StatusForbidden
	case oauth2.ErrorCodeTemporarilyUnavailable:
		statusCode = http.StatusServiceUnavailable
	}

	w.Header().Set("Content-Type", "application/json")
	w.WriteHeader(statusCode)
	json.NewEncoder(w).Encode(oauthErr)
}

func redirectClient(w http.ResponseWriter, req *http.Request, redirectURI *url.URL, values url.Values) {
	q := redirectURI.Query()
	for k, v := range values {
		q[k] = v
	}
	q.Set("iss", getIssuer(req))

	u := *redirectURI
	u.RawQuery = q.Encode()

	http.Redirect(w, req, u.String(), http.StatusFound)
}

func redirectClientError(w http.ResponseWriter, req *http.Request, redirectURI *url.URL, state string, err error) {
func redirectClientError(w http.ResponseWriter, req *http.Request, redirectURI *url.URL, state string, stateProvided bool, err error) {
	var oauthErr *oauth2.Error
	if !errors.As(err, &oauthErr) {
		oauthErr = &oauth2.Error{Code: oauth2.ErrorCodeServerError}
		log.Print(err)
	}

	values := make(url.Values)
	values.Set("error", string(oauthErr.Code))
	if oauthErr.Description != "" {
		values.Set("error_description", oauthErr.Description)
	}
	if oauthErr.URI != "" {
		values.Set("error_uri", oauthErr.URI)
	}
	if state != "" {
	if stateProvided {
		values.Set("state", state)
	}
	redirectClient(w, req, redirectURI, values)
}

func validateRedirectURI(u *url.URL, allowedURIs []*url.URL) bool {
	// Loopback interface, see RFC 8252 section 7.3
	host, _, _ := net.SplitHostPort(u.Host)
	ip := net.ParseIP(host)
	if u.Scheme == "http" && ip.IsLoopback() {
		uu := *u
		uu.Host = "localhost"
		u = &uu
	}

	for _, allowed := range allowedURIs {
		if u.String() == allowed.String() {
			return true
		}
	}

	return false
}

func maybeAuthenticateClient(w http.ResponseWriter, req *http.Request) (*Client, error) {
	ctx := req.Context()
	db := dbFromContext(ctx)

	clientID, clientSecret, ok := req.BasicAuth()
	if !ok {
		return nil, nil
	}

	client, err := db.FetchClientByClientID(ctx, clientID)
	if err == errNoDBRows || (err == nil && !client.VerifySecret(clientSecret)) {
		return nil, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidClient,
			Description: "Invalid client ID or secret",
		}
	} else if err != nil {
		return nil, fmt.Errorf("failed to fetch client: %v", err)
	}

	return client, nil
}
package main

import (
	"context"
	"crypto/rand"
	"crypto/rsa"
	"crypto/sha256"
	"crypto/x509"
	"encoding/base64"
	"encoding/json"
	"encoding/pem"
	"fmt"
	"math/big"
	"net/http"
	"sort"
	"strconv"
	"strings"
	"time"

	oauth2 "github.com/emersion/go-oauth2"
	"github.com/golang-jwt/jwt/v5"
)

type OIDCProvider struct {
	signingKeys []*oidcSigningKey
}

type oidcSigningKey struct {
	key       *SigningKey
	private   *rsa.PrivateKey
	publicJWK jwk
}

type jwk struct {
	Kty string `json:"kty"`
	Use string `json:"use,omitempty"`
	Alg string `json:"alg,omitempty"`
	Kid string `json:"kid,omitempty"`
	N   string `json:"n,omitempty"`
	E   string `json:"e,omitempty"`
}

type jwks struct {
	Keys []jwk `json:"keys"`
}

const idTokenTTL = 15 * time.Minute

func newOIDCProvider(ctx context.Context, db *DB) (*OIDCProvider, error) {
	signingRecords, err := db.FetchSigningKeys(ctx)
	if err == errNoDBRows {
		generated, genErr := generateSigningKey()
		if genErr != nil {
			return nil, genErr
		}
		if storeErr := db.StoreSigningKey(ctx, generated); storeErr != nil {
			return nil, fmt.Errorf("failed to persist signing key: %w", storeErr)
		}
		signingRecords = []SigningKey{*generated}
	} else if err != nil {
		return nil, fmt.Errorf("failed to fetch signing keys: %w", err)
	}

	signingKeys := make([]*oidcSigningKey, 0, len(signingRecords))
	for i := range signingRecords {
		material, convErr := toOIDCSigningKey(&signingRecords[i])
		if convErr != nil {
			return nil, convErr
		}
		signingKeys = append(signingKeys, material)
	}

	return &OIDCProvider{signingKeys: signingKeys}, nil
}

func generateSigningKey() (*SigningKey, error) {
	priv, err := rsa.GenerateKey(rand.Reader, 2048)
	if err != nil {
		return nil, fmt.Errorf("failed to generate signing key: %w", err)
	}

	pemBlock := pem.EncodeToMemory(&pem.Block{
		Type:  "RSA PRIVATE KEY",
		Bytes: x509.MarshalPKCS1PrivateKey(priv),
	})
	if pemBlock == nil {
		return nil, fmt.Errorf("failed to encode signing key")
	}

	kid, err := generateUID()
	if err != nil {
		return nil, fmt.Errorf("failed to generate signing key ID: %w", err)
	}
	return &SigningKey{
		KID:        kid,
		Algorithm:  "RS256",
		PrivateKey: pemBlock,
		CreatedAt:  time.Now(),
	}, nil
}

func toOIDCSigningKey(signing *SigningKey) (*oidcSigningKey, error) {
	block, _ := pem.Decode(signing.PrivateKey)
	if block == nil {
		return nil, fmt.Errorf("failed to decode signing key PEM")
	}
	priv, err := x509.ParsePKCS1PrivateKey(block.Bytes)
	if err != nil {
		return nil, fmt.Errorf("failed to parse signing key: %w", err)
	}

	jwk := jwk{
		Kty: "RSA",
		Use: "sig",
		Alg: signing.Algorithm,
		Kid: signing.KID,
		N:   base64.RawURLEncoding.EncodeToString(priv.N.Bytes()),
	}

	e := big.NewInt(int64(priv.E)).Bytes()
	jwk.E = base64.RawURLEncoding.EncodeToString(e)

	return &oidcSigningKey{
		key:       signing,
		private:   priv,
		publicJWK: jwk,
	}, nil
}

func (op *OIDCProvider) currentSigningKey() *oidcSigningKey {
	if len(op.signingKeys) == 0 {
		return nil
	}
	return op.signingKeys[0]
}

func (op *OIDCProvider) signingMethod() (*jwt.SigningMethodRSA, *oidcSigningKey, error) {
	key := op.currentSigningKey()
	if key == nil {
		return nil, nil, fmt.Errorf("no signing key configured")
	}

	switch key.key.Algorithm {
	case "RS256":
		return jwt.SigningMethodRS256, key, nil
	default:
		return nil, nil, fmt.Errorf("unsupported signing algorithm %q", key.key.Algorithm)
	}
}

func (op *OIDCProvider) MintIDToken(issuer string, client *Client, user *User, token *AccessToken, scopes []string, nonce string, accessToken string, authTime time.Time) (string, error) {
	method, signingKey, err := op.signingMethod()
	if err != nil {
		return "", err
	}

	now := time.Now()
	expiresAt := now.Add(idTokenTTL)
	if token.ExpiresAt.Before(expiresAt) {
		expiresAt = token.ExpiresAt
	}
	if expiresAt.Before(now) {
		expiresAt = now
	}

	claims := jwt.MapClaims{
		"iss": issuer,
		"sub": subjectForUser(user),
		"aud": client.ClientID,
		"exp": jwt.NewNumericDate(expiresAt),
		"iat": jwt.NewNumericDate(now),
	}
	if !authTime.IsZero() {
		claims["auth_time"] = jwt.NewNumericDate(authTime)
	}
	if nonce != "" {
		claims["nonce"] = nonce
	}
	if accessToken != "" {
		claims["at_hash"] = computeAtHash(accessToken)
	}
	if containsScope(scopes, scopeProfile) {
		displayName := user.Name
		if displayName == "" {
			displayName = user.Username
		}
		claims["preferred_username"] = user.Username
		claims["name"] = displayName
	}
	if containsScope(scopes, scopeEmail) && user.Email != "" {
		claims["email"] = user.Email
		claims["email_verified"] = false
	}

	tokenJWT := jwt.NewWithClaims(method, claims)
	tokenJWT.Header["kid"] = signingKey.key.KID

	return tokenJWT.SignedString(signingKey.private)
}

func (op *OIDCProvider) JWKS() jwks {
	keys := make([]jwk, 0, len(op.signingKeys))
	for _, key := range op.signingKeys {
		keys = append(keys, key.publicJWK)
	}
	return jwks{Keys: keys}
}

func subjectForUser(user *User) string {
	return strconv.FormatInt(int64(user.ID), 10)
}

func computeAtHash(accessToken string) string {
	sum := sha256.Sum256([]byte(accessToken))
	return base64.RawURLEncoding.EncodeToString(sum[:len(sum)/2])
}

func containsScope(scopes []string, scope string) bool {
	for _, s := range scopes {
		if strings.EqualFold(s, scope) {
			return true
		}
	}
	return false
}

func getOpenIDConfiguration(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	oidc := oidcProviderFromContext(ctx)
	issuer := getIssuer(req)
	currentKey := oidc.currentSigningKey()

	scopes := make([]string, 0, len(allowedScopes))
	for scope := range allowedScopes {
		scopes = append(scopes, scope)
	}
	sort.Strings(scopes)

	idTokenAlgs := []string{"RS256"}
	if currentKey != nil && currentKey.key.Algorithm != "" {
		idTokenAlgs = []string{currentKey.key.Algorithm}
	}

	config := map[string]interface{}{
		"issuer":                                         issuer,
		"authorization_endpoint":                         issuer + "/authorize",
		"token_endpoint":                                 issuer + "/token",
		"userinfo_endpoint":                              issuer + "/userinfo",
		"jwks_uri":                                       issuer + "/.well-known/jwks.json",
		"response_types_supported":                       []string{string(oauth2.ResponseTypeCode)},
		"response_modes_supported":                       []string{string(oauth2.ResponseModeQuery)},
		"grant_types_supported":                          []string{string(oauth2.GrantTypeAuthorizationCode), string(oauth2.GrantTypeRefreshToken)},
		"subject_types_supported":                        []string{"public"},
		"id_token_signing_alg_values_supported":          idTokenAlgs,
		"scopes_supported":                               scopes,
		"claims_supported":                               []string{"sub", "preferred_username", "name", "email", "email_verified"},
		"token_endpoint_auth_methods_supported":          []string{string(oauth2.AuthMethodNone), string(oauth2.AuthMethodClientSecretBasic)},
		"introspection_endpoint":                         issuer + "/introspect",
		"revocation_endpoint":                            issuer + "/revoke",
		"authorization_response_iss_parameter_supported": true,
		"claims_parameter_supported":                     false,
		"request_parameter_supported":                    false,
		"request_uri_parameter_supported":                false,
		"code_challenge_methods_supported":               []string{pkceMethodPlain, pkceMethodS256},
	}

	w.Header().Set("Content-Type", "application/json")
	json.NewEncoder(w).Encode(config)
}

func getOIDCJWKS(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	oidc := oidcProviderFromContext(ctx)

	w.Header().Set("Content-Type", "application/json")
	w.Header().Set("Cache-Control", "public, max-age=300")
	json.NewEncoder(w).Encode(oidc.JWKS())
}

func userInfo(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	db := dbFromContext(ctx)

	if req.Method != http.MethodGet && req.Method != http.MethodPost {
		w.Header().Set("Allow", "GET, POST")
		http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
		return
	}

	tokenValue, err := bearerTokenFromRequest(req)
	if err != nil {
		writeBearerError(w, http.StatusUnauthorized, "invalid_token", err.Error())
		return
	}

	tokenID, secret, err := UnmarshalSecret[*AccessToken](tokenValue)
	if err != nil {
		writeBearerError(w, http.StatusUnauthorized, "invalid_token", "Malformed access token")
		return
	}

	token, err := db.FetchAccessToken(ctx, tokenID)
	if err == errNoDBRows || (err == nil && !token.VerifySecret(secret)) {
		writeBearerError(w, http.StatusUnauthorized, "invalid_token", "Invalid access token")
		return
	} else if err != nil {
		httpError(w, fmt.Errorf("failed to fetch access token: %v", err))
		return
	}

	scopes := parseScopes(token.Scope)
	if !containsScope(scopes, scopeOpenID) {
		writeBearerError(w, http.StatusForbidden, "insufficient_scope", "Scope openid missing")
		return
	}

	user, err := db.FetchUser(ctx, token.User)
	if err != nil {
		httpError(w, fmt.Errorf("failed to fetch user: %v", err))
		return
	}

	resp := map[string]interface{}{
		"sub": subjectForUser(user),
	}
	if containsScope(scopes, scopeProfile) {
		displayName := user.Name
		if displayName == "" {
			displayName = user.Username
		}
		resp["preferred_username"] = user.Username
		resp["name"] = displayName
	}
	if containsScope(scopes, scopeEmail) && user.Email != "" {
		resp["email"] = user.Email
		resp["email_verified"] = false
	}

	w.Header().Set("Content-Type", "application/json")
	json.NewEncoder(w).Encode(resp)
}

func bearerTokenFromRequest(req *http.Request) (string, error) {
	authz := req.Header.Get("Authorization")
	if authz == "" {
		return "", fmt.Errorf("Authorization header missing")
	}
	if len(authz) < 7 || !strings.EqualFold(authz[:7], "Bearer ") {
		return "", fmt.Errorf("Unsupported authorization scheme")
	}
	token := strings.TrimSpace(authz[7:])
	if token == "" {
		return "", fmt.Errorf("Missing access token")
	}
	return token, nil
}

func writeBearerError(w http.ResponseWriter, status int, code, description string) {
	challenge := "Bearer"
	if code != "" {
		challenge += fmt.Sprintf(" error=\"%s\"", code)
	}
	if description != "" {
		if code == "" {
			challenge += " "
		} else {
			challenge += ", "
		}
		challenge += fmt.Sprintf("error_description=\"%s\"", description)
	}
	w.Header().Set("WWW-Authenticate", challenge)
	http.Error(w, http.StatusText(status), status)
}
package main

import (
	"fmt"
	"strings"
)

const (
	pkceRequirementNone  = ""
	pkceRequirementPlain = pkceMethodPlain
	pkceRequirementS256  = pkceMethodS256
)

func normalizeClientPKCERequirement(value string) (string, error) {
	switch strings.ToUpper(strings.TrimSpace(value)) {
	case "", "NONE":
		return pkceRequirementNone, nil
	case strings.ToUpper(pkceRequirementPlain):
		return pkceRequirementPlain, nil
	case pkceRequirementS256:
		return pkceRequirementS256, nil
	default:
		return "", fmt.Errorf("invalid PKCE requirement")
	}
}

func allowPKCERequirement(method, requirement string) bool {
	requirement = strings.ToUpper(requirement)
	method = strings.ToUpper(method)
	switch requirement {
	case "", "NONE":
		return true
	case strings.ToUpper(pkceRequirementPlain):
		return method == strings.ToUpper(pkceMethodPlain) || method == strings.ToUpper(pkceMethodS256)
	case pkceRequirementS256:
		return method == strings.ToUpper(pkceMethodS256)
	default:
		return false
	}
}
CREATE TABLE User (
	id INTEGER PRIMARY KEY,
	username TEXT NOT NULL UNIQUE,
	name TEXT,
	email TEXT,
	password_hash TEXT,
	admin INTEGER NOT NULL DEFAULT 0
);

CREATE TABLE Client (
	id INTEGER PRIMARY KEY,
	client_id TEXT NOT NULL UNIQUE,
	client_secret_hash BLOB,
	owner INTEGER REFERENCES User(id) ON DELETE CASCADE,
	redirect_uris TEXT,
	client_name TEXT,
	client_uri TEXT
	client_uri TEXT,
	pkce_requirement TEXT
);

CREATE TABLE AccessToken (
	id INTEGER PRIMARY KEY,
	hash BLOB NOT NULL UNIQUE,
	user INTEGER NOT NULL REFERENCES User(id) ON DELETE CASCADE,
	client INTEGER REFERENCES Client(id) ON DELETE CASCADE,
	scope TEXT,
	issued_at datetime NOT NULL,
	expires_at datetime NOT NULL,
	auth_time datetime,
	refresh_hash BLOB UNIQUE,
	refresh_expires_at datetime
);

CREATE TABLE AuthCode (
	id INTEGER PRIMARY KEY,
	hash BLOB NOT NULL UNIQUE,
	created_at datetime NOT NULL,
	user INTEGER NOT NULL REFERENCES User(id) ON DELETE CASCADE,
	client INTEGER NOT NULL REFERENCES Client(id) ON DELETE CASCADE,
	redirect_uri TEXT,
	scope TEXT
	scope TEXT,
	nonce TEXT,
	code_challenge TEXT,
	code_challenge_method TEXT
);

CREATE TABLE SigningKey (
	id INTEGER PRIMARY KEY,
	kid TEXT NOT NULL UNIQUE,
	algorithm TEXT NOT NULL,
	private_key BLOB NOT NULL,
	created_at datetime NOT NULL
);

CREATE INDEX signing_key_created_at ON SigningKey(created_at);
body {
	font-family: sans-serif;
	margin: 0;
	color: #444;
}
main, #nav-inner, footer {
	padding: 0 5px;
	max-width: 800px;
	margin: 0 auto;
}
main {
	padding-bottom: 20px;
}

nav {
	border-bottom: 1px solid #eee;
}
nav h1 {
	margin: 0;
	padding: 10px 0;
	font-size: 1.2em;
}
nav h1 a {
	color: inherit;
	text-decoration: none;
}

h2 {
	font-size: 1.2em;
}

table {
	border-collapse: collapse;
}
td, th {
	border: 1px solid rgb(208, 210, 215);
	padding: 5px;
}

button {
	border: 1px solid rgb(208, 210, 215);
	border-radius: 4px;
	padding: 6px 12px;
	margin: 4px 0;
	color: #444;
	background-color: transparent;
	cursor: pointer;
}
button:hover {
	background-color: rgba(0, 0, 0, 0.02);
}
button[type="submit"]:not(.btn-regular):first-of-type {
	background-color: rgb(0, 128, 0);
	border-color: rgb(0, 128, 0);
	color: white;
}
button[type="submit"]::not(.btn-regular):first-of-type:hover {
	background-color: rgb(0, 150, 0);
	border-color: rgb(0, 150, 0);
}

input[type="text"], input[type="password"], input[type="url"], textarea {
input[type="text"], input[type="email"], input[type="password"], input[type="url"], textarea {
	border: 1px solid rgb(208, 210, 215);
	border-radius: 4px;
	padding: 6px;
	margin: 4px 0;
	color: #444;
}
input[type="text"]:focus, input[type="password"]:focus, input[type="url"]:focus, textarea:focus {
input[type="email"]:focus, input[type="text"]:focus, input[type="password"]:focus, input[type="url"]:focus, textarea:focus {
	outline: none;
	border-color: rgb(0, 128, 0);
}

label {
	display: block;
	margin: 15px 0;
}
label input[type="text"], label input[type="password"], label input[type="url"] {
label input[type="email"], label input[type="text"], label input[type="password"], label input[type="url"] {
	display: block;
	width: 100%;
	max-width: 350px;
	box-sizing: border-box;
}
label:has(input[type="radio"]) {
	margin: 5px 0;
}
label textarea {
	display: block;
	width: 100%;
	resize: vertical;
}

main.narrow {
	max-width: 400px;
}
main.narrow input {
	max-width: 100%;
}

@media (prefers-color-scheme: dark) {
	body {
		background: #212529;
		color: #f8f9fa;
	}

	nav {
		border-color: rgba(255, 255, 2555, 0.02);
	}

	a {
		color: #809fff;
	}

	button {
		color: #f8f9fa;
	}
	button:hover {
		background-color: rgba(255, 255, 255, 0.02);
	}

	input[type="text"], input[type="password"], input[type="url"], textarea {
	input[type="email"], input[type="text"], input[type="password"], input[type="url"], textarea {
		background-color: rgba(255, 255, 255, 0.05);
		color: inherit;
	}
}
{{ template "head.html" .Base }}

<main class="narrow">

<h1>{{ .ServerName }}</h1>

<p>
	Authorize
	{{ if .Client.ClientURI }}
		<a href="{{ .Client.ClientURI }}" target="_blank">
	{{ end }}
	{{- if .Client.ClientName -}}
		{{- .Client.ClientName -}}
	{{- else -}}
		<code>{{- .Client.ClientID -}}</code>
	{{- end -}}
	{{- if .Client.ClientURI -}}
		</a>
	{{- end -}}
	?
</p>

<form method="post" action="">
	<input type="hidden" name="_csrf" value="{{ .Base.CSRFToken }}">
	<button type="submit" name="authorize">Authorize</button>
	<button type="submit" name="deny">Cancel</button>
</form>

</main>

{{ template "foot.html" }}
<!DOCTYPE HTML>
<html lang="en">
<head>
    <meta charset="utf-8"/>
	<meta charset="utf-8"/>
	<title>{{ .ServerName }}</title>
    <meta name="viewport" content="width=device-width, initial-scale=1"/>
    <link rel="stylesheet" href="/static/style.css"/>
	<meta name="viewport" content="width=device-width, initial-scale=1"/>
	<link rel="stylesheet" href="/static/style.css"/>
</head>
<body>
{{ template "head.html" .Base }}
{{ template "nav.html" .Base }}

<main>

<p>Welcome, {{ .Me.Username }}!</p>

<form method="post">
	<input type="hidden" name="_csrf" value="{{ .Base.CSRFToken }}">
	<a href="/user/{{ .Me.ID }}"><button type="button">Settings</button></a>
	<button type="submit" formaction="/logout" class="btn-regular">Logout</button>
</form>

<h2>Authorized clients</h2>

{{ with .AuthorizedClients }}
	<table>
		<tr>
			<th>Client</th>
			<th>Authorized until</th>
			<th></th>
		</tr>
		{{ range . }}
			<tr>
				<td>
					{{ with .Client }}
						{{ if .ClientURI }}
							<a href="{{ .ClientURI }}" target="_blank">
						{{ end }}
						{{ if .ClientName }}
							{{ .ClientName }}
						{{ else }}
							<code>{{ .ClientID }}</code>
						{{ end }}
						{{ if .ClientURI }}
							</a>
						{{ end }}
					{{ end }}
				</td>
				<td>{{ .ExpiresAt }}</td>
				<td>
					<form method="post" action="/client/{{ .Client.ID }}/revoke">
						<button type="submit">Revoke</button>
					</form>
				<form method="post" action="/client/{{ .Client.ID }}/revoke">
					<input type="hidden" name="_csrf" value="{{ $.Base.CSRFToken }}">
					<button type="submit">Revoke</button>
				</form>
				</td>
			</tr>
		{{ end }}
	</table>
{{ else }}
	<p>No client authorized yet.</p>
{{ end }}

{{ if .Me.Admin }}
	<h2>Registered clients</h2>

	<p>
		<a href="/client/new"><button type="button">Register new client</button></a>
	</p>

	{{ with .Clients }}
		<table>
			<tr>
				<th>Client ID</th>
				<th>Name</th>
			</tr>
			{{ range . }}
				<tr>
					<td><a href="/client/{{ .ID }}"><code>{{ .ClientID }}</code></a></td>
					<td>{{ .ClientName }}</td>
				</tr>
			{{ end }}
		</table>
	{{ else }}
		<p>No client registered yet.</p>
	{{ end }}

	<h2>Users</h2>

	<p>
		<a href="/user/new"><button type="button">Create user</button></a>
	</p>

	<table>
		<tr>
			<th>Username</th>
			<th>Name</th>
			<th>Email</th>
			<th>Role</th>
		</tr>
		{{ range .Users }}
			<tr>
				<td><a href="/user/{{ .ID }}">{{ .Username }}</a></td>
				<td>{{ .Name }}</td>
				<td>{{ .Email }}</td>
				<td>
					{{ if .Admin }}
						Administrator
					{{ else }}
						Regular user
					{{ end}}
				</td>
			</tr>
		{{ end }}
	</table>
{{ end }}

</main>

{{ template "foot.html" }}
{{ template "head.html" .Base }}

<main class="narrow">

<h1>{{ .ServerName }}</h1>

<form method="post" action="">
	<input type="hidden" name="_csrf" value="{{ .Base.CSRFToken }}">
	<label>
		Username
		<input type="text" name="username" autocomplete="username" autofocus>
	</label>
	<label>
		Password
		<input type="password" name="password">
	</label>
	<button type="submit">Login</button>
</form>

</main>

{{ template "foot.html" }}
{{ template "head.html" .Base }}
{{ template "nav.html" .Base }}

<main>

<h2>
	{{ if .Client.ID }}
		Update client
	{{ else }}
		Create client
	{{ end }}
</h2>

<form method="post" action="">
	<input type="hidden" name="_csrf" value="{{ .Base.CSRFToken }}">
	{{ if .Client.ClientID }}
		Client ID: <code>{{ .Client.ClientID }}</code><br>
	{{ end }}
	<label>
		Name
		<input type="text" name="client_name" value="{{ .Client.ClientName }}">
	</label>
	<label>
		Website
		<input type="url" name="client_uri" value="{{ .Client.ClientURI }}">
	</label>

	{{ if .Client.ID }}
		<p>
			Client type
			<br>
			<strong>
			{{ if .Client.IsPublic }}
				Public
			{{ else }}
				Confidential
			{{ end }}
			</strong>
		</p>
	{{ else }}
		Client type
		<label>
			<input type="radio" name="client_type" value="confidential" checked>
			Confidential
		</label>
		<label>
			<input type="radio" name="client_type" value="public">
			Public
		</label>
	{{ end }}

	<label>
		Redirect URIs
		<textarea name="redirect_uris" wrap="off">{{ .Client.RedirectURIs }}</textarea>
		<small>The special URI <code>http://localhost</code> matches all loopback interfaces.</small><br>
	</label>
	<label>
		PKCE requirement
		<select name="pkce_requirement" {{ if not .Client.IsPublic }}disabled{{ end }}>
			<option value="" {{ if eq .Client.PKCERequirement "" }}selected{{ end }}>Optional</option>
			<option value="plain" {{ if eq .Client.PKCERequirement "plain" }}selected{{ end }}>Require PKCE (plain or S256)</option>
			<option value="S256" {{ if eq .Client.PKCERequirement "S256" }}selected{{ end }}>Require PKCE (S256)</option>
		</select>
		<small>Applies to public clients only.</small>
	</label>

	<button type="submit">
		{{ if .Client.ID }}
			Update client
		{{ else }}
			Create client
		{{ end }}
	</button>
	{{ if .Client.ID }}
		{{ if not .Client.IsPublic }}
			<button type="submit" name="rotate">Rotate client secret</button>
		{{ end }}
		<button type="submit" name="delete">Delete client</button>
	{{ end }}
	<a href="/"><button type="button">Cancel</button></a>
</form>

</main>

{{ template "foot.html" }}
{{ template "head.html" .Base }}
{{ template "nav.html" .Base }}

<main>

<h2>
	{{ if .User.ID }}
		Update user
	{{ else }}
		Create user
	{{ end }}
</h2>

<form method="post" action="">
	<input type="hidden" name="_csrf" value="{{ .Base.CSRFToken }}">
	<label>
		Username
		<input type="text" name="username" value="{{ .User.Username }}" required>
	</label>
	<label>
		Display name
		<input type="text" name="name" value="{{ .User.Name }}">
	</label>
	<label>
		Email
		<input type="email" name="email" value="{{ .User.Email }}">
	</label>
	<label>
		Password
		<input type="password" name="password">
	</label>
	{{ if not (eq .Me.ID .User.ID) }}
		<label>
			<input type="checkbox" name="admin" {{ if .User.Admin }}checked{{ end }}>
			Administrator
		</label>
	{{ end }}

	<button type="submit">
		{{ if .User.ID }}
			Update user
		{{ else }}
			Create user
		{{ end }}
	</button>
	<a href="/"><button type="button">Cancel</button></a>
</form>

</main>

{{ template "foot.html" }}
package main

import (
	"fmt"
	"log"
	"net/http"
	"net/url"
	"strings"
	"time"

	"github.com/go-chi/chi/v5"
)

func index(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	db := dbFromContext(ctx)
	tpl := templateFromContext(ctx)

	loginToken := loginTokenFromContext(ctx)
	if loginToken == nil {
		http.Redirect(w, req, "/login", http.StatusFound)
		return
	}

	me, err := db.FetchUser(ctx, loginToken.User)
	if err != nil {
		httpError(w, err)
		return
	}

	authorizedClients, err := db.ListAuthorizedClients(ctx, loginToken.User)
	if err != nil {
		httpError(w, err)
		return
	}

	clients, err := db.ListClients(ctx, loginToken.User)
	if err != nil {
		httpError(w, err)
		return
	}

	var users []User
	if me.Admin {
		users, err = db.ListUsers(ctx)
		if err != nil {
			httpError(w, err)
			return
		}
	}

	data := struct {
		TemplateBaseData
		Me                *User
		AuthorizedClients []AuthorizedClient
		Clients           []Client
		Users             []User
	}{
		Me:                me,
		AuthorizedClients: authorizedClients,
		Clients:           clients,
		Users:             users,
	}
	tpl.MustExecuteTemplate(w, "index.html", &data)
	tpl.MustExecuteTemplate(req.Context(), w, "index.html", &data)
}

func login(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	db := dbFromContext(ctx)
	tpl := templateFromContext(ctx)

	q := req.URL.Query()
	rawRedirectURI := q.Get("redirect_uri")
	if rawRedirectURI == "" {
		rawRedirectURI = "/"
	}

	redirectURI, err := url.Parse(rawRedirectURI)
	if err != nil || redirectURI.Scheme != "" || redirectURI.Opaque != "" || redirectURI.User != nil || redirectURI.Host != "" {
		http.Error(w, "Invalid redirect URI", http.StatusBadRequest)
		return
	}

	if loginTokenFromContext(ctx) != nil {
		http.Redirect(w, req, redirectURI.String(), http.StatusFound)
		return
	}

	username := req.PostFormValue("username")
	username := strings.TrimSpace(req.PostFormValue("username"))
	password := req.PostFormValue("password")
	if username == "" {
		tpl.MustExecuteTemplate(w, "login.html", nil)
		tpl.MustExecuteTemplate(req.Context(), w, "login.html", nil)
		return
	}

	user, err := db.FetchUserByUsername(ctx, username)
	if err != nil && err != errNoDBRows {
		httpError(w, fmt.Errorf("failed to fetch user: %v", err))
		return
	}
	if err == nil {
		err = user.VerifyPassword(password)
	}
	if err != nil {
		log.Printf("login failed for user %q: %v", username, err)
		// TODO: show error message
		tpl.MustExecuteTemplate(w, "login.html", nil)
		tpl.MustExecuteTemplate(req.Context(), w, "login.html", nil)
		return
	}

	if user.PasswordNeedsRehash() {
		if err := user.SetPassword(password); err != nil {
			httpError(w, fmt.Errorf("failed to rehash password: %v", err))
			return
		}
		if err := db.StoreUser(ctx, user); err != nil {
			httpError(w, fmt.Errorf("failed to store user: %v", err))
			return
		}
	}

	token := AccessToken{
		User:  user.ID,
		Scope: internalTokenScope,
	}
	secret, err := token.Generate(4 * time.Hour)
	if err != nil {
		httpError(w, fmt.Errorf("failed to generate access token: %v", err))
		return
	}
	token.AuthTime = token.IssuedAt
	if err := db.StoreAccessToken(ctx, &token); err != nil {
		httpError(w, fmt.Errorf("failed to create access token: %v", err))
		return
	}

	setLoginTokenCookie(w, req, &token, secret)
	http.Redirect(w, req, redirectURI.String(), http.StatusFound)
}

func logout(w http.ResponseWriter, req *http.Request) {
	unsetLoginTokenCookie(w, req)
	http.Redirect(w, req, "/login", http.StatusFound)
}

func manageUser(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	db := dbFromContext(ctx)
	tpl := templateFromContext(ctx)

	user := new(User)
	if idStr := chi.URLParam(req, "id"); idStr != "" {
		id, err := ParseID[*User](idStr)
		if err != nil {
			http.Error(w, err.Error(), http.StatusBadRequest)
			return
		}

		user, err = db.FetchUser(ctx, id)
		if err != nil {
			httpError(w, err)
			return
		}
	}

	loginToken := loginTokenFromContext(ctx)
	if loginToken == nil {
		http.Redirect(w, req, "/login", http.StatusFound)
		return
	}

	me, err := db.FetchUser(ctx, loginToken.User)
	if err != nil {
		httpError(w, err)
		return
	} else if loginToken.User != user.ID && !me.Admin {
		http.Error(w, "Access denied", http.StatusForbidden)
		return
	}

	username := req.PostFormValue("username")
	username := strings.TrimSpace(req.PostFormValue("username"))
	name := strings.TrimSpace(req.PostFormValue("name"))
	email := strings.TrimSpace(req.PostFormValue("email"))
	password := req.PostFormValue("password")
	admin := req.PostFormValue("admin") == "on"
	if username == "" {
		data := struct {
			TemplateBaseData
			User *User
			Me   *User
		}{
			User: user,
			Me:   me,
		}
		tpl.MustExecuteTemplate(w, "manage-user.html", &data)
		tpl.MustExecuteTemplate(req.Context(), w, "manage-user.html", &data)
		return
	}

	user.Username = username
	user.Name = name
	user.Email = email
	if me.Admin && user.ID != me.ID {
		user.Admin = admin
	}
	if password != "" {
		if err := user.SetPassword(password); err != nil {
			httpError(w, err)
			return
		}
	}

	if err := db.StoreUser(ctx, user); err != nil {
		httpError(w, err)
		return
	}

	http.Redirect(w, req, "/", http.StatusFound)
}