Stateless auth proxy that converts AT Protocol native apps from public to confidential OAuth clients. Deploy once, get 180-day refresh tokens instead of 24-hour ones.
1package main
2
3import (
4 "encoding/json"
5 "log"
6 "net/http"
7 "net/url"
8)
9
10type tokenRequest struct {
11 TokenEndpoint string `json:"token_endpoint"`
12 Issuer string `json:"issuer"`
13 KeyID string `json:"key_id,omitempty"`
14 GrantType string `json:"grant_type"`
15 Code string `json:"code,omitempty"`
16 RedirectURI string `json:"redirect_uri,omitempty"`
17 CodeVerifier string `json:"code_verifier,omitempty"`
18 RefreshToken string `json:"refresh_token,omitempty"`
19}
20
21func HandleToken(signers *SignerSet, clientID string) http.HandlerFunc {
22 return func(w http.ResponseWriter, r *http.Request) {
23 var req tokenRequest
24 if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
25 writeJSONError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
26 return
27 }
28
29 if req.TokenEndpoint == "" {
30 writeJSONError(w, http.StatusBadRequest, "invalid_request", "token_endpoint is required")
31 return
32 }
33 if req.Issuer == "" {
34 writeJSONError(w, http.StatusBadRequest, "invalid_request", "issuer is required")
35 return
36 }
37 if req.GrantType == "" {
38 writeJSONError(w, http.StatusBadRequest, "invalid_request", "grant_type is required")
39 return
40 }
41
42 if err := ValidateTokenEndpointForIssuer(r.Context(), req.Issuer, req.TokenEndpoint); err != nil {
43 writeAPIError(w, err)
44 return
45 }
46
47 candidateKeyIDs, err := signers.CandidateKeyIDs(req.KeyID)
48 if err != nil {
49 writeJSONError(w, http.StatusBadRequest, "invalid_request", err.Error())
50 return
51 }
52
53 params := url.Values{}
54 params.Set("grant_type", req.GrantType)
55 params.Set("client_id", clientID)
56
57 if req.Code != "" {
58 params.Set("code", req.Code)
59 }
60 if req.RedirectURI != "" {
61 params.Set("redirect_uri", req.RedirectURI)
62 }
63 if req.CodeVerifier != "" {
64 params.Set("code_verifier", req.CodeVerifier)
65 }
66 if req.RefreshToken != "" {
67 params.Set("refresh_token", req.RefreshToken)
68 }
69
70 dpopHeader := r.Header.Get("DPoP")
71
72 var proxied *upstreamResponse
73 var usedKeyID string
74
75 for i, keyID := range candidateKeyIDs {
76 signer, err := signers.Lookup(keyID)
77 if err != nil {
78 writeJSONError(w, http.StatusBadRequest, "invalid_request", err.Error())
79 return
80 }
81
82 assertion, err := GenerateClientAssertion(signer, clientID, req.Issuer)
83 if err != nil {
84 log.Printf("failed to generate client assertion: %v", err)
85 writeJSONError(w, http.StatusInternalServerError, "server_error", "failed to generate client assertion")
86 return
87 }
88
89 attemptParams := cloneValues(params)
90 attemptParams.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
91 attemptParams.Set("client_assertion", assertion)
92
93 proxied, err = PostForm(r.Context(), req.TokenEndpoint, attemptParams, dpopHeader)
94 if err != nil {
95 log.Printf("proxy request failed: %v", err)
96 writeAPIError(w, upstreamRequestError("upstream request failed"))
97 return
98 }
99
100 usedKeyID = keyID
101 if req.KeyID == "" && i < len(candidateKeyIDs)-1 && isInvalidClientResponse(proxied) {
102 continue
103 }
104
105 break
106 }
107
108 w.Header().Set(authProxyKeyIDHeader, usedKeyID)
109 if err := WriteProxiedResponse(w, proxied); err != nil {
110 log.Printf("failed to write proxied response: %v", err)
111 }
112 }
113}
114
115func isInvalidClientResponse(resp *upstreamResponse) bool {
116 if resp == nil {
117 return false
118 }
119 if resp.statusCode != http.StatusBadRequest && resp.statusCode != http.StatusUnauthorized {
120 return false
121 }
122
123 var payload struct {
124 Error string `json:"error"`
125 }
126 if err := json.Unmarshal(resp.body, &payload); err != nil {
127 return false
128 }
129
130 return payload.Error == "invalid_client"
131}
132
133func cloneValues(values url.Values) url.Values {
134 cloned := make(url.Values, len(values))
135 for key, entries := range values {
136 cloned[key] = append([]string(nil), entries...)
137 }
138 return cloned
139}