Lindenii Project Forge
Login

server

Vireo IdP server

Hi… I am well aware that this diff view is very suboptimal. It will be fixed when the refactored server comes along!

Commit info
ID
4073e1e0d7d2f66b646cb4949bec53c199f87db7
Author
Author date
Mon, 19 Feb 2024 10:52:45 +0100
Committer
Committer date
Mon, 19 Feb 2024 10:52:45 +0100
Actions
Show client ID on authorize page
package main

import (
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"log"
	"mime"
	"net"
	"net/http"
	"net/url"
	"strings"
	"time"

	"git.sr.ht/~emersion/go-oauth2"
)

func getOAuthServerMetadata(w http.ResponseWriter, req *http.Request) {
	issuerURL := url.URL{
		Scheme: "https",
		Host:   req.Host,
	}
	if !isForwardedHTTPS(req) && isLoopback(req) {
		// TODO: add config option for allowed reverse proxy IPs
		issuerURL.Scheme = "http"
	}
	issuer := issuerURL.String()

	w.Header().Set("Content-Type", "application/json")
	json.NewEncoder(w).Encode(&oauth2.ServerMetadata{
		Issuer:                            issuer,
		AuthorizationEndpoint:             issuer + "/authorize",
		TokenEndpoint:                     issuer + "/token",
		ResponseTypesSupported:            []oauth2.ResponseType{oauth2.ResponseTypeCode},
		ResponseModesSupported:            []oauth2.ResponseMode{oauth2.ResponseModeQuery},
		GrantTypesSupported:               []oauth2.GrantType{oauth2.GrantTypeAuthorizationCode},
		TokenEndpointAuthMethodsSupported: []oauth2.AuthMethod{oauth2.AuthMethodClientSecretBasic},
	})
}

func isLoopback(req *http.Request) bool {
	host, _, _ := net.SplitHostPort(req.RemoteAddr)
	ip := net.ParseIP(host)
	return ip.IsLoopback()
}

func authorize(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	db := dbFromContext(ctx)
	tpl := templateFromContext(ctx)

	q := req.URL.Query()
	respType := oauth2.ResponseType(q.Get("response_type"))
	clientID := q.Get("client_id")
	rawRedirectURI := q.Get("redirect_uri")
	scope := q.Get("scope")
	state := q.Get("state")

	if clientID == "" {
		http.Error(w, "Missing client ID", http.StatusBadRequest)
		return
	}

	client, err := db.FetchClient(ctx, clientID)
	if err == errNoDBRows {
		http.Error(w, "Invalid client ID", http.StatusForbidden)
		return
	} else if err != nil {
		httpError(w, fmt.Errorf("failed to fetch client: %v", err))
		return
	}

	// TODO: validate redirect URI with client
	// TODO: make redirect URI optional
	redirectURI, err := url.Parse(rawRedirectURI)
	if err != nil {
		http.Error(w, "Invalid redirect URI", http.StatusBadRequest)
		return
	}

	if respType != oauth2.ResponseTypeCode {
		redirectClientError(w, req, redirectURI, state, &oauth2.Error{
			Code: oauth2.ErrorCodeUnsupportedResponseType,
		})
		return
	}

	// TODO: add support for scope
	if scope != "" {
		redirectClientError(w, req, redirectURI, state, &oauth2.Error{
			Code: oauth2.ErrorCodeInvalidScope,
		})
		return
	}

	loginToken := loginTokenFromContext(ctx)
	if loginToken == nil {
		q := make(url.Values)
		q.Set("redirect_uri", req.URL.String())
		u := url.URL{
			Path:     "/login",
			RawQuery: q.Encode(),
		}
		http.Redirect(w, req, u.String(), http.StatusFound)
		return
	}

	_ = req.ParseForm()
	if _, ok := req.PostForm["authorize"]; !ok {
		if err := tpl.ExecuteTemplate(w, "authorize.html", nil); err != nil {
		data := struct {
			Client *Client
		}{
			Client: client,
		}
		if err := tpl.ExecuteTemplate(w, "authorize.html", data); err != nil {
			panic(err)
		}
		return
	}

	authCode, secret, err := NewAuthCode(loginToken.User, client.ID, scope)
	if err != nil {
		httpError(w, fmt.Errorf("failed to generate authentication code: %v", err))
		return
	}

	if err := db.CreateAuthCode(ctx, authCode); err != nil {
		httpError(w, fmt.Errorf("failed to create authentication code: %v", err))
		return
	}

	code := MarshalSecret(authCode.ID, secret)

	values := make(url.Values)
	values.Set("code", code)
	if state != "" {
		values.Set("state", state)
	}
	redirectClient(w, req, redirectURI, values)
}

func exchangeToken(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	db := dbFromContext(ctx)

	values, err := parseRequestBody(req)
	if err != nil {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidRequest,
			Description: err.Error(),
		})
		return
	}

	clientID := values.Get("client_id")
	grantType := oauth2.GrantType(values.Get("grant_type"))
	scope := values.Get("scope")

	authClientID, clientSecret, _ := req.BasicAuth()
	if clientID == "" && authClientID == "" {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidRequest,
			Description: "Missing client ID",
		})
		return
	} else if clientID == "" {
		clientID = authClientID
	} else if clientID != authClientID {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidRequest,
			Description: "Client ID in request body doesn't match Authorization header field",
		})
		return
	}

	client, err := db.FetchClient(ctx, clientID)
	if err == errNoDBRows {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidClient,
			Description: "Invalid client ID",
		})
		return
	} else if err != nil {
		oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
		return
	}

	if client.ClientSecretHash != nil {
		if !client.VerifySecret(clientSecret) {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid client secret",
			})
			return
		}
	}

	if grantType != oauth2.GrantTypeAuthorizationCode {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeUnsupportedGrantType,
			Description: "Unsupported grant type",
		})
		return
	}

	codeID, codeSecret, _ := UnmarshalSecret[*AuthCode](values.Get("code"))
	authCode, err := db.PopAuthCode(ctx, codeID)
	if err == errNoDBRows || (err == nil && !authCode.VerifySecret(codeSecret)) || authCode.Client != client.ID {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeAccessDenied,
			Description: "Invalid authorization code",
		})
		return
	} else if err != nil {
		oauthError(w, fmt.Errorf("failed to fetch authorization code: %v", err))
		return
	}

	if scope != authCode.Scope {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeAccessDenied,
			Description: "Invalid scope",
		})
		return
	}

	// TODO: check redirect_uri

	token, secret, err := NewAccessTokenFromAuthCode(authCode)
	if err != nil {
		oauthError(w, err)
		return
	}

	if err := db.CreateAccessToken(ctx, token); err != nil {
		oauthError(w, fmt.Errorf("failed to create access token: %v", err))
		return
	}

	w.Header().Set("Content-Type", "application/json")
	w.Header().Set("Cache-Control", "no-store")
	json.NewEncoder(w).Encode(&oauth2.TokenResp{
		AccessToken: secret,
		TokenType:   oauth2.TokenTypeBearer,
		ExpiresIn:   time.Until(token.ExpiresAt),
		Scope:       strings.Split(token.Scope, " "),
	})
}

