Lindenii Project Forge
Login

server

Vireo IdP server

Hi… I am well aware that this diff view is very suboptimal. It will be fixed when the refactored server comes along!

Commit info
ID
4e6b789167266f312630101773867906d75bada8
Author
Author date
Mon, 19 Feb 2024 01:19:56 +0100
Committer
Committer date
Mon, 19 Feb 2024 01:19:56 +0100
Actions
Add page to create/update user
package main

import (
	"context"
	"database/sql"
	_ "embed"
	"fmt"

	_ "github.com/mattn/go-sqlite3"
)

//go:embed schema.sql
var schema string

var errNoDBRows = sql.ErrNoRows

type DB struct {
	db *sql.DB
}

func openDB(filename string) (*DB, error) {
	sqlDB, err := sql.Open("sqlite3", filename)
	if err != nil {
		return nil, err
	}

	db := &DB{sqlDB}
	if err := db.init(context.TODO()); err != nil {
		db.Close()
		return nil, err
	}

	return db, nil
}

func (db *DB) init(ctx context.Context) error {
	var n int
	if err := db.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM sqlite_schema").Scan(&n); err != nil {
		return err
	} else if n != 0 {
		return nil
	}

	if _, err := db.db.ExecContext(ctx, schema); err != nil {
		return err
	}

	// TODO: drop this
	defaultUser := User{Username: "root"}
	if err := defaultUser.SetPassword("root"); err != nil {
		return err
	}
	return db.StoreUser(ctx, &defaultUser)
}

func (db *DB) Close() error {
	return db.db.Close()
}

