Lindenii Project Forge
Login

server

Vireo IdP server

Hi… I am well aware that this diff view is very suboptimal. It will be fixed when the refactored server comes along!

Commit info
ID
2f403be1236a514bc23fafba9b072683e2c69161
Author
Author date
Mon, 19 Feb 2024 09:52:55 +0100
Committer
Committer date
Mon, 19 Feb 2024 09:54:07 +0100
Actions
Cleanup expired tokens and codes from DB

Closes: https://todo.sr.ht/~emersion/sinwon/4
package main

import (
	"context"
	"database/sql"
	_ "embed"
	"fmt"
	"time"

	_ "github.com/mattn/go-sqlite3"
)

//go:embed schema.sql
var schema string

var errNoDBRows = sql.ErrNoRows

type DB struct {
	db *sql.DB
}

func openDB(filename string) (*DB, error) {
	sqlDB, err := sql.Open("sqlite3", filename)
	if err != nil {
		return nil, err
	}

	db := &DB{sqlDB}
	if err := db.init(context.TODO()); err != nil {
		db.Close()
		return nil, err
	}

	return db, nil
}

func (db *DB) init(ctx context.Context) error {
	var n int
	if err := db.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM sqlite_schema").Scan(&n); err != nil {
		return err
	} else if n != 0 {
		return nil
	}

	if _, err := db.db.ExecContext(ctx, schema); err != nil {
		return err
	}

	// TODO: drop this
	defaultUser := User{Username: "root"}
	if err := defaultUser.SetPassword("root"); err != nil {
		return err
	}
	return db.StoreUser(ctx, &defaultUser)
}

func (db *DB) Close() error {
	return db.db.Close()
}

func (db *DB) FetchUser(ctx context.Context, id ID[*User]) (*User, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM User WHERE id = ?", id)
	if err != nil {
		return nil, err
	}
	var user User
	err = scanRow(&user, rows)
	return &user, err
}

func (db *DB) FetchUserByUsername(ctx context.Context, username string) (*User, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM User WHERE username = ?", username)
	if err != nil {
		return nil, err
	}
	var user User
	err = scanRow(&user, rows)
	return &user, err
}

func (db *DB) StoreUser(ctx context.Context, user *User) error {
	return db.db.QueryRowContext(ctx, `
		INSERT INTO User(id, username, password_hash)
		VALUES (:id, :username, :password_hash)
		ON CONFLICT(id) DO UPDATE SET
			username = :username,
			password_hash = :password_hash
		RETURNING id
	`, entityArgs(user)...).Scan(&user.ID)
}

func (db *DB) FetchClient(ctx context.Context, clientID string) (*Client, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM Client WHERE client_id = ?", clientID)
	if err != nil {
		return nil, err
	}
	var client Client
	err = scanRow(&client, rows)
	return &client, err
}

func (db *DB) StoreClient(ctx context.Context, client *Client) error {
	return db.db.QueryRowContext(ctx, `
		INSERT INTO Client(id, client_id, client_secret_hash, owner)
		VALUES (:id, :client_id, :client_secret_hash, :owner)
		ON CONFLICT(id) DO UPDATE SET
			client_id = :client_id,
			client_secret_hash = :client_secret_hash,
			owner = :owner
		RETURNING id
	`, entityArgs(client)...).Scan(&client.ID)
}

func (db *DB) ListClients(ctx context.Context, owner ID[*User]) ([]Client, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM Client WHERE owner IS ?", owner)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var l []Client
	for rows.Next() {
		var client Client
		if err := scan(&client, rows); err != nil {
			return nil, err
		}
		l = append(l, client)
	}

	return l, rows.Close()
}

func (db *DB) FetchAccessToken(ctx context.Context, id ID[*AccessToken]) (*AccessToken, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM AccessToken WHERE id = ?", id)
	if err != nil {
		return nil, err
	}
	var token AccessToken
	err = scanRow(&token, rows)
	return &token, err
}

