Lindenii Project Forge
Login

server

Vireo IdP server

Hi… I am well aware that this diff view is very suboptimal. It will be fixed when the refactored server comes along!

Commit info
ID
03bad951f6713fb5f70184550d8cb4546f22bc73
Author
Author date
Tue, 20 Feb 2024 12:11:04 +0100
Committer
Committer date
Tue, 20 Feb 2024 12:11:18 +0100
Actions
Fix admin ignored in DB.StoreUser
package main

import (
	"context"
	"database/sql"
	_ "embed"
	"fmt"
	"time"

	"github.com/mattn/go-sqlite3"
)

//go:embed schema.sql
var schema string

var errNoDBRows = sql.ErrNoRows

type DB struct {
	db *sql.DB
}

func openDB(filename string) (*DB, error) {
	sqlDB, err := sql.Open("sqlite3", filename)
	if err != nil {
		return nil, err
	}

	db := &DB{sqlDB}
	if err := db.init(context.TODO()); err != nil {
		db.Close()
		return nil, err
	}

	return db, nil
}

func (db *DB) init(ctx context.Context) error {
	var n int
	if err := db.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM sqlite_schema").Scan(&n); err != nil {
		return err
	} else if n != 0 {
		return nil
	}

	if _, err := db.db.ExecContext(ctx, schema); err != nil {
		return err
	}

	// TODO: drop this
	defaultUser := User{Username: "root", Admin: true}
	if err := defaultUser.SetPassword("root"); err != nil {
		return err
	}
	return db.StoreUser(ctx, &defaultUser)
}

func (db *DB) Close() error {
	return db.db.Close()
}

func (db *DB) FetchUser(ctx context.Context, id ID[*User]) (*User, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM User WHERE id = ?", id)
	if err != nil {
		return nil, err
	}
	var user User
	err = scanRow(&user, rows)
	return &user, err
}

func (db *DB) FetchUserByUsername(ctx context.Context, username string) (*User, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM User WHERE username = ?", username)
	if err != nil {
		return nil, err
	}
	var user User
	err = scanRow(&user, rows)
	return &user, err
}

func (db *DB) StoreUser(ctx context.Context, user *User) error {
	return db.db.QueryRowContext(ctx, `
		INSERT INTO User(id, username, password_hash)
		VALUES (:id, :username, :password_hash)
		INSERT INTO User(id, username, password_hash, admin)
		VALUES (:id, :username, :password_hash, :admin)
		ON CONFLICT(id) DO UPDATE SET
			username = :username,
			password_hash = :password_hash
			password_hash = :password_hash,
			admin = :admin
		RETURNING id
	`, entityArgs(user)...).Scan(&user.ID)
}

func (db *DB) ListUsers(ctx context.Context) ([]User, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM User")
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var l []User
	for rows.Next() {
		var user User
		if err := scan(&user, rows); err != nil {
			return nil, err
		}
		l = append(l, user)
	}

	return l, rows.Close()
}

func (db *DB) FetchClient(ctx context.Context, id ID[*Client]) (*Client, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM Client WHERE id = ?", id)
	if err != nil {
		return nil, err
	}
	var client Client
	err = scanRow(&client, rows)
	return &client, err
}

func (db *DB) FetchClientByClientID(ctx context.Context, clientID string) (*Client, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM Client WHERE client_id = ?", clientID)
	if err != nil {
		return nil, err
	}
	var client Client
	err = scanRow(&client, rows)
	return &client, err
}

func (db *DB) StoreClient(ctx context.Context, client *Client) error {
	return db.db.QueryRowContext(ctx, `
		INSERT INTO Client(id, client_id, client_secret_hash, owner,
			redirect_uris, client_name, client_uri)
		VALUES (:id, :client_id, :client_secret_hash, :owner,
			:redirect_uris, :client_name, :client_uri)
		ON CONFLICT(id) DO UPDATE SET
			client_id = :client_id,
			client_secret_hash = :client_secret_hash,
			owner = :owner,
			redirect_uris = :redirect_uris,
			client_name = :client_name,
			client_uri = :client_uri
		RETURNING id
	`, entityArgs(client)...).Scan(&client.ID)
}

func (db *DB) ListClients(ctx context.Context, owner ID[*User]) ([]Client, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM Client WHERE owner IS ?", owner)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var l []Client
	for rows.Next() {
		var client Client
		if err := scan(&client, rows); err != nil {
			return nil, err
		}
		l = append(l, client)
	}

	return l, rows.Close()
}

