Hi… I am well aware that this diff view is very suboptimal. It will be fixed when the refactored server comes along!
Add page to create/update user
package main
import (
	"context"
	"database/sql"
	_ "embed"
	"fmt"
	_ "github.com/mattn/go-sqlite3"
)
//go:embed schema.sql
var schema string
var errNoDBRows = sql.ErrNoRows
type DB struct {
	db *sql.DB
}
func openDB(filename string) (*DB, error) {
	sqlDB, err := sql.Open("sqlite3", filename)
	if err != nil {
		return nil, err
	}
	db := &DB{sqlDB}
	if err := db.init(context.TODO()); err != nil {
		db.Close()
		return nil, err
	}
	return db, nil
}
func (db *DB) init(ctx context.Context) error {
	var n int
	if err := db.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM sqlite_schema").Scan(&n); err != nil {
		return err
	} else if n != 0 {
		return nil
	}
	if _, err := db.db.ExecContext(ctx, schema); err != nil {
		return err
	}
	// TODO: drop this
	defaultUser := User{Username: "root"}
	if err := defaultUser.SetPassword("root"); err != nil {
		return err
	}
	return db.StoreUser(ctx, &defaultUser)
}
func (db *DB) Close() error {
	return db.db.Close()
}
func (db *DB) FetchUser(ctx context.Context, username string) (*User, error) {
func (db *DB) FetchUser(ctx context.Context, id ID[*User]) (*User, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM User WHERE id = ?", id)
	if err != nil {
		return nil, err
	}
	var user User
	err = scanRow(&user, rows)
	return &user, err
}
func (db *DB) FetchUserByUsername(ctx context.Context, username string) (*User, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM User WHERE username = ?", username)
	if err != nil {
		return nil, err
	}
	var user User
	err = scanRow(&user, rows)
	return &user, err
}
func (db *DB) StoreUser(ctx context.Context, user *User) error {
	return db.db.QueryRowContext(ctx, `
		INSERT INTO User(id, username, password_hash)
		VALUES (:id, :username, :password_hash)
		ON CONFLICT(id) DO UPDATE SET
			username = :username,
			password_hash = :password_hash
		RETURNING id
	`, entityArgs(user)...).Scan(&user.ID)
}
func (db *DB) FetchClient(ctx context.Context, clientID string) (*Client, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM Client WHERE client_id = ?", clientID)
	if err != nil {
		return nil, err
	}
	var client Client
	err = scanRow(&client, rows)
	return &client, err
}
func (db *DB) StoreClient(ctx context.Context, client *Client) error {
	return db.db.QueryRowContext(ctx, `
		INSERT INTO Client(id, client_id, client_secret_hash, owner)
		VALUES (:id, :client_id, :client_secret_hash, :owner)
		ON CONFLICT(id) DO UPDATE SET
			client_id = :client_id,
			client_secret_hash = :client_secret_hash,
			owner = :owner
		RETURNING id
	`, entityArgs(client)...).Scan(&client.ID)
}
func (db *DB) ListClients(ctx context.Context, owner ID[*User]) ([]Client, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM Client WHERE owner IS ?", owner)
	if err != nil {
		return nil, err
	}
	defer rows.Close()
	var l []Client
	for rows.Next() {
		var client Client
		if err := scan(&client, rows); err != nil {
			return nil, err
		}
		l = append(l, client)
	}
	return l, rows.Close()
}
func (db *DB) FetchAccessToken(ctx context.Context, id ID[*AccessToken]) (*AccessToken, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM AccessToken WHERE id = ?", id)
	if err != nil {
		return nil, err
	}
	var token AccessToken
	err = scanRow(&token, rows)
	return &token, err
}
func (db *DB) CreateAccessToken(ctx context.Context, token *AccessToken) error {
	return db.db.QueryRowContext(ctx, `
		INSERT INTO AccessToken(hash, user, client, scope, issued_at, expires_at)
		VALUES (:hash, :user, :client, :scope, :issued_at, :expires_at)
		RETURNING id
	`, entityArgs(token)...).Scan(&token.ID)
}
func (db *DB) CreateAuthCode(ctx context.Context, code *AuthCode) error {
	return db.db.QueryRowContext(ctx, `
		INSERT INTO AuthCode(hash, created_at, user, client, scope)
		VALUES (:hash, :created_at, :user, :client, :scope)
		RETURNING id
	`, entityArgs(code)...).Scan(&code.ID)
}
func (db *DB) PopAuthCode(ctx context.Context, id ID[*AuthCode]) (*AuthCode, error) {
	rows, err := db.db.QueryContext(ctx, `
		DELETE FROM AuthCode
		WHERE id = ?
		RETURNING *
	`, id)
	if err != nil {
		return nil, err
	}
	var authCode AuthCode
	err = scanRow(&authCode, rows)
	return &authCode, err
}
func scan(e entity, rows *sql.Rows) error {
	columns := e.columns()
	keys, err := rows.Columns()
	if err != nil {
		panic(err)
	}
	out := make([]interface{}, len(keys))
	for i, k := range keys {
		v, ok := columns[k]
		if !ok {
			panic(fmt.Errorf("unknown column %q", k))
		}
		out[i] = v
	}
	return rows.Scan(out...)
}
func scanRow(e entity, rows *sql.Rows) error {
	if !rows.Next() {
		return sql.ErrNoRows
	}
	if err := scan(e, rows); err != nil {
		return err
	}
	return rows.Close()
}
func entityArgs(e entity) []interface{} {
	columns := e.columns()
	l := make([]interface{}, 0, len(columns))
	for k, v := range columns {
		l = append(l, sql.Named(k, v))
	}
	return l
}
package main
import (
	"crypto/rand"
	"crypto/sha512"
	"crypto/subtle"
	"database/sql"
	"database/sql/driver"
	"encoding/base64"
	"fmt"
	"strconv"
	"strings"
	"time"
	"golang.org/x/crypto/bcrypt"
)
type entity interface {
	columns() map[string]interface{}
}
var (
	_ entity = (*User)(nil)
	_ entity = (*Client)(nil)
	_ entity = (*AccessToken)(nil)
	_ entity = (*AuthCode)(nil)
)
type ID[T entity] int64
var (
	_ sql.Scanner   = (*ID[*User])(nil)
	_ driver.Valuer = ID[*User](0)
)
func ParseID[T entity](s string) (ID[T], error) {
	u, _ := strconv.ParseUint(s, 10, 63)
	if u == 0 {
		return 0, fmt.Errorf("invalid ID")
	}
	return ID[T](u), nil
}
func (ptr *ID[T]) Scan(v interface{}) error {
	if v == nil {
		*ptr = 0
		return nil
	}
	id, ok := v.(int64)
	if !ok {
		return fmt.Errorf("cannot scan ID from %T", v)
	}
	*ptr = ID[T](id)
	return nil
}
func (id ID[T]) Value() (driver.Value, error) {
	if id == 0 {
		return nil, nil
	} else {
		return int64(id), nil
	}
}
type User struct {
	ID           ID[*User]
	Username     string
	PasswordHash string
}
func (user *User) columns() map[string]interface{} {
	return map[string]interface{}{
		"id":            &user.ID,
		"username":      &user.Username,
		"password_hash": &user.PasswordHash,
	}
}
func (user *User) VerifyPassword(password string) error {
	// TODO: upgrade hash
	return bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password))
}
func (user *User) SetPassword(password string) error {
	hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
	if err != nil {
		return err
	}
	user.PasswordHash = string(hash)
	return nil
}
type Client struct {
	ID               ID[*Client]
	ClientID         string
	ClientSecretHash []byte
	Owner            ID[*User]
}
func NewClient(owner ID[*User]) (client *Client, secret string, err error) {
	id, err := generateUID()
	if err != nil {
		return nil, "", fmt.Errorf("failed to generate client ID: %v", err)
	}
	secret, hash, err := generateSecret()
	if err != nil {
		return nil, "", fmt.Errorf("failed to generate client secret: %v", err)
	}
	client = &Client{
		ClientID:         id,
		ClientSecretHash: hash,
		Owner:            owner,
	}
	return client, secret, nil
}
func (client *Client) columns() map[string]interface{} {
	return map[string]interface{}{
		"id":                 &client.ID,
		"client_id":          &client.ClientID,
		"client_secret_hash": &client.ClientSecretHash,
		"owner":              &client.Owner,
	}
}
func (client *Client) VerifySecret(secret string) bool {
	return verifyHash(client.ClientSecretHash, secret)
}
type AccessToken struct {
	ID        ID[*AccessToken]
	Hash      []byte
	User      ID[*User]
	Client    ID[*Client]
	Scope     string
	IssuedAt  time.Time
	ExpiresAt time.Time
}
func (token *AccessToken) Generate() (secret string, err error) {
	secret, hash, err := generateSecret()
	if err != nil {
		return "", fmt.Errorf("failed to generate access token secret: %v", err)
	}
	token.Hash = hash
	token.IssuedAt = time.Now()
	token.ExpiresAt = time.Now().Add(2 * time.Hour)
	return secret, nil
}
func NewAccessTokenFromAuthCode(authCode *AuthCode) (token *AccessToken, secret string, err error) {
	token = &AccessToken{
		User:   authCode.User,
		Client: authCode.Client,
		Scope:  authCode.Scope,
	}
	secret, err = token.Generate()
	return token, secret, err
}
func (token *AccessToken) columns() map[string]interface{} {
	return map[string]interface{}{
		"id":         &token.ID,
		"hash":       &token.Hash,
		"user":       &token.User,
		"client":     &token.Client,
		"scope":      &token.Scope,
		"issued_at":  &token.IssuedAt,
		"expires_at": &token.ExpiresAt,
	}
}
func (token *AccessToken) VerifySecret(secret string) bool {
	return verifyHash(token.Hash, secret) && verifyExpiration(token.ExpiresAt)
}
type AuthCode struct {
	ID        ID[*AuthCode]
	Hash      []byte
	CreatedAt time.Time
	User      ID[*User]
	Client    ID[*Client]
	Scope     string
}
func NewAuthCode(user ID[*User], client ID[*Client], scope string) (code *AuthCode, secret string, err error) {
	secret, hash, err := generateSecret()
	if err != nil {
		return nil, "", fmt.Errorf("failed to generate authentication code secret: %v", err)
	}
	code = &AuthCode{
		Hash:      hash,
		CreatedAt: time.Now(),
		User:      user,
		Client:    client,
		Scope:     scope,
	}
	return code, secret, nil
}
func (code *AuthCode) columns() map[string]interface{} {
	return map[string]interface{}{
		"id":         &code.ID,
		"hash":       &code.Hash,
		"created_at": &code.CreatedAt,
		"user":       &code.User,
		"client":     &code.Client,
		"scope":      &code.Scope,
	}
}
func (code *AuthCode) VerifySecret(secret string) bool {
	return verifyHash(code.Hash, secret) && verifyExpiration(code.CreatedAt.Add(10*time.Minute))
}
func UnmarshalSecret[T entity](s string) (id ID[T], secret string, err error) {
	idStr, secret, _ := strings.Cut(s, ".")
	u, _ := strconv.ParseUint(idStr, 10, 63)
	if u == 0 {
		return 0, "", fmt.Errorf("invalid ID")
	}
	return ID[T](int64(u)), secret, nil
id, err = ParseID[T](idStr) return id, secret, err
}
func MarshalSecret[T entity](id ID[T], secret string) string {
	if id == 0 {
		panic("cannot marshal zero ID")
	}
	return fmt.Sprintf("%v.%v", int64(id), secret)
}
func generateUID() (string, error) {
	b := make([]byte, 32)
	if _, err := rand.Read(b); err != nil {
		return "", err
	}
	return base64.RawURLEncoding.EncodeToString(b), nil
}
func generateSecret() (secret string, hash []byte, err error) {
	b := make([]byte, 32)
	if _, err := rand.Read(b); err != nil {
		return "", nil, err
	}
	secret = base64.RawURLEncoding.EncodeToString(b)
	h := sha512.Sum512(b)
	return secret, h[:], nil
}
func verifyHash(hash []byte, secret string) bool {
	b, _ := base64.RawURLEncoding.DecodeString(secret)
	h := sha512.Sum512(b)
	return subtle.ConstantTimeCompare(hash, h[:]) == 1
}
func verifyExpiration(t time.Time) bool {
	return time.Now().Before(t)
}
package main
import (
	"context"
	"embed"
	"flag"
	"html/template"
	"log"
	"net"
	"net/http"
	"github.com/go-chi/chi/v5"
)
var (
	//go:embed template
	templateFS embed.FS
	//go:embed static
	staticFS embed.FS
)
func main() {
	var listenAddr string
	flag.StringVar(&listenAddr, "listen", ":8080", "HTTP listen address")
	flag.Parse()
	tpl := template.Must(template.ParseFS(templateFS, "template/*.html"))
	db, err := openDB("sinwon.db")
	if err != nil {
		log.Fatalf("Failed to open DB: %v", err)
	}
	mux := chi.NewRouter()
	mux.Handle("/static/*", http.FileServer(http.FS(staticFS)))
	mux.Get("/", index)
	mux.Post("/client/new", createClient)
	mux.HandleFunc("/login", login)
	mux.HandleFunc("/user/new", updateUser)
	mux.HandleFunc("/user/{id}", updateUser)
	mux.HandleFunc("/authorize", authorize)
	mux.Post("/token", exchangeToken)
	server := http.Server{
		Addr:    listenAddr,
		Handler: loginTokenMiddleware(mux),
		BaseContext: func(net.Listener) context.Context {
			return newBaseContext(db, tpl)
		},
	}
	log.Printf("OAuth server listening on %v", server.Addr)
	if err := server.ListenAndServe(); err != nil {
		log.Fatalf("Failed to listen and serve: %v", err)
	}
}
func httpError(w http.ResponseWriter, err error) {
	log.Print(err)
	http.Error(w, "Internal server error", http.StatusInternalServerError)
}
{{ template "head.html" }}
<main>
<h1>sinwon</h1>
<form method="post" action="/client/new">
	<a href="/user/new"><button type="button">Create user</button></a>
	<a href="/user/{{ .Me }}"><button type="button">Edit user</button></a>
