Lindenii Project Forge
Login

server

Vireo IdP server

Hi… I am well aware that this diff view is very suboptimal. It will be fixed when the refactored server comes along!

Commit info
ID
319ad984dec6499b61cdcb8f4e5948109add1108
Author
Author date
Sun, 18 Feb 2024 22:14:10 +0100
Committer
Committer date
Sun, 18 Feb 2024 22:14:10 +0100
Actions
Add template to context
package main

import (
	"context"
	"embed"
	"encoding/json"
	"errors"
	"flag"
	"fmt"
	"html/template"
	"io"
	"log"
	"mime"
	"net"
	"net/http"
	"net/url"
	"strings"
	"time"

	"git.sr.ht/~emersion/go-oauth2"
	"github.com/go-chi/chi/v5"
)

var (
	//go:embed template
	templateFS embed.FS
	//go:embed static
	staticFS embed.FS
)

func main() {
	var listenAddr string
	flag.StringVar(&listenAddr, "listen", ":8080", "HTTP listen address")
	flag.Parse()

	tpl := template.Must(template.ParseFS(templateFS, "template/*.html"))

	db, err := openDB("sinwon.db")
	if err != nil {
		log.Fatalf("Failed to open DB: %v", err)
	}

	mux := chi.NewRouter()

	mux.Handle("/static/*", http.FileServer(http.FS(staticFS)))

	mux.Get("/", func(w http.ResponseWriter, req *http.Request) {
		ctx := req.Context()

		loginToken := loginTokenFromContext(ctx)
		if loginToken == nil {
			http.Redirect(w, req, "/login", http.StatusFound)
			return
		}

		clients, err := db.ListClients(ctx, loginToken.User)
		if err != nil {
			httpError(w, err)
			return
		}

		if err := tpl.ExecuteTemplate(w, "index.html", clients); err != nil {
			panic(err)
		}
	})

	mux.Post("/client/new", func(w http.ResponseWriter, req *http.Request) {
		ctx := req.Context()

		loginToken := loginTokenFromContext(ctx)
		if loginToken == nil {
			http.Redirect(w, req, "/login", http.StatusFound)
			return
		}

		client, clientSecret, err := NewClient(loginToken.User)
		if err != nil {
			httpError(w, err)
			return
		}
		if err := db.StoreClient(ctx, client); err != nil {
			httpError(w, err)
			return
		}

		data := struct {
			ClientID     string
			ClientSecret string
		}{
			ClientID:     client.ClientID,
			ClientSecret: clientSecret,
		}
		if err := tpl.ExecuteTemplate(w, "client-secret.html", &data); err != nil {
			panic(err)
		}
	})

	mux.HandleFunc("/login", func(w http.ResponseWriter, req *http.Request) {
		ctx := req.Context()

		q := req.URL.Query()
		rawRedirectURI := q.Get("redirect_uri")
		if rawRedirectURI == "" {
			rawRedirectURI = "/"
		}

		redirectURI, err := url.Parse(rawRedirectURI)
		if err != nil || redirectURI.Scheme != "" || redirectURI.Opaque != "" || redirectURI.User != nil || redirectURI.Host != "" {
			http.Error(w, "Invalid redirect URI", http.StatusBadRequest)
			return
		}

		if loginTokenFromContext(ctx) != nil {
			http.Redirect(w, req, redirectURI.String(), http.StatusFound)
			return
		}

		username := req.PostFormValue("username")
		password := req.PostFormValue("password")
		if username == "" {
			if err := tpl.ExecuteTemplate(w, "login.html", nil); err != nil {
				panic(err)
			}
			return
		}

		user, err := db.FetchUser(ctx, username)
		if err != nil && err != errNoDBRows {
			httpError(w, fmt.Errorf("failed to fetch user: %v", err))
			return
		}
		if err == nil {
			err = user.VerifyPassword(password)
		}
		if err != nil {
			log.Printf("login failed for user %q: %v", username, err)
			// TODO: show error message
			if err := tpl.ExecuteTemplate(w, "login.html", nil); err != nil {
				panic(err)
			}
			return
		}

		token := AccessToken{
			User:  user.ID,
			Scope: internalTokenScope,
		}
		secret, err := token.Generate()
		if err != nil {
			httpError(w, fmt.Errorf("failed to generate access token: %v", err))
			return
		}
		if err := db.CreateAccessToken(ctx, &token); err != nil {
			httpError(w, fmt.Errorf("failed to create access token: %v", err))
			return
		}

		setLoginTokenCookie(w, &token, secret)
		http.Redirect(w, req, redirectURI.String(), http.StatusFound)
	})

	mux.HandleFunc("/authorize", func(w http.ResponseWriter, req *http.Request) {
		ctx := req.Context()

		q := req.URL.Query()
		respType := oauth2.ResponseType(q.Get("response_type"))
		clientID := q.Get("client_id")
		rawRedirectURI := q.Get("redirect_uri")
		scope := q.Get("scope")
		state := q.Get("state")

		if clientID == "" {
			http.Error(w, "Missing client ID", http.StatusBadRequest)
			return
		}

		client, err := db.FetchClient(ctx, clientID)
		if err == errNoDBRows {
			http.Error(w, "Invalid client ID", http.StatusForbidden)
			return
		} else if err != nil {
			httpError(w, fmt.Errorf("failed to fetch client: %v", err))
			return
		}

		// TODO: validate redirect URI with client
		// TODO: make redirect URI optional
		redirectURI, err := url.Parse(rawRedirectURI)
		if err != nil {
			http.Error(w, "Invalid redirect URI", http.StatusBadRequest)
			return
		}

		if respType != oauth2.ResponseTypeCode {
			redirectClientError(w, req, redirectURI, state, &oauth2.Error{
				Code: oauth2.ErrorCodeUnsupportedResponseType,
			})
			return
		}

		// TODO: add support for scope
		if scope != "" {
			redirectClientError(w, req, redirectURI, state, &oauth2.Error{
				Code: oauth2.ErrorCodeInvalidScope,
			})
			return
		}

		loginToken := loginTokenFromContext(ctx)
		if loginToken == nil {
			q := make(url.Values)
			q.Set("redirect_uri", req.URL.String())
			u := url.URL{
				Path:     "/login",
				RawQuery: q.Encode(),
			}
			http.Redirect(w, req, u.String(), http.StatusFound)
			return
		}

		_ = req.ParseForm()
		if _, ok := req.PostForm["authorize"]; !ok {
			if err := tpl.ExecuteTemplate(w, "authorize.html", nil); err != nil {
				panic(err)
			}
			return
		}

		authCode, secret, err := NewAuthCode(loginToken.User, client.ID, scope)
		if err != nil {
			httpError(w, fmt.Errorf("failed to generate authentication code: %v", err))
			return
		}

		if err := db.CreateAuthCode(ctx, authCode); err != nil {
			httpError(w, fmt.Errorf("failed to create authentication code: %v", err))
			return
		}

		code := MarshalSecret(authCode.ID, secret)

		values := make(url.Values)
		values.Set("code", code)
		if state != "" {
			values.Set("state", state)
		}
		redirectClient(w, req, redirectURI, values)
	})

	mux.Post("/token", func(w http.ResponseWriter, req *http.Request) {
		ctx := req.Context()

		values, err := parseRequestBody(req)
		if err != nil {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidRequest,
				Description: err.Error(),
			})
			return
		}

		clientID := values.Get("client_id")
		grantType := oauth2.GrantType(values.Get("grant_type"))
		scope := values.Get("scope")

		authClientID, clientSecret, _ := req.BasicAuth()
		if clientID == "" && authClientID == "" {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidRequest,
				Description: "Missing client ID",
			})
			return
		} else if clientID == "" {
			clientID = authClientID
		} else if clientID != authClientID {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidRequest,
				Description: "Client ID in request body doesn't match Authorization header field",
			})
			return
		}

		client, err := db.FetchClient(ctx, clientID)
		if err == errNoDBRows {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidClient,
				Description: "Invalid client ID",
			})
			return
		} else if err != nil {
			oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
			return
		}

		if client.ClientSecretHash != nil {
			if !client.VerifySecret(clientSecret) {
				oauthError(w, &oauth2.Error{
					Code:        oauth2.ErrorCodeAccessDenied,
					Description: "Invalid client secret",
				})
				return
			}
		}

		if grantType != oauth2.GrantTypeAuthorizationCode {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeUnsupportedGrantType,
				Description: "Unsupported grant type",
			})
			return
		}

		codeID, codeSecret, _ := UnmarshalSecret[*AuthCode](values.Get("code"))
		authCode, err := db.PopAuthCode(ctx, codeID)
		if err == errNoDBRows || (err == nil && !authCode.VerifySecret(codeSecret)) || authCode.Client != client.ID {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid authorization code",
			})
			return
		} else if err != nil {
			oauthError(w, fmt.Errorf("failed to fetch authorization code: %v", err))
			return
		}

		if scope != authCode.Scope {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid scope",
			})
			return
		}

		// TODO: check redirect_uri

		token, secret, err := NewAccessTokenFromAuthCode(authCode)
		if err != nil {
			oauthError(w, err)
			return
		}

		if err := db.CreateAccessToken(ctx, token); err != nil {
			oauthError(w, fmt.Errorf("failed to create access token: %v", err))
			return
		}

		w.Header().Set("Content-Type", "application/json")
		w.Header().Set("Cache-Control", "no-store")
		json.NewEncoder(w).Encode(&oauth2.TokenResp{
			AccessToken: secret,
			TokenType:   oauth2.TokenTypeBearer,
			ExpiresIn:   time.Until(token.ExpiresAt),
			Scope:       strings.Split(token.Scope, " "),
		})
	})

	server := http.Server{
		Addr:    listenAddr,
		Handler: loginTokenMiddleware(mux),
		BaseContext: func(net.Listener) context.Context {
			return newBaseContext(db)
			return newBaseContext(db, tpl)
		},
	}
	log.Printf("OAuth server listening on %v", server.Addr)
	if err := server.ListenAndServe(); err != nil {
		log.Fatalf("Failed to listen and serve: %v", err)
	}
}

