Lindenii Project Forge
Login

server

Vireo IdP server

Hi… I am well aware that this diff view is very suboptimal. It will be fixed when the refactored server comes along!

Commit info
ID
6cf70981de33cee2d88e38e18f9fb33238fc4507
Author
Author date
Sun, 18 Feb 2024 22:10:08 +0100
Committer
Committer date
Sun, 18 Feb 2024 22:10:08 +0100
Actions
Use chi
module git.sr.ht/~emersion/sinwon

go 1.18

require git.sr.ht/~emersion/go-oauth2 v0.0.0-20240217160856-2e0d6e20b088

require github.com/mattn/go-sqlite3 v1.14.22

require golang.org/x/crypto v0.19.0

require github.com/go-chi/chi/v5 v5.0.12
git.sr.ht/~emersion/go-oauth2 v0.0.0-20240217160856-2e0d6e20b088 h1:KuPliLD8CQM1WbCHdjHR6mhadIzLaAJCNENmvB1y9gs=
git.sr.ht/~emersion/go-oauth2 v0.0.0-20240217160856-2e0d6e20b088/go.mod h1:VHj0jSCLIkrfEwmOvJ4+ykpoVbD/YLN7BM523oKKBHc=
github.com/go-chi/chi/v5 v5.0.12 h1:9euLV5sTrTNTRUU9POmDUvfxyj6LAABLUcEWO+JJb4s=
github.com/go-chi/chi/v5 v5.0.12/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
package main

import (
	"context"
	"embed"
	"encoding/json"
	"errors"
	"flag"
	"fmt"
	"html/template"
	"io"
	"log"
	"mime"
	"net"
	"net/http"
	"net/url"
	"strings"
	"time"

	"git.sr.ht/~emersion/go-oauth2"
	"github.com/go-chi/chi/v5"
)

var (
	//go:embed template
	templateFS embed.FS
	//go:embed static
	staticFS embed.FS
)

