Lindenii Project Forge
Login

server

Vireo IdP server

Hi… I am well aware that this diff view is very suboptimal. It will be fixed when the refactored server comes along!

Commit info
ID
3a5d08a13a45d60b32b357d65cac6efcefc9ba07
Author
Author date
Mon, 19 Feb 2024 13:09:54 +0100
Committer
Committer date
Mon, 19 Feb 2024 13:09:54 +0100
Actions
Add support for token introspection

Closes: https://todo.sr.ht/~emersion/sinwon/7
package main

import (
	"context"
	"embed"
	"flag"
	"html/template"
	"log"
	"net"
	"net/http"
	"time"

	"github.com/go-chi/chi/v5"
)

var (
	//go:embed template
	templateFS embed.FS
	//go:embed static
	staticFS embed.FS
)

func main() {
	var configFilename, listenAddr string
	flag.StringVar(&configFilename, "config", "/etc/sinwon/config", "Configuration filename")
	flag.StringVar(&listenAddr, "listen", ":8080", "HTTP listen address")
	flag.Parse()

	cfg, err := loadConfig(configFilename)
	if err != nil {
		log.Fatalf("Failed to load config file: %v", err)
	}

	if listenAddr == "" {
		listenAddr = cfg.Listen
	}
	if listenAddr == "" {
		log.Fatalf("Missing listen configuration")
	}
	if cfg.Database == "" {
		log.Fatalf("Missing database configuration")
	}

	db, err := openDB(cfg.Database)
	if err != nil {
		log.Fatalf("Failed to open DB: %v", err)
	}

	tpl := template.Must(template.ParseFS(templateFS, "template/*.html"))

	mux := chi.NewRouter()
	mux.Handle("/static/*", http.FileServer(http.FS(staticFS)))
	mux.Get("/", index)
	mux.HandleFunc("/login", login)
	mux.Post("/logout", logout)
	mux.HandleFunc("/client/new", manageClient)
	mux.HandleFunc("/client/{id}", manageClient)
	mux.HandleFunc("/user/new", manageUser)
	mux.HandleFunc("/user/{id}", manageUser)
	mux.Get("/.well-known/oauth-authorization-server", getOAuthServerMetadata)
	mux.HandleFunc("/authorize", authorize)
	mux.Post("/token", exchangeToken)
	mux.Post("/introspect", introspectToken)

	go maintainDBLoop(db)

	server := http.Server{
		Addr:    listenAddr,
		Handler: loginTokenMiddleware(mux),
		BaseContext: func(net.Listener) context.Context {
			return newBaseContext(db, tpl)
		},
	}
	log.Printf("OAuth server listening on %v", server.Addr)
	if err := server.ListenAndServe(); err != nil {
		log.Fatalf("Failed to listen and serve: %v", err)
	}
}

func httpError(w http.ResponseWriter, err error) {
	log.Print(err)
	http.Error(w, "Internal server error", http.StatusInternalServerError)
}

func maintainDBLoop(db *DB) {
	ticker := time.NewTicker(15 * time.Minute)
	defer ticker.Stop()

	for range ticker.C {
		ctx := context.Background()
		ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
		if err := db.Maintain(ctx); err != nil {
			log.Printf("Failed to perform database maintenance: %v", err)
		}
		cancel()
	}
}
package main

import (
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"log"
	"mime"
	"net"
	"net/http"
	"net/url"
	"strings"
	"time"

	"git.sr.ht/~emersion/go-oauth2"
)

