Hi… I am well aware that this diff view is very suboptimal. It will be fixed when the refactored server comes along!
sinwon -> vireo in various parts of the codebase
package main
import (
	"context"
	"database/sql"
	_ "embed"
	"fmt"
	"time"
	"github.com/mattn/go-sqlite3"
)
//go:embed schema.sql
var schema string
var migrations = []string{
	"", // migration #0 is reserved for schema initialization
	`
		ALTER TABLE AccessToken ADD COLUMN refresh_hash BLOB;
		ALTER TABLE AccessToken ADD COLUMN refresh_expires_at datetime;
		CREATE UNIQUE INDEX access_token_refresh_hash ON AccessToken(refresh_hash);
	`,
}
var errNoDBRows = sql.ErrNoRows
type DB struct {
	db *sql.DB
}
func openDB(filename string) (*DB, error) {
	sqlDB, err := sql.Open("sqlite3", filename+"?cache=shared&_foreign_keys=1")
	if err != nil {
		return nil, err
	}
	sqlDB.SetMaxOpenConns(1)
	db := &DB{sqlDB}
	if err := db.init(context.TODO()); err != nil {
		db.Close()
		return nil, err
	}
	return db, nil
}
func (db *DB) init(ctx context.Context) error {
	version, err := db.upgrade(ctx)
	if err != nil {
		return err
	}
	if version > 0 {
		return nil
	}
	// TODO: drop this
	defaultUser := User{Username: "root", Admin: true}
	if err := defaultUser.SetPassword("root"); err != nil {
		return err
	}
	return db.StoreUser(ctx, &defaultUser)
}
func (db *DB) upgrade(ctx context.Context) (version int, err error) {
	if err := db.db.QueryRowContext(ctx, "PRAGMA user_version").Scan(&version); err != nil {
		return 0, fmt.Errorf("failed to query schema version: %v", err)
	}
	if version == len(migrations) {
		return version, nil
	} else if version > len(migrations) {
		return version, fmt.Errorf("sinwon (version %d) older than schema (version %d)", len(migrations), version)
		return version, fmt.Errorf("vireo (version %d) older than schema (version %d)", len(migrations), version)
	}
	tx, err := db.db.BeginTx(ctx, nil)
	if err != nil {
		return version, err
	}
	defer tx.Rollback()
	if version == 0 {
		if _, err := tx.ExecContext(ctx, schema); err != nil {
			return version, fmt.Errorf("failed to initialize schema: %v", err)
		}
	} else {
		for i := version; i < len(migrations); i++ {
			if _, err := tx.ExecContext(ctx, migrations[i]); err != nil {
				return version, fmt.Errorf("failed to execute migration #%v: %v", i, err)
			}
		}
	}
	// For some reason prepared statements don't work here
	_, err = tx.ExecContext(ctx, fmt.Sprintf("PRAGMA user_version = %d", len(migrations)))
	if err != nil {
		return version, fmt.Errorf("failed to bump schema version: %v", err)
	}
	return version, tx.Commit()
}
func (db *DB) Close() error {
	return db.db.Close()
}
func (db *DB) FetchUser(ctx context.Context, id ID[*User]) (*User, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM User WHERE id = ?", id)
	if err != nil {
		return nil, err
	}
	var user User
	err = scanRow(&user, rows)
	return &user, err
}
func (db *DB) FetchUserByUsername(ctx context.Context, username string) (*User, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM User WHERE username = ?", username)
	if err != nil {
		return nil, err
	}
	var user User
	err = scanRow(&user, rows)
	return &user, err
}
func (db *DB) StoreUser(ctx context.Context, user *User) error {
	return db.db.QueryRowContext(ctx, `
		INSERT INTO User(id, username, password_hash, admin)
		VALUES (:id, :username, :password_hash, :admin)
		ON CONFLICT(id) DO UPDATE SET
			username = :username,
			password_hash = :password_hash,
			admin = :admin
		RETURNING id
	`, entityArgs(user)...).Scan(&user.ID)
}
func (db *DB) ListUsers(ctx context.Context) ([]User, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM User")
	if err != nil {
		return nil, err
	}
	defer rows.Close()
	var l []User
	for rows.Next() {
		var user User
		if err := scan(&user, rows); err != nil {
			return nil, err
		}
		l = append(l, user)
	}
	return l, rows.Close()
}
func (db *DB) FetchClient(ctx context.Context, id ID[*Client]) (*Client, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM Client WHERE id = ?", id)
	if err != nil {
		return nil, err
	}
	var client Client
	err = scanRow(&client, rows)
	return &client, err
}
func (db *DB) FetchClientByClientID(ctx context.Context, clientID string) (*Client, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM Client WHERE client_id = ?", clientID)
	if err != nil {
		return nil, err
	}
	var client Client
	err = scanRow(&client, rows)
	return &client, err
}
func (db *DB) StoreClient(ctx context.Context, client *Client) error {
	return db.db.QueryRowContext(ctx, `
		INSERT INTO Client(id, client_id, client_secret_hash, owner,
			redirect_uris, client_name, client_uri)
		VALUES (:id, :client_id, :client_secret_hash, :owner,
			:redirect_uris, :client_name, :client_uri)
		ON CONFLICT(id) DO UPDATE SET
			client_id = :client_id,
			client_secret_hash = :client_secret_hash,
			owner = :owner,
			redirect_uris = :redirect_uris,
			client_name = :client_name,
			client_uri = :client_uri
		RETURNING id
	`, entityArgs(client)...).Scan(&client.ID)
}
func (db *DB) ListClients(ctx context.Context, owner ID[*User]) ([]Client, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM Client WHERE owner IS ?", owner)
	if err != nil {
		return nil, err
	}
	defer rows.Close()
	var l []Client
	for rows.Next() {
		var client Client
		if err := scan(&client, rows); err != nil {
			return nil, err
		}
		l = append(l, client)
	}
	return l, rows.Close()
}
func (db *DB) ListAuthorizedClients(ctx context.Context, user ID[*User]) ([]AuthorizedClient, error) {
	rows, err := db.db.QueryContext(ctx, `
		SELECT id, client_id, client_name, client_uri, token.expires_at
		FROM Client,
		(
			SELECT client, MAX(COALESCE(refresh_expires_at, expires_at)) as expires_at
			FROM AccessToken
			WHERE user = ?
			GROUP BY client
		) AS token
		WHERE Client.id = token.client
	`, user)
	if err != nil {
		return nil, err
	}
	var l []AuthorizedClient
	for rows.Next() {
		var authClient AuthorizedClient
		columns := authClient.Client.columns()
		var expiresAt string
		err := rows.Scan(columns["id"], columns["client_id"], columns["client_name"], columns["client_uri"], &expiresAt)
		if err != nil {
			return nil, err
		}
		authClient.ExpiresAt, err = time.Parse(sqlite3.SQLiteTimestampFormats[0], expiresAt)
		if err != nil {
			return nil, err
		}
		l = append(l, authClient)
	}
	return l, rows.Close()
}
func (db *DB) DeleteClient(ctx context.Context, id ID[*Client]) error {
	_, err := db.db.ExecContext(ctx, "DELETE FROM Client WHERE id = ?", id)
	return err
}
func (db *DB) FetchAccessToken(ctx context.Context, id ID[*AccessToken]) (*AccessToken, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM AccessToken WHERE id = ?", id)
	if err != nil {
		return nil, err
	}
	var token AccessToken
	err = scanRow(&token, rows)
	return &token, err
}
func (db *DB) StoreAccessToken(ctx context.Context, token *AccessToken) error {
	return db.db.QueryRowContext(ctx, `
		INSERT INTO AccessToken(id, hash, user, client, scope, issued_at,
			expires_at, refresh_hash, refresh_expires_at)
		VALUES (:id, :hash, :user, :client, :scope, :issued_at, :expires_at,
			:refresh_hash, :refresh_expires_at)
		ON CONFLICT(id) DO UPDATE SET
			hash = :hash,
			user = :user,
			client = :client,
			scope = :scope,
			issued_at = :issued_at,
			expires_at = :expires_at,
			refresh_hash = :refresh_hash,
			refresh_expires_at = :refresh_expires_at
		RETURNING id
	`, entityArgs(token)...).Scan(&token.ID)
}
func (db *DB) DeleteAccessToken(ctx context.Context, id ID[*AccessToken]) error {
	_, err := db.db.ExecContext(ctx, "DELETE FROM AccessToken WHERE id = ?", id)
	return err
}
func (db *DB) RevokeClientUser(ctx context.Context, clientID ID[*Client], userID ID[*User]) error {
	tx, err := db.db.BeginTx(ctx, nil)
	if err != nil {
		return err
	}
	defer tx.Rollback()
	_, err = tx.ExecContext(ctx, `
		DELETE FROM AccessToken
		WHERE client = ? AND user = ?
	`, clientID, userID)
	if err != nil {
		return err
	}
	_, err = tx.ExecContext(ctx, `
		DELETE FROM AuthCode
		WHERE client = ? AND user = ?
	`, clientID, userID)
	if err != nil {
		return err
	}
	return tx.Commit()
}
func (db *DB) CreateAuthCode(ctx context.Context, code *AuthCode) error {
	return db.db.QueryRowContext(ctx, `
		INSERT INTO AuthCode(hash, created_at, user, client, scope, redirect_uri)
		VALUES (:hash, :created_at, :user, :client, :scope, :redirect_uri)
		RETURNING id
	`, entityArgs(code)...).Scan(&code.ID)
}
func (db *DB) PopAuthCode(ctx context.Context, id ID[*AuthCode]) (*AuthCode, error) {
	rows, err := db.db.QueryContext(ctx, `
		DELETE FROM AuthCode
		WHERE id = ?
		RETURNING *
	`, id)
	if err != nil {
		return nil, err
	}
	var authCode AuthCode
	err = scanRow(&authCode, rows)
	return &authCode, err
}
func (db *DB) Maintain(ctx context.Context) error {
	_, err := db.db.ExecContext(ctx, `
		DELETE FROM AccessToken
		WHERE timediff('now', COALESCE(refresh_expires_at, expires_at)) > 0
	`)
	if err != nil {
		return err
	}
	_, err = db.db.ExecContext(ctx, `
		DELETE FROM AuthCode
		WHERE timediff(?, created_at) > 0
	`, time.Now().Add(-authCodeExpiration))
	if err != nil {
		return err
	}
	return nil
}
func scan(e entity, rows *sql.Rows) error {
	columns := e.columns()
	keys, err := rows.Columns()
	if err != nil {
		panic(err)
	}
	out := make([]interface{}, len(keys))
	for i, k := range keys {
		v, ok := columns[k]
		if !ok {
			panic(fmt.Errorf("unknown column %q", k))
		}
		out[i] = v
	}
	return rows.Scan(out...)
}
func scanRow(e entity, rows *sql.Rows) error {
	if !rows.Next() {
		return sql.ErrNoRows
	}
	if err := scan(e, rows); err != nil {
		return err
	}
	return rows.Close()
}
func entityArgs(e entity) []interface{} {
	columns := e.columns()
	l := make([]interface{}, 0, len(columns))
	for k, v := range columns {
		l = append(l, sql.Named(k, v))
	}
	return l
}
package main
import (
	"context"
	"embed"
	"flag"
	"log"
	"net"
	"net/http"
	"time"
	"github.com/go-chi/chi/v5"
)
var (
	//go:embed template
	templateFS embed.FS
	//go:embed static
	staticFS embed.FS
)
func main() {
	var configFilename string
flag.StringVar(&configFilename, "config", "/etc/sinwon/config", "Configuration filename")
flag.StringVar(&configFilename, "config", "/etc/lindenii/vireo/config", "Configuration filename")
	flag.Parse()
	cfg, err := loadConfig(configFilename)
	if err != nil {
		log.Fatalf("Failed to load config file: %v", err)
	}
	listenAddr := cfg.Listen
	if cfg.Database == "" {
		log.Fatalf("Missing database configuration")
	}
	db, err := openDB(cfg.Database)
	if err != nil {
		log.Fatalf("Failed to open DB: %v", err)
	}
	tplBaseData := &TemplateBaseData{
		ServerName: cfg.ServerName,
	}
	if tplBaseData.ServerName == "" {
tplBaseData.ServerName = "sinwon"
tplBaseData.ServerName = "vireo"
	}
	tpl, err := loadTemplate(templateFS, "template/*.html", tplBaseData)
	if err != nil {
		log.Fatalf("Failed to load template: %v", err)
	}
	mux := chi.NewRouter()
	mux.Handle("/static/*", http.FileServer(http.FS(staticFS)))
	mux.Get("/", index)
	mux.HandleFunc("/login", login)
	mux.Post("/logout", logout)
	mux.HandleFunc("/client/new", manageClient)
	mux.HandleFunc("/client/{id}", manageClient)
	mux.Post("/client/{id}/revoke", revokeClient)
	mux.HandleFunc("/user/new", manageUser)
	mux.HandleFunc("/user/{id}", manageUser)
	mux.Get("/.well-known/oauth-authorization-server", getOAuthServerMetadata)
	mux.HandleFunc("/authorize", authorize)
	mux.Post("/token", exchangeToken)
	mux.Post("/introspect", introspectToken)
	mux.Post("/revoke", revokeToken)
	go maintainDBLoop(db)
	server := http.Server{
		Addr:    listenAddr,
		Handler: loginTokenMiddleware(mux),
		BaseContext: func(net.Listener) context.Context {
			return newBaseContext(db, tpl)
		},
	}
	log.Printf("OAuth server listening on %v", server.Addr)
	if err := server.ListenAndServe(); err != nil {
		log.Fatalf("Failed to listen and serve: %v", err)
	}
}
func httpError(w http.ResponseWriter, err error) {
	log.Print(err)
	http.Error(w, "Internal server error", http.StatusInternalServerError)
}
func maintainDBLoop(db *DB) {
	ticker := time.NewTicker(15 * time.Minute)
	defer ticker.Stop()
	for range ticker.C {
		ctx := context.Background()
		ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
		if err := db.Maintain(ctx); err != nil {
			log.Printf("Failed to perform database maintenance: %v", err)
		}
		cancel()
	}
}
package main import ( "context" "fmt" "html/template" "io" "io/fs" "mime" "net/http" ) const (
loginCookieName = "sinwon-login" internalTokenScope = "_sinwon"
loginCookieName = "vireo-login" internalTokenScope = "_vireo"
)
type contextKey string
const (
	contextKeyDB         = "db"
	contextKeyTemplate   = "template"
	contextKeyLoginToken = "login-token"
)
func dbFromContext(ctx context.Context) *DB {
	return ctx.Value(contextKeyDB).(*DB)
}
func templateFromContext(ctx context.Context) *Template {
	return ctx.Value(contextKeyTemplate).(*Template)
}
func loginTokenFromContext(ctx context.Context) *AccessToken {
	v := ctx.Value(contextKeyLoginToken)
	if v == nil {
		return nil
	}
	return v.(*AccessToken)
}
func newBaseContext(db *DB, tpl *Template) context.Context {
	ctx := context.Background()
	ctx = context.WithValue(ctx, contextKeyDB, db)
	ctx = context.WithValue(ctx, contextKeyTemplate, tpl)
	return ctx
}
func setLoginTokenCookie(w http.ResponseWriter, req *http.Request, token *AccessToken, secret string) {
	http.SetCookie(w, &http.Cookie{
		Name:     loginCookieName,
		Value:    MarshalSecret(token.ID, SecretKindAccessToken, secret),
		HttpOnly: true,
		SameSite: http.SameSiteStrictMode,
		Secure:   isForwardedHTTPS(req),
	})
}
func unsetLoginTokenCookie(w http.ResponseWriter, req *http.Request) {
	http.SetCookie(w, &http.Cookie{
		Name:     loginCookieName,
		HttpOnly: true,
		SameSite: http.SameSiteStrictMode,
		Secure:   isForwardedHTTPS(req),
		MaxAge:   -1,
	})
}
func isForwardedHTTPS(req *http.Request) bool {
	if forwarded := req.Header.Get("Forwarded"); forwarded != "" {
		_, params, _ := mime.ParseMediaType("_; " + forwarded)
		return params["proto"] == "https"
	}
	if forwardedProto := req.Header.Get("X-Forwarded-Proto"); forwardedProto != "" {
		return forwardedProto == "https"
	}
	return false
}
func loginTokenMiddleware(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
		cookie, _ := req.Cookie(loginCookieName)
		if cookie == nil {
			next.ServeHTTP(w, req)
			return
		}
		ctx := req.Context()
		db := dbFromContext(ctx)
		tokenID, tokenSecret, _ := UnmarshalSecret[*AccessToken](cookie.Value)
		token, err := db.FetchAccessToken(ctx, tokenID)
		if err == errNoDBRows || (err == nil && !token.VerifySecret(tokenSecret)) {
			unsetLoginTokenCookie(w, req)
			next.ServeHTTP(w, req)
			return
		} else if err != nil {
			httpError(w, fmt.Errorf("failed to fetch access token: %v", err))
			return
		}
		if token.Scope != internalTokenScope {
			http.Error(w, "Invalid login token scope", http.StatusForbidden)
			return
		}
		if token.User == 0 {
			panic("login token with zero user ID")
		}
		ctx = context.WithValue(ctx, contextKeyLoginToken, token)
		req = req.WithContext(ctx)
		next.ServeHTTP(w, req)
	})
}
type TemplateBaseData struct {
	ServerName string
}
func (data *TemplateBaseData) Base() *TemplateBaseData {
	return data
}
type TemplateData interface {
	Base() *TemplateBaseData
}
type Template struct {
	tpl      *template.Template
	baseData *TemplateBaseData
}
func loadTemplate(fs fs.FS, pattern string, baseData *TemplateBaseData) (*Template, error) {
	tpl, err := template.ParseFS(fs, pattern)
	if err != nil {
		return nil, err
	}
	return &Template{tpl: tpl, baseData: baseData}, nil
}
func (tpl *Template) MustExecuteTemplate(w io.Writer, filename string, data TemplateData) {
	if data == nil {
		data = tpl.baseData
	} else {
		*data.Base() = *tpl.baseData
	}
	if err := tpl.tpl.ExecuteTemplate(w, filename, data); err != nil {
		panic(err)
	}
}