/csrf.go (raw)
package main
import (
	"context"
	"crypto/subtle"
	"net/http"
	"strings"
)
const (
	csrfCookieName = "vireo-csrf"
	csrfFormField  = "_csrf"
)
func csrfTokenFromContext(ctx context.Context) string {
	if ctx == nil {
		return ""
	}
	if token, ok := ctx.Value(contextKeyCSRFToken).(string); ok {
		return token
	}
	return ""
}
func ensureCSRFToken(w http.ResponseWriter, req *http.Request) (string, error) {
	if cookie, err := req.Cookie(csrfCookieName); err == nil && cookie.Value != "" {
		return cookie.Value, nil
	}
	token, err := generateUID()
	if err != nil {
		return "", err
	}
	http.SetCookie(w, &http.Cookie{
		Name:     csrfCookieName,
		Value:    token,
		Path:     "/",
		HttpOnly: true,
		SameSite: http.SameSiteStrictMode,
		Secure:   isForwardedHTTPS(req),
	})
	return token, nil
}
func csrfMiddleware(next http.Handler) http.Handler {
	cop := http.NewCrossOriginProtection()
	cop.SetDenyHandler(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
		http.Error(w, "Forbidden", http.StatusForbidden)
	}))
	handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
		if isCSRFBypass(req) || req.Header.Get("Authorization") != "" {
			next.ServeHTTP(w, req)
			return
		}
		token, err := ensureCSRFToken(w, req)
		if err != nil {
			httpError(w, err)
			return
		}
		ctx := context.WithValue(req.Context(), contextKeyCSRFToken, token)
		req = req.WithContext(ctx)
		switch req.Method {
		case http.MethodGet, http.MethodHead, http.MethodOptions:
			next.ServeHTTP(w, req)
			return
		}
		if err := req.ParseForm(); err != nil {
			http.Error(w, "Bad request", http.StatusBadRequest)
			return
		}
		formToken := req.PostFormValue(csrfFormField)
		if formToken == "" {
			formToken = req.Header.Get("X-CSRF-Token")
		}
		if formToken == "" || subtle.ConstantTimeCompare([]byte(formToken), []byte(token)) != 1 {
			http.Error(w, "Invalid CSRF token", http.StatusForbidden)
			return
		}
		next.ServeHTTP(w, req)
	})
	return cop.Handler(handler)
}
func isCSRFBypass(req *http.Request) bool {
	path := req.URL.Path
	switch {
	case path == "/token",
		path == "/introspect",
		path == "/revoke",
		path == "/userinfo",
		strings.HasPrefix(path, "/static/"),
		strings.HasPrefix(path, "/.well-known/"),
		path == "/favicon.ico":
		return true
	}
	return false
}