func getOAuthServerMetadata(w http.ResponseWriter, req *http.Request) {
	issuerURL := url.URL{
		Scheme: "https",
		Host:   req.Host,
	}
	if !isForwardedHTTPS(req) && isLoopback(req) {
		// TODO: add config option for allowed reverse proxy IPs
		issuerURL.Scheme = "http"
	}
	issuer := issuerURL.String()

	w.Header().Set("Content-Type", "application/json")
	json.NewEncoder(w).Encode(&oauth2.ServerMetadata{
		Issuer:                            issuer,
		AuthorizationEndpoint:             issuer + "/authorize",
		TokenEndpoint:                     issuer + "/token",
		ResponseTypesSupported:            []oauth2.ResponseType{oauth2.ResponseTypeCode},
		ResponseModesSupported:            []oauth2.ResponseMode{oauth2.ResponseModeQuery},
		GrantTypesSupported:               []oauth2.GrantType{oauth2.GrantTypeAuthorizationCode},
		TokenEndpointAuthMethodsSupported: []oauth2.AuthMethod{oauth2.AuthMethodClientSecretBasic},
		Issuer:                                    issuer,
		AuthorizationEndpoint:                     issuer + "/authorize",
		TokenEndpoint:                             issuer + "/token",
		IntrospectionEndpoint:                     issuer + "/introspect",
		ResponseTypesSupported:                    []oauth2.ResponseType{oauth2.ResponseTypeCode},
		ResponseModesSupported:                    []oauth2.ResponseMode{oauth2.ResponseModeQuery},
		GrantTypesSupported:                       []oauth2.GrantType{oauth2.GrantTypeAuthorizationCode},
		TokenEndpointAuthMethodsSupported:         []oauth2.AuthMethod{oauth2.AuthMethodClientSecretBasic},
		IntrospectionEndpointAuthMethodsSupported: []oauth2.AuthMethod{oauth2.AuthMethodClientSecretBasic},
	})
}

func isLoopback(req *http.Request) bool {
	host, _, _ := net.SplitHostPort(req.RemoteAddr)
	ip := net.ParseIP(host)
	return ip.IsLoopback()
}

func authorize(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	db := dbFromContext(ctx)
	tpl := templateFromContext(ctx)

	q := req.URL.Query()
	respType := oauth2.ResponseType(q.Get("response_type"))
	clientID := q.Get("client_id")
	rawRedirectURI := q.Get("redirect_uri")
	scope := q.Get("scope")
	state := q.Get("state")

	if clientID == "" {
		http.Error(w, "Missing client ID", http.StatusBadRequest)
		return
	}

	client, err := db.FetchClientByClientID(ctx, clientID)
	if err == errNoDBRows {
		http.Error(w, "Invalid client ID", http.StatusForbidden)
		return
	} else if err != nil {
		httpError(w, fmt.Errorf("failed to fetch client: %v", err))
		return
	}

	var allowedRedirectURIs []*url.URL
	for _, s := range strings.Split(client.RedirectURIs, "\n") {
		if s == "" {
			continue
		}
		u, err := url.Parse(s)
		if err != nil {
			httpError(w, fmt.Errorf("failed to parse client redirect URI"))
			return
		}
		allowedRedirectURIs = append(allowedRedirectURIs, u)
	}

	var redirectURI *url.URL
	if rawRedirectURI != "" {
		redirectURI, err = url.Parse(rawRedirectURI)
		if err != nil {
			http.Error(w, "Invalid redirect URI", http.StatusBadRequest)
			return
		}
		if !validateRedirectURI(redirectURI, allowedRedirectURIs) {
			http.Error(w, "Invalid redirect URI", http.StatusBadRequest)
			return
		}
	} else {
		if len(allowedRedirectURIs) == 0 {
			http.Error(w, "Missing redirect URI", http.StatusBadRequest)
			return
		}
		redirectURI = allowedRedirectURIs[0]
	}

	if respType != oauth2.ResponseTypeCode {
		redirectClientError(w, req, redirectURI, state, &oauth2.Error{
			Code: oauth2.ErrorCodeUnsupportedResponseType,
		})
		return
	}

	// TODO: add support for scope
	if scope != "" {
		redirectClientError(w, req, redirectURI, state, &oauth2.Error{
			Code: oauth2.ErrorCodeInvalidScope,
		})
		return
	}

	loginToken := loginTokenFromContext(ctx)
	if loginToken == nil {
		q := make(url.Values)
		q.Set("redirect_uri", req.URL.String())
		u := url.URL{
			Path:     "/login",
			RawQuery: q.Encode(),
		}
		http.Redirect(w, req, u.String(), http.StatusFound)
		return
	}

	_ = req.ParseForm()
	if _, ok := req.PostForm["deny"]; ok {
		redirectClientError(w, req, redirectURI, state, &oauth2.Error{
			Code: oauth2.ErrorCodeAccessDenied,
		})
		return
	}
	if _, ok := req.PostForm["authorize"]; !ok {
		data := struct {
			Client *Client
		}{
			Client: client,
		}
		if err := tpl.ExecuteTemplate(w, "authorize.html", data); err != nil {
			panic(err)
		}
		return
	}

	authCode, secret, err := NewAuthCode(loginToken.User, client.ID, scope)
	if err != nil {
		httpError(w, fmt.Errorf("failed to generate authentication code: %v", err))
		return
	}

	if err := db.CreateAuthCode(ctx, authCode); err != nil {
		httpError(w, fmt.Errorf("failed to create authentication code: %v", err))
		return
	}

	code := MarshalSecret(authCode.ID, secret)

	values := make(url.Values)
	values.Set("code", code)
	if state != "" {
		values.Set("state", state)
	}
	redirectClient(w, req, redirectURI, values)
}

