Hi… I am well aware that this diff view is very suboptimal. It will be fixed when the refactored server comes along!
Implement basic OIDC and some fixes
package main
import (
"fmt"
"net/http"
"net/url"
"strings"
"github.com/go-chi/chi/v5"
)
func manageClient(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
db := dbFromContext(ctx)
tpl := templateFromContext(ctx)
loginToken := loginTokenFromContext(ctx)
if loginToken == nil {
http.Redirect(w, req, "/login", http.StatusFound)
return
}
me, err := db.FetchUser(ctx, loginToken.User)
if err != nil {
httpError(w, err)
return
} else if !me.Admin {
http.Error(w, "Access denied", http.StatusForbidden)
return
}
client := &Client{Owner: loginToken.User}
if idStr := chi.URLParam(req, "id"); idStr != "" {
id, err := ParseID[*Client](idStr)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
client, err = db.FetchClient(ctx, id)
if err != nil {
httpError(w, err)
return
}
}
if normalized, err := normalizeClientPKCERequirement(client.PKCERequirement); err == nil {
client.PKCERequirement = normalized
} else {
client.PKCERequirement = pkceRequirementNone
}
if req.Method != http.MethodPost {
data := struct {
TemplateBaseData
Client *Client
}{
Client: client,
}
tpl.MustExecuteTemplate(w, "manage-client.html", &data)
tpl.MustExecuteTemplate(req.Context(), w, "manage-client.html", &data)
return
}
_ = req.ParseForm()
if _, ok := req.PostForm["delete"]; ok {
if err := db.DeleteClient(ctx, client.ID); err != nil {
httpError(w, err)
return
}
http.Redirect(w, req, "/", http.StatusFound)
return
}
_, rotate := req.PostForm["rotate"]
var isPublic bool
if client.ID != 0 {
isPublic = client.IsPublic()
} else {
isPublic = req.PostFormValue("client_type") == "public"
}
if !rotate {
client.ClientName = req.PostFormValue("client_name")
client.ClientURI = req.PostFormValue("client_uri")
client.RedirectURIs = req.PostFormValue("redirect_uris")
pkceRequirement := req.PostFormValue("pkce_requirement")
if !isPublic {
pkceRequirement = pkceRequirementNone
}
normalizedRequirement, err := normalizeClientPKCERequirement(pkceRequirement)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
client.PKCERequirement = normalizedRequirement
if err := validateAllowedRedirectURIs(client.RedirectURIs); err != nil {
// TODO: nicer error message
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
}
var clientSecret string
if client.ID == 0 || rotate {
clientSecret, err = client.Generate(isPublic)
if err != nil {
httpError(w, err)
return
}
}
if err := db.StoreClient(ctx, client); err != nil {
httpError(w, err)
return
}
if clientSecret == "" {
http.Redirect(w, req, "/", http.StatusFound)
return
}
data := struct {
TemplateBaseData
ClientID string
ClientSecret string
}{
ClientID: client.ClientID,
ClientSecret: clientSecret,
}
tpl.MustExecuteTemplate(w, "client-secret.html", &data)
tpl.MustExecuteTemplate(req.Context(), w, "client-secret.html", &data)
}
func validateAllowedRedirectURIs(rawRedirectURIs string) error {
for _, s := range strings.Split(rawRedirectURIs, "\n") {
if s == "" {
continue
}
u, err := url.Parse(s)
if err != nil {
// TODO: nicer error message
return fmt.Errorf("Invalid redirect URI %q: %v", s, err)
}
switch u.Scheme {
case "https":
// ok
case "http":
if u.Host != "localhost" {
return fmt.Errorf("Only http://localhost is allowed for insecure HTTP URIs")
}
// insecure but let's just trust the admin
default:
if !strings.Contains(u.Scheme, ".") {
return fmt.Errorf("Only private-use URIs referring to domain names are allowed")
}
}
}
return nil
}
func revokeClient(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
db := dbFromContext(ctx)
id, err := ParseID[*Client](chi.URLParam(req, "id"))
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
loginToken := loginTokenFromContext(ctx)
if loginToken == nil {
http.Redirect(w, req, "/login", http.StatusFound)
return
}
if err := db.RevokeClientUser(ctx, id, loginToken.User); err != nil {
httpError(w, err)
return
}
http.Redirect(w, req, "/", http.StatusFound)
}
package main
import (
"context"
"crypto/subtle"
"net/http"
"strings"
)
const (
csrfCookieName = "vireo-csrf"
csrfFormField = "_csrf"
)
func csrfTokenFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}
if token, ok := ctx.Value(contextKeyCSRFToken).(string); ok {
return token
}
return ""
}
func ensureCSRFToken(w http.ResponseWriter, req *http.Request) (string, error) {
if cookie, err := req.Cookie(csrfCookieName); err == nil && cookie.Value != "" {
return cookie.Value, nil
}
token, err := generateUID()
if err != nil {
return "", err
}
http.SetCookie(w, &http.Cookie{
Name: csrfCookieName,
Value: token,
Path: "/",
HttpOnly: true,
SameSite: http.SameSiteStrictMode,
Secure: isForwardedHTTPS(req),
})
return token, nil
}
func csrfMiddleware(next http.Handler) http.Handler {
cop := http.NewCrossOriginProtection()
cop.SetDenyHandler(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
http.Error(w, "Forbidden", http.StatusForbidden)
}))
handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if isCSRFBypass(req) || req.Header.Get("Authorization") != "" {
next.ServeHTTP(w, req)
return
}
token, err := ensureCSRFToken(w, req)
if err != nil {
httpError(w, err)
return
}
ctx := context.WithValue(req.Context(), contextKeyCSRFToken, token)
req = req.WithContext(ctx)
switch req.Method {
case http.MethodGet, http.MethodHead, http.MethodOptions:
next.ServeHTTP(w, req)
return
}
if err := req.ParseForm(); err != nil {
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
formToken := req.PostFormValue(csrfFormField)
if formToken == "" {
formToken = req.Header.Get("X-CSRF-Token")
}
if formToken == "" || subtle.ConstantTimeCompare([]byte(formToken), []byte(token)) != 1 {
http.Error(w, "Invalid CSRF token", http.StatusForbidden)
return
}
next.ServeHTTP(w, req)
})
return cop.Handler(handler)
}
func isCSRFBypass(req *http.Request) bool {
path := req.URL.Path
switch {
case path == "/token",
path == "/introspect",
path == "/revoke",
path == "/userinfo",
strings.HasPrefix(path, "/static/"),
strings.HasPrefix(path, "/.well-known/"),
path == "/favicon.ico":
return true
}
return false
}
package main
import (
"context"
"database/sql"
_ "embed"
"fmt"
"time"
"github.com/mattn/go-sqlite3"
)
//go:embed schema.sql
var schema string
var migrations = []string{
"", // migration #0 is reserved for schema initialization
`
ALTER TABLE AccessToken ADD COLUMN refresh_hash BLOB;
ALTER TABLE AccessToken ADD COLUMN refresh_expires_at datetime;
CREATE UNIQUE INDEX access_token_refresh_hash ON AccessToken(refresh_hash);
`,
` ALTER TABLE AuthCode ADD COLUMN nonce TEXT; CREATE TABLE IF NOT EXISTS SigningKey ( id INTEGER PRIMARY KEY, kid TEXT NOT NULL UNIQUE, algorithm TEXT NOT NULL, private_key BLOB NOT NULL, created_at datetime NOT NULL ); `, ` ALTER TABLE AccessToken ADD COLUMN auth_time datetime; ALTER TABLE SigningKey RENAME TO SigningKey_old; CREATE TABLE SigningKey ( id INTEGER PRIMARY KEY, kid TEXT NOT NULL UNIQUE, algorithm TEXT NOT NULL, private_key BLOB NOT NULL, created_at datetime NOT NULL ); INSERT INTO SigningKey(id, kid, algorithm, private_key, created_at) SELECT id, kid, algorithm, private_key, created_at FROM SigningKey_old; DROP TABLE SigningKey_old; CREATE INDEX IF NOT EXISTS signing_key_created_at ON SigningKey(created_at); `, ` ALTER TABLE User ADD COLUMN email TEXT; `, ` ALTER TABLE User ADD COLUMN name TEXT; `, ` ALTER TABLE AuthCode ADD COLUMN code_challenge TEXT; ALTER TABLE AuthCode ADD COLUMN code_challenge_method TEXT; ALTER TABLE Client ADD COLUMN pkce_requirement TEXT; `,
}
var errNoDBRows = sql.ErrNoRows
type DB struct {
db *sql.DB
}
func openDB(filename string) (*DB, error) {
sqlDB, err := sql.Open("sqlite3", filename+"?cache=shared&_foreign_keys=1")
if err != nil {
return nil, err
}
sqlDB.SetMaxOpenConns(1)
db := &DB{sqlDB}
if err := db.init(context.TODO()); err != nil {
db.Close()
return nil, err
}
return db, nil
}
func (db *DB) init(ctx context.Context) error {
version, err := db.upgrade(ctx)
if err != nil {
return err
}
if version > 0 {
return nil
}
// TODO: drop this
defaultUser := User{Username: "root", Admin: true}
defaultUser := User{Username: "root", Name: "Root User", Email: "root@example.invalid", Admin: true}
if err := defaultUser.SetPassword("root"); err != nil {
return err
}
return db.StoreUser(ctx, &defaultUser)
}
func (db *DB) upgrade(ctx context.Context) (version int, err error) {
if err := db.db.QueryRowContext(ctx, "PRAGMA user_version").Scan(&version); err != nil {
return 0, fmt.Errorf("failed to query schema version: %v", err)
}
if version == len(migrations) {
return version, nil
} else if version > len(migrations) {
return version, fmt.Errorf("vireo (version %d) older than schema (version %d)", len(migrations), version)
}
tx, err := db.db.BeginTx(ctx, nil)
if err != nil {
return version, err
}
defer tx.Rollback()
if version == 0 {
if _, err := tx.ExecContext(ctx, schema); err != nil {
return version, fmt.Errorf("failed to initialize schema: %v", err)
}
} else {
for i := version; i < len(migrations); i++ {
if _, err := tx.ExecContext(ctx, migrations[i]); err != nil {
return version, fmt.Errorf("failed to execute migration #%v: %v", i, err)
}
}
}
// For some reason prepared statements don't work here
_, err = tx.ExecContext(ctx, fmt.Sprintf("PRAGMA user_version = %d", len(migrations)))
if err != nil {
return version, fmt.Errorf("failed to bump schema version: %v", err)
}
return version, tx.Commit()
}
func (db *DB) Close() error {
return db.db.Close()
}
func (db *DB) FetchUser(ctx context.Context, id ID[*User]) (*User, error) {
rows, err := db.db.QueryContext(ctx, "SELECT * FROM User WHERE id = ?", id)
if err != nil {
return nil, err
}
var user User
err = scanRow(&user, rows)
return &user, err
}
func (db *DB) FetchUserByUsername(ctx context.Context, username string) (*User, error) {
rows, err := db.db.QueryContext(ctx, "SELECT * FROM User WHERE username = ?", username)
if err != nil {
return nil, err
}
var user User
err = scanRow(&user, rows)
return &user, err
}
func (db *DB) StoreUser(ctx context.Context, user *User) error {
return db.db.QueryRowContext(ctx, `
INSERT INTO User(id, username, password_hash, admin) VALUES (:id, :username, :password_hash, :admin)
INSERT INTO User(id, username, name, email, password_hash, admin) VALUES (:id, :username, :name, :email, :password_hash, :admin)
ON CONFLICT(id) DO UPDATE SET username = :username,
name = :name, email = :email,
password_hash = :password_hash,
admin = :admin
RETURNING id
`, entityArgs(user)...).Scan(&user.ID)
}
func (db *DB) ListUsers(ctx context.Context) ([]User, error) {
rows, err := db.db.QueryContext(ctx, "SELECT * FROM User")
if err != nil {
return nil, err
}
defer rows.Close()
var l []User
for rows.Next() {
var user User
if err := scan(&user, rows); err != nil {
return nil, err
}
l = append(l, user)
}
return l, rows.Close()
}
func (db *DB) FetchClient(ctx context.Context, id ID[*Client]) (*Client, error) {
rows, err := db.db.QueryContext(ctx, "SELECT * FROM Client WHERE id = ?", id)
if err != nil {
return nil, err
}
var client Client
err = scanRow(&client, rows)
return &client, err
}
func (db *DB) FetchClientByClientID(ctx context.Context, clientID string) (*Client, error) {
rows, err := db.db.QueryContext(ctx, "SELECT * FROM Client WHERE client_id = ?", clientID)
if err != nil {
return nil, err
}
var client Client
err = scanRow(&client, rows)
return &client, err
}
func (db *DB) StoreClient(ctx context.Context, client *Client) error {
return db.db.QueryRowContext(ctx, `
INSERT INTO Client(id, client_id, client_secret_hash, owner,
redirect_uris, client_name, client_uri)
redirect_uris, client_name, client_uri, pkce_requirement)
VALUES (:id, :client_id, :client_secret_hash, :owner,
:redirect_uris, :client_name, :client_uri)
:redirect_uris, :client_name, :client_uri, :pkce_requirement)
ON CONFLICT(id) DO UPDATE SET client_id = :client_id, client_secret_hash = :client_secret_hash, owner = :owner, redirect_uris = :redirect_uris, client_name = :client_name,
client_uri = :client_uri
client_uri = :client_uri, pkce_requirement = :pkce_requirement
RETURNING id
`, entityArgs(client)...).Scan(&client.ID)
}
func (db *DB) ListClients(ctx context.Context, owner ID[*User]) ([]Client, error) {
rows, err := db.db.QueryContext(ctx, "SELECT * FROM Client WHERE owner IS ?", owner)
if err != nil {
return nil, err
}
defer rows.Close()
var l []Client
for rows.Next() {
var client Client
if err := scan(&client, rows); err != nil {
return nil, err
}
l = append(l, client)
}
return l, rows.Close()
}
func (db *DB) ListAuthorizedClients(ctx context.Context, user ID[*User]) ([]AuthorizedClient, error) {
rows, err := db.db.QueryContext(ctx, `
SELECT id, client_id, client_name, client_uri, token.expires_at
FROM Client,
(
SELECT client, MAX(COALESCE(refresh_expires_at, expires_at)) as expires_at
FROM AccessToken
WHERE user = ?
GROUP BY client
) AS token
WHERE Client.id = token.client
`, user)
if err != nil {
return nil, err
}
var l []AuthorizedClient
for rows.Next() {
var authClient AuthorizedClient
columns := authClient.Client.columns()
var expiresAt string
err := rows.Scan(columns["id"], columns["client_id"], columns["client_name"], columns["client_uri"], &expiresAt)
if err != nil {
return nil, err
}
authClient.ExpiresAt, err = time.Parse(sqlite3.SQLiteTimestampFormats[0], expiresAt)
if err != nil {
return nil, err
}
l = append(l, authClient)
}
return l, rows.Close()
}
func (db *DB) DeleteClient(ctx context.Context, id ID[*Client]) error {
_, err := db.db.ExecContext(ctx, "DELETE FROM Client WHERE id = ?", id)
return err
}
func (db *DB) FetchAccessToken(ctx context.Context, id ID[*AccessToken]) (*AccessToken, error) {
rows, err := db.db.QueryContext(ctx, "SELECT * FROM AccessToken WHERE id = ?", id)
if err != nil {
return nil, err
}
var token AccessToken
err = scanRow(&token, rows)
return &token, err
}
func (db *DB) StoreAccessToken(ctx context.Context, token *AccessToken) error {
return db.db.QueryRowContext(ctx, `
INSERT INTO AccessToken(id, hash, user, client, scope, issued_at,
expires_at, refresh_hash, refresh_expires_at)
expires_at, auth_time, refresh_hash, refresh_expires_at)
VALUES (:id, :hash, :user, :client, :scope, :issued_at, :expires_at,
:refresh_hash, :refresh_expires_at)
:auth_time, :refresh_hash, :refresh_expires_at)
ON CONFLICT(id) DO UPDATE SET hash = :hash, user = :user, client = :client, scope = :scope, issued_at = :issued_at, expires_at = :expires_at,
auth_time = :auth_time,
refresh_hash = :refresh_hash,
refresh_expires_at = :refresh_expires_at
RETURNING id
`, entityArgs(token)...).Scan(&token.ID)
}
func (db *DB) DeleteAccessToken(ctx context.Context, id ID[*AccessToken]) error {
_, err := db.db.ExecContext(ctx, "DELETE FROM AccessToken WHERE id = ?", id)
return err
}
func (db *DB) RevokeClientUser(ctx context.Context, clientID ID[*Client], userID ID[*User]) error {
tx, err := db.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
_, err = tx.ExecContext(ctx, `
DELETE FROM AccessToken
WHERE client = ? AND user = ?
`, clientID, userID)
if err != nil {
return err
}
_, err = tx.ExecContext(ctx, `
DELETE FROM AuthCode
WHERE client = ? AND user = ?
`, clientID, userID)
if err != nil {
return err
}
return tx.Commit()
}
func (db *DB) CreateAuthCode(ctx context.Context, code *AuthCode) error {
return db.db.QueryRowContext(ctx, `
INSERT INTO AuthCode(hash, created_at, user, client, scope, redirect_uri) VALUES (:hash, :created_at, :user, :client, :scope, :redirect_uri)
INSERT INTO AuthCode(hash, created_at, user, client, scope, redirect_uri, nonce, code_challenge, code_challenge_method) VALUES (:hash, :created_at, :user, :client, :scope, :redirect_uri, :nonce, :code_challenge, :code_challenge_method)
RETURNING id
`, entityArgs(code)...).Scan(&code.ID)
}
func (db *DB) PopAuthCode(ctx context.Context, id ID[*AuthCode]) (*AuthCode, error) {
rows, err := db.db.QueryContext(ctx, `
DELETE FROM AuthCode
WHERE id = ?
RETURNING *
`, id)
if err != nil {
return nil, err
}
var authCode AuthCode
err = scanRow(&authCode, rows)
return &authCode, err
}
func (db *DB) FetchSigningKeys(ctx context.Context) ([]SigningKey, error) {
rows, err := db.db.QueryContext(ctx, `
SELECT * FROM SigningKey
ORDER BY created_at DESC
`)
if err != nil {
return nil, err
}
defer rows.Close()
var keys []SigningKey
for rows.Next() {
var key SigningKey
if err := scan(&key, rows); err != nil {
return nil, err
}
keys = append(keys, key)
}
if err := rows.Err(); err != nil {
return nil, err
}
if len(keys) == 0 {
return nil, errNoDBRows
}
return keys, nil
}
func (db *DB) StoreSigningKey(ctx context.Context, key *SigningKey) error {
return db.db.QueryRowContext(ctx, `
INSERT INTO SigningKey(kid, algorithm, private_key, created_at)
VALUES (:kid, :algorithm, :private_key, :created_at)
RETURNING id
`, sql.Named("kid", key.KID), sql.Named("algorithm", key.Algorithm), sql.Named("private_key", key.PrivateKey), sql.Named("created_at", key.CreatedAt)).Scan(&key.ID)
}
func (db *DB) Maintain(ctx context.Context) error {
_, err := db.db.ExecContext(ctx, `
DELETE FROM AccessToken
WHERE timediff('now', COALESCE(refresh_expires_at, expires_at)) > 0
`)
if err != nil {
return err
}
_, err = db.db.ExecContext(ctx, `
DELETE FROM AuthCode
WHERE timediff(?, created_at) > 0
`, time.Now().Add(-authCodeExpiration))
if err != nil {
return err
}
return nil
}
func scan(e entity, rows *sql.Rows) error {
columns := e.columns()
keys, err := rows.Columns()
if err != nil {
panic(err)
}
out := make([]interface{}, len(keys))
for i, k := range keys {
v, ok := columns[k]
if !ok {
panic(fmt.Errorf("unknown column %q", k))
}
out[i] = v
}
return rows.Scan(out...)
}
func scanRow(e entity, rows *sql.Rows) error {
if !rows.Next() {
return sql.ErrNoRows
}
if err := scan(e, rows); err != nil {
return err
}
return rows.Close()
}
func entityArgs(e entity) []interface{} {
columns := e.columns()
l := make([]interface{}, 0, len(columns))
for k, v := range columns {
l = append(l, sql.Named(k, v))
}
return l
}
package main
import (
"crypto/rand"
"crypto/sha512"
"crypto/subtle"
"database/sql"
"database/sql/driver"
"encoding/base64"
"fmt"
"reflect"
"strconv"
"strings"
"time"
"golang.org/x/crypto/bcrypt"
)
const (
accessTokenExpiration = 30 * 24 * time.Hour
refreshTokenExpiration = 2 * accessTokenExpiration
authCodeExpiration = 10 * time.Minute
)
type entity interface {
columns() map[string]interface{}
}
var (
_ entity = (*User)(nil)
_ entity = (*Client)(nil)
_ entity = (*AccessToken)(nil)
_ entity = (*AuthCode)(nil)
_ entity = (*SigningKey)(nil)
)
type ID[T entity] int64
var (
_ sql.Scanner = (*ID[*User])(nil)
_ driver.Valuer = ID[*User](0)
)
func ParseID[T entity](s string) (ID[T], error) {
u, _ := strconv.ParseUint(s, 10, 63)
if u == 0 {
return 0, fmt.Errorf("invalid ID")
}
return ID[T](u), nil
}
func (ptr *ID[T]) Scan(v interface{}) error {
if v == nil {
*ptr = 0
return nil
}
id, ok := v.(int64)
if !ok {
return fmt.Errorf("cannot scan ID from %T", v)
}
*ptr = ID[T](id)
return nil
}
func (id ID[T]) Value() (driver.Value, error) {
if id == 0 {
return nil, nil
} else {
return int64(id), nil
}
}
type nullValue struct {
ptr interface{}
}
var (
_ sql.Scanner = nullValue{nil}
_ driver.Valuer = nullValue{nil}
)
func (nv nullValue) Scan(v interface{}) error {
out := reflect.ValueOf(nv.ptr).Elem()
if v == nil {
out.SetZero()
return nil
}
rv := reflect.ValueOf(v)
if rv.Type() != out.Type() {
return fmt.Errorf("cannot scan %v into %v", rv.Type(), out.Type())
}
out.Set(rv)
return nil
}
func (nv nullValue) Value() (driver.Value, error) {
in := reflect.ValueOf(nv.ptr).Elem()
if in.IsZero() {
return nil, nil
}
return in.Interface(), nil
}
type User struct {
ID ID[*User]
Username string
Name string Email string
PasswordHash string
Admin bool
}
func (user *User) columns() map[string]interface{} {
return map[string]interface{}{
"id": &user.ID,
"username": &user.Username,
"name": nullValue{&user.Name},
"email": nullValue{&user.Email},
"password_hash": nullValue{&user.PasswordHash},
"admin": &user.Admin,
}
}
func (user *User) VerifyPassword(password string) error {
return bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password))
}
func (user *User) SetPassword(password string) error {
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return err
}
user.PasswordHash = string(hash)
return nil
}
func (user *User) PasswordNeedsRehash() bool {
cost, _ := bcrypt.Cost([]byte(user.PasswordHash))
return cost != bcrypt.DefaultCost
}
type Client struct {
ID ID[*Client]
ClientID string
ClientSecretHash []byte
Owner ID[*User]
RedirectURIs string
ClientName string
ClientURI string
PKCERequirement string
}
func (client *Client) Generate(isPublic bool) (secret string, err error) {
id, err := generateUID()
if err != nil {
return "", fmt.Errorf("failed to generate client ID: %v", err)
}
client.ClientID = id
if !isPublic {
var hash []byte
secret, hash, err = generateSecret()
if err != nil {
return "", fmt.Errorf("failed to generate client secret: %v", err)
}
client.ClientSecretHash = hash
}
return secret, nil
}
func (client *Client) columns() map[string]interface{} {
return map[string]interface{}{
"id": &client.ID,
"client_id": &client.ClientID,
"client_secret_hash": &client.ClientSecretHash,
"owner": &client.Owner,
"redirect_uris": nullValue{&client.RedirectURIs},
"client_name": nullValue{&client.ClientName},
"client_uri": nullValue{&client.ClientURI},
"pkce_requirement": nullValue{&client.PKCERequirement},
}
}
func (client *Client) VerifySecret(secret string) bool {
return verifyHash(client.ClientSecretHash, secret)
}
func (client *Client) IsPublic() bool {
return client.ClientSecretHash == nil
}
type AccessToken struct {
ID ID[*AccessToken]
Hash []byte
User ID[*User]
Client ID[*Client]
Scope string
IssuedAt time.Time
ExpiresAt time.Time
AuthTime time.Time
RefreshHash []byte
RefreshExpiresAt time.Time
}
func (token *AccessToken) Generate(expiration time.Duration) (secret string, err error) {
secret, hash, err := generateSecret()
if err != nil {
return "", fmt.Errorf("failed to generate access token secret: %v", err)
}
token.Hash = hash
token.IssuedAt = time.Now()
token.ExpiresAt = time.Now().Add(expiration)
return secret, nil
}
func (token *AccessToken) GenerateRefresh() (secret string, err error) {
secret, hash, err := generateSecret()
if err != nil {
return "", fmt.Errorf("failed to generate refresh token secret: %v", err)
}
token.RefreshHash = hash
token.RefreshExpiresAt = time.Now().Add(refreshTokenExpiration)
return secret, nil
}
func NewAccessTokenFromAuthCode(authCode *AuthCode) *AccessToken {
return &AccessToken{
User: authCode.User, Client: authCode.Client, Scope: authCode.Scope,
User: authCode.User, Client: authCode.Client, Scope: authCode.Scope, AuthTime: authCode.CreatedAt,
}
}
func (token *AccessToken) columns() map[string]interface{} {
return map[string]interface{}{
"id": &token.ID,
"hash": &token.Hash,
"user": &token.User,
"client": &token.Client,
"scope": nullValue{&token.Scope},
"issued_at": &token.IssuedAt,
"expires_at": &token.ExpiresAt,
"auth_time": nullValue{&token.AuthTime},
"refresh_hash": &token.RefreshHash,
"refresh_expires_at": nullValue{&token.RefreshExpiresAt},
}
}
func (token *AccessToken) VerifySecret(secret string) bool {
return verifyHash(token.Hash, secret) && verifyExpiration(token.ExpiresAt)
}
func (token *AccessToken) VerifyRefreshSecret(secret string) bool {
return verifyHash(token.RefreshHash, secret) && verifyExpiration(token.RefreshExpiresAt)
}
type AuthorizedClient struct {
Client Client
ExpiresAt time.Time
}
type AuthCode struct {
ID ID[*AuthCode] Hash []byte CreatedAt time.Time User ID[*User] Client ID[*Client] Scope string RedirectURI string
ID ID[*AuthCode] Hash []byte CreatedAt time.Time User ID[*User] Client ID[*Client] Scope string RedirectURI string Nonce string CodeChallenge string CodeChallengeMethod string
}
func (code *AuthCode) Generate() (secret string, err error) {
secret, hash, err := generateSecret()
if err != nil {
return "", fmt.Errorf("failed to generate authentication code secret: %v", err)
}
code.Hash = hash
code.CreatedAt = time.Now()
return secret, nil
}
func (code *AuthCode) columns() map[string]interface{} {
return map[string]interface{}{
"id": &code.ID,
"hash": &code.Hash,
"created_at": &code.CreatedAt,
"user": &code.User,
"client": &code.Client,
"scope": nullValue{&code.Scope},
"redirect_uri": nullValue{&code.RedirectURI},
"id": &code.ID,
"hash": &code.Hash,
"created_at": &code.CreatedAt,
"user": &code.User,
"client": &code.Client,
"scope": nullValue{&code.Scope},
"redirect_uri": nullValue{&code.RedirectURI},
"nonce": nullValue{&code.Nonce},
"code_challenge": nullValue{&code.CodeChallenge},
"code_challenge_method": nullValue{&code.CodeChallengeMethod},
}
}
func (code *AuthCode) VerifySecret(secret string) bool {
return verifyHash(code.Hash, secret) && verifyExpiration(code.CreatedAt.Add(authCodeExpiration))
}
type SecretKind byte
const (
SecretKindAccessToken = SecretKind('a')
SecretKindRefreshToken = SecretKind('r')
SecretKindAuthCode = SecretKind('c')
)
type SigningKey struct {
ID ID[*SigningKey]
KID string
Algorithm string
PrivateKey []byte
CreatedAt time.Time
}
func (key *SigningKey) columns() map[string]interface{} {
return map[string]interface{}{
"id": &key.ID,
"kid": &key.KID,
"algorithm": &key.Algorithm,
"private_key": &key.PrivateKey,
"created_at": &key.CreatedAt,
}
}
func UnmarshalSecret[T entity](s string) (id ID[T], secret string, err error) {
kind, s, _ := strings.Cut(s, ".")
idStr, secret, ok := strings.Cut(s, ".")
if !ok || len(kind) != 1 {
return 0, "", fmt.Errorf("malformed secret")
}
switch SecretKind(kind[0]) {
case SecretKindAccessToken, SecretKindRefreshToken:
_, ok = interface{}(id).(ID[*AccessToken])
case SecretKindAuthCode:
_, ok = interface{}(id).(ID[*AuthCode])
}
if !ok {
return 0, "", fmt.Errorf("invalid secret kind %q", kind)
}
id, err = ParseID[T](idStr)
return id, secret, err
}
func MarshalSecret[T entity](id ID[T], kind SecretKind, secret string) string {
if id == 0 {
panic("cannot marshal zero ID")
}
var ok bool
switch interface{}(id).(type) {
case ID[*AccessToken]:
ok = kind == SecretKindAccessToken || kind == SecretKindRefreshToken
case ID[*AuthCode]:
ok = kind == SecretKindAuthCode
}
if !ok {
panic(fmt.Sprintf("unsupported secret kind %q for ID type %T", string(kind), id))
}
return fmt.Sprintf("%v.%v.%v", string(kind), int64(id), secret)
}
func generateUID() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
func generateSecret() (secret string, hash []byte, err error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", nil, err
}
secret = base64.RawURLEncoding.EncodeToString(b)
h := sha512.Sum512(b)
return secret, h[:], nil
}
func verifyHash(hash []byte, secret string) bool {
b, _ := base64.RawURLEncoding.DecodeString(secret)
h := sha512.Sum512(b)
return subtle.ConstantTimeCompare(hash, h[:]) == 1
}
func verifyExpiration(t time.Time) bool {
return time.Now().Before(t)
}
module vireo go 1.24.0 require ( codeberg.org/emersion/go-scfg v0.1.0 github.com/emersion/go-oauth2 v0.0.0-20250228145955-eaead4148315 github.com/go-chi/chi/v5 v5.2.3
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/mattn/go-sqlite3 v1.14.32 golang.org/x/crypto v0.42.0 )
codeberg.org/emersion/go-scfg v0.1.0 h1:6dnGU0ZI4gX+O5rMjwhoaySItzHG710eXL5TIQKl+uM= codeberg.org/emersion/go-scfg v0.1.0/go.mod h1:0nooW1ufBB4SlJEdTtiVN9Or+bnNM1icOkQ6Tbrq6O0= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/emersion/go-oauth2 v0.0.0-20250228145955-eaead4148315 h1:sXzwA8yItbg3ji0UuTLkuO4NKPqQJjC035hPoZI40h8= github.com/emersion/go-oauth2 v0.0.0-20250228145955-eaead4148315/go.mod h1:pSj8CBn/jb+ynRxt/ESIJisazza/Sh2DjwUn31l2tI0=
github.com/go-chi/chi/v5 v5.2.1 h1:KOIHODQj58PmL80G2Eak4WdvUzjSJSm0vG72crDCqb8= github.com/go-chi/chi/v5 v5.2.1/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops=
github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE= github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops=
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI= golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
package main
import (
"context"
"embed"
"flag"
"log"
"net"
"net/http"
"time"
"github.com/go-chi/chi/v5"
)
var (
//go:embed template
templateFS embed.FS
//go:embed static
staticFS embed.FS
)
func main() {
var configFilename string
flag.StringVar(&configFilename, "config", "/etc/lindenii/vireo/config", "Configuration filename")
flag.Parse()
cfg, err := loadConfig(configFilename)
if err != nil {
log.Fatalf("Failed to load config file: %v", err)
}
listenAddr := cfg.Listen
if cfg.Database == "" {
log.Fatalf("Missing database configuration")
}
db, err := openDB(cfg.Database)
if err != nil {
log.Fatalf("Failed to open DB: %v", err)
}
tplBaseData := &TemplateBaseData{
ServerName: cfg.ServerName,
}
if tplBaseData.ServerName == "" {
tplBaseData.ServerName = "vireo"
}
tpl, err := loadTemplate(templateFS, "template/*.html", tplBaseData)
if err != nil {
log.Fatalf("Failed to load template: %v", err)
}
oidcProvider, err := newOIDCProvider(context.Background(), db)
if err != nil {
log.Fatalf("Failed to initialize OpenID Connect provider: %v", err)
}
mux := chi.NewRouter()
mux.Handle("/static/*", http.FileServer(http.FS(staticFS)))
mux.Get("/", index)
mux.HandleFunc("/login", login)
mux.Post("/logout", logout)
mux.HandleFunc("/client/new", manageClient)
mux.HandleFunc("/client/{id}", manageClient)
mux.Post("/client/{id}/revoke", revokeClient)
mux.HandleFunc("/user/new", manageUser)
mux.HandleFunc("/user/{id}", manageUser)
mux.Get("/.well-known/oauth-authorization-server", getOAuthServerMetadata)
mux.Get("/.well-known/openid-configuration", getOpenIDConfiguration)
mux.Get("/.well-known/jwks.json", getOIDCJWKS)
mux.HandleFunc("/authorize", authorize)
mux.Post("/token", exchangeToken)
mux.Post("/introspect", introspectToken)
mux.Post("/revoke", revokeToken)
mux.HandleFunc("/userinfo", userInfo)
go maintainDBLoop(db)
server := http.Server{
Addr: listenAddr,
Handler: loginTokenMiddleware(mux),
Handler: csrfMiddleware(loginTokenMiddleware(mux)),
BaseContext: func(net.Listener) context.Context {
return newBaseContext(db, tpl)
return newBaseContext(db, tpl, oidcProvider)
},
}
log.Printf("OAuth server listening on %v", server.Addr)
if err := server.ListenAndServe(); err != nil {
log.Fatalf("Failed to listen and serve: %v", err)
}
}
func httpError(w http.ResponseWriter, err error) {
log.Print(err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
}
func maintainDBLoop(db *DB) {
ticker := time.NewTicker(15 * time.Minute)
defer ticker.Stop()
for range ticker.C {
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
if err := db.Maintain(ctx); err != nil {
log.Printf("Failed to perform database maintenance: %v", err)
}
cancel()
}
}
package main import ( "context" "fmt" "html/template" "io" "io/fs" "mime" "net/http" ) const ( loginCookieName = "vireo-login" internalTokenScope = "_vireo" ) type contextKey string const ( contextKeyDB = "db" contextKeyTemplate = "template" contextKeyLoginToken = "login-token"
contextKeyOIDC = "oidc" contextKeyCSRFToken = "csrf-token"
)
func dbFromContext(ctx context.Context) *DB {
return ctx.Value(contextKeyDB).(*DB)
}
func templateFromContext(ctx context.Context) *Template {
return ctx.Value(contextKeyTemplate).(*Template)
}
func oidcProviderFromContext(ctx context.Context) *OIDCProvider {
return ctx.Value(contextKeyOIDC).(*OIDCProvider)
}
func loginTokenFromContext(ctx context.Context) *AccessToken {
v := ctx.Value(contextKeyLoginToken)
if v == nil {
return nil
}
return v.(*AccessToken)
}
func newBaseContext(db *DB, tpl *Template) context.Context {
func newBaseContext(db *DB, tpl *Template, oidc *OIDCProvider) context.Context {
ctx := context.Background() ctx = context.WithValue(ctx, contextKeyDB, db) ctx = context.WithValue(ctx, contextKeyTemplate, tpl)
ctx = context.WithValue(ctx, contextKeyOIDC, oidc)
return ctx
}
func setLoginTokenCookie(w http.ResponseWriter, req *http.Request, token *AccessToken, secret string) {
http.SetCookie(w, &http.Cookie{
Name: loginCookieName,
Value: MarshalSecret(token.ID, SecretKindAccessToken, secret),
HttpOnly: true,
SameSite: http.SameSiteStrictMode,
SameSite: http.SameSiteLaxMode,
Secure: isForwardedHTTPS(req),
})
}
func unsetLoginTokenCookie(w http.ResponseWriter, req *http.Request) {
http.SetCookie(w, &http.Cookie{
Name: loginCookieName,
HttpOnly: true,
SameSite: http.SameSiteStrictMode,
SameSite: http.SameSiteLaxMode,
Secure: isForwardedHTTPS(req),
MaxAge: -1,
})
}
func isForwardedHTTPS(req *http.Request) bool {
if forwarded := req.Header.Get("Forwarded"); forwarded != "" {
_, params, _ := mime.ParseMediaType("_; " + forwarded)
return params["proto"] == "https"
}
if forwardedProto := req.Header.Get("X-Forwarded-Proto"); forwardedProto != "" {
return forwardedProto == "https"
}
return false
}
func loginTokenMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
cookie, _ := req.Cookie(loginCookieName)
if cookie == nil {
next.ServeHTTP(w, req)
return
}
ctx := req.Context()
db := dbFromContext(ctx)
tokenID, tokenSecret, _ := UnmarshalSecret[*AccessToken](cookie.Value)
token, err := db.FetchAccessToken(ctx, tokenID)
if err == errNoDBRows || (err == nil && !token.VerifySecret(tokenSecret)) {
unsetLoginTokenCookie(w, req)
next.ServeHTTP(w, req)
return
} else if err != nil {
httpError(w, fmt.Errorf("failed to fetch access token: %v", err))
return
}
if token.Scope != internalTokenScope {
http.Error(w, "Invalid login token scope", http.StatusForbidden)
return
}
if token.User == 0 {
panic("login token with zero user ID")
}
ctx = context.WithValue(ctx, contextKeyLoginToken, token)
req = req.WithContext(ctx)
next.ServeHTTP(w, req)
})
}
type TemplateBaseData struct {
ServerName string
CSRFToken string
}
func (data *TemplateBaseData) Base() *TemplateBaseData {
return data
}
type TemplateData interface {
Base() *TemplateBaseData
}
type Template struct {
tpl *template.Template
baseData *TemplateBaseData
}
func loadTemplate(fs fs.FS, pattern string, baseData *TemplateBaseData) (*Template, error) {
tpl, err := template.ParseFS(fs, pattern)
if err != nil {
return nil, err
}
return &Template{tpl: tpl, baseData: baseData}, nil
}
func (tpl *Template) MustExecuteTemplate(w io.Writer, filename string, data TemplateData) {
func (tpl *Template) MustExecuteTemplate(ctx context.Context, w io.Writer, filename string, data TemplateData) {
baseCopy := *tpl.baseData
if token := csrfTokenFromContext(ctx); token != "" {
baseCopy.CSRFToken = token
}
if data == nil {
data = tpl.baseData
base := baseCopy data = &base
} else {
*data.Base() = *tpl.baseData
*data.Base() = baseCopy
}
if err := tpl.tpl.ExecuteTemplate(w, filename, data); err != nil {
panic(err)
}
}
package main import (
"crypto/sha256" "encoding/base64"
"encoding/json" "errors" "fmt" "io" "log" "mime" "net" "net/http" "net/url" "strings" "time" "github.com/emersion/go-oauth2" )
const (
scopeOpenID = "openid"
scopeProfile = "profile"
scopeEmail = "email"
scopeOfflineAccess = "offline_access"
pkceMethodPlain = "plain"
pkceMethodS256 = "S256"
)
var allowedScopes = map[string]struct{}{
scopeOpenID: {},
scopeProfile: {},
scopeEmail: {},
scopeOfflineAccess: {},
}
type oidcTokenResponse struct {
AccessToken string `json:"access_token"`
TokenType oauth2.TokenType `json:"token_type"`
ExpiresIn int64 `json:"expires_in,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
Scope string `json:"scope,omitempty"`
IDToken string `json:"id_token,omitempty"`
}
func parseScopes(scope string) []string {
if scope == "" {
return nil
}
parts := strings.Fields(scope)
var scopes []string
seen := make(map[string]struct{}, len(parts))
for _, p := range parts {
if p == "" {
continue
}
p = strings.ToLower(p)
if _, ok := seen[p]; ok {
continue
}
seen[p] = struct{}{}
scopes = append(scopes, p)
}
return scopes
}
func normalizeScope(scope string) (string, []string) {
scopes := parseScopes(scope)
if len(scopes) == 0 {
return "", nil
}
return strings.Join(scopes, " "), scopes
}
func validateScopes(scopes []string) error {
for _, scope := range scopes {
if _, ok := allowedScopes[scope]; !ok {
return fmt.Errorf("unsupported scope %q", scope)
}
}
return nil
}
func normalizeCodeChallengeMethod(method string) (string, error) {
if method == "" {
return pkceMethodPlain, nil
}
switch {
case strings.EqualFold(method, pkceMethodPlain):
return pkceMethodPlain, nil
case strings.EqualFold(method, pkceMethodS256):
return pkceMethodS256, nil
default:
return "", fmt.Errorf("unsupported code_challenge_method")
}
}
func validateCodeVerifier(verifier string) error {
if verifier == "" {
return fmt.Errorf("missing code_verifier")
}
if len(verifier) < 43 || len(verifier) > 128 {
return fmt.Errorf("invalid code_verifier length")
}
for _, r := range verifier {
if !(r >= 'A' && r <= 'Z' || r >= 'a' && r <= 'z' || r >= '0' && r <= '9' || r == '-' || r == '.' || r == '_' || r == '~') {
return fmt.Errorf("invalid character in code_verifier")
}
}
return nil
}
func validateCodeChallenge(method, challenge string) error {
if challenge == "" {
return fmt.Errorf("missing code_challenge")
}
if err := validateCodeVerifier(challenge); err != nil {
return err
}
if method != pkceMethodPlain && method != pkceMethodS256 {
return fmt.Errorf("unsupported code_challenge_method")
}
return nil
}
func verifyCodeVerifier(method, challenge, verifier string) error {
if err := validateCodeVerifier(verifier); err != nil {
return err
}
switch method {
case "", pkceMethodPlain:
if challenge != verifier {
return fmt.Errorf("code_verifier mismatch")
}
case pkceMethodS256:
hash := sha256.Sum256([]byte(verifier))
expected := base64.RawURLEncoding.EncodeToString(hash[:])
if expected != challenge {
return fmt.Errorf("code_verifier mismatch")
}
default:
return fmt.Errorf("unsupported code_challenge_method")
}
return nil
}
func getOAuthServerMetadata(w http.ResponseWriter, req *http.Request) {
issuer := getIssuer(req)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(&oauth2.ServerMetadata{
Issuer: issuer,
AuthorizationEndpoint: issuer + "/authorize",
TokenEndpoint: issuer + "/token",
IntrospectionEndpoint: issuer + "/introspect",
RevocationEndpoint: issuer + "/revoke",
ResponseTypesSupported: []oauth2.ResponseType{oauth2.ResponseTypeCode},
ResponseModesSupported: []oauth2.ResponseMode{oauth2.ResponseModeQuery},
GrantTypesSupported: []oauth2.GrantType{oauth2.GrantTypeAuthorizationCode},
TokenEndpointAuthMethodsSupported: []oauth2.AuthMethod{oauth2.AuthMethodNone, oauth2.AuthMethodClientSecretBasic},
Issuer: issuer,
AuthorizationEndpoint: issuer + "/authorize",
TokenEndpoint: issuer + "/token",
IntrospectionEndpoint: issuer + "/introspect",
RevocationEndpoint: issuer + "/revoke",
JWKSURI: issuer + "/.well-known/jwks.json",
ScopesSupported: []string{scopeOpenID, scopeProfile, scopeEmail, scopeOfflineAccess},
ResponseTypesSupported: []oauth2.ResponseType{oauth2.ResponseTypeCode},
ResponseModesSupported: []oauth2.ResponseMode{oauth2.ResponseModeQuery},
GrantTypesSupported: []oauth2.GrantType{oauth2.GrantTypeAuthorizationCode, oauth2.GrantTypeRefreshToken},
TokenEndpointAuthMethodsSupported: []oauth2.AuthMethod{oauth2.AuthMethodNone, oauth2.AuthMethodClientSecretBasic},
IntrospectionEndpointAuthMethodsSupported: []oauth2.AuthMethod{oauth2.AuthMethodNone, oauth2.AuthMethodClientSecretBasic},
RevocationEndpointAuthMethodsSupported: []oauth2.AuthMethod{oauth2.AuthMethodNone, oauth2.AuthMethodClientSecretBasic},
CodeChallengeMethodsSupported: []string{pkceMethodPlain, pkceMethodS256},
AuthorizationResponseIssParameterSupported: true,
})
}
func getIssuer(req *http.Request) string {
issuerURL := url.URL{
Scheme: "https",
Host: req.Host,
}
if !isForwardedHTTPS(req) && isLoopback(req) {
// TODO: add config option for allowed reverse proxy IPs
issuerURL.Scheme = "http"
}
return issuerURL.String()
}
func isLoopback(req *http.Request) bool {
host, _, _ := net.SplitHostPort(req.RemoteAddr)
ip := net.ParseIP(host)
return ip.IsLoopback()
}
func authorize(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
db := dbFromContext(ctx)
tpl := templateFromContext(ctx)
q := req.URL.Query()
respType := oauth2.ResponseType(q.Get("response_type"))
clientID := q.Get("client_id")
rawRedirectURI := q.Get("redirect_uri")
scope := q.Get("scope")
state := q.Get("state")
_, stateProvided := q["state"]
codeChallenge := q.Get("code_challenge")
codeChallengeMethod := q.Get("code_challenge_method")
var normalizedCodeChallengeMethod string
nonce := q.Get("nonce")
if clientID == "" {
http.Error(w, "Missing client ID", http.StatusBadRequest)
return
}
client, err := db.FetchClientByClientID(ctx, clientID)
if err == errNoDBRows {
http.Error(w, "Invalid client ID", http.StatusForbidden)
return
} else if err != nil {
httpError(w, fmt.Errorf("failed to fetch client: %v", err))
return
}
requiredPKCE, err := normalizeClientPKCERequirement(client.PKCERequirement)
if err != nil {
httpError(w, fmt.Errorf("invalid PKCE requirement configuration: %v", err))
return
}
var allowedRedirectURIs []*url.URL
for _, s := range strings.Split(client.RedirectURIs, "\n") {
if s == "" {
continue
}
u, err := url.Parse(s)
if err != nil {
httpError(w, fmt.Errorf("failed to parse client redirect URI"))
return
}
allowedRedirectURIs = append(allowedRedirectURIs, u)
}
var redirectURI *url.URL
if rawRedirectURI != "" {
redirectURI, err = url.Parse(rawRedirectURI)
if err != nil {
http.Error(w, "Invalid redirect URI", http.StatusBadRequest)
return
}
if !validateRedirectURI(redirectURI, allowedRedirectURIs) {
http.Error(w, "Invalid redirect URI", http.StatusBadRequest)
return
}
} else {
if len(allowedRedirectURIs) == 0 {
http.Error(w, "Missing redirect URI", http.StatusBadRequest)
return
}
redirectURI = allowedRedirectURIs[0]
}
if codeChallenge != "" {
method, err := normalizeCodeChallengeMethod(codeChallengeMethod)
if err != nil {
redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: err.Error(),
})
return
}
if err := validateCodeChallenge(method, codeChallenge); err != nil {
redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: err.Error(),
})
return
}
normalizedCodeChallengeMethod = method
}
if codeChallenge == "" && codeChallengeMethod != "" {
redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: "code_challenge_method without code_challenge",
})
return
}
switch requiredPKCE {
case pkceRequirementPlain:
if codeChallenge == "" {
redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: "PKCE is required",
})
return
}
if !allowPKCERequirement(normalizedCodeChallengeMethod, pkceRequirementPlain) {
redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: "PKCE method does not satisfy requirement",
})
return
}
case pkceRequirementS256:
if codeChallenge == "" {
redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: "PKCE (S256) is required",
})
return
}
if !allowPKCERequirement(normalizedCodeChallengeMethod, pkceRequirementS256) {
redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: "PKCE (S256) is required",
})
return
}
}
codeChallengeMethod = normalizedCodeChallengeMethod
if respType != oauth2.ResponseTypeCode {
redirectClientError(w, req, redirectURI, state, &oauth2.Error{
redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
Code: oauth2.ErrorCodeUnsupportedResponseType, }) return }
// TODO: add support for scope
if scope != "" {
redirectClientError(w, req, redirectURI, state, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidScope,
normalizedScope, scopes := normalizeScope(scope)
if len(scopes) == 0 {
redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidScope,
Description: "Missing required openid scope",
}) return }
if err := validateScopes(scopes); err != nil {
redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidScope,
Description: err.Error(),
})
return
}
if !containsScope(scopes, scopeOpenID) {
redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidScope,
Description: "Scope openid is required",
})
return
}
scope = normalizedScope
loginToken := loginTokenFromContext(ctx)
if loginToken == nil {
q := make(url.Values)
q.Set("redirect_uri", req.URL.String())
u := url.URL{
Path: "/login",
RawQuery: q.Encode(),
}
http.Redirect(w, req, u.String(), http.StatusFound)
return
}
_ = req.ParseForm()
if _, ok := req.PostForm["deny"]; ok {
redirectClientError(w, req, redirectURI, state, &oauth2.Error{
redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
Code: oauth2.ErrorCodeAccessDenied,
})
return
}
if _, ok := req.PostForm["authorize"]; !ok {
data := struct {
TemplateBaseData
Client *Client
}{
Client: client,
}
tpl.MustExecuteTemplate(w, "authorize.html", &data)
tpl.MustExecuteTemplate(req.Context(), w, "authorize.html", &data)
return
}
authCode := AuthCode{
User: loginToken.User, Client: client.ID, Scope: scope, RedirectURI: rawRedirectURI,
User: loginToken.User, Client: client.ID, Scope: scope, RedirectURI: rawRedirectURI, Nonce: nonce, CodeChallenge: codeChallenge, CodeChallengeMethod: codeChallengeMethod,
}
secret, err := authCode.Generate()
if err != nil {
httpError(w, fmt.Errorf("failed to generate authentication code: %v", err))
return
}
if err := db.CreateAuthCode(ctx, &authCode); err != nil {
httpError(w, fmt.Errorf("failed to create authentication code: %v", err))
return
}
code := MarshalSecret(authCode.ID, SecretKindAuthCode, secret)
values := make(url.Values)
values.Set("code", code)
if state != "" {
if stateProvided {
values.Set("state", state)
}
redirectClient(w, req, redirectURI, values)
}
func exchangeToken(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
db := dbFromContext(ctx)
values, err := parseRequestBody(req)
if err != nil {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: err.Error(),
})
return
}
clientID := values.Get("client_id")
grantType := oauth2.GrantType(values.Get("grant_type"))
scope := values.Get("scope")
codeVerifier := values.Get("code_verifier")
authClientID, clientSecret, _ := req.BasicAuth()
if clientID == "" {
clientID = authClientID
} else if clientID != authClientID {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: "Client ID in request body doesn't match Authorization header field",
})
return
}
var client *Client
if clientID != "" {
client, err = db.FetchClientByClientID(ctx, clientID)
if err == errNoDBRows {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidClient,
Description: "Invalid client ID",
})
return
} else if err != nil {
oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
return
}
if !client.IsPublic() && !client.VerifySecret(clientSecret) {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeAccessDenied,
Description: "Invalid client secret",
})
return
}
}
var token *AccessToken
var ( token *AccessToken authorizationCode *AuthCode currentClient *Client nonceValue string )
switch grantType {
case oauth2.GrantTypeAuthorizationCode:
if client == nil {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: "Missing client ID",
})
return
}
codeID, codeSecret, _ := UnmarshalSecret[*AuthCode](values.Get("code"))
authCode, err := db.PopAuthCode(ctx, codeID)
if err == errNoDBRows || (err == nil && !authCode.VerifySecret(codeSecret)) || authCode.Client != client.ID {
authorizationCode, err = db.PopAuthCode(ctx, codeID)
if err == errNoDBRows || (err == nil && !authorizationCode.VerifySecret(codeSecret)) || authorizationCode.Client != client.ID {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeAccessDenied,
Description: "Invalid authorization code",
})
return
} else if err != nil {
oauthError(w, fmt.Errorf("failed to fetch authorization code: %v", err))
return
}
if scope != authCode.Scope {
if scope != "" && scope != authorizationCode.Scope {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeAccessDenied,
Description: "Invalid scope",
})
return
}
if values.Get("redirect_uri") != authCode.RedirectURI {
if values.Get("redirect_uri") != authorizationCode.RedirectURI {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeAccessDenied,
Description: "Invalid redirect URI",
})
return
}
if authorizationCode.CodeChallenge != "" {
if codeVerifier == "" {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: "Missing code_verifier",
})
return
}
if err := verifyCodeVerifier(authorizationCode.CodeChallengeMethod, authorizationCode.CodeChallenge, codeVerifier); err != nil {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidGrant,
Description: "Invalid code_verifier",
})
return
}
} else if codeVerifier != "" {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: "Unexpected code_verifier",
})
return
}
token = NewAccessTokenFromAuthCode(authCode)
token = NewAccessTokenFromAuthCode(authorizationCode) currentClient = client nonceValue = authorizationCode.Nonce
case oauth2.GrantTypeRefreshToken:
tokenID, refreshSecret, _ := UnmarshalSecret[*AccessToken](values.Get("refresh_token"))
token, err = db.FetchAccessToken(ctx, tokenID)
if err == errNoDBRows || (err == nil && !token.VerifyRefreshSecret(refreshSecret)) {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeAccessDenied,
Description: "Invalid refresh token",
})
return
} else if err != nil {
oauthError(w, fmt.Errorf("failed to fetch access token: %v", err))
return
}
if client != nil && client.ID != token.Client {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeAccessDenied,
Description: "Invalid refresh token",
})
return
}
tokenClient, err := db.FetchClient(ctx, token.Client)
currentClient, err = db.FetchClient(ctx, token.Client)
if err != nil {
oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
return
}
if !tokenClient.IsPublic() && client == nil {
if !currentClient.IsPublic() && client == nil {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeAccessDenied,
Description: "Invalid client secret",
})
return
}
if scope != token.Scope {
if scope != "" && scope != token.Scope {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeAccessDenied,
Description: "Invalid scope",
})
return
}
default:
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeUnsupportedGrantType,
Description: "Unsupported grant type",
})
}
secret, err := token.Generate(accessTokenExpiration)
if err != nil {
oauthError(w, err)
return
}
refreshSecret, err := token.GenerateRefresh()
if err != nil {
oauthError(w, err)
return
tokenScopes := parseScopes(token.Scope)
issueRefresh := containsScope(tokenScopes, scopeOfflineAccess)
var refreshSecret string
if issueRefresh {
refreshSecret, err = token.GenerateRefresh()
if err != nil {
oauthError(w, err)
return
}
} else {
token.RefreshHash = nil
token.RefreshExpiresAt = time.Time{}
}
if err := db.StoreAccessToken(ctx, token); err != nil {
oauthError(w, fmt.Errorf("failed to create access token: %v", err))
return
}
accessTokenValue := MarshalSecret(token.ID, SecretKindAccessToken, secret)
if token.AuthTime.IsZero() {
token.AuthTime = token.IssuedAt
}
var idToken string
if containsScope(tokenScopes, scopeOpenID) {
if currentClient == nil {
currentClient, err = db.FetchClient(ctx, token.Client)
if err != nil {
oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
return
}
}
user, err := db.FetchUser(ctx, token.User)
if err != nil {
oauthError(w, fmt.Errorf("failed to fetch user: %v", err))
return
}
issuer := getIssuer(req)
oidcProvider := oidcProviderFromContext(ctx)
idToken, err = oidcProvider.MintIDToken(issuer, currentClient, user, token, tokenScopes, nonceValue, accessTokenValue, token.AuthTime)
if err != nil {
oauthError(w, fmt.Errorf("failed to mint ID token: %v", err))
return
}
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Cache-Control", "no-store")
json.NewEncoder(w).Encode(&oauth2.TokenResp{
AccessToken: MarshalSecret(token.ID, SecretKindAccessToken, secret),
TokenType: oauth2.TokenTypeBearer,
ExpiresIn: time.Until(token.ExpiresAt),
Scope: strings.Split(token.Scope, " "),
RefreshToken: MarshalSecret(token.ID, SecretKindRefreshToken, refreshSecret),
})
resp := oidcTokenResponse{
AccessToken: accessTokenValue,
TokenType: oauth2.TokenTypeBearer,
ExpiresIn: int64(time.Until(token.ExpiresAt).Seconds()),
Scope: token.Scope,
IDToken: idToken,
}
if issueRefresh {
refreshTokenValue := MarshalSecret(token.ID, SecretKindRefreshToken, refreshSecret)
resp.RefreshToken = refreshTokenValue
}
json.NewEncoder(w).Encode(&resp)
}
func introspectToken(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
db := dbFromContext(ctx)
values, err := parseRequestBody(req)
if err != nil {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: err.Error(),
})
return
}
client, err := maybeAuthenticateClient(w, req)
if err != nil {
oauthError(w, err)
return
}
tokenID, secret, _ := UnmarshalSecret[*AccessToken](values.Get("token"))
token, err := db.FetchAccessToken(ctx, tokenID)
if err == errNoDBRows || (err == nil && !token.VerifySecret(secret)) {
token = nil
} else if err != nil {
oauthError(w, fmt.Errorf("failed to fetch access token: %v", err))
return
}
var resp oauth2.IntrospectionResp
if token != nil {
if client == nil {
client, err = db.FetchClient(ctx, token.Client)
if err != nil {
oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
return
}
if !client.IsPublic() {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidClient,
Description: "Missing client ID and secret",
})
return
}
}
if client.ID != token.Client {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidClient,
Description: "Invalid client ID or secret",
})
return
}
user, err := db.FetchUser(ctx, token.User)
if err != nil {
oauthError(w, fmt.Errorf("failed to fetch user: %v", err))
return
}
resp.Active = true
resp.TokenType = oauth2.TokenTypeBearer
resp.ExpiresAt = token.ExpiresAt
resp.IssuedAt = token.IssuedAt
resp.ClientID = client.ClientID
resp.Username = user.Username
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(&resp)
}
func revokeToken(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
db := dbFromContext(ctx)
values, err := parseRequestBody(req)
if err != nil {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: err.Error(),
})
return
}
client, err := maybeAuthenticateClient(w, req)
if err != nil {
oauthError(w, err)
return
}
tokenID, secret, _ := UnmarshalSecret[*AccessToken](values.Get("token"))
token, err := db.FetchAccessToken(ctx, tokenID)
if err == errNoDBRows || (err == nil && !token.VerifySecret(secret)) {
return // ignore
} else if err != nil {
oauthError(w, fmt.Errorf("failed to fetch access token: %v", err))
return
}
if client == nil {
client, err = db.FetchClient(ctx, token.Client)
if err != nil {
oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
return
}
if !client.IsPublic() {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidClient,
Description: "Missing client ID and secret",
})
return
}
}
if client.ID != token.Client {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidClient,
Description: "Invalid client ID or secret",
})
return
}
if err := db.DeleteAccessToken(ctx, token.ID); err != nil {
oauthError(w, err)
return
}
}
func parseRequestBody(req *http.Request) (url.Values, error) {
ct := req.Header.Get("Content-Type")
if ct != "" {
mimeType, _, err := mime.ParseMediaType(req.Header.Get("Content-Type"))
if err != nil {
return nil, fmt.Errorf("malformed Content-Type header field")
} else if mimeType != "application/x-www-form-urlencoded" {
return nil, fmt.Errorf("unsupported request content type")
}
}
r := io.LimitReader(req.Body, 10<<20)
b, err := io.ReadAll(r)
if err != nil {
return nil, fmt.Errorf("failed to read request body: %v", err)
}
values, err := url.ParseQuery(string(b))
if err != nil {
return nil, fmt.Errorf("failed to parse request body: %v", err)
}
return values, nil
}
func oauthError(w http.ResponseWriter, err error) {
var oauthErr *oauth2.Error
if !errors.As(err, &oauthErr) {
oauthErr = &oauth2.Error{Code: oauth2.ErrorCodeServerError}
log.Print(err)
}
statusCode := http.StatusInternalServerError
switch oauthErr.Code {
case oauth2.ErrorCodeInvalidRequest, oauth2.ErrorCodeUnsupportedResponseType, oauth2.ErrorCodeInvalidScope, oauth2.ErrorCodeInvalidClient, oauth2.ErrorCodeInvalidGrant, oauth2.ErrorCodeUnsupportedGrantType:
statusCode = http.StatusBadRequest
case oauth2.ErrorCodeUnauthorizedClient, oauth2.ErrorCodeAccessDenied:
statusCode = http.StatusForbidden
case oauth2.ErrorCodeTemporarilyUnavailable:
statusCode = http.StatusServiceUnavailable
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
json.NewEncoder(w).Encode(oauthErr)
}
func redirectClient(w http.ResponseWriter, req *http.Request, redirectURI *url.URL, values url.Values) {
q := redirectURI.Query()
for k, v := range values {
q[k] = v
}
q.Set("iss", getIssuer(req))
u := *redirectURI
u.RawQuery = q.Encode()
http.Redirect(w, req, u.String(), http.StatusFound)
}
func redirectClientError(w http.ResponseWriter, req *http.Request, redirectURI *url.URL, state string, err error) {
func redirectClientError(w http.ResponseWriter, req *http.Request, redirectURI *url.URL, state string, stateProvided bool, err error) {
var oauthErr *oauth2.Error
if !errors.As(err, &oauthErr) {
oauthErr = &oauth2.Error{Code: oauth2.ErrorCodeServerError}
log.Print(err)
}
values := make(url.Values)
values.Set("error", string(oauthErr.Code))
if oauthErr.Description != "" {
values.Set("error_description", oauthErr.Description)
}
if oauthErr.URI != "" {
values.Set("error_uri", oauthErr.URI)
}
if state != "" {
if stateProvided {
values.Set("state", state)
}
redirectClient(w, req, redirectURI, values)
}
func validateRedirectURI(u *url.URL, allowedURIs []*url.URL) bool {
// Loopback interface, see RFC 8252 section 7.3
host, _, _ := net.SplitHostPort(u.Host)
ip := net.ParseIP(host)
if u.Scheme == "http" && ip.IsLoopback() {
uu := *u
uu.Host = "localhost"
u = &uu
}
for _, allowed := range allowedURIs {
if u.String() == allowed.String() {
return true
}
}
return false
}
func maybeAuthenticateClient(w http.ResponseWriter, req *http.Request) (*Client, error) {
ctx := req.Context()
db := dbFromContext(ctx)
clientID, clientSecret, ok := req.BasicAuth()
if !ok {
return nil, nil
}
client, err := db.FetchClientByClientID(ctx, clientID)
if err == errNoDBRows || (err == nil && !client.VerifySecret(clientSecret)) {
return nil, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidClient,
Description: "Invalid client ID or secret",
}
} else if err != nil {
return nil, fmt.Errorf("failed to fetch client: %v", err)
}
return client, nil
}
package main
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"math/big"
"net/http"
"sort"
"strconv"
"strings"
"time"
oauth2 "github.com/emersion/go-oauth2"
"github.com/golang-jwt/jwt/v5"
)
type OIDCProvider struct {
signingKeys []*oidcSigningKey
}
type oidcSigningKey struct {
key *SigningKey
private *rsa.PrivateKey
publicJWK jwk
}
type jwk struct {
Kty string `json:"kty"`
Use string `json:"use,omitempty"`
Alg string `json:"alg,omitempty"`
Kid string `json:"kid,omitempty"`
N string `json:"n,omitempty"`
E string `json:"e,omitempty"`
}
type jwks struct {
Keys []jwk `json:"keys"`
}
const idTokenTTL = 15 * time.Minute
func newOIDCProvider(ctx context.Context, db *DB) (*OIDCProvider, error) {
signingRecords, err := db.FetchSigningKeys(ctx)
if err == errNoDBRows {
generated, genErr := generateSigningKey()
if genErr != nil {
return nil, genErr
}
if storeErr := db.StoreSigningKey(ctx, generated); storeErr != nil {
return nil, fmt.Errorf("failed to persist signing key: %w", storeErr)
}
signingRecords = []SigningKey{*generated}
} else if err != nil {
return nil, fmt.Errorf("failed to fetch signing keys: %w", err)
}
signingKeys := make([]*oidcSigningKey, 0, len(signingRecords))
for i := range signingRecords {
material, convErr := toOIDCSigningKey(&signingRecords[i])
if convErr != nil {
return nil, convErr
}
signingKeys = append(signingKeys, material)
}
return &OIDCProvider{signingKeys: signingKeys}, nil
}
func generateSigningKey() (*SigningKey, error) {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, fmt.Errorf("failed to generate signing key: %w", err)
}
pemBlock := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(priv),
})
if pemBlock == nil {
return nil, fmt.Errorf("failed to encode signing key")
}
kid, err := generateUID()
if err != nil {
return nil, fmt.Errorf("failed to generate signing key ID: %w", err)
}
return &SigningKey{
KID: kid,
Algorithm: "RS256",
PrivateKey: pemBlock,
CreatedAt: time.Now(),
}, nil
}
func toOIDCSigningKey(signing *SigningKey) (*oidcSigningKey, error) {
block, _ := pem.Decode(signing.PrivateKey)
if block == nil {
return nil, fmt.Errorf("failed to decode signing key PEM")
}
priv, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse signing key: %w", err)
}
jwk := jwk{
Kty: "RSA",
Use: "sig",
Alg: signing.Algorithm,
Kid: signing.KID,
N: base64.RawURLEncoding.EncodeToString(priv.N.Bytes()),
}
e := big.NewInt(int64(priv.E)).Bytes()
jwk.E = base64.RawURLEncoding.EncodeToString(e)
return &oidcSigningKey{
key: signing,
private: priv,
publicJWK: jwk,
}, nil
}
func (op *OIDCProvider) currentSigningKey() *oidcSigningKey {
if len(op.signingKeys) == 0 {
return nil
}
return op.signingKeys[0]
}
func (op *OIDCProvider) signingMethod() (*jwt.SigningMethodRSA, *oidcSigningKey, error) {
key := op.currentSigningKey()
if key == nil {
return nil, nil, fmt.Errorf("no signing key configured")
}
switch key.key.Algorithm {
case "RS256":
return jwt.SigningMethodRS256, key, nil
default:
return nil, nil, fmt.Errorf("unsupported signing algorithm %q", key.key.Algorithm)
}
}
func (op *OIDCProvider) MintIDToken(issuer string, client *Client, user *User, token *AccessToken, scopes []string, nonce string, accessToken string, authTime time.Time) (string, error) {
method, signingKey, err := op.signingMethod()
if err != nil {
return "", err
}
now := time.Now()
expiresAt := now.Add(idTokenTTL)
if token.ExpiresAt.Before(expiresAt) {
expiresAt = token.ExpiresAt
}
if expiresAt.Before(now) {
expiresAt = now
}
claims := jwt.MapClaims{
"iss": issuer,
"sub": subjectForUser(user),
"aud": client.ClientID,
"exp": jwt.NewNumericDate(expiresAt),
"iat": jwt.NewNumericDate(now),
}
if !authTime.IsZero() {
claims["auth_time"] = jwt.NewNumericDate(authTime)
}
if nonce != "" {
claims["nonce"] = nonce
}
if accessToken != "" {
claims["at_hash"] = computeAtHash(accessToken)
}
if containsScope(scopes, scopeProfile) {
displayName := user.Name
if displayName == "" {
displayName = user.Username
}
claims["preferred_username"] = user.Username
claims["name"] = displayName
}
if containsScope(scopes, scopeEmail) && user.Email != "" {
claims["email"] = user.Email
claims["email_verified"] = false
}
tokenJWT := jwt.NewWithClaims(method, claims)
tokenJWT.Header["kid"] = signingKey.key.KID
return tokenJWT.SignedString(signingKey.private)
}
func (op *OIDCProvider) JWKS() jwks {
keys := make([]jwk, 0, len(op.signingKeys))
for _, key := range op.signingKeys {
keys = append(keys, key.publicJWK)
}
return jwks{Keys: keys}
}
func subjectForUser(user *User) string {
return strconv.FormatInt(int64(user.ID), 10)
}
func computeAtHash(accessToken string) string {
sum := sha256.Sum256([]byte(accessToken))
return base64.RawURLEncoding.EncodeToString(sum[:len(sum)/2])
}
func containsScope(scopes []string, scope string) bool {
for _, s := range scopes {
if strings.EqualFold(s, scope) {
return true
}
}
return false
}
func getOpenIDConfiguration(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
oidc := oidcProviderFromContext(ctx)
issuer := getIssuer(req)
currentKey := oidc.currentSigningKey()
scopes := make([]string, 0, len(allowedScopes))
for scope := range allowedScopes {
scopes = append(scopes, scope)
}
sort.Strings(scopes)
idTokenAlgs := []string{"RS256"}
if currentKey != nil && currentKey.key.Algorithm != "" {
idTokenAlgs = []string{currentKey.key.Algorithm}
}
config := map[string]interface{}{
"issuer": issuer,
"authorization_endpoint": issuer + "/authorize",
"token_endpoint": issuer + "/token",
"userinfo_endpoint": issuer + "/userinfo",
"jwks_uri": issuer + "/.well-known/jwks.json",
"response_types_supported": []string{string(oauth2.ResponseTypeCode)},
"response_modes_supported": []string{string(oauth2.ResponseModeQuery)},
"grant_types_supported": []string{string(oauth2.GrantTypeAuthorizationCode), string(oauth2.GrantTypeRefreshToken)},
"subject_types_supported": []string{"public"},
"id_token_signing_alg_values_supported": idTokenAlgs,
"scopes_supported": scopes,
"claims_supported": []string{"sub", "preferred_username", "name", "email", "email_verified"},
"token_endpoint_auth_methods_supported": []string{string(oauth2.AuthMethodNone), string(oauth2.AuthMethodClientSecretBasic)},
"introspection_endpoint": issuer + "/introspect",
"revocation_endpoint": issuer + "/revoke",
"authorization_response_iss_parameter_supported": true,
"claims_parameter_supported": false,
"request_parameter_supported": false,
"request_uri_parameter_supported": false,
"code_challenge_methods_supported": []string{pkceMethodPlain, pkceMethodS256},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(config)
}
func getOIDCJWKS(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
oidc := oidcProviderFromContext(ctx)
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Cache-Control", "public, max-age=300")
json.NewEncoder(w).Encode(oidc.JWKS())
}
func userInfo(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
db := dbFromContext(ctx)
if req.Method != http.MethodGet && req.Method != http.MethodPost {
w.Header().Set("Allow", "GET, POST")
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
return
}
tokenValue, err := bearerTokenFromRequest(req)
if err != nil {
writeBearerError(w, http.StatusUnauthorized, "invalid_token", err.Error())
return
}
tokenID, secret, err := UnmarshalSecret[*AccessToken](tokenValue)
if err != nil {
writeBearerError(w, http.StatusUnauthorized, "invalid_token", "Malformed access token")
return
}
token, err := db.FetchAccessToken(ctx, tokenID)
if err == errNoDBRows || (err == nil && !token.VerifySecret(secret)) {
writeBearerError(w, http.StatusUnauthorized, "invalid_token", "Invalid access token")
return
} else if err != nil {
httpError(w, fmt.Errorf("failed to fetch access token: %v", err))
return
}
scopes := parseScopes(token.Scope)
if !containsScope(scopes, scopeOpenID) {
writeBearerError(w, http.StatusForbidden, "insufficient_scope", "Scope openid missing")
return
}
user, err := db.FetchUser(ctx, token.User)
if err != nil {
httpError(w, fmt.Errorf("failed to fetch user: %v", err))
return
}
resp := map[string]interface{}{
"sub": subjectForUser(user),
}
if containsScope(scopes, scopeProfile) {
displayName := user.Name
if displayName == "" {
displayName = user.Username
}
resp["preferred_username"] = user.Username
resp["name"] = displayName
}
if containsScope(scopes, scopeEmail) && user.Email != "" {
resp["email"] = user.Email
resp["email_verified"] = false
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
func bearerTokenFromRequest(req *http.Request) (string, error) {
authz := req.Header.Get("Authorization")
if authz == "" {
return "", fmt.Errorf("Authorization header missing")
}
if len(authz) < 7 || !strings.EqualFold(authz[:7], "Bearer ") {
return "", fmt.Errorf("Unsupported authorization scheme")
}
token := strings.TrimSpace(authz[7:])
if token == "" {
return "", fmt.Errorf("Missing access token")
}
return token, nil
}
func writeBearerError(w http.ResponseWriter, status int, code, description string) {
challenge := "Bearer"
if code != "" {
challenge += fmt.Sprintf(" error=\"%s\"", code)
}
if description != "" {
if code == "" {
challenge += " "
} else {
challenge += ", "
}
challenge += fmt.Sprintf("error_description=\"%s\"", description)
}
w.Header().Set("WWW-Authenticate", challenge)
http.Error(w, http.StatusText(status), status)
}
package main
import (
"fmt"
"strings"
)
const (
pkceRequirementNone = ""
pkceRequirementPlain = pkceMethodPlain
pkceRequirementS256 = pkceMethodS256
)
func normalizeClientPKCERequirement(value string) (string, error) {
switch strings.ToUpper(strings.TrimSpace(value)) {
case "", "NONE":
return pkceRequirementNone, nil
case strings.ToUpper(pkceRequirementPlain):
return pkceRequirementPlain, nil
case pkceRequirementS256:
return pkceRequirementS256, nil
default:
return "", fmt.Errorf("invalid PKCE requirement")
}
}
func allowPKCERequirement(method, requirement string) bool {
requirement = strings.ToUpper(requirement)
method = strings.ToUpper(method)
switch requirement {
case "", "NONE":
return true
case strings.ToUpper(pkceRequirementPlain):
return method == strings.ToUpper(pkceMethodPlain) || method == strings.ToUpper(pkceMethodS256)
case pkceRequirementS256:
return method == strings.ToUpper(pkceMethodS256)
default:
return false
}
}
CREATE TABLE User ( id INTEGER PRIMARY KEY, username TEXT NOT NULL UNIQUE,
name TEXT, email TEXT,
password_hash TEXT, admin INTEGER NOT NULL DEFAULT 0 ); CREATE TABLE Client ( id INTEGER PRIMARY KEY, client_id TEXT NOT NULL UNIQUE, client_secret_hash BLOB, owner INTEGER REFERENCES User(id) ON DELETE CASCADE, redirect_uris TEXT, client_name TEXT,
client_uri TEXT
client_uri TEXT, pkce_requirement TEXT
); CREATE TABLE AccessToken ( id INTEGER PRIMARY KEY, hash BLOB NOT NULL UNIQUE, user INTEGER NOT NULL REFERENCES User(id) ON DELETE CASCADE, client INTEGER REFERENCES Client(id) ON DELETE CASCADE, scope TEXT, issued_at datetime NOT NULL, expires_at datetime NOT NULL,
auth_time datetime,
refresh_hash BLOB UNIQUE, refresh_expires_at datetime ); CREATE TABLE AuthCode ( id INTEGER PRIMARY KEY, hash BLOB NOT NULL UNIQUE, created_at datetime NOT NULL, user INTEGER NOT NULL REFERENCES User(id) ON DELETE CASCADE, client INTEGER NOT NULL REFERENCES Client(id) ON DELETE CASCADE, redirect_uri TEXT,
scope TEXT
scope TEXT, nonce TEXT, code_challenge TEXT, code_challenge_method TEXT ); CREATE TABLE SigningKey ( id INTEGER PRIMARY KEY, kid TEXT NOT NULL UNIQUE, algorithm TEXT NOT NULL, private_key BLOB NOT NULL, created_at datetime NOT NULL
);
CREATE INDEX signing_key_created_at ON SigningKey(created_at);
body {
font-family: sans-serif;
margin: 0;
color: #444;
}
main, #nav-inner, footer {
padding: 0 5px;
max-width: 800px;
margin: 0 auto;
}
main {
padding-bottom: 20px;
}
nav {
border-bottom: 1px solid #eee;
}
nav h1 {
margin: 0;
padding: 10px 0;
font-size: 1.2em;
}
nav h1 a {
color: inherit;
text-decoration: none;
}
h2 {
font-size: 1.2em;
}
table {
border-collapse: collapse;
}
td, th {
border: 1px solid rgb(208, 210, 215);
padding: 5px;
}
button {
border: 1px solid rgb(208, 210, 215);
border-radius: 4px;
padding: 6px 12px;
margin: 4px 0;
color: #444;
background-color: transparent;
cursor: pointer;
}
button:hover {
background-color: rgba(0, 0, 0, 0.02);
}
button[type="submit"]:not(.btn-regular):first-of-type {
background-color: rgb(0, 128, 0);
border-color: rgb(0, 128, 0);
color: white;
}
button[type="submit"]::not(.btn-regular):first-of-type:hover {
background-color: rgb(0, 150, 0);
border-color: rgb(0, 150, 0);
}
input[type="text"], input[type="password"], input[type="url"], textarea {
input[type="text"], input[type="email"], input[type="password"], input[type="url"], textarea {
border: 1px solid rgb(208, 210, 215); border-radius: 4px; padding: 6px; margin: 4px 0; color: #444; }
input[type="text"]:focus, input[type="password"]:focus, input[type="url"]:focus, textarea:focus {
input[type="email"]:focus, input[type="text"]:focus, input[type="password"]:focus, input[type="url"]:focus, textarea:focus {
outline: none;
border-color: rgb(0, 128, 0);
}
label {
display: block;
margin: 15px 0;
}
label input[type="text"], label input[type="password"], label input[type="url"] {
label input[type="email"], label input[type="text"], label input[type="password"], label input[type="url"] {
display: block;
width: 100%;
max-width: 350px;
box-sizing: border-box;
}
label:has(input[type="radio"]) {
margin: 5px 0;
}
label textarea {
display: block;
width: 100%;
resize: vertical;
}
main.narrow {
max-width: 400px;
}
main.narrow input {
max-width: 100%;
}
@media (prefers-color-scheme: dark) {
body {
background: #212529;
color: #f8f9fa;
}
nav {
border-color: rgba(255, 255, 2555, 0.02);
}
a {
color: #809fff;
}
button {
color: #f8f9fa;
}
button:hover {
background-color: rgba(255, 255, 255, 0.02);
}
input[type="text"], input[type="password"], input[type="url"], textarea {
input[type="email"], input[type="text"], input[type="password"], input[type="url"], textarea {
background-color: rgba(255, 255, 255, 0.05); color: inherit; } }
{{ template "head.html" .Base }}
<main class="narrow">
<h1>{{ .ServerName }}</h1>
<p>
Authorize
{{ if .Client.ClientURI }}
<a href="{{ .Client.ClientURI }}" target="_blank">
{{ end }}
{{- if .Client.ClientName -}}
{{- .Client.ClientName -}}
{{- else -}}
<code>{{- .Client.ClientID -}}</code>
{{- end -}}
{{- if .Client.ClientURI -}}
</a>
{{- end -}}
?
</p>
<form method="post" action="">
<input type="hidden" name="_csrf" value="{{ .Base.CSRFToken }}">
<button type="submit" name="authorize">Authorize</button>
<button type="submit" name="deny">Cancel</button>
</form>
</main>
{{ template "foot.html" }}
<!DOCTYPE HTML> <html lang="en"> <head>
<meta charset="utf-8"/>
<meta charset="utf-8"/>
<title>{{ .ServerName }}</title>
<meta name="viewport" content="width=device-width, initial-scale=1"/>
<link rel="stylesheet" href="/static/style.css"/>
<meta name="viewport" content="width=device-width, initial-scale=1"/> <link rel="stylesheet" href="/static/style.css"/>
</head> <body>
{{ template "head.html" .Base }}
{{ template "nav.html" .Base }}
<main>
<p>Welcome, {{ .Me.Username }}!</p>
<form method="post">
<input type="hidden" name="_csrf" value="{{ .Base.CSRFToken }}">
<a href="/user/{{ .Me.ID }}"><button type="button">Settings</button></a>
<button type="submit" formaction="/logout" class="btn-regular">Logout</button>
</form>
<h2>Authorized clients</h2>
{{ with .AuthorizedClients }}
<table>
<tr>
<th>Client</th>
<th>Authorized until</th>
<th></th>
</tr>
{{ range . }}
<tr>
<td>
{{ with .Client }}
{{ if .ClientURI }}
<a href="{{ .ClientURI }}" target="_blank">
{{ end }}
{{ if .ClientName }}
{{ .ClientName }}
{{ else }}
<code>{{ .ClientID }}</code>
{{ end }}
{{ if .ClientURI }}
</a>
{{ end }}
{{ end }}
</td>
<td>{{ .ExpiresAt }}</td>
<td>
<form method="post" action="/client/{{ .Client.ID }}/revoke">
<button type="submit">Revoke</button>
</form>
<form method="post" action="/client/{{ .Client.ID }}/revoke">
<input type="hidden" name="_csrf" value="{{ $.Base.CSRFToken }}">
<button type="submit">Revoke</button>
</form>
</td>
</tr>
{{ end }}
</table>
{{ else }}
<p>No client authorized yet.</p>
{{ end }}
{{ if .Me.Admin }}
<h2>Registered clients</h2>
<p>
<a href="/client/new"><button type="button">Register new client</button></a>
</p>
{{ with .Clients }}
<table>
<tr>
<th>Client ID</th>
<th>Name</th>
</tr>
{{ range . }}
<tr>
<td><a href="/client/{{ .ID }}"><code>{{ .ClientID }}</code></a></td>
<td>{{ .ClientName }}</td>
</tr>
{{ end }}
</table>
{{ else }}
<p>No client registered yet.</p>
{{ end }}
<h2>Users</h2>
<p>
<a href="/user/new"><button type="button">Create user</button></a>
</p>
<table>
<tr>
<th>Username</th>
<th>Name</th> <th>Email</th>
<th>Role</th>
</tr>
{{ range .Users }}
<tr>
<td><a href="/user/{{ .ID }}">{{ .Username }}</a></td>
<td>{{ .Name }}</td>
<td>{{ .Email }}</td>
<td>
{{ if .Admin }}
Administrator
{{ else }}
Regular user
{{ end}}
</td>
</tr>
{{ end }}
</table>
{{ end }}
</main>
{{ template "foot.html" }}
{{ template "head.html" .Base }}
<main class="narrow">
<h1>{{ .ServerName }}</h1>
<form method="post" action="">
<input type="hidden" name="_csrf" value="{{ .Base.CSRFToken }}">
<label>
Username
<input type="text" name="username" autocomplete="username" autofocus>
</label>
<label>
Password
<input type="password" name="password">
</label>
<button type="submit">Login</button>
</form>
</main>
{{ template "foot.html" }}
{{ template "head.html" .Base }}
{{ template "nav.html" .Base }}
<main>
<h2>
{{ if .Client.ID }}
Update client
{{ else }}
Create client
{{ end }}
</h2>
<form method="post" action="">
<input type="hidden" name="_csrf" value="{{ .Base.CSRFToken }}">
{{ if .Client.ClientID }}
Client ID: <code>{{ .Client.ClientID }}</code><br>
{{ end }}
<label>
Name
<input type="text" name="client_name" value="{{ .Client.ClientName }}">
</label>
<label>
Website
<input type="url" name="client_uri" value="{{ .Client.ClientURI }}">
</label>
{{ if .Client.ID }}
<p>
Client type
<br>
<strong>
{{ if .Client.IsPublic }}
Public
{{ else }}
Confidential
{{ end }}
</strong>
</p>
{{ else }}
Client type
<label>
<input type="radio" name="client_type" value="confidential" checked>
Confidential
</label>
<label>
<input type="radio" name="client_type" value="public">
Public
</label>
{{ end }}
<label>
Redirect URIs
<textarea name="redirect_uris" wrap="off">{{ .Client.RedirectURIs }}</textarea>
<small>The special URI <code>http://localhost</code> matches all loopback interfaces.</small><br>
</label>
<label>
PKCE requirement
<select name="pkce_requirement" {{ if not .Client.IsPublic }}disabled{{ end }}>
<option value="" {{ if eq .Client.PKCERequirement "" }}selected{{ end }}>Optional</option>
<option value="plain" {{ if eq .Client.PKCERequirement "plain" }}selected{{ end }}>Require PKCE (plain or S256)</option>
<option value="S256" {{ if eq .Client.PKCERequirement "S256" }}selected{{ end }}>Require PKCE (S256)</option>
</select>
<small>Applies to public clients only.</small>
</label>
<button type="submit">
{{ if .Client.ID }}
Update client
{{ else }}
Create client
{{ end }}
</button>
{{ if .Client.ID }}
{{ if not .Client.IsPublic }}
<button type="submit" name="rotate">Rotate client secret</button>
{{ end }}
<button type="submit" name="delete">Delete client</button>
{{ end }}
<a href="/"><button type="button">Cancel</button></a>
</form>
</main>
{{ template "foot.html" }}
{{ template "head.html" .Base }}
{{ template "nav.html" .Base }}
<main>
<h2>
{{ if .User.ID }}
Update user
{{ else }}
Create user
{{ end }}
</h2>
<form method="post" action="">
<input type="hidden" name="_csrf" value="{{ .Base.CSRFToken }}">
<label>
Username
<input type="text" name="username" value="{{ .User.Username }}" required>
</label>
<label>
Display name
<input type="text" name="name" value="{{ .User.Name }}">
</label>
<label>
Email
<input type="email" name="email" value="{{ .User.Email }}">
</label>
<label>
Password
<input type="password" name="password">
</label>
{{ if not (eq .Me.ID .User.ID) }}
<label>
<input type="checkbox" name="admin" {{ if .User.Admin }}checked{{ end }}>
Administrator
</label>
{{ end }}
<button type="submit">
{{ if .User.ID }}
Update user
{{ else }}
Create user
{{ end }}
</button>
<a href="/"><button type="button">Cancel</button></a>
</form>
</main>
{{ template "foot.html" }}
package main import ( "fmt" "log" "net/http" "net/url"
"strings"
"time"
"github.com/go-chi/chi/v5"
)
func index(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
db := dbFromContext(ctx)
tpl := templateFromContext(ctx)
loginToken := loginTokenFromContext(ctx)
if loginToken == nil {
http.Redirect(w, req, "/login", http.StatusFound)
return
}
me, err := db.FetchUser(ctx, loginToken.User)
if err != nil {
httpError(w, err)
return
}
authorizedClients, err := db.ListAuthorizedClients(ctx, loginToken.User)
if err != nil {
httpError(w, err)
return
}
clients, err := db.ListClients(ctx, loginToken.User)
if err != nil {
httpError(w, err)
return
}
var users []User
if me.Admin {
users, err = db.ListUsers(ctx)
if err != nil {
httpError(w, err)
return
}
}
data := struct {
TemplateBaseData
Me *User
AuthorizedClients []AuthorizedClient
Clients []Client
Users []User
}{
Me: me,
AuthorizedClients: authorizedClients,
Clients: clients,
Users: users,
}
tpl.MustExecuteTemplate(w, "index.html", &data)
tpl.MustExecuteTemplate(req.Context(), w, "index.html", &data)
}
func login(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
db := dbFromContext(ctx)
tpl := templateFromContext(ctx)
q := req.URL.Query()
rawRedirectURI := q.Get("redirect_uri")
if rawRedirectURI == "" {
rawRedirectURI = "/"
}
redirectURI, err := url.Parse(rawRedirectURI)
if err != nil || redirectURI.Scheme != "" || redirectURI.Opaque != "" || redirectURI.User != nil || redirectURI.Host != "" {
http.Error(w, "Invalid redirect URI", http.StatusBadRequest)
return
}
if loginTokenFromContext(ctx) != nil {
http.Redirect(w, req, redirectURI.String(), http.StatusFound)
return
}
username := req.PostFormValue("username")
username := strings.TrimSpace(req.PostFormValue("username"))
password := req.PostFormValue("password")
if username == "" {
tpl.MustExecuteTemplate(w, "login.html", nil)
tpl.MustExecuteTemplate(req.Context(), w, "login.html", nil)
return
}
user, err := db.FetchUserByUsername(ctx, username)
if err != nil && err != errNoDBRows {
httpError(w, fmt.Errorf("failed to fetch user: %v", err))
return
}
if err == nil {
err = user.VerifyPassword(password)
}
if err != nil {
log.Printf("login failed for user %q: %v", username, err)
// TODO: show error message
tpl.MustExecuteTemplate(w, "login.html", nil)
tpl.MustExecuteTemplate(req.Context(), w, "login.html", nil)
return
}
if user.PasswordNeedsRehash() {
if err := user.SetPassword(password); err != nil {
httpError(w, fmt.Errorf("failed to rehash password: %v", err))
return
}
if err := db.StoreUser(ctx, user); err != nil {
httpError(w, fmt.Errorf("failed to store user: %v", err))
return
}
}
token := AccessToken{
User: user.ID,
Scope: internalTokenScope,
}
secret, err := token.Generate(4 * time.Hour)
if err != nil {
httpError(w, fmt.Errorf("failed to generate access token: %v", err))
return
}
token.AuthTime = token.IssuedAt
if err := db.StoreAccessToken(ctx, &token); err != nil {
httpError(w, fmt.Errorf("failed to create access token: %v", err))
return
}
setLoginTokenCookie(w, req, &token, secret)
http.Redirect(w, req, redirectURI.String(), http.StatusFound)
}
func logout(w http.ResponseWriter, req *http.Request) {
unsetLoginTokenCookie(w, req)
http.Redirect(w, req, "/login", http.StatusFound)
}
func manageUser(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
db := dbFromContext(ctx)
tpl := templateFromContext(ctx)
user := new(User)
if idStr := chi.URLParam(req, "id"); idStr != "" {
id, err := ParseID[*User](idStr)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
user, err = db.FetchUser(ctx, id)
if err != nil {
httpError(w, err)
return
}
}
loginToken := loginTokenFromContext(ctx)
if loginToken == nil {
http.Redirect(w, req, "/login", http.StatusFound)
return
}
me, err := db.FetchUser(ctx, loginToken.User)
if err != nil {
httpError(w, err)
return
} else if loginToken.User != user.ID && !me.Admin {
http.Error(w, "Access denied", http.StatusForbidden)
return
}
username := req.PostFormValue("username")
username := strings.TrimSpace(req.PostFormValue("username"))
name := strings.TrimSpace(req.PostFormValue("name"))
email := strings.TrimSpace(req.PostFormValue("email"))
password := req.PostFormValue("password")
admin := req.PostFormValue("admin") == "on"
if username == "" {
data := struct {
TemplateBaseData
User *User
Me *User
}{
User: user,
Me: me,
}
tpl.MustExecuteTemplate(w, "manage-user.html", &data)
tpl.MustExecuteTemplate(req.Context(), w, "manage-user.html", &data)
return } user.Username = username
user.Name = name user.Email = email
if me.Admin && user.ID != me.ID {
user.Admin = admin
}
if password != "" {
if err := user.SetPassword(password); err != nil {
httpError(w, err)
return
}
}
if err := db.StoreUser(ctx, user); err != nil {
httpError(w, err)
return
}
http.Redirect(w, req, "/", http.StatusFound)
}