func (db *DB) CreateAccessToken(ctx context.Context, token *AccessToken) error {
	return db.db.QueryRowContext(ctx, `
		INSERT INTO AccessToken(hash, user, client, scope, issued_at, expires_at)
		VALUES (:hash, :user, :client, :scope, :issued_at, :expires_at)
		RETURNING id
	`, entityArgs(token)...).Scan(&token.ID)
}

func (db *DB) CreateAuthCode(ctx context.Context, code *AuthCode) error {
	return db.db.QueryRowContext(ctx, `
		INSERT INTO AuthCode(hash, created_at, user, client, scope)
		VALUES (:hash, :created_at, :user, :client, :scope)
		RETURNING id
	`, entityArgs(code)...).Scan(&code.ID)
}

func (db *DB) PopAuthCode(ctx context.Context, id ID[*AuthCode]) (*AuthCode, error) {
	rows, err := db.db.QueryContext(ctx, `
		DELETE FROM AuthCode
		WHERE id = ?
		RETURNING *
	`, id)
	if err != nil {
		return nil, err
	}
	var authCode AuthCode
	err = scanRow(&authCode, rows)
	return &authCode, err
}

func (db *DB) Maintain(ctx context.Context) error {
	_, err := db.db.ExecContext(ctx, `
		DELETE FROM AccessToken
		WHERE timediff('now', expires_at) > 0
	`)
	if err != nil {
		return err
	}

	_, err = db.db.ExecContext(ctx, `
		DELETE FROM AuthCode
		WHERE timediff(?, created_at) > 0
	`, time.Now().Add(-authCodeExpiration))
	if err != nil {
		return err
	}

	return nil
}

func scan(e entity, rows *sql.Rows) error {
	columns := e.columns()

	keys, err := rows.Columns()
	if err != nil {
		panic(err)
	}
	out := make([]interface{}, len(keys))
	for i, k := range keys {
		v, ok := columns[k]
		if !ok {
			panic(fmt.Errorf("unknown column %q", k))
		}
		out[i] = v
	}

	return rows.Scan(out...)
}

func scanRow(e entity, rows *sql.Rows) error {
	if !rows.Next() {
		return sql.ErrNoRows
	}
	if err := scan(e, rows); err != nil {
		return err
	}
	return rows.Close()
}

func entityArgs(e entity) []interface{} {
	columns := e.columns()

	l := make([]interface{}, 0, len(columns))
	for k, v := range columns {
		l = append(l, sql.Named(k, v))
	}

	return l
}
package main

import (
	"crypto/rand"
	"crypto/sha512"
	"crypto/subtle"
	"database/sql"
	"database/sql/driver"
	"encoding/base64"
	"fmt"
	"strconv"
	"strings"
	"time"

	"golang.org/x/crypto/bcrypt"
)

const authCodeExpiration = 10 * time.Minute

type entity interface {
	columns() map[string]interface{}
}

var (
	_ entity = (*User)(nil)
	_ entity = (*Client)(nil)
	_ entity = (*AccessToken)(nil)
	_ entity = (*AuthCode)(nil)
)

type ID[T entity] int64

var (
	_ sql.Scanner   = (*ID[*User])(nil)
	_ driver.Valuer = ID[*User](0)
)

func ParseID[T entity](s string) (ID[T], error) {
	u, _ := strconv.ParseUint(s, 10, 63)
	if u == 0 {
		return 0, fmt.Errorf("invalid ID")
	}
	return ID[T](u), nil
}

func (ptr *ID[T]) Scan(v interface{}) error {
	if v == nil {
		*ptr = 0
		return nil
	}
	id, ok := v.(int64)
	if !ok {
		return fmt.Errorf("cannot scan ID from %T", v)
	}
	*ptr = ID[T](id)
	return nil
}

func (id ID[T]) Value() (driver.Value, error) {
	if id == 0 {
		return nil, nil
	} else {
		return int64(id), nil
	}
}

type User struct {
	ID           ID[*User]
	Username     string
	PasswordHash string
}