func parseRequestBody(req *http.Request) (url.Values, error) {
	ct := req.Header.Get("Content-Type")
	if ct != "" {
		mimeType, _, err := mime.ParseMediaType(req.Header.Get("Content-Type"))
		if err != nil {
			return nil, fmt.Errorf("malformed Content-Type header field")
		} else if mimeType != "application/x-www-form-urlencoded" {
			return nil, fmt.Errorf("unsupported request content type")
		}
	}

	r := io.LimitReader(req.Body, 10<<20)
	b, err := io.ReadAll(r)
	if err != nil {
		return nil, fmt.Errorf("failed to read request body: %v", err)
	}

	values, err := url.ParseQuery(string(b))
	if err != nil {
		return nil, fmt.Errorf("failed to parse request body: %v", err)
	}

	return values, nil
}

func oauthError(w http.ResponseWriter, err error) {
	var oauthErr *oauth2.Error
	if !errors.As(err, &oauthErr) {
		oauthErr = &oauth2.Error{Code: oauth2.ErrorCodeServerError}
		log.Print(err)
	}

	statusCode := http.StatusInternalServerError
	switch oauthErr.Code {
	case oauth2.ErrorCodeInvalidRequest, oauth2.ErrorCodeUnsupportedResponseType, oauth2.ErrorCodeInvalidScope, oauth2.ErrorCodeInvalidClient, oauth2.ErrorCodeInvalidGrant, oauth2.ErrorCodeUnsupportedGrantType:
		statusCode = http.StatusBadRequest
	case oauth2.ErrorCodeUnauthorizedClient, oauth2.ErrorCodeAccessDenied:
		statusCode = http.StatusForbidden
	case oauth2.ErrorCodeTemporarilyUnavailable:
		statusCode = http.StatusServiceUnavailable
	}

	w.Header().Set("Content-Type", "application/json")
	w.WriteHeader(statusCode)
	json.NewEncoder(w).Encode(oauthErr)
}

func redirectClient(w http.ResponseWriter, req *http.Request, redirectURI *url.URL, values url.Values) {
	q := redirectURI.Query()
	for k, v := range values {
		q[k] = v
	}

	u := *redirectURI
	u.RawQuery = q.Encode()

	http.Redirect(w, req, u.String(), http.StatusFound)
}

func redirectClientError(w http.ResponseWriter, req *http.Request, redirectURI *url.URL, state string, err error) {
	var oauthErr *oauth2.Error
	if !errors.As(err, &oauthErr) {
		oauthErr = &oauth2.Error{Code: oauth2.ErrorCodeServerError}
		log.Print(err)
	}

	values := make(url.Values)
	values.Set("error", string(oauthErr.Code))
	if oauthErr.Description != "" {
		values.Set("error_description", oauthErr.Description)
	}
	if oauthErr.URI != "" {
		values.Set("error_uri", oauthErr.URI)
	}
	if state != "" {
		values.Set("state", state)
	}
	redirectClient(w, req, redirectURI, values)
}
{{ template "head.html" }}

<main>

<h1>sinwon</h1>

<p>Authorize client?</p>
<p>Authorize client <code>{{ .Client.ClientID }}</code>?</p>

<form method="post" action="">
	<button type="submit" name="authorize">Authorize</button>
</form>

</main>

{{ template "foot.html" }}