Lindenii Project Forge
Login

server

Vireo IdP server

Warning: Due to various recent migrations, viewing non-HEAD refs may be broken.

/middleware.go (raw)

package main

import (
	"context"
	"fmt"
	"html/template"
	"io"
	"io/fs"
	"mime"
	"net/http"
)

const (
	loginCookieName    = "vireo-login"
	internalTokenScope = "_vireo"
)

type contextKey string

const (
	contextKeyDB         = "db"
	contextKeyTemplate   = "template"
	contextKeyLoginToken = "login-token"
	contextKeyOIDC       = "oidc"
	contextKeyCSRFToken  = "csrf-token"
)

func dbFromContext(ctx context.Context) *DB {
	return ctx.Value(contextKeyDB).(*DB)
}

func templateFromContext(ctx context.Context) *Template {
	return ctx.Value(contextKeyTemplate).(*Template)
}

func oidcProviderFromContext(ctx context.Context) *OIDCProvider {
	return ctx.Value(contextKeyOIDC).(*OIDCProvider)
}

func loginTokenFromContext(ctx context.Context) *AccessToken {
	v := ctx.Value(contextKeyLoginToken)
	if v == nil {
		return nil
	}
	return v.(*AccessToken)
}

func newBaseContext(db *DB, tpl *Template, oidc *OIDCProvider) context.Context {
	ctx := context.Background()
	ctx = context.WithValue(ctx, contextKeyDB, db)
	ctx = context.WithValue(ctx, contextKeyTemplate, tpl)
	ctx = context.WithValue(ctx, contextKeyOIDC, oidc)
	return ctx
}

func setLoginTokenCookie(w http.ResponseWriter, req *http.Request, token *AccessToken, secret string) {
	http.SetCookie(w, &http.Cookie{
		Name:     loginCookieName,
		Value:    MarshalSecret(token.ID, SecretKindAccessToken, secret),
		HttpOnly: true,
		SameSite: http.SameSiteLaxMode,
		Secure:   isForwardedHTTPS(req),
	})
}

func unsetLoginTokenCookie(w http.ResponseWriter, req *http.Request) {
	http.SetCookie(w, &http.Cookie{
		Name:     loginCookieName,
		HttpOnly: true,
		SameSite: http.SameSiteLaxMode,
		Secure:   isForwardedHTTPS(req),
		MaxAge:   -1,
	})
}

func isForwardedHTTPS(req *http.Request) bool {
	if forwarded := req.Header.Get("Forwarded"); forwarded != "" {
		_, params, _ := mime.ParseMediaType("_; " + forwarded)
		return params["proto"] == "https"
	}
	if forwardedProto := req.Header.Get("X-Forwarded-Proto"); forwardedProto != "" {
		return forwardedProto == "https"
	}
	return false
}

func loginTokenMiddleware(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
		cookie, _ := req.Cookie(loginCookieName)
		if cookie == nil {
			next.ServeHTTP(w, req)
			return
		}

		ctx := req.Context()
		db := dbFromContext(ctx)
		tokenID, tokenSecret, _ := UnmarshalSecret[*AccessToken](cookie.Value)
		token, err := db.FetchAccessToken(ctx, tokenID)
		if err == errNoDBRows || (err == nil && !token.VerifySecret(tokenSecret)) {
			unsetLoginTokenCookie(w, req)
			next.ServeHTTP(w, req)
			return
		} else if err != nil {
			httpError(w, fmt.Errorf("failed to fetch access token: %v", err))
			return
		}

		if token.Scope != internalTokenScope {
			http.Error(w, "Invalid login token scope", http.StatusForbidden)
			return
		}
		if token.User == 0 {
			panic("login token with zero user ID")
		}

		ctx = context.WithValue(ctx, contextKeyLoginToken, token)
		req = req.WithContext(ctx)
		next.ServeHTTP(w, req)
	})
}

type TemplateBaseData struct {
	ServerName string
	CSRFToken  string
}

func (data *TemplateBaseData) Base() *TemplateBaseData {
	return data
}

type TemplateData interface {
	Base() *TemplateBaseData
}

type Template struct {
	tpl      *template.Template
	baseData *TemplateBaseData
}

func loadTemplate(fs fs.FS, pattern string, baseData *TemplateBaseData) (*Template, error) {
	tpl, err := template.ParseFS(fs, pattern)
	if err != nil {
		return nil, err
	}
	return &Template{tpl: tpl, baseData: baseData}, nil
}

func (tpl *Template) MustExecuteTemplate(ctx context.Context, w io.Writer, filename string, data TemplateData) {
	baseCopy := *tpl.baseData
	if token := csrfTokenFromContext(ctx); token != "" {
		baseCopy.CSRFToken = token
	}
	if data == nil {
		base := baseCopy
		data = &base
	} else {
		*data.Base() = baseCopy
	}
	if err := tpl.tpl.ExecuteTemplate(w, filename, data); err != nil {
		panic(err)
	}
}