fork of haileyok/atproto-oauth-golang
1package oauth
2
3import (
4 "context"
5 "crypto/ecdsa"
6 "encoding/json"
7 "fmt"
8 "io"
9 "net/http"
10 "net/url"
11 "strings"
12 "time"
13
14 "github.com/golang-jwt/jwt/v5"
15 "github.com/google/uuid"
16 "github.com/lestrrat-go/jwx/v2/jwk"
17 "tangled.org/anirudh.fi/atproto-oauth/helpers"
18 internal_helpers "tangled.org/anirudh.fi/atproto-oauth/internal/helpers"
19)
20
21type Client struct {
22 h *http.Client
23 clientPrivateKey *ecdsa.PrivateKey
24 clientKid string
25 clientId string
26 redirectUri string
27 insecure bool
28}
29
30type ClientArgs struct {
31 Http *http.Client
32 ClientJwk jwk.Key
33 ClientId string
34 RedirectUri string
35 Insecure bool
36}
37
38func NewClient(args ClientArgs) (*Client, error) {
39 if args.ClientId == "" {
40 return nil, fmt.Errorf("no client id provided")
41 }
42
43 if args.RedirectUri == "" {
44 return nil, fmt.Errorf("no redirect uri provided")
45 }
46
47 if args.Http == nil {
48 args.Http = &http.Client{
49 Timeout: 5 * time.Second,
50 }
51 }
52
53 clientPkey, err := helpers.GetPrivateKey(args.ClientJwk)
54 if err != nil {
55 return nil, fmt.Errorf("could not load private key from provided client jwk: %w", err)
56 }
57
58 kid := args.ClientJwk.KeyID()
59
60 return &Client{
61 h: args.Http,
62 clientKid: kid,
63 clientPrivateKey: clientPkey,
64 clientId: args.ClientId,
65 redirectUri: args.RedirectUri,
66 insecure: args.Insecure,
67 }, nil
68}
69
70func (c *Client) ResolvePdsAuthServer(ctx context.Context, ustr string) (string, error) {
71 u, err := helpers.IsUrlSafeAndParsed(ustr, c.insecure)
72 if err != nil {
73 return "", err
74 }
75
76 u.Path = "/.well-known/oauth-protected-resource"
77
78 req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
79 if err != nil {
80 return "", fmt.Errorf("error creating request for oauth protected resource: %w", err)
81 }
82
83 resp, err := c.h.Do(req)
84 if err != nil {
85 return "", fmt.Errorf("could not get response from server: %w", err)
86 }
87 defer resp.Body.Close()
88
89 if resp.StatusCode != http.StatusOK {
90 io.Copy(io.Discard, resp.Body)
91 return "", fmt.Errorf("received non-200 response from pds. code was %d", resp.StatusCode)
92 }
93
94 b, err := io.ReadAll(resp.Body)
95 if err != nil {
96 return "", fmt.Errorf("could not read body: %w", err)
97 }
98
99 var resource OauthProtectedResource
100 if err := resource.UnmarshalJSON(b); err != nil {
101 return "", fmt.Errorf("could not unmarshal json: %w", err)
102 }
103
104 if len(resource.AuthorizationServers) == 0 {
105 return "", fmt.Errorf("oauth protected resource contained no authorization servers")
106 }
107
108 return resource.AuthorizationServers[0], nil
109}
110
111func (c *Client) FetchAuthServerMetadata(ctx context.Context, ustr string) (*OauthAuthorizationMetadata, error) {
112 u, err := helpers.IsUrlSafeAndParsed(ustr, c.insecure)
113 if err != nil {
114 return nil, err
115 }
116
117 u.Path = "/.well-known/oauth-authorization-server"
118
119 req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
120 if err != nil {
121 return nil, fmt.Errorf("error creating request to fetch auth metadata: %w", err)
122 }
123
124 resp, err := c.h.Do(req)
125 if err != nil {
126 return nil, fmt.Errorf("error getting response for authserver metadata: %w", err)
127 }
128 defer resp.Body.Close()
129
130 if resp.StatusCode != http.StatusOK {
131 io.Copy(io.Discard, resp.Body)
132 return nil, fmt.Errorf("received non-200 response from pds. status code was %d", resp.StatusCode)
133 }
134
135 b, err := io.ReadAll(resp.Body)
136 if err != nil {
137 return nil, fmt.Errorf("could not read body for authserver metadata response: %w", err)
138 }
139
140 var metadata OauthAuthorizationMetadata
141 if err := metadata.UnmarshalJSON(b); err != nil {
142 return nil, fmt.Errorf("could not unmarshal authserver metadata: %w", err)
143 }
144
145 if err := metadata.Validate(u, c.insecure); err != nil {
146 return nil, fmt.Errorf("could not validate authserver metadata: %w", err)
147 }
148
149 return &metadata, nil
150}
151
152func (c *Client) ClientAssertionJwt(authServerUrl string) (string, error) {
153 claims := jwt.MapClaims{
154 "iss": c.clientId,
155 "sub": c.clientId,
156 "aud": authServerUrl,
157 "jti": uuid.NewString(),
158 "iat": time.Now().Add(-5 * time.Second).Unix(),
159 }
160
161 token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
162 token.Header["kid"] = c.clientKid
163
164 tokenString, err := token.SignedString(c.clientPrivateKey)
165 if err != nil {
166 return "", err
167 }
168
169 return tokenString, nil
170}
171
172func (c *Client) AuthServerDpopJwt(method, url, nonce string, privateJwk jwk.Key) (string, error) {
173 pubJwk, err := privateJwk.PublicKey()
174 if err != nil {
175 return "", err
176 }
177
178 b, err := json.Marshal(pubJwk)
179 if err != nil {
180 return "", err
181 }
182
183 var pubMap map[string]any
184 if err := json.Unmarshal(b, &pubMap); err != nil {
185 return "", err
186 }
187
188 now := time.Now().Unix()
189
190 claims := jwt.MapClaims{
191 "jti": uuid.NewString(),
192 "htm": method,
193 "htu": url,
194 "iat": now,
195 "exp": now + 30,
196 }
197
198 if nonce != "" {
199 claims["nonce"] = nonce
200 }
201
202 token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
203 token.Header["typ"] = "dpop+jwt"
204 token.Header["alg"] = "ES256"
205 token.Header["jwk"] = pubMap
206
207 var rawKey any
208 if err := privateJwk.Raw(&rawKey); err != nil {
209 return "", err
210 }
211
212 tokenString, err := token.SignedString(rawKey)
213 if err != nil {
214 return "", fmt.Errorf("failed to sign token: %w", err)
215 }
216
217 return tokenString, nil
218}
219
220func (c *Client) SendParAuthRequest(ctx context.Context, authServerUrl string, authServerMeta *OauthAuthorizationMetadata, loginHint, scope string, dpopPrivateKey jwk.Key) (*SendParAuthResponse, error) {
221 if authServerMeta == nil {
222 return nil, fmt.Errorf("nil metadata provided")
223 }
224
225 parUrl := authServerMeta.PushedAuthorizationRequestEndpoint
226
227 state, err := internal_helpers.GenerateToken(10)
228 if err != nil {
229 return nil, fmt.Errorf("could not generate state token: %w", err)
230 }
231
232 pkceVerifier, err := internal_helpers.GenerateToken(48)
233 if err != nil {
234 return nil, fmt.Errorf("could not generate pkce verifier: %w", err)
235 }
236
237 codeChallenge := internal_helpers.GenerateCodeChallenge(pkceVerifier)
238 codeChallengeMethod := "S256"
239
240 clientAssertion, err := c.ClientAssertionJwt(authServerUrl)
241 if err != nil {
242 return nil, fmt.Errorf("error getting client assertion: %w", err)
243 }
244
245 dpopAuthserverNonce := ""
246 dpopProof, err := c.AuthServerDpopJwt("POST", parUrl, dpopAuthserverNonce, dpopPrivateKey)
247 if err != nil {
248 return nil, fmt.Errorf("error getting dpop proof: %w", err)
249 }
250
251 params := url.Values{
252 "response_type": {"code"},
253 "code_challenge": {codeChallenge},
254 "code_challenge_method": {codeChallengeMethod},
255 "client_id": {c.clientId},
256 "state": {state},
257 "redirect_uri": {c.redirectUri},
258 "scope": {scope},
259 "client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"},
260 "client_assertion": {clientAssertion},
261 }
262
263 if loginHint != "" {
264 params.Set("login_hint", loginHint)
265 }
266
267 _, err = helpers.IsUrlSafeAndParsed(parUrl, c.insecure)
268 if err != nil {
269 return nil, err
270 }
271
272 req, err := http.NewRequestWithContext(ctx, "POST", parUrl, strings.NewReader(params.Encode()))
273 if err != nil {
274 return nil, err
275 }
276
277 req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
278 req.Header.Set("DPoP", dpopProof)
279
280 resp, err := c.h.Do(req)
281 if err != nil {
282 return nil, err
283 }
284 defer resp.Body.Close()
285
286 var rmap map[string]any
287 if err := json.NewDecoder(resp.Body).Decode(&rmap); err != nil {
288 return nil, err
289 }
290
291 if resp.StatusCode != 201 {
292 if resp.StatusCode == 400 && rmap["error"] == "use_dpop_nonce" {
293 dpopAuthserverNonce = resp.Header.Get("DPoP-Nonce")
294 dpopProof, err := c.AuthServerDpopJwt("POST", parUrl, dpopAuthserverNonce, dpopPrivateKey)
295 if err != nil {
296 return nil, err
297 }
298
299 req2, err := http.NewRequestWithContext(
300 ctx,
301 "POST",
302 parUrl,
303 strings.NewReader(params.Encode()),
304 )
305 if err != nil {
306 return nil, err
307 }
308
309 req2.Header.Set("Content-Type", "application/x-www-form-urlencoded")
310 req2.Header.Set("DPoP", dpopProof)
311
312 resp2, err := c.h.Do(req2)
313 if err != nil {
314 return nil, err
315 }
316 defer resp2.Body.Close()
317
318 rmap = map[string]any{}
319 if err := json.NewDecoder(resp2.Body).Decode(&rmap); err != nil {
320 return nil, err
321 }
322
323 if resp2.StatusCode != 201 {
324 return nil, fmt.Errorf("received error from server when submitting par request: %v", rmap)
325 }
326 } else {
327 return nil, fmt.Errorf("received error from server when submitting par request: %v", rmap)
328 }
329 }
330
331 return &SendParAuthResponse{
332 PkceVerifier: pkceVerifier,
333 State: state,
334 DpopAuthserverNonce: dpopAuthserverNonce,
335 ExpiresIn: rmap["expires_in"].(float64),
336 RequestUri: rmap["request_uri"].(string),
337 }, nil
338}
339
340func (c *Client) InitialTokenRequest(
341 ctx context.Context,
342 code,
343 authserverIss,
344 pkceVerifier,
345 dpopAuthserverNonce string,
346 dpopPrivateJwk jwk.Key,
347) (*TokenResponse, error) {
348 // we might need to re-run to update dpop nonce
349 for range 2 {
350 authserverMeta, err := c.FetchAuthServerMetadata(ctx, authserverIss)
351 if err != nil {
352 return nil, err
353 }
354
355 clientAssertion, err := c.ClientAssertionJwt(authserverIss)
356 if err != nil {
357 return nil, err
358 }
359
360 params := url.Values{
361 "client_id": {c.clientId},
362 "redirect_uri": {c.redirectUri},
363 "grant_type": {"authorization_code"},
364 "code": {code},
365 "code_verifier": {pkceVerifier},
366 "client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"},
367 "client_assertion": {clientAssertion},
368 }
369
370 dpopProof, err := c.AuthServerDpopJwt("POST", authserverMeta.TokenEndpoint, dpopAuthserverNonce, dpopPrivateJwk)
371 if err != nil {
372 return nil, err
373 }
374
375 req, err := http.NewRequestWithContext(ctx, "POST", authserverMeta.TokenEndpoint, strings.NewReader(params.Encode()))
376 if err != nil {
377 return nil, err
378 }
379
380 req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
381 req.Header.Set("DPoP", dpopProof)
382
383 resp, err := c.h.Do(req)
384 if err != nil {
385 return nil, err
386 }
387 defer resp.Body.Close()
388
389 if resp.StatusCode != 200 && resp.StatusCode != 201 {
390 var respMap map[string]string
391 if err := json.NewDecoder(resp.Body).Decode(&respMap); err != nil {
392 return nil, err
393 }
394
395 if resp.StatusCode == 400 && respMap["error"] == "use_dpop_nonce" {
396 dpopAuthserverNonce = resp.Header.Get("DPoP-Nonce")
397 continue
398 }
399
400 return nil, fmt.Errorf("token refresh error: %s", respMap["error"])
401 }
402
403 var tokenResponse TokenResponse
404 if err := json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil {
405 return nil, err
406 }
407
408 // set nonce so the updates are reflected in the response
409 tokenResponse.DpopAuthserverNonce = dpopAuthserverNonce
410
411 return &tokenResponse, nil
412 }
413
414 return nil, nil
415}
416
417func (c *Client) RefreshTokenRequest(
418 ctx context.Context,
419 refreshToken,
420 authserverIss,
421 dpopAuthserverNonce string,
422 dpopPrivateJwk jwk.Key,
423) (*TokenResponse, error) {
424 // we may need to update the dpop nonce
425 for range 2 {
426 authserverMeta, err := c.FetchAuthServerMetadata(ctx, authserverIss)
427 if err != nil {
428 return nil, err
429 }
430
431 clientAssertion, err := c.ClientAssertionJwt(authserverIss)
432 if err != nil {
433 return nil, err
434 }
435
436 params := url.Values{
437 "client_id": {c.clientId},
438 "grant_type": {"refresh_token"},
439 "refresh_token": {refreshToken},
440 "client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"},
441 "client_assertion": {clientAssertion},
442 }
443
444 dpopProof, err := c.AuthServerDpopJwt("POST", authserverMeta.TokenEndpoint, dpopAuthserverNonce, dpopPrivateJwk)
445 if err != nil {
446 return nil, err
447 }
448
449 req, err := http.NewRequestWithContext(ctx, "POST", authserverMeta.TokenEndpoint, strings.NewReader(params.Encode()))
450 if err != nil {
451 return nil, err
452 }
453
454 req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
455 req.Header.Set("DPoP", dpopProof)
456
457 resp, err := c.h.Do(req)
458 if err != nil {
459 return nil, err
460 }
461 defer resp.Body.Close()
462
463 if resp.StatusCode != 200 && resp.StatusCode != 201 {
464 var respMap map[string]string
465 if err := json.NewDecoder(resp.Body).Decode(&respMap); err != nil {
466 return nil, err
467 }
468
469 if resp.StatusCode == 400 && respMap["error"] == "use_dpop_nonce" {
470 dpopAuthserverNonce = resp.Header.Get("DPoP-Nonce")
471 continue
472 }
473
474 return nil, fmt.Errorf("token refresh error: %s", respMap["error"])
475 }
476
477 var tokenResponse TokenResponse
478 if err := json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil {
479 return nil, err
480 }
481
482 // set the nonce so that updates are reflected in response
483 tokenResponse.DpopAuthserverNonce = dpopAuthserverNonce
484
485 return &tokenResponse, nil
486 }
487
488 return nil, nil
489}