func (db *DB) FetchUser(ctx context.Context, username string) (*User, error) {
func (db *DB) FetchUser(ctx context.Context, id ID[*User]) (*User, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM User WHERE id = ?", id)
	if err != nil {
		return nil, err
	}
	var user User
	err = scanRow(&user, rows)
	return &user, err
}

func (db *DB) FetchUserByUsername(ctx context.Context, username string) (*User, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM User WHERE username = ?", username)
	if err != nil {
		return nil, err
	}
	var user User
	err = scanRow(&user, rows)
	return &user, err
}

func (db *DB) StoreUser(ctx context.Context, user *User) error {
	return db.db.QueryRowContext(ctx, `
		INSERT INTO User(id, username, password_hash)
		VALUES (:id, :username, :password_hash)
		ON CONFLICT(id) DO UPDATE SET
			username = :username,
			password_hash = :password_hash
		RETURNING id
	`, entityArgs(user)...).Scan(&user.ID)
}

func (db *DB) FetchClient(ctx context.Context, clientID string) (*Client, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM Client WHERE client_id = ?", clientID)
	if err != nil {
		return nil, err
	}
	var client Client
	err = scanRow(&client, rows)
	return &client, err
}

func (db *DB) StoreClient(ctx context.Context, client *Client) error {
	return db.db.QueryRowContext(ctx, `
		INSERT INTO Client(id, client_id, client_secret_hash, owner)
		VALUES (:id, :client_id, :client_secret_hash, :owner)
		ON CONFLICT(id) DO UPDATE SET
			client_id = :client_id,
			client_secret_hash = :client_secret_hash,
			owner = :owner
		RETURNING id
	`, entityArgs(client)...).Scan(&client.ID)
}

func (db *DB) ListClients(ctx context.Context, owner ID[*User]) ([]Client, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM Client WHERE owner IS ?", owner)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var l []Client
	for rows.Next() {
		var client Client
		if err := scan(&client, rows); err != nil {
			return nil, err
		}
		l = append(l, client)
	}

	return l, rows.Close()
}

func (db *DB) FetchAccessToken(ctx context.Context, id ID[*AccessToken]) (*AccessToken, error) {
	rows, err := db.db.QueryContext(ctx, "SELECT * FROM AccessToken WHERE id = ?", id)
	if err != nil {
		return nil, err
	}
	var token AccessToken
	err = scanRow(&token, rows)
	return &token, err
}

func (db *DB) CreateAccessToken(ctx context.Context, token *AccessToken) error {
	return db.db.QueryRowContext(ctx, `
		INSERT INTO AccessToken(hash, user, client, scope, issued_at, expires_at)
		VALUES (:hash, :user, :client, :scope, :issued_at, :expires_at)
		RETURNING id
	`, entityArgs(token)...).Scan(&token.ID)
}

func (db *DB) CreateAuthCode(ctx context.Context, code *AuthCode) error {
	return db.db.QueryRowContext(ctx, `
		INSERT INTO AuthCode(hash, created_at, user, client, scope)
		VALUES (:hash, :created_at, :user, :client, :scope)
		RETURNING id
	`, entityArgs(code)...).Scan(&code.ID)
}

func (db *DB) PopAuthCode(ctx context.Context, id ID[*AuthCode]) (*AuthCode, error) {
	rows, err := db.db.QueryContext(ctx, `
		DELETE FROM AuthCode
		WHERE id = ?
		RETURNING *
	`, id)
	if err != nil {
		return nil, err
	}
	var authCode AuthCode
	err = scanRow(&authCode, rows)
	return &authCode, err
}

func scan(e entity, rows *sql.Rows) error {
	columns := e.columns()

	keys, err := rows.Columns()
	if err != nil {
		panic(err)
	}
	out := make([]interface{}, len(keys))
	for i, k := range keys {
		v, ok := columns[k]
		if !ok {
			panic(fmt.Errorf("unknown column %q", k))
		}
		out[i] = v
	}

	return rows.Scan(out...)
}

func scanRow(e entity, rows *sql.Rows) error {
	if !rows.Next() {
		return sql.ErrNoRows
	}
	if err := scan(e, rows); err != nil {
		return err
	}
	return rows.Close()
}

func entityArgs(e entity) []interface{} {
	columns := e.columns()

	l := make([]interface{}, 0, len(columns))
	for k, v := range columns {
		l = append(l, sql.Named(k, v))
	}

	return l
}
package main

import (
	"crypto/rand"
	"crypto/sha512"
	"crypto/subtle"
	"database/sql"
	"database/sql/driver"
	"encoding/base64"
	"fmt"
	"strconv"
	"strings"
	"time"

	"golang.org/x/crypto/bcrypt"
)

type entity interface {
	columns() map[string]interface{}
}

var (
	_ entity = (*User)(nil)
	_ entity = (*Client)(nil)
	_ entity = (*AccessToken)(nil)
	_ entity = (*AuthCode)(nil)
)

type ID[T entity] int64

var (
	_ sql.Scanner   = (*ID[*User])(nil)
	_ driver.Valuer = ID[*User](0)
)

func ParseID[T entity](s string) (ID[T], error) {
	u, _ := strconv.ParseUint(s, 10, 63)
	if u == 0 {
		return 0, fmt.Errorf("invalid ID")
	}
	return ID[T](u), nil
}

func (ptr *ID[T]) Scan(v interface{}) error {
	if v == nil {
		*ptr = 0
		return nil
	}
	id, ok := v.(int64)
	if !ok {
		return fmt.Errorf("cannot scan ID from %T", v)
	}
	*ptr = ID[T](id)
	return nil
}

func (id ID[T]) Value() (driver.Value, error) {
	if id == 0 {
		return nil, nil
	} else {
		return int64(id), nil
	}
}

type User struct {
	ID           ID[*User]
	Username     string
	PasswordHash string
}

func (user *User) columns() map[string]interface{} {
	return map[string]interface{}{
		"id":            &user.ID,
		"username":      &user.Username,
		"password_hash": &user.PasswordHash,
	}
}

func (user *User) VerifyPassword(password string) error {
	// TODO: upgrade hash
	return bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password))
}

func (user *User) SetPassword(password string) error {
	hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
	if err != nil {
		return err
	}
	user.PasswordHash = string(hash)
	return nil
}