func main() {
	var listenAddr string
	flag.StringVar(&listenAddr, "listen", ":8080", "HTTP listen address")
	flag.Parse()

	tpl := template.Must(template.ParseFS(templateFS, "template/*.html"))

	db, err := openDB("sinwon.db")
	if err != nil {
		log.Fatalf("Failed to open DB: %v", err)
	}

	mux := http.NewServeMux()
	mux := chi.NewRouter()

	mux.Handle("/static/*", http.FileServer(http.FS(staticFS)))

	mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
	mux.Get("/", func(w http.ResponseWriter, req *http.Request) {
		ctx := req.Context()

		loginToken := loginTokenFromContext(ctx)
		if loginToken == nil {
			http.Redirect(w, req, "/login", http.StatusFound)
			return
		}

		clients, err := db.ListClients(ctx, loginToken.User)
		if err != nil {
			httpError(w, err)
			return
		}

		if err := tpl.ExecuteTemplate(w, "index.html", clients); err != nil {
			panic(err)
		}
	})

	mux.HandleFunc("/client/new", func(w http.ResponseWriter, req *http.Request) {
	mux.Post("/client/new", func(w http.ResponseWriter, req *http.Request) {
		ctx := req.Context()

		loginToken := loginTokenFromContext(ctx)
		if loginToken == nil {
			http.Redirect(w, req, "/login", http.StatusFound)
			return
		}

		client, clientSecret, err := NewClient(loginToken.User)
		if err != nil {
			httpError(w, err)
			return
		}
		if err := db.StoreClient(ctx, client); err != nil {
			httpError(w, err)
			return
		}

		data := struct {
			ClientID     string
			ClientSecret string
		}{
			ClientID:     client.ClientID,
			ClientSecret: clientSecret,
		}
		if err := tpl.ExecuteTemplate(w, "client-secret.html", &data); err != nil {
			panic(err)
		}
	})

	mux.HandleFunc("/login", func(w http.ResponseWriter, req *http.Request) {
		ctx := req.Context()

		q := req.URL.Query()
		rawRedirectURI := q.Get("redirect_uri")
		if rawRedirectURI == "" {
			rawRedirectURI = "/"
		}

		redirectURI, err := url.Parse(rawRedirectURI)
		if err != nil || redirectURI.Scheme != "" || redirectURI.Opaque != "" || redirectURI.User != nil || redirectURI.Host != "" {
			http.Error(w, "Invalid redirect URI", http.StatusBadRequest)
			return
		}

		if loginTokenFromContext(ctx) != nil {
			http.Redirect(w, req, redirectURI.String(), http.StatusFound)
			return
		}

		username := req.PostFormValue("username")
		password := req.PostFormValue("password")
		if username == "" {
			if err := tpl.ExecuteTemplate(w, "login.html", nil); err != nil {
				panic(err)
			}
			return
		}

		user, err := db.FetchUser(ctx, username)
		if err != nil && err != errNoDBRows {
			httpError(w, fmt.Errorf("failed to fetch user: %v", err))
			return
		}
		if err == nil {
			err = user.VerifyPassword(password)
		}
		if err != nil {
			log.Printf("login failed for user %q: %v", username, err)
			// TODO: show error message
			if err := tpl.ExecuteTemplate(w, "login.html", nil); err != nil {
				panic(err)
			}
			return
		}

		token := AccessToken{
			User:  user.ID,
			Scope: internalTokenScope,
		}
		secret, err := token.Generate()
		if err != nil {
			httpError(w, fmt.Errorf("failed to generate access token: %v", err))
			return
		}
		if err := db.CreateAccessToken(ctx, &token); err != nil {
			httpError(w, fmt.Errorf("failed to create access token: %v", err))
			return
		}

		setLoginTokenCookie(w, &token, secret)
		http.Redirect(w, req, redirectURI.String(), http.StatusFound)
	})

	mux.HandleFunc("/authorize", func(w http.ResponseWriter, req *http.Request) {
		ctx := req.Context()

		q := req.URL.Query()
		respType := oauth2.ResponseType(q.Get("response_type"))
		clientID := q.Get("client_id")
		rawRedirectURI := q.Get("redirect_uri")
		scope := q.Get("scope")
		state := q.Get("state")

		if clientID == "" {
			http.Error(w, "Missing client ID", http.StatusBadRequest)
			return
		}

		client, err := db.FetchClient(ctx, clientID)
		if err == errNoDBRows {
			http.Error(w, "Invalid client ID", http.StatusForbidden)
			return
		} else if err != nil {
			httpError(w, fmt.Errorf("failed to fetch client: %v", err))
			return
		}

		// TODO: validate redirect URI with client
		// TODO: make redirect URI optional
		redirectURI, err := url.Parse(rawRedirectURI)
		if err != nil {
			http.Error(w, "Invalid redirect URI", http.StatusBadRequest)
			return
		}

		if respType != oauth2.ResponseTypeCode {
			redirectClientError(w, req, redirectURI, state, &oauth2.Error{
				Code: oauth2.ErrorCodeUnsupportedResponseType,
			})
			return
		}

		// TODO: add support for scope
		if scope != "" {
			redirectClientError(w, req, redirectURI, state, &oauth2.Error{
				Code: oauth2.ErrorCodeInvalidScope,
			})
			return
		}

		loginToken := loginTokenFromContext(ctx)
		if loginToken == nil {
			q := make(url.Values)
			q.Set("redirect_uri", req.URL.String())
			u := url.URL{
				Path:     "/login",
				RawQuery: q.Encode(),
			}
			http.Redirect(w, req, u.String(), http.StatusFound)
			return
		}

		_ = req.ParseForm()
		if _, ok := req.PostForm["authorize"]; !ok {
			if err := tpl.ExecuteTemplate(w, "authorize.html", nil); err != nil {
				panic(err)
			}
			return
		}

		authCode, secret, err := NewAuthCode(loginToken.User, client.ID, scope)
		if err != nil {
			httpError(w, fmt.Errorf("failed to generate authentication code: %v", err))
			return
		}

		if err := db.CreateAuthCode(ctx, authCode); err != nil {
			httpError(w, fmt.Errorf("failed to create authentication code: %v", err))
			return
		}

		code := MarshalSecret(authCode.ID, secret)

		values := make(url.Values)
		values.Set("code", code)
		if state != "" {
			values.Set("state", state)
		}
		redirectClient(w, req, redirectURI, values)
	})

	mux.HandleFunc("/token", func(w http.ResponseWriter, req *http.Request) {
	mux.Post("/token", func(w http.ResponseWriter, req *http.Request) {
		ctx := req.Context()

		values, err := parseRequestBody(req)
		if err != nil {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidRequest,
				Description: err.Error(),
			})
			return
		}

		clientID := values.Get("client_id")
		grantType := oauth2.GrantType(values.Get("grant_type"))
		scope := values.Get("scope")

		authClientID, clientSecret, _ := req.BasicAuth()
		if clientID == "" && authClientID == "" {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidRequest,
				Description: "Missing client ID",
			})
			return
		} else if clientID == "" {
			clientID = authClientID
		} else if clientID != authClientID {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidRequest,
				Description: "Client ID in request body doesn't match Authorization header field",
			})
			return
		}

		client, err := db.FetchClient(ctx, clientID)
		if err == errNoDBRows {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidClient,
				Description: "Invalid client ID",
			})
			return
		} else if err != nil {
			oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
			return
		}

		if client.ClientSecretHash != nil {
			if !client.VerifySecret(clientSecret) {
				oauthError(w, &oauth2.Error{
					Code:        oauth2.ErrorCodeAccessDenied,
					Description: "Invalid client secret",
				})
				return
			}
		}

		if grantType != oauth2.GrantTypeAuthorizationCode {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeUnsupportedGrantType,
				Description: "Unsupported grant type",
			})
			return
		}

		codeID, codeSecret, _ := UnmarshalSecret[*AuthCode](values.Get("code"))
		authCode, err := db.PopAuthCode(ctx, codeID)
		if err == errNoDBRows || (err == nil && !authCode.VerifySecret(codeSecret)) || authCode.Client != client.ID {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid authorization code",
			})
			return
		} else if err != nil {
			oauthError(w, fmt.Errorf("failed to fetch authorization code: %v", err))
			return
		}

		if scope != authCode.Scope {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid scope",
			})
			return
		}

		// TODO: check redirect_uri

		token, secret, err := NewAccessTokenFromAuthCode(authCode)
		if err != nil {
			oauthError(w, err)
			return
		}

		if err := db.CreateAccessToken(ctx, token); err != nil {
			oauthError(w, fmt.Errorf("failed to create access token: %v", err))
			return
		}

		w.Header().Set("Content-Type", "application/json")
		w.Header().Set("Cache-Control", "no-store")
		json.NewEncoder(w).Encode(&oauth2.TokenResp{
			AccessToken: secret,
			TokenType:   oauth2.TokenTypeBearer,
			ExpiresIn:   time.Until(token.ExpiresAt),
			Scope:       strings.Split(token.Scope, " "),
		})
	})

	server := http.Server{
		Addr:    listenAddr,
		Handler: loginTokenMiddleware(mux),
		BaseContext: func(net.Listener) context.Context {
			return newBaseContext(db)
		},
	}
	log.Printf("OAuth server listening on %v", server.Addr)
	if err := server.ListenAndServe(); err != nil {
		log.Fatalf("Failed to listen and serve: %v", err)
	}
}

