Hi… I am well aware that this diff view is very suboptimal. It will be fixed when the refactored server comes along!
Advertise "none" client auth This is used by public clients.
package main
import (
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"log"
	"mime"
	"net"
	"net/http"
	"net/url"
	"strings"
	"time"
	"git.sr.ht/~emersion/go-oauth2"
)
func getOAuthServerMetadata(w http.ResponseWriter, req *http.Request) {
	issuerURL := url.URL{
		Scheme: "https",
		Host:   req.Host,
	}
	if !isForwardedHTTPS(req) && isLoopback(req) {
		// TODO: add config option for allowed reverse proxy IPs
		issuerURL.Scheme = "http"
	}
	issuer := issuerURL.String()
	w.Header().Set("Content-Type", "application/json")
	json.NewEncoder(w).Encode(&oauth2.ServerMetadata{
		Issuer:                                    issuer,
		AuthorizationEndpoint:                     issuer + "/authorize",
		TokenEndpoint:                             issuer + "/token",
		IntrospectionEndpoint:                     issuer + "/introspect",
		RevocationEndpoint:                        issuer + "/revoke",
		ResponseTypesSupported:                    []oauth2.ResponseType{oauth2.ResponseTypeCode},
		ResponseModesSupported:                    []oauth2.ResponseMode{oauth2.ResponseModeQuery},
		GrantTypesSupported:                       []oauth2.GrantType{oauth2.GrantTypeAuthorizationCode},
		TokenEndpointAuthMethodsSupported:         []oauth2.AuthMethod{oauth2.AuthMethodClientSecretBasic},
		IntrospectionEndpointAuthMethodsSupported: []oauth2.AuthMethod{oauth2.AuthMethodClientSecretBasic},
		RevocationEndpointAuthMethodsSupported:    []oauth2.AuthMethod{oauth2.AuthMethodClientSecretBasic},
		TokenEndpointAuthMethodsSupported:         []oauth2.AuthMethod{oauth2.AuthMethodNone, oauth2.AuthMethodClientSecretBasic},
		IntrospectionEndpointAuthMethodsSupported: []oauth2.AuthMethod{oauth2.AuthMethodNone, oauth2.AuthMethodClientSecretBasic},
		RevocationEndpointAuthMethodsSupported:    []oauth2.AuthMethod{oauth2.AuthMethodNone, oauth2.AuthMethodClientSecretBasic},
	})
}
func isLoopback(req *http.Request) bool {
	host, _, _ := net.SplitHostPort(req.RemoteAddr)
	ip := net.ParseIP(host)
	return ip.IsLoopback()
}
func authorize(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	db := dbFromContext(ctx)
	tpl := templateFromContext(ctx)
	q := req.URL.Query()
	respType := oauth2.ResponseType(q.Get("response_type"))
	clientID := q.Get("client_id")
	rawRedirectURI := q.Get("redirect_uri")
	scope := q.Get("scope")
	state := q.Get("state")
	if clientID == "" {
		http.Error(w, "Missing client ID", http.StatusBadRequest)
		return
	}
	client, err := db.FetchClientByClientID(ctx, clientID)
	if err == errNoDBRows {
		http.Error(w, "Invalid client ID", http.StatusForbidden)
		return
	} else if err != nil {
		httpError(w, fmt.Errorf("failed to fetch client: %v", err))
		return
	}
	var allowedRedirectURIs []*url.URL
	for _, s := range strings.Split(client.RedirectURIs, "\n") {
		if s == "" {
			continue
		}
		u, err := url.Parse(s)
		if err != nil {
			httpError(w, fmt.Errorf("failed to parse client redirect URI"))
			return
		}
		allowedRedirectURIs = append(allowedRedirectURIs, u)
	}
	var redirectURI *url.URL
	if rawRedirectURI != "" {
		redirectURI, err = url.Parse(rawRedirectURI)
		if err != nil {
			http.Error(w, "Invalid redirect URI", http.StatusBadRequest)
			return
		}
		if !validateRedirectURI(redirectURI, allowedRedirectURIs) {
			http.Error(w, "Invalid redirect URI", http.StatusBadRequest)
			return
		}
	} else {
		if len(allowedRedirectURIs) == 0 {
			http.Error(w, "Missing redirect URI", http.StatusBadRequest)
			return
		}
		redirectURI = allowedRedirectURIs[0]
	}
	if respType != oauth2.ResponseTypeCode {
		redirectClientError(w, req, redirectURI, state, &oauth2.Error{
			Code: oauth2.ErrorCodeUnsupportedResponseType,
		})
		return
	}
	// TODO: add support for scope
	if scope != "" {
		redirectClientError(w, req, redirectURI, state, &oauth2.Error{
			Code: oauth2.ErrorCodeInvalidScope,
		})
		return
	}
	loginToken := loginTokenFromContext(ctx)
	if loginToken == nil {
		q := make(url.Values)
		q.Set("redirect_uri", req.URL.String())
		u := url.URL{
			Path:     "/login",
			RawQuery: q.Encode(),
		}
		http.Redirect(w, req, u.String(), http.StatusFound)
		return
	}
	_ = req.ParseForm()
	if _, ok := req.PostForm["deny"]; ok {
		redirectClientError(w, req, redirectURI, state, &oauth2.Error{
			Code: oauth2.ErrorCodeAccessDenied,
		})
		return
	}
	if _, ok := req.PostForm["authorize"]; !ok {
		data := struct {
			Client *Client
		}{
			Client: client,
		}
		if err := tpl.ExecuteTemplate(w, "authorize.html", data); err != nil {
			panic(err)
		}
		return
	}
	authCode, secret, err := NewAuthCode(loginToken.User, client.ID, scope)
	if err != nil {
		httpError(w, fmt.Errorf("failed to generate authentication code: %v", err))
		return
	}
	if err := db.CreateAuthCode(ctx, authCode); err != nil {
		httpError(w, fmt.Errorf("failed to create authentication code: %v", err))
		return
	}
	code := MarshalSecret(authCode.ID, secret)
	values := make(url.Values)
	values.Set("code", code)
	if state != "" {
		values.Set("state", state)
	}
	redirectClient(w, req, redirectURI, values)
}
func exchangeToken(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	db := dbFromContext(ctx)
	values, err := parseRequestBody(req)
	if err != nil {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidRequest,
			Description: err.Error(),
		})
		return
	}
	clientID := values.Get("client_id")
	grantType := oauth2.GrantType(values.Get("grant_type"))
	scope := values.Get("scope")
	authClientID, clientSecret, _ := req.BasicAuth()
	if clientID == "" && authClientID == "" {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidRequest,
			Description: "Missing client ID",
		})
		return
	} else if clientID == "" {
		clientID = authClientID
	} else if clientID != authClientID {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidRequest,
			Description: "Client ID in request body doesn't match Authorization header field",
		})
		return
	}
	client, err := db.FetchClientByClientID(ctx, clientID)
	if err == errNoDBRows {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidClient,
			Description: "Invalid client ID",
		})
		return
	} else if err != nil {
		oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
		return
	}
	if !client.IsPublic() {
		if !client.VerifySecret(clientSecret) {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid client secret",
			})
			return
		}
	}
	if grantType != oauth2.GrantTypeAuthorizationCode {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeUnsupportedGrantType,
			Description: "Unsupported grant type",
		})
		return
	}
	codeID, codeSecret, _ := UnmarshalSecret[*AuthCode](values.Get("code"))
	authCode, err := db.PopAuthCode(ctx, codeID)
	if err == errNoDBRows || (err == nil && !authCode.VerifySecret(codeSecret)) || authCode.Client != client.ID {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeAccessDenied,
			Description: "Invalid authorization code",
		})
		return
	} else if err != nil {
		oauthError(w, fmt.Errorf("failed to fetch authorization code: %v", err))
		return
	}
	if scope != authCode.Scope {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeAccessDenied,
			Description: "Invalid scope",
		})
		return
	}
	// TODO: check redirect_uri
	token, secret, err := NewAccessTokenFromAuthCode(authCode)
	if err != nil {
		oauthError(w, err)
		return
	}
	if err := db.CreateAccessToken(ctx, token); err != nil {
		oauthError(w, fmt.Errorf("failed to create access token: %v", err))
		return
	}
	w.Header().Set("Content-Type", "application/json")
	w.Header().Set("Cache-Control", "no-store")
	json.NewEncoder(w).Encode(&oauth2.TokenResp{
		AccessToken: MarshalSecret(token.ID, secret),
		TokenType:   oauth2.TokenTypeBearer,
		ExpiresIn:   time.Until(token.ExpiresAt),
		Scope:       strings.Split(token.Scope, " "),
	})
}
func introspectToken(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	db := dbFromContext(ctx)
	values, err := parseRequestBody(req)
	if err != nil {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidRequest,
			Description: err.Error(),
		})
		return
	}
	client, err := maybeAuthenticateClient(w, req)
	if err != nil {
		oauthError(w, err)
		return
	}
	tokenID, secret, _ := UnmarshalSecret[*AccessToken](values.Get("token"))
	token, err := db.FetchAccessToken(ctx, tokenID)
	if err == errNoDBRows || (err == nil && !token.VerifySecret(secret)) {
		token = nil
	} else if err != nil {
		oauthError(w, fmt.Errorf("failed to fetch access token: %v", err))
		return
	}
	var resp oauth2.IntrospectionResp
	if token != nil {
		if client == nil {
			client, err = db.FetchClient(ctx, token.Client)
			if err != nil {
				oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
				return
			}
			if !client.IsPublic() {
				oauthError(w, &oauth2.Error{
					Code:        oauth2.ErrorCodeInvalidClient,
					Description: "Missing client ID and secret",
				})
				return
			}
		}
		if client.ID != token.Client {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidClient,
				Description: "Invalid client ID or secret",
			})
			return
		}
		user, err := db.FetchUser(ctx, token.User)
		if err != nil {
			oauthError(w, fmt.Errorf("failed to fetch user: %v", err))
			return
		}
		resp.Active = true
		resp.TokenType = oauth2.TokenTypeBearer
		resp.ExpiresAt = token.ExpiresAt
		resp.IssuedAt = token.IssuedAt
		resp.ClientID = client.ClientID
		resp.Username = user.Username
	}
	w.Header().Set("Content-Type", "application/json")
	json.NewEncoder(w).Encode(&resp)
}
func revokeToken(w http.ResponseWriter, req *http.Request) {
	ctx := req.Context()
	db := dbFromContext(ctx)
	values, err := parseRequestBody(req)
	if err != nil {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidRequest,
			Description: err.Error(),
		})
		return
	}
	client, err := maybeAuthenticateClient(w, req)
	if err != nil {
		oauthError(w, err)
		return
	}
	tokenID, secret, _ := UnmarshalSecret[*AccessToken](values.Get("token"))
	token, err := db.FetchAccessToken(ctx, tokenID)
	if err == errNoDBRows || (err == nil && !token.VerifySecret(secret)) {
		return // ignore
	} else if err != nil {
		oauthError(w, fmt.Errorf("failed to fetch access token: %v", err))
		return
	}
	if client == nil {
		client, err = db.FetchClient(ctx, token.Client)
		if err != nil {
			oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
			return
		}
		if !client.IsPublic() {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidClient,
				Description: "Missing client ID and secret",
			})
			return
		}
	}
	if client.ID != token.Client {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidClient,
			Description: "Invalid client ID or secret",
		})
		return
	}
	if err := db.DeleteAccessToken(ctx, token.ID); err != nil {
		oauthError(w, err)
		return
	}
}
func parseRequestBody(req *http.Request) (url.Values, error) {
	ct := req.Header.Get("Content-Type")
	if ct != "" {
		mimeType, _, err := mime.ParseMediaType(req.Header.Get("Content-Type"))
		if err != nil {
			return nil, fmt.Errorf("malformed Content-Type header field")
		} else if mimeType != "application/x-www-form-urlencoded" {
			return nil, fmt.Errorf("unsupported request content type")
		}
	}
	r := io.LimitReader(req.Body, 10<<20)
	b, err := io.ReadAll(r)
	if err != nil {
		return nil, fmt.Errorf("failed to read request body: %v", err)
	}
	values, err := url.ParseQuery(string(b))
	if err != nil {
		return nil, fmt.Errorf("failed to parse request body: %v", err)
	}
	return values, nil
}
func oauthError(w http.ResponseWriter, err error) {
	var oauthErr *oauth2.Error
	if !errors.As(err, &oauthErr) {
		oauthErr = &oauth2.Error{Code: oauth2.ErrorCodeServerError}
		log.Print(err)
	}
	statusCode := http.StatusInternalServerError
	switch oauthErr.Code {
	case oauth2.ErrorCodeInvalidRequest, oauth2.ErrorCodeUnsupportedResponseType, oauth2.ErrorCodeInvalidScope, oauth2.ErrorCodeInvalidClient, oauth2.ErrorCodeInvalidGrant, oauth2.ErrorCodeUnsupportedGrantType:
		statusCode = http.StatusBadRequest
	case oauth2.ErrorCodeUnauthorizedClient, oauth2.ErrorCodeAccessDenied:
		statusCode = http.StatusForbidden
	case oauth2.ErrorCodeTemporarilyUnavailable:
		statusCode = http.StatusServiceUnavailable
	}
	w.Header().Set("Content-Type", "application/json")
	w.WriteHeader(statusCode)
	json.NewEncoder(w).Encode(oauthErr)
}
func redirectClient(w http.ResponseWriter, req *http.Request, redirectURI *url.URL, values url.Values) {
	q := redirectURI.Query()
	for k, v := range values {
		q[k] = v
	}
	u := *redirectURI
	u.RawQuery = q.Encode()
	http.Redirect(w, req, u.String(), http.StatusFound)
}
func redirectClientError(w http.ResponseWriter, req *http.Request, redirectURI *url.URL, state string, err error) {
	var oauthErr *oauth2.Error
	if !errors.As(err, &oauthErr) {
		oauthErr = &oauth2.Error{Code: oauth2.ErrorCodeServerError}
		log.Print(err)
	}
	values := make(url.Values)
	values.Set("error", string(oauthErr.Code))
	if oauthErr.Description != "" {
		values.Set("error_description", oauthErr.Description)
	}
	if oauthErr.URI != "" {
		values.Set("error_uri", oauthErr.URI)
	}
	if state != "" {
		values.Set("state", state)
	}
	redirectClient(w, req, redirectURI, values)
}
func validateRedirectURI(u *url.URL, allowedURIs []*url.URL) bool {
	// Loopback interface, see RFC 8252 section 7.3
	host, _, _ := net.SplitHostPort(u.Host)
	ip := net.ParseIP(host)
	if u.Scheme == "http" && ip.IsLoopback() {
		uu := *u
		uu.Host = "localhost"
		u = &uu
	}
	for _, allowed := range allowedURIs {
		if u.String() == allowed.String() {
			return true
		}
	}
	return false
}
func maybeAuthenticateClient(w http.ResponseWriter, req *http.Request) (*Client, error) {
	ctx := req.Context()
	db := dbFromContext(ctx)
	clientID, clientSecret, ok := req.BasicAuth()
	if !ok {
		return nil, nil
	}
	client, err := db.FetchClientByClientID(ctx, clientID)
	if err == errNoDBRows || (err == nil && !client.VerifySecret(clientSecret)) {
		return nil, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidClient,
			Description: "Invalid client ID or secret",
		}
	} else if err != nil {
		return nil, fmt.Errorf("failed to fetch client: %v", err)
	}
	return client, nil
}