type Client struct {
	ID               ID[*Client]
	ClientID         string
	ClientSecretHash []byte
	Owner            ID[*User]
}

func NewClient(owner ID[*User]) (client *Client, secret string, err error) {
	id, err := generateUID()
	if err != nil {
		return nil, "", fmt.Errorf("failed to generate client ID: %v", err)
	}
	secret, hash, err := generateSecret()
	if err != nil {
		return nil, "", fmt.Errorf("failed to generate client secret: %v", err)
	}
	client = &Client{
		ClientID:         id,
		ClientSecretHash: hash,
		Owner:            owner,
	}
	return client, secret, nil
}

func (client *Client) columns() map[string]interface{} {
	return map[string]interface{}{
		"id":                 &client.ID,
		"client_id":          &client.ClientID,
		"client_secret_hash": &client.ClientSecretHash,
		"owner":              &client.Owner,
	}
}

func (client *Client) VerifySecret(secret string) bool {
	return verifyHash(client.ClientSecretHash, secret)
}

type AccessToken struct {
	ID        ID[*AccessToken]
	Hash      []byte
	User      ID[*User]
	Client    ID[*Client]
	Scope     string
	IssuedAt  time.Time
	ExpiresAt time.Time
}

func (token *AccessToken) Generate() (secret string, err error) {
	secret, hash, err := generateSecret()
	if err != nil {
		return "", fmt.Errorf("failed to generate access token secret: %v", err)
	}
	token.Hash = hash
	token.IssuedAt = time.Now()
	token.ExpiresAt = time.Now().Add(2 * time.Hour)
	return secret, nil
}

func NewAccessTokenFromAuthCode(authCode *AuthCode) (token *AccessToken, secret string, err error) {
	token = &AccessToken{
		User:   authCode.User,
		Client: authCode.Client,
		Scope:  authCode.Scope,
	}
	secret, err = token.Generate()
	return token, secret, err
}

func (token *AccessToken) columns() map[string]interface{} {
	return map[string]interface{}{
		"id":         &token.ID,
		"hash":       &token.Hash,
		"user":       &token.User,
		"client":     &token.Client,
		"scope":      &token.Scope,
		"issued_at":  &token.IssuedAt,
		"expires_at": &token.ExpiresAt,
	}
}

func (token *AccessToken) VerifySecret(secret string) bool {
	return verifyHash(token.Hash, secret) && verifyExpiration(token.ExpiresAt)
}

type AuthCode struct {
	ID        ID[*AuthCode]
	Hash      []byte
	CreatedAt time.Time
	User      ID[*User]
	Client    ID[*Client]
	Scope     string
}

func NewAuthCode(user ID[*User], client ID[*Client], scope string) (code *AuthCode, secret string, err error) {
	secret, hash, err := generateSecret()
	if err != nil {
		return nil, "", fmt.Errorf("failed to generate authentication code secret: %v", err)
	}
	code = &AuthCode{
		Hash:      hash,
		CreatedAt: time.Now(),
		User:      user,
		Client:    client,
		Scope:     scope,
	}
	return code, secret, nil
}

func (code *AuthCode) columns() map[string]interface{} {
	return map[string]interface{}{
		"id":         &code.ID,
		"hash":       &code.Hash,
		"created_at": &code.CreatedAt,
		"user":       &code.User,
		"client":     &code.Client,
		"scope":      &code.Scope,
	}
}

func (code *AuthCode) VerifySecret(secret string) bool {
	return verifyHash(code.Hash, secret) && verifyExpiration(code.CreatedAt.Add(10*time.Minute))
}

func UnmarshalSecret[T entity](s string) (id ID[T], secret string, err error) {
	idStr, secret, _ := strings.Cut(s, ".")
	u, _ := strconv.ParseUint(idStr, 10, 63)
	if u == 0 {
		return 0, "", fmt.Errorf("invalid ID")
	}
	return ID[T](int64(u)), secret, nil
	id, err = ParseID[T](idStr)
	return id, secret, err
}