func parseRequestBody(req *http.Request) (url.Values, error) {
	ct := req.Header.Get("Content-Type")
	if ct != "" {
		mimeType, _, err := mime.ParseMediaType(req.Header.Get("Content-Type"))
		if err != nil {
			return nil, fmt.Errorf("malformed Content-Type header field")
		} else if mimeType != "application/x-www-form-urlencoded" {
			return nil, fmt.Errorf("unsupported request content type")
		}
	}

	r := io.LimitReader(req.Body, 10<<20)
	b, err := io.ReadAll(r)
	if err != nil {
		return nil, fmt.Errorf("failed to read request body: %v", err)
	}

	values, err := url.ParseQuery(string(b))
	if err != nil {
		return nil, fmt.Errorf("failed to parse request body: %v", err)
	}

	return values, nil
}

func httpError(w http.ResponseWriter, err error) {
	log.Print(err)
	http.Error(w, "Internal server error", http.StatusInternalServerError)
}

func oauthError(w http.ResponseWriter, err error) {
	var oauthErr *oauth2.Error
	if !errors.As(err, &oauthErr) {
		oauthErr = &oauth2.Error{Code: oauth2.ErrorCodeServerError}
		log.Print(err)
	}

	statusCode := http.StatusInternalServerError
	switch oauthErr.Code {
	case oauth2.ErrorCodeInvalidRequest, oauth2.ErrorCodeUnsupportedResponseType, oauth2.ErrorCodeInvalidScope, oauth2.ErrorCodeInvalidClient, oauth2.ErrorCodeInvalidGrant, oauth2.ErrorCodeUnsupportedGrantType:
		statusCode = http.StatusBadRequest
	case oauth2.ErrorCodeUnauthorizedClient, oauth2.ErrorCodeAccessDenied:
		statusCode = http.StatusForbidden
	case oauth2.ErrorCodeTemporarilyUnavailable:
		statusCode = http.StatusServiceUnavailable
	}

	w.Header().Set("Content-Type", "application/json")
	w.WriteHeader(statusCode)
	json.NewEncoder(w).Encode(oauthErr)
}

func redirectClient(w http.ResponseWriter, req *http.Request, redirectURI *url.URL, values url.Values) {
	q := redirectURI.Query()
	for k, v := range values {
		q[k] = v
	}

	u := *redirectURI
	u.RawQuery = q.Encode()

	http.Redirect(w, req, u.String(), http.StatusFound)
}

func redirectClientError(w http.ResponseWriter, req *http.Request, redirectURI *url.URL, state string, err error) {
	var oauthErr *oauth2.Error
	if !errors.As(err, &oauthErr) {
		oauthErr = &oauth2.Error{Code: oauth2.ErrorCodeServerError}
		log.Print(err)
	}

	values := make(url.Values)
	values.Set("error", string(oauthErr.Code))
	if oauthErr.Description != "" {
		values.Set("error_description", oauthErr.Description)
	}
	if oauthErr.URI != "" {
		values.Set("error_uri", oauthErr.URI)
	}
	if state != "" {
		values.Set("state", state)
	}
	redirectClient(w, req, redirectURI, values)
}