func exchangeToken(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	db := dbFromContext(ctx)

	values, err := parseRequestBody(req)
	if err != nil {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidRequest,
			Description: err.Error(),
		})
		return
	}

	clientID := values.Get("client_id")
	grantType := oauth2.GrantType(values.Get("grant_type"))
	scope := values.Get("scope")

	authClientID, clientSecret, _ := req.BasicAuth()
	if clientID == "" && authClientID == "" {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidRequest,
			Description: "Missing client ID",
		})
		return
	} else if clientID == "" {
		clientID = authClientID
	} else if clientID != authClientID {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidRequest,
			Description: "Client ID in request body doesn't match Authorization header field",
		})
		return
	}

	client, err := db.FetchClientByClientID(ctx, clientID)
	if err == errNoDBRows {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidClient,
			Description: "Invalid client ID",
		})
		return
	} else if err != nil {
		oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
		return
	}

	if client.ClientSecretHash != nil {
		if !client.VerifySecret(clientSecret) {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid client secret",
			})
			return
		}
	}

	if grantType != oauth2.GrantTypeAuthorizationCode {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeUnsupportedGrantType,
			Description: "Unsupported grant type",
		})
		return
	}

	codeID, codeSecret, _ := UnmarshalSecret[*AuthCode](values.Get("code"))
	authCode, err := db.PopAuthCode(ctx, codeID)
	if err == errNoDBRows || (err == nil && !authCode.VerifySecret(codeSecret)) || authCode.Client != client.ID {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeAccessDenied,
			Description: "Invalid authorization code",
		})
		return
	} else if err != nil {
		oauthError(w, fmt.Errorf("failed to fetch authorization code: %v", err))
		return
	}

	if scope != authCode.Scope {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeAccessDenied,
			Description: "Invalid scope",
		})
		return
	}

	// TODO: check redirect_uri

	token, secret, err := NewAccessTokenFromAuthCode(authCode)
	if err != nil {
		oauthError(w, err)
		return
	}

	if err := db.CreateAccessToken(ctx, token); err != nil {
		oauthError(w, fmt.Errorf("failed to create access token: %v", err))
		return
	}

	w.Header().Set("Content-Type", "application/json")
	w.Header().Set("Cache-Control", "no-store")
	json.NewEncoder(w).Encode(&oauth2.TokenResp{
		AccessToken: MarshalSecret(token.ID, secret),
		TokenType:   oauth2.TokenTypeBearer,
		ExpiresIn:   time.Until(token.ExpiresAt),
		Scope:       strings.Split(token.Scope, " "),
	})
}

func introspectToken(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	db := dbFromContext(ctx)

	values, err := parseRequestBody(req)
	if err != nil {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidRequest,
			Description: err.Error(),
		})
		return
	}

	var client *Client
	if clientID, clientSecret, ok := req.BasicAuth(); ok {
		client, err = db.FetchClientByClientID(ctx, clientID)
		if err == errNoDBRows || (err == nil && !client.VerifySecret(clientSecret)) {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidClient,
				Description: "Invalid client ID or secret",
			})
			return
		} else if err != nil {
			oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
			return
		}
	}

	tokenID, secret, _ := UnmarshalSecret[*AccessToken](values.Get("token"))
	token, err := db.FetchAccessToken(ctx, tokenID)
	if err == errNoDBRows || (err == nil && !token.VerifySecret(secret)) {
		token = nil
	} else if err != nil {
		oauthError(w, fmt.Errorf("failed to fetch access token: %v", err))
		return
	}

	var resp oauth2.IntrospectionResp
	if token != nil {
		if client == nil {
			if client.ClientSecretHash != nil {
				oauthError(w, &oauth2.Error{
					Code:        oauth2.ErrorCodeInvalidClient,
					Description: "Missing client ID and secret",
				})
				return
			}

			client, err = db.FetchClient(ctx, token.Client)
			if err != nil {
				oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
				return
			}
		}

		user, err := db.FetchUser(ctx, token.User)
		if err != nil {
			oauthError(w, fmt.Errorf("failed to fetch user: %v", err))
			return
		}

		resp.Active = true
		resp.TokenType = oauth2.TokenTypeBearer
		resp.ExpiresAt = token.ExpiresAt
		resp.IssuedAt = token.IssuedAt
		resp.ClientID = client.ClientID
		resp.Username = user.Username
	}

	w.Header().Set("Content-Type", "application/json")
	json.NewEncoder(w).Encode(&resp)
}