func MarshalSecret[T entity](id ID[T], secret string) string {
	if id == 0 {
		panic("cannot marshal zero ID")
	}
	return fmt.Sprintf("%v.%v", int64(id), secret)
}

func generateUID() (string, error) {
	b := make([]byte, 32)
	if _, err := rand.Read(b); err != nil {
		return "", err
	}
	return base64.RawURLEncoding.EncodeToString(b), nil
}

func generateSecret() (secret string, hash []byte, err error) {
	b := make([]byte, 32)
	if _, err := rand.Read(b); err != nil {
		return "", nil, err
	}
	secret = base64.RawURLEncoding.EncodeToString(b)
	h := sha512.Sum512(b)
	return secret, h[:], nil
}

func verifyHash(hash []byte, secret string) bool {
	b, _ := base64.RawURLEncoding.DecodeString(secret)
	h := sha512.Sum512(b)
	return subtle.ConstantTimeCompare(hash, h[:]) == 1
}

func verifyExpiration(t time.Time) bool {
	return time.Now().Before(t)
}
package main

import (
	"context"
	"embed"
	"flag"
	"html/template"
	"log"
	"net"
	"net/http"

	"github.com/go-chi/chi/v5"
)

var (
	//go:embed template
	templateFS embed.FS
	//go:embed static
	staticFS embed.FS
)

func main() {
	var listenAddr string
	flag.StringVar(&listenAddr, "listen", ":8080", "HTTP listen address")
	flag.Parse()

	tpl := template.Must(template.ParseFS(templateFS, "template/*.html"))

	db, err := openDB("sinwon.db")
	if err != nil {
		log.Fatalf("Failed to open DB: %v", err)
	}

	mux := chi.NewRouter()
	mux.Handle("/static/*", http.FileServer(http.FS(staticFS)))
	mux.Get("/", index)
	mux.Post("/client/new", createClient)
	mux.HandleFunc("/login", login)
	mux.HandleFunc("/user/new", updateUser)
	mux.HandleFunc("/user/{id}", updateUser)
	mux.HandleFunc("/authorize", authorize)
	mux.Post("/token", exchangeToken)

	server := http.Server{
		Addr:    listenAddr,
		Handler: loginTokenMiddleware(mux),
		BaseContext: func(net.Listener) context.Context {
			return newBaseContext(db, tpl)
		},
	}
	log.Printf("OAuth server listening on %v", server.Addr)
	if err := server.ListenAndServe(); err != nil {
		log.Fatalf("Failed to listen and serve: %v", err)
	}
}

func httpError(w http.ResponseWriter, err error) {
	log.Print(err)
	http.Error(w, "Internal server error", http.StatusInternalServerError)
}
{{ template "head.html" }}

<main>

<h1>sinwon</h1>

<form method="post" action="/client/new">
	<a href="/user/new"><button type="button">Create user</button></a>
	<a href="/user/{{ .Me }}"><button type="button">Edit user</button></a>
	<button type="submit">Register new client</button>
</form>

{{ with . }}
{{ with .Clients }}
	<p>{{ . | len }} clients registered:</p>
	<ul>
		{{ range . }}
			<li><code>{{ .ClientID }}</code></li>
		{{ end }}
	</ul>
{{ else }}
	<p>No client registered yet.</p>
{{ end }}

</main>

{{ template "foot.html" }}
{{ template "head.html" }}

<main>

<h1>sinwon</h1>

<form method="post" action="">
	Username: <input type="text" name="username" value="{{ .Username }}" required><br>
	Password: <input type="password" name="password"><br>
	<button type="submit">
		{{ if .Username }}
			Update user
		{{ else }}
			Create user
		{{ end }}
	</button>
</form>

</main>

{{ template "foot.html" }}
package main

import (
	"fmt"
	"log"
	"net/http"
	"net/url"

	"github.com/go-chi/chi/v5"
)