func (db *DB) ListAuthorizedClients(ctx context.Context, user ID[*User]) ([]AuthorizedClient, error) {
	rows, err := db.db.QueryContext(ctx, `
		SELECT id, client_id, client_name, client_uri, token.expires_at
		FROM Client,
		(
			SELECT client, MAX(expires_at) as expires_at
			FROM AccessToken
			WHERE user = ?
			GROUP BY client
		) AS token
		WHERE Client.id = token.client
	`, user)
	if err != nil {
		return nil, err
	}

	var l []AuthorizedClient
	for rows.Next() {
		var authClient AuthorizedClient
		columns := authClient.Client.columns()
		var expiresAt string
		err := rows.Scan(columns["id"], columns["client_id"], columns["client_name"], columns["client_uri"], &expiresAt)
		if err != nil {
			return nil, err
		}
		authClient.ExpiresAt, err = time.Parse(sqlite3.SQLiteTimestampFormats[0], expiresAt)
		if err != nil {
			return nil, err
		}
		l = append(l, authClient)
	}

	return l, rows.Close()
}

func (db *DB) DeleteClient(ctx context.Context, id ID[*Client]) error {
	_, err := db.db.ExecContext(ctx, "DELETE FROM Client WHERE id = ?", id)
	return err
}

func (db *DB) FetchAccessToken(ctx context.Context, id ID[*AccessToken]) (*AccessToken, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM AccessToken WHERE id = ?", id)
	if err != nil {
		return nil, err
	}
	var token AccessToken
	err = scanRow(&token, rows)
	return &token, err
}

func (db *DB) CreateAccessToken(ctx context.Context, token *AccessToken) error {
	return db.db.QueryRowContext(ctx, `
		INSERT INTO AccessToken(hash, user, client, scope, issued_at, expires_at)
		VALUES (:hash, :user, :client, :scope, :issued_at, :expires_at)
		RETURNING id
	`, entityArgs(token)...).Scan(&token.ID)
}

func (db *DB) DeleteAccessToken(ctx context.Context, id ID[*AccessToken]) error {
	_, err := db.db.ExecContext(ctx, "DELETE FROM AccessToken WHERE id = ?", id)
	return err
}

func (db *DB) RevokeAccessTokens(ctx context.Context, clientID ID[*Client], userID ID[*User]) error {
	_, err := db.db.ExecContext(ctx, `
		DELETE FROM AccessToken
		WHERE client = ? AND user = ?
	`, clientID, userID)
	return err
}

func (db *DB) CreateAuthCode(ctx context.Context, code *AuthCode) error {
	return db.db.QueryRowContext(ctx, `
		INSERT INTO AuthCode(hash, created_at, user, client, scope, redirect_uri)
		VALUES (:hash, :created_at, :user, :client, :scope, :redirect_uri)
		RETURNING id
	`, entityArgs(code)...).Scan(&code.ID)
}

func (db *DB) PopAuthCode(ctx context.Context, id ID[*AuthCode]) (*AuthCode, error) {
	rows, err := db.db.QueryContext(ctx, `
		DELETE FROM AuthCode
		WHERE id = ?
		RETURNING *
	`, id)
	if err != nil {
		return nil, err
	}
	var authCode AuthCode
	err = scanRow(&authCode, rows)
	return &authCode, err
}

func (db *DB) Maintain(ctx context.Context) error {
	_, err := db.db.ExecContext(ctx, `
		DELETE FROM AccessToken
		WHERE timediff('now', expires_at) > 0
	`)
	if err != nil {
		return err
	}

	_, err = db.db.ExecContext(ctx, `
		DELETE FROM AuthCode
		WHERE timediff(?, created_at) > 0
	`, time.Now().Add(-authCodeExpiration))
	if err != nil {
		return err
	}

	return nil
}

func scan(e entity, rows *sql.Rows) error {
	columns := e.columns()

	keys, err := rows.Columns()
	if err != nil {
		panic(err)
	}
	out := make([]interface{}, len(keys))
	for i, k := range keys {
		v, ok := columns[k]
		if !ok {
			panic(fmt.Errorf("unknown column %q", k))
		}
		out[i] = v
	}

	return rows.Scan(out...)
}

func scanRow(e entity, rows *sql.Rows) error {
	if !rows.Next() {
		return sql.ErrNoRows
	}
	if err := scan(e, rows); err != nil {
		return err
	}
	return rows.Close()
}

func entityArgs(e entity) []interface{} {
	columns := e.columns()

	l := make([]interface{}, 0, len(columns))
	for k, v := range columns {
		l = append(l, sql.Named(k, v))
	}

	return l
}