func (user *User) columns() map[string]interface{} {
	return map[string]interface{}{
		"id":            &user.ID,
		"username":      &user.Username,
		"password_hash": &user.PasswordHash,
	}
}

func (user *User) VerifyPassword(password string) error {
	// TODO: upgrade hash
	return bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password))
}

func (user *User) SetPassword(password string) error {
	hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
	if err != nil {
		return err
	}
	user.PasswordHash = string(hash)
	return nil
}

type Client struct {
	ID               ID[*Client]
	ClientID         string
	ClientSecretHash []byte
	Owner            ID[*User]
}

func NewClient(owner ID[*User]) (client *Client, secret string, err error) {
	id, err := generateUID()
	if err != nil {
		return nil, "", fmt.Errorf("failed to generate client ID: %v", err)
	}
	secret, hash, err := generateSecret()
	if err != nil {
		return nil, "", fmt.Errorf("failed to generate client secret: %v", err)
	}
	client = &Client{
		ClientID:         id,
		ClientSecretHash: hash,
		Owner:            owner,
	}
	return client, secret, nil
}

func (client *Client) columns() map[string]interface{} {
	return map[string]interface{}{
		"id":                 &client.ID,
		"client_id":          &client.ClientID,
		"client_secret_hash": &client.ClientSecretHash,
		"owner":              &client.Owner,
	}
}

func (client *Client) VerifySecret(secret string) bool {
	return verifyHash(client.ClientSecretHash, secret)
}

type AccessToken struct {
	ID        ID[*AccessToken]
	Hash      []byte
	User      ID[*User]
	Client    ID[*Client]
	Scope     string
	IssuedAt  time.Time
	ExpiresAt time.Time
}

func (token *AccessToken) Generate() (secret string, err error) {
	secret, hash, err := generateSecret()
	if err != nil {
		return "", fmt.Errorf("failed to generate access token secret: %v", err)
	}
	token.Hash = hash
	token.IssuedAt = time.Now()
	token.ExpiresAt = time.Now().Add(2 * time.Hour)
	return secret, nil
}

func NewAccessTokenFromAuthCode(authCode *AuthCode) (token *AccessToken, secret string, err error) {
	token = &AccessToken{
		User:   authCode.User,
		Client: authCode.Client,
		Scope:  authCode.Scope,
	}
	secret, err = token.Generate()
	return token, secret, err
}

func (token *AccessToken) columns() map[string]interface{} {
	return map[string]interface{}{
		"id":         &token.ID,
		"hash":       &token.Hash,
		"user":       &token.User,
		"client":     &token.Client,
		"scope":      &token.Scope,
		"issued_at":  &token.IssuedAt,
		"expires_at": &token.ExpiresAt,
	}
}

func (token *AccessToken) VerifySecret(secret string) bool {
	return verifyHash(token.Hash, secret) && verifyExpiration(token.ExpiresAt)
}

type AuthCode struct {
	ID        ID[*AuthCode]
	Hash      []byte
	CreatedAt time.Time
	User      ID[*User]
	Client    ID[*Client]
	Scope     string
}

func NewAuthCode(user ID[*User], client ID[*Client], scope string) (code *AuthCode, secret string, err error) {
	secret, hash, err := generateSecret()
	if err != nil {
		return nil, "", fmt.Errorf("failed to generate authentication code secret: %v", err)
	}
	code = &AuthCode{
		Hash:      hash,
		CreatedAt: time.Now(),
		User:      user,
		Client:    client,
		Scope:     scope,
	}
	return code, secret, nil
}

func (code *AuthCode) columns() map[string]interface{} {
	return map[string]interface{}{
		"id":         &code.ID,
		"hash":       &code.Hash,
		"created_at": &code.CreatedAt,
		"user":       &code.User,
		"client":     &code.Client,
		"scope":      &code.Scope,
	}
}