func index(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	db := dbFromContext(ctx)
	tpl := templateFromContext(ctx)

	loginToken := loginTokenFromContext(ctx)
	if loginToken == nil {
		http.Redirect(w, req, "/login", http.StatusFound)
		return
	}

	clients, err := db.ListClients(ctx, loginToken.User)
	if err != nil {
		httpError(w, err)
		return
	}

	if err := tpl.ExecuteTemplate(w, "index.html", clients); err != nil {
	data := struct {
		Clients []Client
		Me      ID[*User]
	}{
		Clients: clients,
		Me:      loginToken.User,
	}
	if err := tpl.ExecuteTemplate(w, "index.html", &data); err != nil {
		panic(err)
	}
}

func login(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	db := dbFromContext(ctx)
	tpl := templateFromContext(ctx)

	q := req.URL.Query()
	rawRedirectURI := q.Get("redirect_uri")
	if rawRedirectURI == "" {
		rawRedirectURI = "/"
	}

	redirectURI, err := url.Parse(rawRedirectURI)
	if err != nil || redirectURI.Scheme != "" || redirectURI.Opaque != "" || redirectURI.User != nil || redirectURI.Host != "" {
		http.Error(w, "Invalid redirect URI", http.StatusBadRequest)
		return
	}

	if loginTokenFromContext(ctx) != nil {
		http.Redirect(w, req, redirectURI.String(), http.StatusFound)
		return
	}

	username := req.PostFormValue("username")
	password := req.PostFormValue("password")
	if username == "" {
		if err := tpl.ExecuteTemplate(w, "login.html", nil); err != nil {
			panic(err)
		}
		return
	}

	user, err := db.FetchUser(ctx, username)
	user, err := db.FetchUserByUsername(ctx, username)
	if err != nil && err != errNoDBRows {
		httpError(w, fmt.Errorf("failed to fetch user: %v", err))
		return
	}
	if err == nil {
		err = user.VerifyPassword(password)
	}
	if err != nil {
		log.Printf("login failed for user %q: %v", username, err)
		// TODO: show error message
		if err := tpl.ExecuteTemplate(w, "login.html", nil); err != nil {
			panic(err)
		}
		return
	}

	token := AccessToken{
		User:  user.ID,
		Scope: internalTokenScope,
	}
	secret, err := token.Generate()
	if err != nil {
		httpError(w, fmt.Errorf("failed to generate access token: %v", err))
		return
	}
	if err := db.CreateAccessToken(ctx, &token); err != nil {
		httpError(w, fmt.Errorf("failed to create access token: %v", err))
		return
	}

	setLoginTokenCookie(w, &token, secret)
	http.Redirect(w, req, redirectURI.String(), http.StatusFound)
}

func updateUser(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	db := dbFromContext(ctx)
	tpl := templateFromContext(ctx)

	user := new(User)
	if idStr := chi.URLParam(req, "id"); idStr != "" {
		id, err := ParseID[*User](idStr)
		if err != nil {
			http.Error(w, err.Error(), http.StatusBadRequest)
			return
		}

		user, err = db.FetchUser(ctx, id)
		if err != nil {
			httpError(w, err)
			return
		}
	}

	loginToken := loginTokenFromContext(ctx)
	if loginToken == nil {
		http.Redirect(w, req, "/login", http.StatusFound)
		return
	}

	if user.ID != 0 && loginToken.User != user.ID {
		http.Error(w, "Access denied", http.StatusForbidden)
		return
	}

	username := req.PostFormValue("username")
	password := req.PostFormValue("password")
	if username == "" {
		if err := tpl.ExecuteTemplate(w, "update-user.html", user); err != nil {
			panic(err)
		}
		return
	}

	user.Username = username
	if password != "" {
		if err := user.SetPassword(password); err != nil {
			httpError(w, err)
			return
		}
	}

	if err := db.StoreUser(ctx, user); err != nil {
		httpError(w, err)
		return
	}

	http.Redirect(w, req, "/", http.StatusFound)
}