<button type="submit">Register new client</button> </form>
{{ with . }}
{{ with .Clients }}
	<p>{{ . | len }} clients registered:</p>
	<ul>
		{{ range . }}
			<li><code>{{ .ClientID }}</code></li>
		{{ end }}
	</ul>
{{ else }}
	<p>No client registered yet.</p>
{{ end }}
</main>
{{ template "foot.html" }}
{{ template "head.html" }}
<main>
<h1>sinwon</h1>
<form method="post" action="">
	Username: <input type="text" name="username" value="{{ .Username }}" required><br>
	Password: <input type="password" name="password"><br>
	<button type="submit">
		{{ if .Username }}
			Update user
		{{ else }}
			Create user
		{{ end }}
	</button>
</form>
</main>
{{ template "foot.html" }}
package main import ( "fmt" "log" "net/http" "net/url"
"github.com/go-chi/chi/v5"
)
func index(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	db := dbFromContext(ctx)
	tpl := templateFromContext(ctx)
	loginToken := loginTokenFromContext(ctx)
	if loginToken == nil {
		http.Redirect(w, req, "/login", http.StatusFound)
		return
	}
	clients, err := db.ListClients(ctx, loginToken.User)
	if err != nil {
		httpError(w, err)
		return
	}
	if err := tpl.ExecuteTemplate(w, "index.html", clients); err != nil {
	data := struct {
		Clients []Client
		Me      ID[*User]
	}{
		Clients: clients,
		Me:      loginToken.User,
	}
	if err := tpl.ExecuteTemplate(w, "index.html", &data); err != nil {
		panic(err)
	}
}
func login(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	db := dbFromContext(ctx)
	tpl := templateFromContext(ctx)
	q := req.URL.Query()
	rawRedirectURI := q.Get("redirect_uri")
	if rawRedirectURI == "" {
		rawRedirectURI = "/"
	}
	redirectURI, err := url.Parse(rawRedirectURI)
	if err != nil || redirectURI.Scheme != "" || redirectURI.Opaque != "" || redirectURI.User != nil || redirectURI.Host != "" {
		http.Error(w, "Invalid redirect URI", http.StatusBadRequest)
		return
	}
	if loginTokenFromContext(ctx) != nil {
		http.Redirect(w, req, redirectURI.String(), http.StatusFound)
		return
	}
	username := req.PostFormValue("username")
	password := req.PostFormValue("password")
	if username == "" {
		if err := tpl.ExecuteTemplate(w, "login.html", nil); err != nil {
			panic(err)
		}
		return
	}
user, err := db.FetchUser(ctx, username)
user, err := db.FetchUserByUsername(ctx, username)
	if err != nil && err != errNoDBRows {
		httpError(w, fmt.Errorf("failed to fetch user: %v", err))
		return
	}
	if err == nil {
		err = user.VerifyPassword(password)
	}
	if err != nil {
		log.Printf("login failed for user %q: %v", username, err)
		// TODO: show error message
		if err := tpl.ExecuteTemplate(w, "login.html", nil); err != nil {
			panic(err)
		}
		return
	}
	token := AccessToken{
		User:  user.ID,
		Scope: internalTokenScope,
	}
	secret, err := token.Generate()
	if err != nil {
		httpError(w, fmt.Errorf("failed to generate access token: %v", err))
		return
	}
	if err := db.CreateAccessToken(ctx, &token); err != nil {
		httpError(w, fmt.Errorf("failed to create access token: %v", err))
		return
	}
	setLoginTokenCookie(w, &token, secret)
	http.Redirect(w, req, redirectURI.String(), http.StatusFound)
}
func updateUser(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	db := dbFromContext(ctx)
	tpl := templateFromContext(ctx)
	user := new(User)
	if idStr := chi.URLParam(req, "id"); idStr != "" {
		id, err := ParseID[*User](idStr)
		if err != nil {
			http.Error(w, err.Error(), http.StatusBadRequest)
			return
		}
		user, err = db.FetchUser(ctx, id)
		if err != nil {
			httpError(w, err)
			return
		}
	}
	loginToken := loginTokenFromContext(ctx)
	if loginToken == nil {
		http.Redirect(w, req, "/login", http.StatusFound)
		return
	}
	if user.ID != 0 && loginToken.User != user.ID {
		http.Error(w, "Access denied", http.StatusForbidden)
		return
	}
	username := req.PostFormValue("username")
	password := req.PostFormValue("password")
	if username == "" {
		if err := tpl.ExecuteTemplate(w, "update-user.html", user); err != nil {
			panic(err)
		}
		return
	}
	user.Username = username
	if password != "" {
		if err := user.SetPassword(password); err != nil {
			httpError(w, err)
			return
		}
	}
	if err := db.StoreUser(ctx, user); err != nil {
		httpError(w, err)
		return
	}
	http.Redirect(w, req, "/", http.StatusFound)
}