func parseRequestBody(req *http.Request) (url.Values, error) {
	ct := req.Header.Get("Content-Type")
	if ct != "" {
		mimeType, _, err := mime.ParseMediaType(req.Header.Get("Content-Type"))
		if err != nil {
			return nil, fmt.Errorf("malformed Content-Type header field")
		} else if mimeType != "application/x-www-form-urlencoded" {
			return nil, fmt.Errorf("unsupported request content type")
		}
	}

	r := io.LimitReader(req.Body, 10<<20)
	b, err := io.ReadAll(r)
	if err != nil {
		return nil, fmt.Errorf("failed to read request body: %v", err)
	}

	values, err := url.ParseQuery(string(b))
	if err != nil {
		return nil, fmt.Errorf("failed to parse request body: %v", err)
	}

	return values, nil
}

func httpError(w http.ResponseWriter, err error) {
	log.Print(err)
	http.Error(w, "Internal server error", http.StatusInternalServerError)
}

func oauthError(w http.ResponseWriter, err error) {
	var oauthErr *oauth2.Error
	if !errors.As(err, &oauthErr) {
		oauthErr = &oauth2.Error{Code: oauth2.ErrorCodeServerError}
		log.Print(err)
	}

	statusCode := http.StatusInternalServerError
	switch oauthErr.Code {
	case oauth2.ErrorCodeInvalidRequest, oauth2.ErrorCodeUnsupportedResponseType, oauth2.ErrorCodeInvalidScope, oauth2.ErrorCodeInvalidClient, oauth2.ErrorCodeInvalidGrant, oauth2.ErrorCodeUnsupportedGrantType:
		statusCode = http.StatusBadRequest
	case oauth2.ErrorCodeUnauthorizedClient, oauth2.ErrorCodeAccessDenied:
		statusCode = http.StatusForbidden
	case oauth2.ErrorCodeTemporarilyUnavailable:
		statusCode = http.StatusServiceUnavailable
	}

	w.Header().Set("Content-Type", "application/json")
	w.WriteHeader(statusCode)
	json.NewEncoder(w).Encode(oauthErr)
}

