Warning: Due to various recent migrations, viewing non-HEAD refs may be broken.
/csrf.go (raw)
package main
import (
"context"
"crypto/subtle"
"net/http"
"strings"
)
const (
csrfCookieName = "vireo-csrf"
csrfFormField = "_csrf"
)
func csrfTokenFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}
if token, ok := ctx.Value(contextKeyCSRFToken).(string); ok {
return token
}
return ""
}
func ensureCSRFToken(w http.ResponseWriter, req *http.Request) (string, error) {
if cookie, err := req.Cookie(csrfCookieName); err == nil && cookie.Value != "" {
return cookie.Value, nil
}
token, err := generateUID()
if err != nil {
return "", err
}
http.SetCookie(w, &http.Cookie{
Name: csrfCookieName,
Value: token,
Path: "/",
HttpOnly: true,
SameSite: http.SameSiteStrictMode,
Secure: isForwardedHTTPS(req),
})
return token, nil
}
func csrfMiddleware(next http.Handler) http.Handler {
cop := http.NewCrossOriginProtection()
cop.SetDenyHandler(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
http.Error(w, "Forbidden", http.StatusForbidden)
}))
handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if isCSRFBypass(req) || req.Header.Get("Authorization") != "" {
next.ServeHTTP(w, req)
return
}
token, err := ensureCSRFToken(w, req)
if err != nil {
httpError(w, err)
return
}
ctx := context.WithValue(req.Context(), contextKeyCSRFToken, token)
req = req.WithContext(ctx)
switch req.Method {
case http.MethodGet, http.MethodHead, http.MethodOptions:
next.ServeHTTP(w, req)
return
}
if err := req.ParseForm(); err != nil {
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
formToken := req.PostFormValue(csrfFormField)
if formToken == "" {
formToken = req.Header.Get("X-CSRF-Token")
}
if formToken == "" || subtle.ConstantTimeCompare([]byte(formToken), []byte(token)) != 1 {
http.Error(w, "Invalid CSRF token", http.StatusForbidden)
return
}
next.ServeHTTP(w, req)
})
return cop.Handler(handler)
}
func isCSRFBypass(req *http.Request) bool {
path := req.URL.Path
switch {
case path == "/token",
path == "/introspect",
path == "/revoke",
path == "/userinfo",
strings.HasPrefix(path, "/static/"),
strings.HasPrefix(path, "/.well-known/"),
path == "/favicon.ico":
return true
}
return false
}