func parseRequestBody(req *http.Request) (url.Values, error) {
	ct := req.Header.Get("Content-Type")
	if ct != "" {
		mimeType, _, err := mime.ParseMediaType(req.Header.Get("Content-Type"))
		if err != nil {
			return nil, fmt.Errorf("malformed Content-Type header field")
		} else if mimeType != "application/x-www-form-urlencoded" {
			return nil, fmt.Errorf("unsupported request content type")
		}
	}

	r := io.LimitReader(req.Body, 10<<20)
	b, err := io.ReadAll(r)
	if err != nil {
		return nil, fmt.Errorf("failed to read request body: %v", err)
	}

	values, err := url.ParseQuery(string(b))
	if err != nil {
		return nil, fmt.Errorf("failed to parse request body: %v", err)
	}

	return values, nil
}

func oauthError(w http.ResponseWriter, err error) {
	var oauthErr *oauth2.Error
	if !errors.As(err, &oauthErr) {
		oauthErr = &oauth2.Error{Code: oauth2.ErrorCodeServerError}
		log.Print(err)
	}

	statusCode := http.StatusInternalServerError
	switch oauthErr.Code {
	case oauth2.ErrorCodeInvalidRequest, oauth2.ErrorCodeUnsupportedResponseType, oauth2.ErrorCodeInvalidScope, oauth2.ErrorCodeInvalidClient, oauth2.ErrorCodeInvalidGrant, oauth2.ErrorCodeUnsupportedGrantType:
		statusCode = http.StatusBadRequest
	case oauth2.ErrorCodeUnauthorizedClient, oauth2.ErrorCodeAccessDenied:
		statusCode = http.StatusForbidden
	case oauth2.ErrorCodeTemporarilyUnavailable:
		statusCode = http.StatusServiceUnavailable
	}

	w.Header().Set("Content-Type", "application/json")
	w.WriteHeader(statusCode)
	json.NewEncoder(w).Encode(oauthErr)
}

func redirectClient(w http.ResponseWriter, req *http.Request, redirectURI *url.URL, values url.Values) {
	q := redirectURI.Query()
	for k, v := range values {
		q[k] = v
	}

	u := *redirectURI
	u.RawQuery = q.Encode()

	http.Redirect(w, req, u.String(), http.StatusFound)
}

func redirectClientError(w http.ResponseWriter, req *http.Request, redirectURI *url.URL, state string, err error) {
	var oauthErr *oauth2.Error
	if !errors.As(err, &oauthErr) {
		oauthErr = &oauth2.Error{Code: oauth2.ErrorCodeServerError}
		log.Print(err)
	}

	values := make(url.Values)
	values.Set("error", string(oauthErr.Code))
	if oauthErr.Description != "" {
		values.Set("error_description", oauthErr.Description)
	}
	if oauthErr.URI != "" {
		values.Set("error_uri", oauthErr.URI)
	}
	if state != "" {
		values.Set("state", state)
	}
	redirectClient(w, req, redirectURI, values)
}

func validateRedirectURI(u *url.URL, allowedURIs []*url.URL) bool {
	// Loopback interface, see RFC 8252 section 7.3
	host, _, _ := net.SplitHostPort(u.Host)
	ip := net.ParseIP(host)
	if u.Scheme == "http" && ip.IsLoopback() {
		uu := *u
		uu.Host = "localhost"
		u = &uu
	}

	for _, allowed := range allowedURIs {
		if u.String() == allowed.String() {
			return true
		}
	}

	return false
}