func redirectClient(w http.ResponseWriter, req *http.Request, redirectURI *url.URL, values url.Values) {
	q := redirectURI.Query()
	for k, v := range values {
		q[k] = v
	}

	u := *redirectURI
	u.RawQuery = q.Encode()

	http.Redirect(w, req, u.String(), http.StatusFound)
}

func redirectClientError(w http.ResponseWriter, req *http.Request, redirectURI *url.URL, state string, err error) {
	var oauthErr *oauth2.Error
	if !errors.As(err, &oauthErr) {
		oauthErr = &oauth2.Error{Code: oauth2.ErrorCodeServerError}
		log.Print(err)
	}

	values := make(url.Values)
	values.Set("error", string(oauthErr.Code))
	if oauthErr.Description != "" {
		values.Set("error_description", oauthErr.Description)
	}
	if oauthErr.URI != "" {
		values.Set("error_uri", oauthErr.URI)
	}
	if state != "" {
		values.Set("state", state)
	}
	redirectClient(w, req, redirectURI, values)
}
package main

import (
	"context"
	"fmt"
	"html/template"
	"net/http"
)

const internalTokenScope = "_sinwon"

type contextKey string

const (
	contextKeyDB         = "db"
	contextKeyTemplate   = "template"
	contextKeyLoginToken = "login-token"
)