func (code *AuthCode) VerifySecret(secret string) bool {
	return verifyHash(code.Hash, secret) && verifyExpiration(code.CreatedAt.Add(10*time.Minute))
	return verifyHash(code.Hash, secret) && verifyExpiration(code.CreatedAt.Add(authCodeExpiration))
}

func UnmarshalSecret[T entity](s string) (id ID[T], secret string, err error) {
	idStr, secret, _ := strings.Cut(s, ".")
	id, err = ParseID[T](idStr)
	return id, secret, err
}

func MarshalSecret[T entity](id ID[T], secret string) string {
	if id == 0 {
		panic("cannot marshal zero ID")
	}
	return fmt.Sprintf("%v.%v", int64(id), secret)
}

func generateUID() (string, error) {
	b := make([]byte, 32)
	if _, err := rand.Read(b); err != nil {
		return "", err
	}
	return base64.RawURLEncoding.EncodeToString(b), nil
}

func generateSecret() (secret string, hash []byte, err error) {
	b := make([]byte, 32)
	if _, err := rand.Read(b); err != nil {
		return "", nil, err
	}
	secret = base64.RawURLEncoding.EncodeToString(b)
	h := sha512.Sum512(b)
	return secret, h[:], nil
}

func verifyHash(hash []byte, secret string) bool {
	b, _ := base64.RawURLEncoding.DecodeString(secret)
	h := sha512.Sum512(b)
	return subtle.ConstantTimeCompare(hash, h[:]) == 1
}

func verifyExpiration(t time.Time) bool {
	return time.Now().Before(t)
}
package main

import (
	"context"
	"embed"
	"flag"
	"html/template"
	"log"
	"net"
	"net/http"
	"time"

	"github.com/go-chi/chi/v5"
)

var (
	//go:embed template
	templateFS embed.FS
	//go:embed static
	staticFS embed.FS
)

func main() {
	var configFilename, listenAddr string
	flag.StringVar(&configFilename, "config", "/etc/sinwon/config", "Configuration filename")
	flag.StringVar(&listenAddr, "listen", ":8080", "HTTP listen address")
	flag.Parse()

	cfg, err := loadConfig(configFilename)
	if err != nil {
		log.Fatalf("Failed to load config file: %v", err)
	}

	if listenAddr == "" {
		listenAddr = cfg.Listen
	}
	if listenAddr == "" {
		log.Fatalf("Missing listen configuration")
	}
	if cfg.Database == "" {
		log.Fatalf("Missing database configuration")
	}

	db, err := openDB(cfg.Database)
	if err != nil {
		log.Fatalf("Failed to open DB: %v", err)
	}

	tpl := template.Must(template.ParseFS(templateFS, "template/*.html"))

	mux := chi.NewRouter()
	mux.Handle("/static/*", http.FileServer(http.FS(staticFS)))
	mux.Get("/", index)
	mux.Post("/client/new", createClient)
	mux.HandleFunc("/login", login)
	mux.Post("/logout", logout)
	mux.HandleFunc("/user/new", updateUser)
	mux.HandleFunc("/user/{id}", updateUser)
	mux.HandleFunc("/authorize", authorize)
	mux.Post("/token", exchangeToken)

	go maintainDBLoop(db)

	server := http.Server{
		Addr:    listenAddr,
		Handler: loginTokenMiddleware(mux),
		BaseContext: func(net.Listener) context.Context {
			return newBaseContext(db, tpl)
		},
	}
	log.Printf("OAuth server listening on %v", server.Addr)
	if err := server.ListenAndServe(); err != nil {
		log.Fatalf("Failed to listen and serve: %v", err)
	}
}

func httpError(w http.ResponseWriter, err error) {
	log.Print(err)
	http.Error(w, "Internal server error", http.StatusInternalServerError)
}

func maintainDBLoop(db *DB) {
	ticker := time.NewTicker(15 * time.Minute)
	defer ticker.Stop()

	for range ticker.C {
		ctx := context.Background()
		ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
		if err := db.Maintain(ctx); err != nil {
			log.Printf("Failed to perform database maintenance: %v", err)
		}
		cancel()
	}
}