fork of haileyok/atproto-oauth-golang
at main 13 kB view raw
1package oauth 2 3import ( 4 "context" 5 "crypto/ecdsa" 6 "encoding/json" 7 "fmt" 8 "io" 9 "net/http" 10 "net/url" 11 "strings" 12 "time" 13 14 "github.com/golang-jwt/jwt/v5" 15 "github.com/google/uuid" 16 "github.com/lestrrat-go/jwx/v2/jwk" 17 "tangled.org/anirudh.fi/atproto-oauth/helpers" 18 internal_helpers "tangled.org/anirudh.fi/atproto-oauth/internal/helpers" 19) 20 21type Client struct { 22 h *http.Client 23 clientPrivateKey *ecdsa.PrivateKey 24 clientKid string 25 clientId string 26 redirectUri string 27 insecure bool 28} 29 30type ClientArgs struct { 31 Http *http.Client 32 ClientJwk jwk.Key 33 ClientId string 34 RedirectUri string 35 Insecure bool 36} 37 38func NewClient(args ClientArgs) (*Client, error) { 39 if args.ClientId == "" { 40 return nil, fmt.Errorf("no client id provided") 41 } 42 43 if args.RedirectUri == "" { 44 return nil, fmt.Errorf("no redirect uri provided") 45 } 46 47 if args.Http == nil { 48 args.Http = &http.Client{ 49 Timeout: 5 * time.Second, 50 } 51 } 52 53 clientPkey, err := helpers.GetPrivateKey(args.ClientJwk) 54 if err != nil { 55 return nil, fmt.Errorf("could not load private key from provided client jwk: %w", err) 56 } 57 58 kid := args.ClientJwk.KeyID() 59 60 return &Client{ 61 h: args.Http, 62 clientKid: kid, 63 clientPrivateKey: clientPkey, 64 clientId: args.ClientId, 65 redirectUri: args.RedirectUri, 66 insecure: args.Insecure, 67 }, nil 68} 69 70func (c *Client) ResolvePdsAuthServer(ctx context.Context, ustr string) (string, error) { 71 u, err := helpers.IsUrlSafeAndParsed(ustr, c.insecure) 72 if err != nil { 73 return "", err 74 } 75 76 u.Path = "/.well-known/oauth-protected-resource" 77 78 req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil) 79 if err != nil { 80 return "", fmt.Errorf("error creating request for oauth protected resource: %w", err) 81 } 82 83 resp, err := c.h.Do(req) 84 if err != nil { 85 return "", fmt.Errorf("could not get response from server: %w", err) 86 } 87 defer resp.Body.Close() 88 89 if resp.StatusCode != http.StatusOK { 90 io.Copy(io.Discard, resp.Body) 91 return "", fmt.Errorf("received non-200 response from pds. code was %d", resp.StatusCode) 92 } 93 94 b, err := io.ReadAll(resp.Body) 95 if err != nil { 96 return "", fmt.Errorf("could not read body: %w", err) 97 } 98 99 var resource OauthProtectedResource 100 if err := resource.UnmarshalJSON(b); err != nil { 101 return "", fmt.Errorf("could not unmarshal json: %w", err) 102 } 103 104 if len(resource.AuthorizationServers) == 0 { 105 return "", fmt.Errorf("oauth protected resource contained no authorization servers") 106 } 107 108 return resource.AuthorizationServers[0], nil 109} 110 111func (c *Client) FetchAuthServerMetadata(ctx context.Context, ustr string) (*OauthAuthorizationMetadata, error) { 112 u, err := helpers.IsUrlSafeAndParsed(ustr, c.insecure) 113 if err != nil { 114 return nil, err 115 } 116 117 u.Path = "/.well-known/oauth-authorization-server" 118 119 req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil) 120 if err != nil { 121 return nil, fmt.Errorf("error creating request to fetch auth metadata: %w", err) 122 } 123 124 resp, err := c.h.Do(req) 125 if err != nil { 126 return nil, fmt.Errorf("error getting response for authserver metadata: %w", err) 127 } 128 defer resp.Body.Close() 129 130 if resp.StatusCode != http.StatusOK { 131 io.Copy(io.Discard, resp.Body) 132 return nil, fmt.Errorf("received non-200 response from pds. status code was %d", resp.StatusCode) 133 } 134 135 b, err := io.ReadAll(resp.Body) 136 if err != nil { 137 return nil, fmt.Errorf("could not read body for authserver metadata response: %w", err) 138 } 139 140 var metadata OauthAuthorizationMetadata 141 if err := metadata.UnmarshalJSON(b); err != nil { 142 return nil, fmt.Errorf("could not unmarshal authserver metadata: %w", err) 143 } 144 145 if err := metadata.Validate(u, c.insecure); err != nil { 146 return nil, fmt.Errorf("could not validate authserver metadata: %w", err) 147 } 148 149 return &metadata, nil 150} 151 152func (c *Client) ClientAssertionJwt(authServerUrl string) (string, error) { 153 claims := jwt.MapClaims{ 154 "iss": c.clientId, 155 "sub": c.clientId, 156 "aud": authServerUrl, 157 "jti": uuid.NewString(), 158 "iat": time.Now().Add(-5 * time.Second).Unix(), 159 } 160 161 token := jwt.NewWithClaims(jwt.SigningMethodES256, claims) 162 token.Header["kid"] = c.clientKid 163 164 tokenString, err := token.SignedString(c.clientPrivateKey) 165 if err != nil { 166 return "", err 167 } 168 169 return tokenString, nil 170} 171 172func (c *Client) AuthServerDpopJwt(method, url, nonce string, privateJwk jwk.Key) (string, error) { 173 pubJwk, err := privateJwk.PublicKey() 174 if err != nil { 175 return "", err 176 } 177 178 b, err := json.Marshal(pubJwk) 179 if err != nil { 180 return "", err 181 } 182 183 var pubMap map[string]any 184 if err := json.Unmarshal(b, &pubMap); err != nil { 185 return "", err 186 } 187 188 now := time.Now().Unix() 189 190 claims := jwt.MapClaims{ 191 "jti": uuid.NewString(), 192 "htm": method, 193 "htu": url, 194 "iat": now, 195 "exp": now + 30, 196 } 197 198 if nonce != "" { 199 claims["nonce"] = nonce 200 } 201 202 token := jwt.NewWithClaims(jwt.SigningMethodES256, claims) 203 token.Header["typ"] = "dpop+jwt" 204 token.Header["alg"] = "ES256" 205 token.Header["jwk"] = pubMap 206 207 var rawKey any 208 if err := privateJwk.Raw(&rawKey); err != nil { 209 return "", err 210 } 211 212 tokenString, err := token.SignedString(rawKey) 213 if err != nil { 214 return "", fmt.Errorf("failed to sign token: %w", err) 215 } 216 217 return tokenString, nil 218} 219 220func (c *Client) SendParAuthRequest(ctx context.Context, authServerUrl string, authServerMeta *OauthAuthorizationMetadata, loginHint, scope string, dpopPrivateKey jwk.Key) (*SendParAuthResponse, error) { 221 if authServerMeta == nil { 222 return nil, fmt.Errorf("nil metadata provided") 223 } 224 225 parUrl := authServerMeta.PushedAuthorizationRequestEndpoint 226 227 state, err := internal_helpers.GenerateToken(10) 228 if err != nil { 229 return nil, fmt.Errorf("could not generate state token: %w", err) 230 } 231 232 pkceVerifier, err := internal_helpers.GenerateToken(48) 233 if err != nil { 234 return nil, fmt.Errorf("could not generate pkce verifier: %w", err) 235 } 236 237 codeChallenge := internal_helpers.GenerateCodeChallenge(pkceVerifier) 238 codeChallengeMethod := "S256" 239 240 clientAssertion, err := c.ClientAssertionJwt(authServerUrl) 241 if err != nil { 242 return nil, fmt.Errorf("error getting client assertion: %w", err) 243 } 244 245 dpopAuthserverNonce := "" 246 dpopProof, err := c.AuthServerDpopJwt("POST", parUrl, dpopAuthserverNonce, dpopPrivateKey) 247 if err != nil { 248 return nil, fmt.Errorf("error getting dpop proof: %w", err) 249 } 250 251 params := url.Values{ 252 "response_type": {"code"}, 253 "code_challenge": {codeChallenge}, 254 "code_challenge_method": {codeChallengeMethod}, 255 "client_id": {c.clientId}, 256 "state": {state}, 257 "redirect_uri": {c.redirectUri}, 258 "scope": {scope}, 259 "client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"}, 260 "client_assertion": {clientAssertion}, 261 } 262 263 if loginHint != "" { 264 params.Set("login_hint", loginHint) 265 } 266 267 _, err = helpers.IsUrlSafeAndParsed(parUrl, c.insecure) 268 if err != nil { 269 return nil, err 270 } 271 272 req, err := http.NewRequestWithContext(ctx, "POST", parUrl, strings.NewReader(params.Encode())) 273 if err != nil { 274 return nil, err 275 } 276 277 req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 278 req.Header.Set("DPoP", dpopProof) 279 280 resp, err := c.h.Do(req) 281 if err != nil { 282 return nil, err 283 } 284 defer resp.Body.Close() 285 286 var rmap map[string]any 287 if err := json.NewDecoder(resp.Body).Decode(&rmap); err != nil { 288 return nil, err 289 } 290 291 if resp.StatusCode != 201 { 292 if resp.StatusCode == 400 && rmap["error"] == "use_dpop_nonce" { 293 dpopAuthserverNonce = resp.Header.Get("DPoP-Nonce") 294 dpopProof, err := c.AuthServerDpopJwt("POST", parUrl, dpopAuthserverNonce, dpopPrivateKey) 295 if err != nil { 296 return nil, err 297 } 298 299 req2, err := http.NewRequestWithContext( 300 ctx, 301 "POST", 302 parUrl, 303 strings.NewReader(params.Encode()), 304 ) 305 if err != nil { 306 return nil, err 307 } 308 309 req2.Header.Set("Content-Type", "application/x-www-form-urlencoded") 310 req2.Header.Set("DPoP", dpopProof) 311 312 resp2, err := c.h.Do(req2) 313 if err != nil { 314 return nil, err 315 } 316 defer resp2.Body.Close() 317 318 rmap = map[string]any{} 319 if err := json.NewDecoder(resp2.Body).Decode(&rmap); err != nil { 320 return nil, err 321 } 322 323 if resp2.StatusCode != 201 { 324 return nil, fmt.Errorf("received error from server when submitting par request: %v", rmap) 325 } 326 } else { 327 return nil, fmt.Errorf("received error from server when submitting par request: %v", rmap) 328 } 329 } 330 331 return &SendParAuthResponse{ 332 PkceVerifier: pkceVerifier, 333 State: state, 334 DpopAuthserverNonce: dpopAuthserverNonce, 335 ExpiresIn: rmap["expires_in"].(float64), 336 RequestUri: rmap["request_uri"].(string), 337 }, nil 338} 339 340func (c *Client) InitialTokenRequest( 341 ctx context.Context, 342 code, 343 authserverIss, 344 pkceVerifier, 345 dpopAuthserverNonce string, 346 dpopPrivateJwk jwk.Key, 347) (*TokenResponse, error) { 348 // we might need to re-run to update dpop nonce 349 for range 2 { 350 authserverMeta, err := c.FetchAuthServerMetadata(ctx, authserverIss) 351 if err != nil { 352 return nil, err 353 } 354 355 clientAssertion, err := c.ClientAssertionJwt(authserverIss) 356 if err != nil { 357 return nil, err 358 } 359 360 params := url.Values{ 361 "client_id": {c.clientId}, 362 "redirect_uri": {c.redirectUri}, 363 "grant_type": {"authorization_code"}, 364 "code": {code}, 365 "code_verifier": {pkceVerifier}, 366 "client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"}, 367 "client_assertion": {clientAssertion}, 368 } 369 370 dpopProof, err := c.AuthServerDpopJwt("POST", authserverMeta.TokenEndpoint, dpopAuthserverNonce, dpopPrivateJwk) 371 if err != nil { 372 return nil, err 373 } 374 375 req, err := http.NewRequestWithContext(ctx, "POST", authserverMeta.TokenEndpoint, strings.NewReader(params.Encode())) 376 if err != nil { 377 return nil, err 378 } 379 380 req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 381 req.Header.Set("DPoP", dpopProof) 382 383 resp, err := c.h.Do(req) 384 if err != nil { 385 return nil, err 386 } 387 defer resp.Body.Close() 388 389 if resp.StatusCode != 200 && resp.StatusCode != 201 { 390 var respMap map[string]string 391 if err := json.NewDecoder(resp.Body).Decode(&respMap); err != nil { 392 return nil, err 393 } 394 395 if resp.StatusCode == 400 && respMap["error"] == "use_dpop_nonce" { 396 dpopAuthserverNonce = resp.Header.Get("DPoP-Nonce") 397 continue 398 } 399 400 return nil, fmt.Errorf("token refresh error: %s", respMap["error"]) 401 } 402 403 var tokenResponse TokenResponse 404 if err := json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil { 405 return nil, err 406 } 407 408 // set nonce so the updates are reflected in the response 409 tokenResponse.DpopAuthserverNonce = dpopAuthserverNonce 410 411 return &tokenResponse, nil 412 } 413 414 return nil, nil 415} 416 417func (c *Client) RefreshTokenRequest( 418 ctx context.Context, 419 refreshToken, 420 authserverIss, 421 dpopAuthserverNonce string, 422 dpopPrivateJwk jwk.Key, 423) (*TokenResponse, error) { 424 // we may need to update the dpop nonce 425 for range 2 { 426 authserverMeta, err := c.FetchAuthServerMetadata(ctx, authserverIss) 427 if err != nil { 428 return nil, err 429 } 430 431 clientAssertion, err := c.ClientAssertionJwt(authserverIss) 432 if err != nil { 433 return nil, err 434 } 435 436 params := url.Values{ 437 "client_id": {c.clientId}, 438 "grant_type": {"refresh_token"}, 439 "refresh_token": {refreshToken}, 440 "client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"}, 441 "client_assertion": {clientAssertion}, 442 } 443 444 dpopProof, err := c.AuthServerDpopJwt("POST", authserverMeta.TokenEndpoint, dpopAuthserverNonce, dpopPrivateJwk) 445 if err != nil { 446 return nil, err 447 } 448 449 req, err := http.NewRequestWithContext(ctx, "POST", authserverMeta.TokenEndpoint, strings.NewReader(params.Encode())) 450 if err != nil { 451 return nil, err 452 } 453 454 req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 455 req.Header.Set("DPoP", dpopProof) 456 457 resp, err := c.h.Do(req) 458 if err != nil { 459 return nil, err 460 } 461 defer resp.Body.Close() 462 463 if resp.StatusCode != 200 && resp.StatusCode != 201 { 464 var respMap map[string]string 465 if err := json.NewDecoder(resp.Body).Decode(&respMap); err != nil { 466 return nil, err 467 } 468 469 if resp.StatusCode == 400 && respMap["error"] == "use_dpop_nonce" { 470 dpopAuthserverNonce = resp.Header.Get("DPoP-Nonce") 471 continue 472 } 473 474 return nil, fmt.Errorf("token refresh error: %s", respMap["error"]) 475 } 476 477 var tokenResponse TokenResponse 478 if err := json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil { 479 return nil, err 480 } 481 482 // set the nonce so that updates are reflected in response 483 tokenResponse.DpopAuthserverNonce = dpopAuthserverNonce 484 485 return &tokenResponse, nil 486 } 487 488 return nil, nil 489}