func dbFromContext(ctx context.Context) *DB {
	return ctx.Value(contextKeyDB).(*DB)
}

func templateFromContext(ctx context.Context) *template.Template {
	return ctx.Value(contextKeyTemplate).(*template.Template)
}

func loginTokenFromContext(ctx context.Context) *AccessToken {
	v := ctx.Value(contextKeyLoginToken)
	if v == nil {
		return nil
	}
	return v.(*AccessToken)
}

func newBaseContext(db *DB) context.Context {
	return context.WithValue(context.Background(), contextKeyDB, db)
func newBaseContext(db *DB, tpl *template.Template) context.Context {
	ctx := context.Background()
	ctx = context.WithValue(ctx, contextKeyDB, db)
	ctx = context.WithValue(ctx, contextKeyLoginToken, tpl)
	return ctx
}

func setLoginTokenCookie(w http.ResponseWriter, token *AccessToken, secret string) {
	http.SetCookie(w, &http.Cookie{
		Name:     "sinwon-token",
		Value:    MarshalSecret(token.ID, secret),
		HttpOnly: true,
		SameSite: http.SameSiteStrictMode,
		// TODO: Secure
	})
}

func loginTokenMiddleware(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
		cookie, _ := req.Cookie("sinwon-token")
		if cookie == nil {
			next.ServeHTTP(w, req)
			return
		}

		ctx := req.Context()
		db := dbFromContext(ctx)
		tokenID, tokenSecret, _ := UnmarshalSecret[*AccessToken](cookie.Value)
		token, err := db.FetchAccessToken(ctx, tokenID)
		if err == errNoDBRows || (err == nil && !token.VerifySecret(tokenSecret)) {
			http.SetCookie(w, &http.Cookie{
				Name:     "sinwon-token",
				HttpOnly: true,
				SameSite: http.SameSiteStrictMode,
				MaxAge:   -1,
			})
			next.ServeHTTP(w, req)
			return
		} else if err != nil {
			httpError(w, fmt.Errorf("failed to fetch access token: %v", err))
			return
		}

		if token.Scope != internalTokenScope {
			http.Error(w, "Invalid login token scope", http.StatusForbidden)
			return
		}
		if token.User == 0 {
			panic("login token with zero user ID")
		}

		ctx = context.WithValue(ctx, contextKeyLoginToken, token)
		req = req.WithContext(ctx)
		next.ServeHTTP(w, req)
	})
}