A better Rust ATProto crate

home stretch. big pile of tests

Orual d02398d1 220b03d3

+1
Cargo.lock
··· 1680 1680 dependencies = [ 1681 1681 "async-trait", 1682 1682 "base64 0.22.1", 1683 + "bytes", 1683 1684 "chrono", 1684 1685 "dashmap", 1685 1686 "elliptic-curve",
+177 -14
crates/jacquard-common/src/types/xrpc.rs
··· 1 + //! Stateless XRPC utilities and request/response mapping 2 + //! 3 + //! Mapping overview: 4 + //! - Success (2xx): parse body into the endpoint's typed output. 5 + //! - 400: try typed error; on failure, fall back to a generic XRPC error (with 6 + //! `nsid`, `method`, and `http_status`) and map common auth errors. 7 + //! - 401: if `WWW-Authenticate` is present, return 8 + //! `ClientError::Auth(AuthError::Other(header))` so higher layers (OAuth/DPoP) 9 + //! can inspect `error="invalid_token"` or `error="use_dpop_nonce"` and refresh/retry. 10 + //! If the header is absent, parse the body and map auth errors to 11 + //! `AuthError::TokenExpired`/`InvalidToken`. 12 + //! 1 13 use bytes::Bytes; 2 14 use http::{ 3 15 HeaderName, HeaderValue, Request, StatusCode, ··· 251 263 } 252 264 253 265 /// Send the given typed XRPC request and return a response wrapper. 266 + /// 267 + /// Note on 401 handling: 268 + /// - When the server returns 401 with a `WWW-Authenticate` header, this surfaces as 269 + /// `ClientError::Auth(AuthError::Other(header))` so higher layers (e.g., OAuth/DPoP) can 270 + /// inspect the header for `error="invalid_token"` or `error="use_dpop_nonce"` and react 271 + /// (refresh/retry). If the header is absent, the 401 body flows through to `Response` and 272 + /// can be parsed/mapped to `AuthError` as appropriate. 254 273 pub async fn send<R: XrpcRequest + Send>(self, request: &R) -> XrpcResult<Response<R>> { 255 274 let http_request = build_http_request(&self.base, request, &self.opts) 256 275 .map_err(crate::error::TransportError::from)?; ··· 262 281 .map_err(|e| crate::error::TransportError::Other(Box::new(e)))?; 263 282 264 283 let status = http_response.status(); 284 + // If the server returned 401 with a WWW-Authenticate header, expose it so higher layers 285 + // (e.g., DPoP handling) can detect `error="invalid_token"` and trigger refresh. 286 + if status.as_u16() == 401 { 287 + if let Some(hv) = http_response.headers().get(http::header::WWW_AUTHENTICATE) { 288 + return Err(crate::error::ClientError::Auth( 289 + crate::error::AuthError::Other(hv.clone()), 290 + )); 291 + } 292 + } 265 293 let buffer = Bytes::from(http_response.into_body()); 266 294 267 295 if !status.is_success() && !matches!(status.as_u16(), 400 | 401) { ··· 430 458 Err(_) => { 431 459 // Fallback to generic error (InvalidRequest, ExpiredToken, etc.) 432 460 match serde_json::from_slice::<GenericXrpcError>(&self.buffer) { 433 - Ok(generic) => { 461 + Ok(mut generic) => { 462 + generic.nsid = R::NSID; 463 + generic.method = R::METHOD.as_str(); 464 + generic.http_status = self.status; 434 465 // Map auth-related errors to AuthError 435 466 match generic.error.as_str() { 436 467 "ExpiredToken" => Err(XrpcError::Auth(AuthError::TokenExpired)), ··· 445 476 // 401: always auth error 446 477 } else { 447 478 match serde_json::from_slice::<GenericXrpcError>(&self.buffer) { 448 - Ok(generic) => match generic.error.as_str() { 449 - "ExpiredToken" => Err(XrpcError::Auth(AuthError::TokenExpired)), 450 - "InvalidToken" => Err(XrpcError::Auth(AuthError::InvalidToken)), 451 - _ => Err(XrpcError::Auth(AuthError::NotAuthenticated)), 452 - }, 479 + Ok(mut generic) => { 480 + generic.nsid = R::NSID; 481 + generic.method = R::METHOD.as_str(); 482 + generic.http_status = self.status; 483 + match generic.error.as_str() { 484 + "ExpiredToken" => Err(XrpcError::Auth(AuthError::TokenExpired)), 485 + "InvalidToken" => Err(XrpcError::Auth(AuthError::InvalidToken)), 486 + _ => Err(XrpcError::Auth(AuthError::NotAuthenticated)), 487 + } 488 + } 453 489 Err(e) => Err(XrpcError::Decode(e)), 454 490 } 455 491 } ··· 487 523 Err(_) => { 488 524 // Fallback to generic error (InvalidRequest, ExpiredToken, etc.) 489 525 match serde_json::from_slice::<GenericXrpcError>(&self.buffer) { 490 - Ok(generic) => { 526 + Ok(mut generic) => { 527 + generic.nsid = R::NSID; 528 + generic.method = R::METHOD.as_str(); 529 + generic.http_status = self.status; 491 530 // Map auth-related errors to AuthError 492 531 match generic.error.as_ref() { 493 532 "ExpiredToken" => Err(XrpcError::Auth(AuthError::TokenExpired)), ··· 502 541 // 401: always auth error 503 542 } else { 504 543 match serde_json::from_slice::<GenericXrpcError>(&self.buffer) { 505 - Ok(generic) => match generic.error.as_ref() { 506 - "ExpiredToken" => Err(XrpcError::Auth(AuthError::TokenExpired)), 507 - "InvalidToken" => Err(XrpcError::Auth(AuthError::InvalidToken)), 508 - _ => Err(XrpcError::Auth(AuthError::NotAuthenticated)), 509 - }, 544 + Ok(mut generic) => { 545 + let status = self.status; 546 + generic.nsid = R::NSID; 547 + generic.method = R::METHOD.as_str(); 548 + generic.http_status = status; 549 + match generic.error.as_ref() { 550 + "ExpiredToken" => Err(XrpcError::Auth(AuthError::TokenExpired)), 551 + "InvalidToken" => Err(XrpcError::Auth(AuthError::InvalidToken)), 552 + _ => Err(XrpcError::Auth(AuthError::NotAuthenticated)), 553 + } 554 + } 510 555 Err(e) => Err(XrpcError::Decode(e)), 511 556 } 512 557 } ··· 527 572 pub error: SmolStr, 528 573 /// Optional error message with details 529 574 pub message: Option<SmolStr>, 575 + /// XRPC method NSID that produced this error (context only; not serialized) 576 + #[serde(skip)] 577 + pub nsid: &'static str, 578 + /// HTTP method used (GET/POST) (context only; not serialized) 579 + #[serde(skip)] 580 + pub method: &'static str, 581 + /// HTTP status code (context only; not serialized) 582 + #[serde(skip)] 583 + pub http_status: StatusCode, 530 584 } 531 585 532 586 impl std::fmt::Display for GenericXrpcError { 533 587 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 534 588 if let Some(msg) = &self.message { 535 - write!(f, "{}: {}", self.error, msg) 589 + write!( 590 + f, 591 + "{}: {} (nsid={}, method={}, status={})", 592 + self.error, msg, self.nsid, self.method, self.http_status 593 + ) 536 594 } else { 537 - write!(f, "{}", self.error) 595 + write!( 596 + f, 597 + "{} (nsid={}, method={}, status={})", 598 + self.error, self.nsid, self.method, self.http_status 599 + ) 538 600 } 539 601 } 540 602 } ··· 549 611 pub enum XrpcError<E: std::error::Error + IntoStatic> { 550 612 /// Typed XRPC error from the endpoint's specific error enum 551 613 #[error("XRPC error: {0}")] 614 + #[diagnostic(code(jacquard_common::xrpc::typed))] 552 615 Xrpc(E), 553 616 554 617 /// Authentication error (ExpiredToken, InvalidToken, etc.) 555 618 #[error("Authentication error: {0}")] 619 + #[diagnostic(code(jacquard_common::xrpc::auth))] 556 620 Auth(#[from] AuthError), 557 621 558 622 /// Generic XRPC error not in the endpoint's error enum (e.g., InvalidRequest) 559 623 #[error("XRPC error: {0}")] 624 + #[diagnostic(code(jacquard_common::xrpc::generic))] 560 625 Generic(GenericXrpcError), 561 626 562 627 /// Failed to decode the response body 563 628 #[error("Failed to decode response: {0}")] 629 + #[diagnostic(code(jacquard_common::xrpc::decode))] 564 630 Decode(#[from] serde_json::Error), 631 + } 632 + 633 + #[cfg(test)] 634 + mod tests { 635 + use super::*; 636 + use serde::{Deserialize, Serialize}; 637 + 638 + #[derive(Serialize)] 639 + struct DummyReq; 640 + 641 + #[derive(Deserialize, Debug, thiserror::Error)] 642 + #[error("{0}")] 643 + struct DummyErr<'a>(#[serde(borrow)] CowStr<'a>); 644 + 645 + impl IntoStatic for DummyErr<'_> { 646 + type Output = DummyErr<'static>; 647 + fn into_static(self) -> Self::Output { 648 + DummyErr(self.0.into_static()) 649 + } 650 + } 651 + 652 + impl XrpcRequest for DummyReq { 653 + const NSID: &'static str = "test.dummy"; 654 + const METHOD: XrpcMethod = XrpcMethod::Procedure("application/json"); 655 + const OUTPUT_ENCODING: &'static str = "application/json"; 656 + type Output<'de> = (); 657 + type Err<'de> = DummyErr<'de>; 658 + } 659 + 660 + #[test] 661 + fn generic_error_carries_context() { 662 + let body = serde_json::json!({"error":"InvalidRequest","message":"missing"}); 663 + let buf = Bytes::from(serde_json::to_vec(&body).unwrap()); 664 + let resp: Response<DummyReq> = Response::new(buf, StatusCode::BAD_REQUEST); 665 + match resp.parse().unwrap_err() { 666 + XrpcError::Generic(g) => { 667 + assert_eq!(g.error.as_str(), "InvalidRequest"); 668 + assert_eq!(g.message.as_deref(), Some("missing")); 669 + assert_eq!(g.nsid, DummyReq::NSID); 670 + assert_eq!(g.method, DummyReq::METHOD.as_str()); 671 + assert_eq!(g.http_status, StatusCode::BAD_REQUEST); 672 + } 673 + other => panic!("unexpected: {other:?}"), 674 + } 675 + } 676 + 677 + #[test] 678 + fn auth_error_mapping() { 679 + for (code, expect) in [ 680 + ("ExpiredToken", AuthError::TokenExpired), 681 + ("InvalidToken", AuthError::InvalidToken), 682 + ] { 683 + let body = serde_json::json!({"error": code}); 684 + let buf = Bytes::from(serde_json::to_vec(&body).unwrap()); 685 + let resp: Response<DummyReq> = Response::new(buf, StatusCode::UNAUTHORIZED); 686 + match resp.parse().unwrap_err() { 687 + XrpcError::Auth(e) => match (e, expect) { 688 + (AuthError::TokenExpired, AuthError::TokenExpired) => {} 689 + (AuthError::InvalidToken, AuthError::InvalidToken) => {} 690 + other => panic!("mismatch: {other:?}"), 691 + }, 692 + other => panic!("unexpected: {other:?}"), 693 + } 694 + } 695 + } 696 + 697 + #[test] 698 + fn no_double_slash_in_path() { 699 + #[derive(Serialize)] 700 + struct Req; 701 + #[derive(Deserialize, Debug, thiserror::Error)] 702 + #[error("{0}")] 703 + struct Err<'a>(#[serde(borrow)] CowStr<'a>); 704 + impl IntoStatic for Err<'_> { 705 + type Output = Err<'static>; 706 + fn into_static(self) -> Self::Output { Err(self.0.into_static()) } 707 + } 708 + impl XrpcRequest for Req { 709 + const NSID: &'static str = "com.example.test"; 710 + const METHOD: XrpcMethod = XrpcMethod::Query; 711 + const OUTPUT_ENCODING: &'static str = "application/json"; 712 + type Output<'de> = (); 713 + type Err<'de> = Err<'de>; 714 + } 715 + 716 + let opts = CallOptions::default(); 717 + for base in [ 718 + Url::parse("https://pds").unwrap(), 719 + Url::parse("https://pds/").unwrap(), 720 + Url::parse("https://pds/base/").unwrap(), 721 + ] { 722 + let req = build_http_request(&base, &Req, &opts).unwrap(); 723 + let uri = req.uri().to_string(); 724 + assert!(uri.contains("/xrpc/com.example.test")); 725 + assert!(!uri.contains("//xrpc")); 726 + } 727 + } 565 728 } 566 729 567 730 /// Stateful XRPC call trait
+12
crates/jacquard-identity/src/resolver.rs
··· 36 36 #[allow(missing_docs)] 37 37 pub enum IdentityError { 38 38 #[error("unsupported DID method: {0}")] 39 + #[diagnostic(code(jacquard_identity::unsupported_did_method), help("supported DID methods: did:web, did:plc"))] 39 40 UnsupportedDidMethod(String), 40 41 #[error("invalid well-known atproto-did content")] 42 + #[diagnostic(code(jacquard_identity::invalid_well_known), help("expected first non-empty line to be a DID"))] 41 43 InvalidWellKnown, 42 44 #[error("missing PDS endpoint in DID document")] 45 + #[diagnostic(code(jacquard_identity::missing_pds_endpoint))] 43 46 MissingPdsEndpoint, 44 47 #[error("HTTP error: {0}")] 48 + #[diagnostic(code(jacquard_identity::http), help("check network connectivity and TLS configuration"))] 45 49 Http(#[from] TransportError), 46 50 #[error("HTTP status {0}")] 51 + #[diagnostic(code(jacquard_identity::http_status), help("verify well-known paths or PDS XRPC endpoints"))] 47 52 HttpStatus(StatusCode), 48 53 #[error("XRPC error: {0}")] 54 + #[diagnostic(code(jacquard_identity::xrpc), help("enable PDS fallback or public resolver if needed"))] 49 55 Xrpc(String), 50 56 #[error("URL parse error: {0}")] 57 + #[diagnostic(code(jacquard_identity::url))] 51 58 Url(#[from] url::ParseError), 52 59 #[error("DNS error: {0}")] 53 60 #[cfg(feature = "dns")] 61 + #[diagnostic(code(jacquard_identity::dns))] 54 62 Dns(#[from] hickory_resolver::error::ResolveError), 55 63 #[error("serialize/deserialize error: {0}")] 64 + #[diagnostic(code(jacquard_identity::serde))] 56 65 Serde(#[from] serde_json::Error), 57 66 #[error("invalid DID document: {0}")] 67 + #[diagnostic(code(jacquard_identity::invalid_doc), help("validate keys and services; ensure AtprotoPersonalDataServer service exists"))] 58 68 InvalidDoc(String), 59 69 #[error(transparent)] 70 + #[diagnostic(code(jacquard_identity::data))] 60 71 Data(#[from] AtDataError), 61 72 /// DID document id did not match requested DID; includes the fetched document 62 73 #[error("DID doc id mismatch")] 74 + #[diagnostic(code(jacquard_identity::doc_id_mismatch), help("document id differs from requested DID; do not trust this document"))] 63 75 DocIdMismatch { 64 76 expected: Did<'static>, 65 77 doc: DidDocument<'static>,
+1
crates/jacquard-oauth/Cargo.toml
··· 25 25 chrono = "0.4" 26 26 elliptic-curve = "0.13.8" 27 27 http.workspace = true 28 + bytes.workspace = true 28 29 rand = { version = "0.8.5", features = ["small_rng"] } 29 30 async-trait = "0.1.89" 30 31 dashmap = "6.1.0"
+77 -26
crates/jacquard-oauth/src/atproto.rs
··· 109 109 mut redirect_uris: Option<Vec<Url>>, 110 110 scopes: Option<Vec<Scope<'m>>>, 111 111 ) -> Self { 112 - // coerce redirect uris to localhost 112 + // Coerce provided redirect URIs to http://localhost while preserving path 113 113 if let Some(redirect_uris) = &mut redirect_uris { 114 114 for redirect_uri in redirect_uris { 115 - redirect_uri.set_host(Some("http://localhost")).unwrap(); 115 + let _ = redirect_uri.set_scheme("http"); 116 + redirect_uri.set_host(Some("localhost")).unwrap(); 117 + let _ = redirect_uri.set_port(None); 116 118 } 117 119 } 118 120 // determine client_id ··· 154 156 metadata: AtprotoClientMetadata<'m>, 155 157 keyset: &Option<Keyset>, 156 158 ) -> Result<OAuthClientMetadata<'m>> { 159 + // For non-loopback clients, require a keyset/JWKs. 160 + let is_loopback = metadata.client_id.scheme() == "http" 161 + && metadata.client_id.host_str() == Some("localhost"); 162 + if !is_loopback && keyset.is_none() { 163 + return Err(Error::EmptyJwks); 164 + } 157 165 if metadata.redirect_uris.is_empty() { 158 166 return Err(Error::EmptyRedirectUris); 159 167 } ··· 179 187 client_uri: metadata.client_uri, 180 188 redirect_uris: metadata.redirect_uris, 181 189 token_endpoint_auth_method: Some(auth_method.into()), 182 - grant_types: Some(metadata.grant_types.into_iter().map(|v| v.into()).collect()), 183 - scope: Some(Scope::serialize_multiple(metadata.scopes.as_slice())), 184 - dpop_bound_access_tokens: Some(true), 190 + grant_types: if keyset.is_some() { 191 + Some(metadata.grant_types.into_iter().map(|v| v.into()).collect()) 192 + } else { 193 + None 194 + }, 195 + scope: if keyset.is_some() { 196 + Some(Scope::serialize_multiple(metadata.scopes.as_slice())) 197 + } else { 198 + None 199 + }, 200 + dpop_bound_access_tokens: if keyset.is_some() { Some(true) } else { None }, 185 201 jwks_uri, 186 202 jwks, 187 203 token_endpoint_auth_signing_alg: if keyset.is_some() { ··· 251 267 .expect("failed to convert metadata"), 252 268 OAuthClientMetadata { 253 269 client_id: Url::from_str( 254 - "http://localhost?redirect_uri=http%3A%2F%2F127.0.0.1%2Fcallback&redirect_uri=http%3A%2F%2F%5B%3A%3A1%5D%2Fcallback&scope=account%3Aemail+atproto+transition%3Ageneric" 270 + "http://localhost?redirect_uri=http%3A%2F%2Flocalhost%2Fcallback&redirect_uri=http%3A%2F%2Flocalhost%2Fcallback&scope=account%3Aemail+atproto+transition%3Ageneric" 255 271 ).unwrap(), 256 272 client_uri: None, 257 273 redirect_uris: vec![ 258 - Url::from_str("http://127.0.0.1/callback").unwrap(), 259 - Url::from_str("http://[::1]/callback").unwrap(), 274 + Url::from_str("http://localhost/callback").unwrap(), 275 + Url::from_str("http://localhost/callback").unwrap(), 260 276 ], 261 277 scope: None, 262 278 grant_types: None, ··· 271 287 272 288 #[test] 273 289 fn test_localhost_client_metadata_invalid() { 290 + // Invalid inputs are coerced to http://localhost rather than failing 274 291 { 275 - let err = atproto_client_metadata( 292 + let out = atproto_client_metadata( 276 293 AtprotoClientMetadata::new_localhost( 277 294 Some(vec![Url::from_str("https://127.0.0.1/").unwrap()]), 278 295 None, 279 296 ), 280 297 &None, 281 298 ) 282 - .expect_err("expected to fail"); 283 - assert!(matches!( 284 - err, 285 - Error::LocalhostClient(LocalhostClientError::NotHttpScheme) 286 - )); 299 + .expect("should coerce to localhost"); 300 + assert_eq!( 301 + out, 302 + OAuthClientMetadata { 303 + client_id: Url::from_str("http://localhost?redirect_uri=http%3A%2F%2Flocalhost%2F").unwrap(), 304 + client_uri: None, 305 + redirect_uris: vec![Url::from_str("http://localhost/").unwrap()], 306 + scope: None, 307 + grant_types: None, 308 + token_endpoint_auth_method: Some(AuthMethod::None.into()), 309 + dpop_bound_access_tokens: None, 310 + jwks_uri: None, 311 + jwks: None, 312 + token_endpoint_auth_signing_alg: None, 313 + } 314 + ); 287 315 } 288 316 { 289 - let err = atproto_client_metadata( 317 + let out = atproto_client_metadata( 290 318 AtprotoClientMetadata::new_localhost( 291 319 Some(vec![Url::from_str("http://localhost:8000/").unwrap()]), 292 320 None, 293 321 ), 294 322 &None, 295 323 ) 296 - .expect_err("expected to fail"); 297 - assert!(matches!( 298 - err, 299 - Error::LocalhostClient(LocalhostClientError::Localhost) 300 - )); 324 + .expect("should coerce to localhost"); 325 + assert_eq!( 326 + out, 327 + OAuthClientMetadata { 328 + client_id: Url::from_str("http://localhost?redirect_uri=http%3A%2F%2Flocalhost%2F").unwrap(), 329 + client_uri: None, 330 + redirect_uris: vec![Url::from_str("http://localhost/").unwrap()], 331 + scope: None, 332 + grant_types: None, 333 + token_endpoint_auth_method: Some(AuthMethod::None.into()), 334 + dpop_bound_access_tokens: None, 335 + jwks_uri: None, 336 + jwks: None, 337 + token_endpoint_auth_signing_alg: None, 338 + } 339 + ); 301 340 } 302 341 { 303 - let err = atproto_client_metadata( 342 + let out = atproto_client_metadata( 304 343 AtprotoClientMetadata::new_localhost( 305 344 Some(vec![Url::from_str("http://192.168.0.0/").unwrap()]), 306 345 None, 307 346 ), 308 347 &None, 309 348 ) 310 - .expect_err("expected to fail"); 311 - assert!(matches!( 312 - err, 313 - Error::LocalhostClient(LocalhostClientError::NotLoopbackHost) 314 - )); 349 + .expect("should coerce to localhost"); 350 + assert_eq!( 351 + out, 352 + OAuthClientMetadata { 353 + client_id: Url::from_str("http://localhost?redirect_uri=http%3A%2F%2Flocalhost%2F").unwrap(), 354 + client_uri: None, 355 + redirect_uris: vec![Url::from_str("http://localhost/").unwrap()], 356 + scope: None, 357 + grant_types: None, 358 + token_endpoint_auth_method: Some(AuthMethod::None.into()), 359 + dpop_bound_access_tokens: None, 360 + jwks_uri: None, 361 + jwks: None, 362 + token_endpoint_auth_signing_alg: None, 363 + } 364 + ); 315 365 } 316 366 } 317 367 ··· 326 376 jwks_uri: None, 327 377 }; 328 378 { 379 + // Non-loopback clients without a keyset should fail (must provide JWKS) 329 380 let metadata = metadata.clone(); 330 381 let err = atproto_client_metadata(metadata, &None).expect_err("expected to fail"); 331 382 assert!(matches!(err, Error::EmptyJwks));
+26 -22
crates/jacquard-oauth/src/client.rs
··· 2 2 atproto::atproto_client_metadata, 3 3 authstore::ClientAuthStore, 4 4 dpop::DpopExt, 5 - error::{OAuthError, Result}, 5 + error::{CallbackError, Result}, 6 6 request::{OAuthMetadata, exchange_code, par}, 7 7 resolver::OAuthResolver, 8 8 scopes::Scope, ··· 105 105 106 106 pub async fn callback(&self, params: CallbackParams<'_>) -> Result<OAuthSession<T, S>> { 107 107 let Some(state_key) = params.state else { 108 - return Err(OAuthError::Callback("missing state parameter".into())); 108 + return Err(CallbackError::MissingState.into()); 109 109 }; 110 110 111 111 let Some(auth_req_info) = self.registry.store.get_auth_req_info(&state_key).await? else { 112 - return Err(OAuthError::Callback(format!( 113 - "unknown authorization state: {state_key}" 114 - ))); 112 + return Err(CallbackError::MissingState.into()); 115 113 }; 116 114 117 115 self.registry.store.delete_auth_req_info(&state_key).await?; ··· 122 120 .await?; 123 121 124 122 if let Some(iss) = params.iss { 125 - if iss != metadata.issuer { 126 - return Err(OAuthError::Callback(format!( 127 - "issuer mismatch: expected {}, got {iss}", 128 - metadata.issuer 129 - ))); 123 + if !crate::resolver::issuer_equivalent(&iss, &metadata.issuer) { 124 + return Err(CallbackError::IssuerMismatch { expected: metadata.issuer.to_string(), got: iss.to_string() }.into()); 130 125 } 131 126 } else if metadata.authorization_response_iss_parameter_supported == Some(true) { 132 - return Err(OAuthError::Callback("missing `iss` parameter".into())); 127 + return Err(CallbackError::MissingIssuer.into()); 133 128 } 134 129 let metadata = OAuthMetadata { 135 130 server_metadata: metadata, ··· 257 252 } 258 253 259 254 pub async fn refresh_token(&self) -> Option<AuthorizationToken<'_>> { 260 - self.data 261 - .read() 262 - .await 263 - .token_set 264 - .refresh_token 265 - .as_ref() 266 - .map(|token| AuthorizationToken::Dpop(token.clone())) 255 + self.data.read().await.token_set.refresh_token.as_ref().map(|t| AuthorizationToken::Dpop(t.clone())) 267 256 } 268 257 } 269 258 impl<T, S> OAuthSession<T, S> ··· 272 261 T: OAuthResolver + DpopExt + Send + Sync + 'static, 273 262 { 274 263 pub async fn refresh(&self) -> Result<AuthorizationToken<'_>> { 275 - let mut data = self.data.write().await; 264 + // Read identifiers without holding the lock across await 265 + let (did, sid) = { 266 + let data = self.data.read().await; 267 + (data.account_did.clone(), data.session_id.clone()) 268 + }; 276 269 let refreshed = self 277 270 .registry 278 271 .as_ref() 279 - .get(&data.account_did, &data.session_id, true) 272 + .get(&did, &sid, true) 280 273 .await?; 281 274 let token = AuthorizationToken::Dpop(refreshed.token_set.access_token.clone()); 282 - *data = refreshed.into_static(); 275 + // Write back updated session 276 + *self.data.write().await = refreshed.into_static(); 283 277 Ok(token) 284 278 } 285 279 } ··· 305 299 T: OAuthResolver + DpopExt + XrpcExt + Send + Sync + 'static, 306 300 { 307 301 fn base_uri(&self) -> Url { 308 - self.data.blocking_read().host_url.clone() 302 + // base_uri is a synchronous trait method; we must avoid async `.read().await`. 303 + // Use `block_in_place` under Tokio to perform a blocking RwLock read safely. 304 + if tokio::runtime::Handle::try_current().is_ok() { 305 + tokio::task::block_in_place(|| self.data.blocking_read().host_url.clone()) 306 + } else { 307 + self.data.blocking_read().host_url.clone() 308 + } 309 309 } 310 310 311 311 async fn opts(&self) -> CallOptions<'_> { ··· 349 349 Err(ClientError::Auth(AuthError::Other(value))) => value 350 350 .to_str() 351 351 .is_ok_and(|s| s.starts_with("DPoP ") && s.contains("error=\"invalid_token\"")), 352 + Ok(resp) => match resp.parse() { 353 + Err(jacquard_common::types::xrpc::XrpcError::Auth(AuthError::InvalidToken)) => true, 354 + _ => false, 355 + }, 352 356 _ => false, 353 357 } 354 358 }
+50 -39
crates/jacquard-oauth/src/error.rs
··· 1 1 use jacquard_common::session::SessionStoreError; 2 2 use miette::Diagnostic; 3 3 4 + use crate::request::RequestError; 4 5 use crate::resolver::ResolverError; 5 6 6 - /// Errors emitted by OAuth helpers. 7 + /// High-level errors emitted by OAuth helpers. 7 8 #[derive(Debug, thiserror::Error, Diagnostic)] 8 9 pub enum OAuthError { 9 - /// Invalid or unsupported JWK 10 - #[error("invalid JWK: {0}")] 11 - #[diagnostic( 12 - code(jacquard_oauth::jwk), 13 - help("Ensure EC P-256 JWK with base64url x,y,d values") 14 - )] 15 - Jwk(String), 16 - /// Signing error 17 - #[error("signing error: {0}")] 18 - #[diagnostic( 19 - code(jacquard_oauth::signing), 20 - help("Check ES256 key material and input payloads") 21 - )] 22 - Signing(String), 23 - /// Serialization error 10 + #[error(transparent)] 11 + #[diagnostic(code(jacquard_oauth::resolver))] 12 + Resolver(#[from] ResolverError), 13 + 14 + #[error(transparent)] 15 + #[diagnostic(code(jacquard_oauth::request))] 16 + Request(#[from] RequestError), 17 + 24 18 #[error(transparent)] 25 - #[diagnostic(code(jacquard_oauth::serde))] 26 - Serde(#[from] serde_json::Error), 27 - /// URL error 19 + #[diagnostic(code(jacquard_oauth::storage))] 20 + Storage(#[from] SessionStoreError), 21 + 28 22 #[error(transparent)] 29 - #[diagnostic(code(jacquard_oauth::url))] 30 - Url(#[from] url::ParseError), 31 - /// URL error 23 + #[diagnostic(code(jacquard_oauth::dpop))] 24 + Dpop(#[from] crate::dpop::Error), 25 + 32 26 #[error(transparent)] 33 - #[diagnostic(code(jacquard_oauth::url))] 34 - UrlEncoding(#[from] serde_html_form::ser::Error), 35 - /// PKCE error 36 - #[error("pkce error: {0}")] 37 - #[diagnostic( 38 - code(jacquard_oauth::pkce), 39 - help("PKCE must use S256; ensure verifier/challenge generated") 40 - )] 41 - Pkce(String), 42 - #[error("authorize error: {0}")] 43 - Authorize(String), 27 + #[diagnostic(code(jacquard_oauth::keyset))] 28 + Keyset(#[from] crate::keyset::Error), 29 + 44 30 #[error(transparent)] 31 + #[diagnostic(code(jacquard_oauth::atproto))] 45 32 Atproto(#[from] crate::atproto::Error), 46 - #[error("callback error: {0}")] 47 - Callback(String), 48 - #[error(transparent)] 49 - Storage(#[from] SessionStoreError), 33 + 50 34 #[error(transparent)] 35 + #[diagnostic(code(jacquard_oauth::session))] 51 36 Session(#[from] crate::session::Error), 37 + 52 38 #[error(transparent)] 53 - Request(#[from] crate::request::Error), 39 + #[diagnostic(code(jacquard_oauth::serde_json))] 40 + SerdeJson(#[from] serde_json::Error), 41 + 54 42 #[error(transparent)] 55 - Client(#[from] ResolverError), 43 + #[diagnostic(code(jacquard_oauth::url))] 44 + Url(#[from] url::ParseError), 45 + 46 + #[error(transparent)] 47 + #[diagnostic(code(jacquard_oauth::form))] 48 + Form(#[from] serde_html_form::ser::Error), 49 + 50 + #[error(transparent)] 51 + #[diagnostic(code(jacquard_oauth::callback))] 52 + Callback(#[from] CallbackError), 53 + } 54 + 55 + /// Typed callback validation errors (redirect handling). 56 + #[derive(Debug, thiserror::Error, Diagnostic)] 57 + pub enum CallbackError { 58 + #[error("missing state parameter in callback")] 59 + #[diagnostic(code(jacquard_oauth::callback::missing_state))] 60 + MissingState, 61 + #[error("missing `iss` parameter")] 62 + #[diagnostic(code(jacquard_oauth::callback::missing_iss))] 63 + MissingIssuer, 64 + #[error("issuer mismatch: expected {expected}, got {got}")] 65 + #[diagnostic(code(jacquard_oauth::callback::issuer_mismatch))] 66 + IssuerMismatch { expected: String, got: String }, 56 67 } 57 68 58 69 pub type Result<T> = core::result::Result<T, OAuthError>;
+207 -17
crates/jacquard-oauth/src/request.rs
··· 40 40 const CLIENT_ASSERTION_TYPE_JWT_BEARER: &str = 41 41 "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"; 42 42 43 - #[derive(Error, Debug)] 44 - pub enum Error { 43 + #[derive(Error, Debug, miette::Diagnostic)] 44 + pub enum RequestError { 45 45 #[error("no {0} endpoint available")] 46 + #[diagnostic(code(jacquard_oauth::request::no_endpoint), help("server does not advertise this endpoint"))] 46 47 NoEndpoint(CowStr<'static>), 47 48 #[error("token response verification failed")] 48 - Token(CowStr<'static>), 49 + #[diagnostic(code(jacquard_oauth::request::token_verification))] 50 + TokenVerification, 49 51 #[error("unsupported authentication method")] 52 + #[diagnostic( 53 + code(jacquard_oauth::request::unsupported_auth_method), 54 + help("server must support `private_key_jwt` or `none`; configure client metadata accordingly") 55 + )] 50 56 UnsupportedAuthMethod, 51 57 #[error("no refresh token available")] 52 - TokenRefresh, 58 + #[diagnostic(code(jacquard_oauth::request::no_refresh_token))] 59 + NoRefreshToken, 53 60 #[error("failed to parse DID: {0}")] 61 + #[diagnostic(code(jacquard_oauth::request::invalid_did))] 54 62 InvalidDid(#[from] AtStrError), 55 63 #[error(transparent)] 64 + #[diagnostic(code(jacquard_oauth::request::dpop))] 56 65 DpopClient(#[from] crate::dpop::Error), 57 66 #[error(transparent)] 67 + #[diagnostic(code(jacquard_oauth::request::storage))] 58 68 Storage(#[from] SessionStoreError), 59 69 60 70 #[error(transparent)] 71 + #[diagnostic(code(jacquard_oauth::request::resolver))] 61 72 ResolverError(#[from] crate::resolver::ResolverError), 62 73 // #[error(transparent)] 63 74 // OAuthSession(#[from] crate::oauth_session::Error), 64 75 #[error(transparent)] 76 + #[diagnostic(code(jacquard_oauth::request::http_build))] 65 77 Http(#[from] http::Error), 66 - #[error("http client error: {0}")] 67 - HttpClient(Box<dyn std::error::Error + Send + Sync + 'static>), 68 78 #[error("http status: {0}")] 79 + #[diagnostic(code(jacquard_oauth::request::http_status), help("see server response for details"))] 69 80 HttpStatus(StatusCode), 70 81 #[error("http status: {0}, body: {1:?}")] 82 + #[diagnostic(code(jacquard_oauth::request::http_status_body), help("server returned error JSON; inspect fields like `error`, `error_description`"))] 71 83 HttpStatusWithBody(StatusCode, Value), 72 84 #[error(transparent)] 85 + #[diagnostic(code(jacquard_oauth::request::identity))] 73 86 Identity(#[from] IdentityError), 74 87 #[error(transparent)] 88 + #[diagnostic(code(jacquard_oauth::request::keyset))] 75 89 Keyset(#[from] crate::keyset::Error), 76 90 #[error(transparent)] 91 + #[diagnostic(code(jacquard_oauth::request::serde_form))] 77 92 SerdeHtmlForm(#[from] serde_html_form::ser::Error), 78 93 #[error(transparent)] 94 + #[diagnostic(code(jacquard_oauth::request::serde_json))] 79 95 SerdeJson(#[from] serde_json::Error), 80 96 #[error(transparent)] 97 + #[diagnostic(code(jacquard_oauth::request::atproto))] 81 98 Atproto(#[from] crate::atproto::Error), 82 99 } 83 100 84 - pub type Result<T> = core::result::Result<T, Error>; 101 + pub type Result<T> = core::result::Result<T, RequestError>; 85 102 86 103 #[allow(dead_code)] 87 104 pub enum OAuthRequest<'a> { ··· 113 130 } 114 131 } 115 132 133 + #[cfg(test)] 134 + mod tests { 135 + use super::*; 136 + use crate::types::{OAuthAuthorizationServerMetadata, OAuthClientMetadata}; 137 + use http::{Response as HttpResponse, StatusCode}; 138 + use bytes::Bytes; 139 + use jacquard_common::http_client::HttpClient; 140 + use jacquard_identity::resolver::IdentityResolver; 141 + use std::sync::Arc; 142 + use tokio::sync::Mutex; 143 + 144 + #[derive(Clone, Default)] 145 + struct MockClient { 146 + resp: Arc<Mutex<Option<HttpResponse<Vec<u8>>>>>, 147 + } 148 + 149 + impl HttpClient for MockClient { 150 + type Error = std::convert::Infallible; 151 + fn send_http( 152 + &self, 153 + _request: http::Request<Vec<u8>>, 154 + ) -> impl core::future::Future< 155 + Output = core::result::Result<http::Response<Vec<u8>>, Self::Error>, 156 + > + Send { 157 + let resp = self.resp.clone(); 158 + async move { Ok(resp.lock().await.take().unwrap()) } 159 + } 160 + } 161 + 162 + // IdentityResolver methods won't be called in these tests; provide stubs. 163 + #[async_trait::async_trait] 164 + impl IdentityResolver for MockClient { 165 + fn options(&self) -> &jacquard_identity::resolver::ResolverOptions { 166 + use std::sync::LazyLock; 167 + static OPTS: LazyLock<jacquard_identity::resolver::ResolverOptions> = 168 + LazyLock::new(|| jacquard_identity::resolver::ResolverOptions::default()); 169 + &OPTS 170 + } 171 + async fn resolve_handle( 172 + &self, 173 + _handle: &jacquard_common::types::string::Handle<'_>, 174 + ) -> std::result::Result< 175 + jacquard_common::types::string::Did<'static>, 176 + jacquard_identity::resolver::IdentityError, 177 + > { 178 + Ok(jacquard_common::types::string::Did::new_static("did:plc:alice").unwrap()) 179 + } 180 + async fn resolve_did_doc( 181 + &self, 182 + _did: &jacquard_common::types::string::Did<'_>, 183 + ) -> std::result::Result< 184 + jacquard_identity::resolver::DidDocResponse, 185 + jacquard_identity::resolver::IdentityError, 186 + > { 187 + let doc = serde_json::json!({ 188 + "id": "did:plc:alice", 189 + "service": [{ 190 + "id": "#pds", 191 + "type": "AtprotoPersonalDataServer", 192 + "serviceEndpoint": "https://pds" 193 + }] 194 + }); 195 + let buf = Bytes::from(serde_json::to_vec(&doc).unwrap()); 196 + Ok(jacquard_identity::resolver::DidDocResponse { 197 + buffer: buf, 198 + status: StatusCode::OK, 199 + requested: None, 200 + }) 201 + } 202 + } 203 + 204 + // Allow using DPoP helpers on MockClient 205 + impl crate::dpop::DpopExt for MockClient {} 206 + impl crate::resolver::OAuthResolver for MockClient {} 207 + 208 + fn base_metadata() -> OAuthMetadata { 209 + let mut server = OAuthAuthorizationServerMetadata::default(); 210 + server.issuer = CowStr::from("https://issuer"); 211 + server.authorization_endpoint = CowStr::from("https://issuer/authorize"); 212 + server.token_endpoint = CowStr::from("https://issuer/token"); 213 + OAuthMetadata { 214 + server_metadata: server, 215 + client_metadata: OAuthClientMetadata { 216 + client_id: url::Url::parse("https://client").unwrap(), 217 + client_uri: None, 218 + redirect_uris: vec![url::Url::parse("https://client/cb").unwrap()], 219 + scope: Some(CowStr::from("atproto")), 220 + grant_types: None, 221 + token_endpoint_auth_method: Some(CowStr::from("none")), 222 + dpop_bound_access_tokens: None, 223 + jwks_uri: None, 224 + jwks: None, 225 + token_endpoint_auth_signing_alg: None, 226 + }, 227 + keyset: None, 228 + } 229 + } 230 + 231 + #[tokio::test] 232 + async fn par_missing_endpoint() { 233 + let mut meta = base_metadata(); 234 + meta.server_metadata.require_pushed_authorization_requests = Some(true); 235 + meta.server_metadata.pushed_authorization_request_endpoint = None; 236 + // require_pushed_authorization_requests is true and no endpoint 237 + let err = super::par(&MockClient::default(), None, None, &meta) 238 + .await 239 + .unwrap_err(); 240 + match err { 241 + RequestError::NoEndpoint(name) => { 242 + assert_eq!(name.as_ref(), "pushed_authorization_request"); 243 + } 244 + other => panic!("unexpected: {other:?}"), 245 + } 246 + } 247 + 248 + #[tokio::test] 249 + async fn refresh_no_refresh_token() { 250 + let client = MockClient::default(); 251 + let meta = base_metadata(); 252 + let mut session = ClientSessionData { 253 + account_did: jacquard_common::types::string::Did::new_static("did:plc:alice").unwrap(), 254 + session_id: CowStr::from("state"), 255 + host_url: url::Url::parse("https://pds").unwrap(), 256 + authserver_url: url::Url::parse("https://issuer").unwrap(), 257 + authserver_token_endpoint: CowStr::from("https://issuer/token"), 258 + authserver_revocation_endpoint: None, 259 + scopes: vec![], 260 + dpop_data: DpopClientData { 261 + dpop_key: crate::utils::generate_key(&[CowStr::from("ES256")]).unwrap(), 262 + dpop_authserver_nonce: CowStr::from(""), 263 + dpop_host_nonce: CowStr::from(""), 264 + }, 265 + token_set: crate::types::TokenSet { 266 + iss: CowStr::from("https://issuer"), 267 + sub: jacquard_common::types::string::Did::new_static("did:plc:alice").unwrap(), 268 + aud: CowStr::from("https://pds"), 269 + scope: None, 270 + refresh_token: None, 271 + access_token: CowStr::from("abc"), 272 + token_type: crate::types::OAuthTokenType::DPoP, 273 + expires_at: None, 274 + }, 275 + }; 276 + let err = super::refresh(&client, session, &meta).await.unwrap_err(); 277 + matches!(err, RequestError::NoRefreshToken); 278 + } 279 + 280 + #[tokio::test] 281 + async fn exchange_code_missing_sub() { 282 + let client = MockClient::default(); 283 + // set mock HTTP response body: token response without `sub` 284 + *client.resp.lock().await = Some( 285 + HttpResponse::builder() 286 + .status(StatusCode::OK) 287 + .body(serde_json::to_vec(&serde_json::json!({ 288 + "access_token":"tok", 289 + "token_type":"DPoP", 290 + "expires_in": 3600 291 + })).unwrap()) 292 + .unwrap(), 293 + ); 294 + let meta = base_metadata(); 295 + let mut dpop = DpopReqData { 296 + dpop_key: crate::utils::generate_key(&[CowStr::from("ES256")]).unwrap(), 297 + dpop_authserver_nonce: None, 298 + }; 299 + let err = super::exchange_code(&client, &mut dpop, "abc", "verifier", &meta) 300 + .await 301 + .unwrap_err(); 302 + matches!(err, RequestError::TokenVerification); 303 + } 304 + } 305 + 116 306 #[derive(Debug, Serialize)] 117 307 pub struct RequestPayload<'a, T> 118 308 where ··· 162 352 let (code_challenge, verifier) = generate_pkce(); 163 353 164 354 let Some(dpop_key) = generate_dpop_key(&metadata.server_metadata) else { 165 - return Err(Error::Token("none of the algorithms worked".into())); 355 + return Err(RequestError::TokenVerification); 166 356 }; 167 357 let mut dpop_data = DpopReqData { 168 358 dpop_key, ··· 218 408 .require_pushed_authorization_requests 219 409 == Some(true) 220 410 { 221 - Err(Error::NoEndpoint(CowStr::new_static( 222 - "server requires PAR but no endpoint is available", 411 + Err(RequestError::NoEndpoint(CowStr::new_static( 412 + "pushed_authorization_request", 223 413 ))) 224 414 } else { 225 415 todo!("use of PAR is mandatory") ··· 235 425 T: OAuthResolver + DpopExt + Send + Sync + 'static, 236 426 { 237 427 let Some(refresh_token) = session_data.token_set.refresh_token.as_ref() else { 238 - return Err(Error::TokenRefresh); 428 + return Err(RequestError::NoRefreshToken); 239 429 }; 240 430 241 431 // /!\ IMPORTANT /!\ ··· 312 502 ) 313 503 .await?; 314 504 let Some(sub) = token_response.sub else { 315 - return Err(Error::Token("missing `sub` in token response".into())); 505 + return Err(RequestError::TokenVerification); 316 506 }; 317 507 let sub = Did::new_owned(sub)?; 318 508 let iss = metadata.server_metadata.issuer.clone(); ··· 376 566 D: DpopDataSource, 377 567 { 378 568 let Some(url) = endpoint_for_req(&metadata.server_metadata, &request) else { 379 - return Err(Error::NoEndpoint(request.name())); 569 + return Err(RequestError::NoEndpoint(request.name())); 380 570 }; 381 571 let client_assertions = build_auth( 382 572 metadata.keyset.as_ref(), ··· 401 591 .dpop_server_call(data_source) 402 592 .send(req) 403 593 .await 404 - .map_err(Error::DpopClient)?; 594 + .map_err(RequestError::DpopClient)?; 405 595 if res.status() == request.expected_status() { 406 596 let body = res.body(); 407 597 if body.is_empty() { ··· 412 602 Ok(output) 413 603 } 414 604 } else if res.status().is_client_error() { 415 - Err(Error::HttpStatusWithBody( 605 + Err(RequestError::HttpStatusWithBody( 416 606 res.status(), 417 607 serde_json::from_slice(res.body())?, 418 608 )) 419 609 } else { 420 - Err(Error::HttpStatus(res.status())) 610 + Err(RequestError::HttpStatus(res.status())) 421 611 } 422 612 } 423 613 ··· 528 718 } 529 719 } 530 720 531 - Err(Error::UnsupportedAuthMethod) 721 + Err(RequestError::UnsupportedAuthMethod) 532 722 }
+129 -17
crates/jacquard-oauth/src/resolver.rs
··· 1 1 use crate::types::{OAuthAuthorizationServerMetadata, OAuthProtectedResourceMetadata}; 2 2 use http::{Request, StatusCode}; 3 - use jacquard_common::IntoStatic; 3 + use jacquard_common::{IntoStatic, error::TransportError}; 4 4 use jacquard_common::types::did_doc::DidDocument; 5 5 use jacquard_common::types::ident::AtIdentifier; 6 6 use jacquard_common::{http_client::HttpClient, types::did::Did}; 7 7 use jacquard_identity::resolver::{IdentityError, IdentityResolver}; 8 8 use url::Url; 9 9 10 + /// Compare two issuer strings strictly but without spuriously failing on trivial differences. 11 + /// 12 + /// Rules: 13 + /// - Schemes must match exactly. 14 + /// - Hostnames and effective ports must match (treat missing port the same as default port). 15 + /// - Path must match, except that an empty path and `/` are equivalent. 16 + /// - Query/fragment are not considered; if present on either side, the comparison fails. 17 + pub(crate) fn issuer_equivalent(a: &str, b: &str) -> bool { 18 + fn normalize(url: &Url) -> Option<(String, String, u16, String)> { 19 + if url.query().is_some() || url.fragment().is_some() { 20 + return None; 21 + } 22 + let scheme = url.scheme().to_string(); 23 + let host = url.host_str()?.to_string(); 24 + let port = url.port_or_known_default()?; 25 + let path = match url.path() { 26 + "" => "/".to_string(), 27 + "/" => "/".to_string(), 28 + other => other.to_string(), 29 + }; 30 + Some((scheme, host, port, path)) 31 + } 32 + 33 + match (Url::parse(a), Url::parse(b)) { 34 + (Ok(ua), Ok(ub)) => match (normalize(&ua), normalize(&ub)) { 35 + (Some((sa, ha, pa, pa_path)), Some((sb, hb, pb, pb_path))) => { 36 + if sa != sb || ha != hb || pa != pb { 37 + return false; 38 + } 39 + if pa_path == "/" && pb_path == "/" { 40 + return true; 41 + } 42 + pa_path == pb_path 43 + } 44 + _ => false, 45 + }, 46 + _ => a == b, 47 + } 48 + } 49 + 10 50 #[derive(thiserror::Error, Debug, miette::Diagnostic)] 11 51 pub enum ResolverError { 12 52 #[error("resource not found")] 53 + #[diagnostic(code(jacquard_oauth::resolver::not_found), help("check the base URL or identifier"))] 13 54 NotFound, 14 55 #[error("invalid at identifier: {0}")] 56 + #[diagnostic(code(jacquard_oauth::resolver::at_identifier), help("ensure a valid handle or DID was provided"))] 15 57 AtIdentifier(String), 16 58 #[error("invalid did: {0}")] 59 + #[diagnostic(code(jacquard_oauth::resolver::did), help("ensure DID is correctly formed (did:plc or did:web)"))] 17 60 Did(String), 18 61 #[error("invalid did document: {0}")] 62 + #[diagnostic(code(jacquard_oauth::resolver::did_document), help("verify the DID document structure and service entries"))] 19 63 DidDocument(String), 20 64 #[error("protected resource metadata is invalid: {0}")] 65 + #[diagnostic(code(jacquard_oauth::resolver::protected_resource_metadata), help("PDS must advertise an authorization server in its protected resource metadata"))] 21 66 ProtectedResourceMetadata(String), 22 67 #[error("authorization server metadata is invalid: {0}")] 68 + #[diagnostic(code(jacquard_oauth::resolver::authorization_server_metadata), help("issuer must match and include the PDS resource"))] 23 69 AuthorizationServerMetadata(String), 24 70 #[error("error resolving identity: {0}")] 71 + #[diagnostic(code(jacquard_oauth::resolver::identity))] 25 72 IdentityResolverError(#[from] IdentityError), 26 73 #[error("unsupported did method: {0:?}")] 74 + #[diagnostic(code(jacquard_oauth::resolver::unsupported_did_method), help("supported DID methods: did:web, did:plc"))] 27 75 UnsupportedDidMethod(Did<'static>), 28 76 #[error(transparent)] 29 - Http(#[from] http::Error), 30 - #[error("http client error: {0}")] 31 - HttpClient(Box<dyn std::error::Error + Send + Sync + 'static>), 77 + #[diagnostic(code(jacquard_oauth::resolver::transport))] 78 + Transport(#[from] TransportError), 32 79 #[error("http status: {0:?}")] 80 + #[diagnostic(code(jacquard_oauth::resolver::http_status), help("check well-known paths and server configuration"))] 33 81 HttpStatus(StatusCode), 34 82 #[error(transparent)] 83 + #[diagnostic(code(jacquard_oauth::resolver::serde_json))] 35 84 SerdeJson(#[from] serde_json::Error), 36 85 #[error(transparent)] 86 + #[diagnostic(code(jacquard_oauth::resolver::serde_form))] 37 87 SerdeHtmlForm(#[from] serde_html_form::ser::Error), 38 88 #[error(transparent)] 89 + #[diagnostic(code(jacquard_oauth::resolver::url))] 39 90 Uri(#[from] url::ParseError), 40 91 } 41 92 ··· 47 98 sub: &Did<'_>, 48 99 ) -> Result<Url, ResolverError> { 49 100 let (metadata, identity) = self.resolve_from_identity(sub).await?; 50 - if metadata.issuer != server_metadata.issuer { 51 - return Err(ResolverError::Did(format!("DIDs did not match"))); 101 + if !issuer_equivalent(&metadata.issuer, &server_metadata.issuer) { 102 + return Err(ResolverError::AuthorizationServerMetadata( 103 + "issuer mismatch".to_string(), 104 + )); 52 105 } 53 106 Ok(identity 54 107 .pds_endpoint() ··· 110 163 &self, 111 164 issuer: &Url, 112 165 ) -> Result<OAuthAuthorizationServerMetadata<'static>, ResolverError> { 113 - Ok(resolve_authorization_server(self, issuer).await?) 166 + let mut md = resolve_authorization_server(self, issuer).await?; 167 + // Normalize issuer string to the input URL representation to avoid slash quirks 168 + md.issuer = jacquard_common::CowStr::from(issuer.as_str()).into_static(); 169 + Ok(md) 114 170 } 115 171 async fn get_resource_server_metadata( 116 172 &self, ··· 166 222 ) -> Result<OAuthAuthorizationServerMetadata<'static>, ResolverError> { 167 223 let url = server 168 224 .join("/.well-known/oauth-authorization-server") 169 - .map_err(|e| ResolverError::HttpClient(e.into()))?; 225 + .map_err(|e| ResolverError::Transport(TransportError::Other(Box::new(e))))?; 170 226 171 227 let req = Request::builder() 172 228 .uri(url.to_string()) 173 229 .body(Vec::new()) 174 - .map_err(|e| ResolverError::HttpClient(e.into()))?; 230 + .map_err(|e| ResolverError::Transport(TransportError::InvalidRequest(e.to_string())))?; 175 231 let res = client 176 232 .send_http(req) 177 233 .await 178 - .map_err(|e| ResolverError::HttpClient(e.into()))?; 234 + .map_err(|e| ResolverError::Transport(TransportError::Other(Box::new(e))))?; 179 235 if res.status() == StatusCode::OK { 180 - let metadata = serde_json::from_slice::<OAuthAuthorizationServerMetadata>(res.body()) 236 + let mut metadata = serde_json::from_slice::<OAuthAuthorizationServerMetadata>(res.body()) 181 237 .map_err(ResolverError::SerdeJson)?; 182 238 // https://datatracker.ietf.org/doc/html/rfc8414#section-3.3 183 - if metadata.issuer == server.as_str() { 239 + // Accept semantically equivalent issuer (normalize to the requested URL form) 240 + if issuer_equivalent(&metadata.issuer, server.as_str()) { 241 + metadata.issuer = server.as_str().into(); 184 242 Ok(metadata.into_static()) 185 243 } else { 186 244 Err(ResolverError::AuthorizationServerMetadata(format!( ··· 199 257 ) -> Result<OAuthProtectedResourceMetadata<'static>, ResolverError> { 200 258 let url = server 201 259 .join("/.well-known/oauth-protected-resource") 202 - .map_err(|e| ResolverError::HttpClient(e.into()))?; 260 + .map_err(|e| ResolverError::Transport(TransportError::Other(Box::new(e))))?; 203 261 204 262 let req = Request::builder() 205 263 .uri(url.to_string()) 206 264 .body(Vec::new()) 207 - .map_err(|e| ResolverError::HttpClient(e.into()))?; 265 + .map_err(|e| ResolverError::Transport(TransportError::InvalidRequest(e.to_string())))?; 208 266 let res = client 209 267 .send_http(req) 210 268 .await 211 - .map_err(|e| ResolverError::HttpClient(e.into()))?; 269 + .map_err(|e| ResolverError::Transport(TransportError::Other(Box::new(e))))?; 212 270 if res.status() == StatusCode::OK { 213 - let metadata = serde_json::from_slice::<OAuthProtectedResourceMetadata>(res.body()) 271 + let mut metadata = serde_json::from_slice::<OAuthProtectedResourceMetadata>(res.body()) 214 272 .map_err(ResolverError::SerdeJson)?; 215 273 // https://datatracker.ietf.org/doc/html/rfc8414#section-3.3 216 - if metadata.resource == server.as_str() { 274 + // Accept semantically equivalent resource URL (normalize to the requested URL form) 275 + if issuer_equivalent(&metadata.resource, server.as_str()) { 276 + metadata.resource = server.as_str().into(); 217 277 Ok(metadata.into_static()) 218 278 } else { 219 279 Err(ResolverError::AuthorizationServerMetadata(format!( ··· 228 288 229 289 #[async_trait::async_trait] 230 290 impl OAuthResolver for jacquard_identity::JacquardResolver {} 291 + 292 + #[cfg(test)] 293 + mod tests { 294 + use super::*; 295 + use http::{Request as HttpRequest, Response as HttpResponse, StatusCode}; 296 + use jacquard_common::http_client::HttpClient; 297 + 298 + #[derive(Default, Clone)] 299 + struct MockHttp { 300 + next: std::sync::Arc<tokio::sync::Mutex<Option<HttpResponse<Vec<u8>>>>>, 301 + } 302 + 303 + impl HttpClient for MockHttp { 304 + type Error = std::convert::Infallible; 305 + fn send_http( 306 + &self, 307 + _request: HttpRequest<Vec<u8>>, 308 + ) -> impl core::future::Future< 309 + Output = core::result::Result<HttpResponse<Vec<u8>>, Self::Error>, 310 + > + Send { 311 + let next = self.next.clone(); 312 + async move { Ok(next.lock().await.take().unwrap()) } 313 + } 314 + } 315 + 316 + #[tokio::test] 317 + async fn authorization_server_http_status() { 318 + let client = MockHttp::default(); 319 + *client.next.lock().await = Some(HttpResponse::builder().status(StatusCode::NOT_FOUND).body(Vec::new()).unwrap()); 320 + let issuer = url::Url::parse("https://issuer").unwrap(); 321 + let err = super::resolve_authorization_server(&client, &issuer).await.unwrap_err(); 322 + matches!(err, ResolverError::HttpStatus(StatusCode::NOT_FOUND)); 323 + } 324 + 325 + #[tokio::test] 326 + async fn authorization_server_bad_json() { 327 + let client = MockHttp::default(); 328 + *client.next.lock().await = Some(HttpResponse::builder().status(StatusCode::OK).body(b"{not json}".to_vec()).unwrap()); 329 + let issuer = url::Url::parse("https://issuer").unwrap(); 330 + let err = super::resolve_authorization_server(&client, &issuer).await.unwrap_err(); 331 + matches!(err, ResolverError::SerdeJson(_)); 332 + } 333 + 334 + #[test] 335 + fn issuer_equivalence_rules() { 336 + assert!(super::issuer_equivalent("https://issuer", "https://issuer/")); 337 + assert!(super::issuer_equivalent("https://issuer:443/", "https://issuer/")); 338 + assert!(!super::issuer_equivalent("http://issuer/", "https://issuer/")); 339 + assert!(!super::issuer_equivalent("https://issuer/foo", "https://issuer/")); 340 + assert!(!super::issuer_equivalent("https://issuer/?q=1", "https://issuer/")); 341 + } 342 + }
+6 -3
crates/jacquard-oauth/src/session.rs
··· 263 263 server_metadata: client 264 264 .get_authorization_server_metadata(&self.session_data.authserver_url) 265 265 .await 266 - .map_err(|e| Error::ServerAgent(crate::request::Error::ResolverError(e)))?, 266 + .map_err(|e| Error::ServerAgent(crate::request::RequestError::ResolverError(e)))?, 267 267 client_metadata: atproto_client_metadata(self.config.clone(), &self.keyset) 268 268 .unwrap() 269 269 .into_static(), ··· 272 272 } 273 273 } 274 274 275 - #[derive(thiserror::Error, Debug)] 275 + #[derive(thiserror::Error, Debug, miette::Diagnostic)] 276 276 pub enum Error { 277 277 #[error(transparent)] 278 - ServerAgent(#[from] crate::request::Error), 278 + #[diagnostic(code(jacquard_oauth::session::request))] 279 + ServerAgent(#[from] crate::request::RequestError), 279 280 #[error(transparent)] 281 + #[diagnostic(code(jacquard_oauth::session::storage))] 280 282 Store(#[from] SessionStoreError), 281 283 #[error("session does not exist")] 284 + #[diagnostic(code(jacquard_oauth::session::not_found))] 282 285 SessionNotFound, 283 286 } 284 287
+7 -3
crates/jacquard/src/client.rs
··· 3 3 //! This module provides HTTP and XRPC client traits along with an authenticated 4 4 //! client implementation that manages session tokens. 5 5 6 + /// Stateful session client for app‑password auth with auto‑refresh. 6 7 pub mod credential_session; 8 + /// Token storage and on‑disk formats shared across app‑password and OAuth. 7 9 pub mod token; 8 10 9 11 use core::future::Future; ··· 72 74 } 73 75 } 74 76 75 - /// A unified indicator for the type of authenticated session. 77 + /// Identifies the active authentication mode for an agent/session. 76 78 #[derive(Debug, Clone, Copy, PartialEq, Eq)] 77 79 pub enum AgentKind { 78 80 /// App password (Bearer) session ··· 82 84 } 83 85 84 86 /// Common interface for stateful sessions used by the Agent wrapper. 87 + /// 88 + /// Implemented by `CredentialSession` (app‑password) and `OAuthSession` (DPoP). 85 89 pub trait AgentSession: XrpcClient + HttpClient + Send + Sync { 86 90 /// Identify the kind of session. 87 91 fn session_kind(&self) -> AgentKind; ··· 188 192 } 189 193 } 190 194 191 - /// Thin wrapper that erases the concrete session type while preserving type-safety. 195 + /// Thin wrapper over a stateful session providing a uniform `XrpcClient`. 192 196 pub struct Agent<A: AgentSession> { 193 197 inner: A, 194 198 } ··· 214 218 self.inner.endpoint().await 215 219 } 216 220 217 - /// Override call options. 221 + /// Override call options for subsequent requests. 218 222 pub async fn set_options(&self, opts: CallOptions<'_>) { 219 223 self.inner.set_options(opts).await 220 224 }
+41 -4
crates/jacquard/src/client/credential_session.rs
··· 18 18 use jacquard_identity::resolver::IdentityResolver; 19 19 use std::any::Any; 20 20 21 + /// Storage key for app‑password sessions: `(account DID, session id)`. 21 22 pub type SessionKey = (Did<'static>, CowStr<'static>); 22 23 24 + /// Stateful client for app‑password based sessions. 25 + /// 26 + /// - Persists sessions via a pluggable `SessionStore`. 27 + /// - Automatically refreshes on token expiry. 28 + /// - Tracks a base endpoint, defaulting to the public appview until login/restore. 23 29 pub struct CredentialSession<S, T> 24 30 where 25 31 S: SessionStore<SessionKey, AtpSession>, 26 32 { 27 33 store: Arc<S>, 28 34 client: Arc<T>, 35 + /// Default call options applied to each request (auth/headers/labelers). 29 36 pub options: RwLock<CallOptions<'static>>, 37 + /// Active session key, if any. 30 38 pub key: RwLock<Option<SessionKey>>, 39 + /// Current base endpoint (PDS); defaults to public appview when unset. 31 40 pub endpoint: RwLock<Option<Url>>, 32 41 } 33 42 ··· 35 44 where 36 45 S: SessionStore<SessionKey, AtpSession>, 37 46 { 47 + /// Create a new credential session using the given store and client. 38 48 pub fn new(store: Arc<S>, client: Arc<T>) -> Self { 39 49 Self { 40 50 store, ··· 50 60 where 51 61 S: SessionStore<SessionKey, AtpSession>, 52 62 { 63 + /// Return a copy configured with the provided default call options. 53 64 pub fn with_options(self, options: CallOptions<'_>) -> Self { 54 65 Self { 55 66 client: self.client, ··· 60 71 } 61 72 } 62 73 74 + /// Replace default call options. 63 75 pub async fn set_options(&self, options: CallOptions<'_>) { 64 76 *self.options.write().await = options.into_static(); 65 77 } 66 78 79 + /// Get the active session key (account DID and session id), if any. 67 80 pub async fn session_info(&self) -> Option<SessionKey> { 68 81 self.key.read().await.clone() 69 82 } 70 83 84 + /// Current base endpoint. Defaults to the public appview when unset. 71 85 pub async fn endpoint(&self) -> Url { 72 86 self.endpoint.read().await.clone().unwrap_or( 73 87 Url::parse("https://public.bsky.app").expect("public appview should be valid url"), 74 88 ) 75 89 } 76 90 91 + /// Override the current base endpoint. 77 92 pub async fn set_endpoint(&self, endpoint: Url) { 78 93 *self.endpoint.write().await = Some(endpoint); 79 94 } 80 95 96 + /// Current access token (Bearer), if logged in. 81 97 pub async fn access_token(&self) -> Option<AuthorizationToken<'_>> { 82 98 let key = self.key.read().await.clone()?; 83 99 let session = self.store.get(&key).await; 84 100 session.map(|session| AuthorizationToken::Bearer(session.access_jwt)) 85 101 } 86 102 103 + /// Current refresh token (Bearer), if logged in. 87 104 pub async fn refresh_token(&self) -> Option<AuthorizationToken<'_>> { 88 105 let key = self.key.read().await.clone()?; 89 106 let session = self.store.get(&key).await; ··· 96 113 S: SessionStore<SessionKey, AtpSession>, 97 114 T: HttpClient, 98 115 { 116 + /// Refresh the active session by calling `com.atproto.server.refreshSession`. 99 117 pub async fn refresh(&self) -> Result<AuthorizationToken<'_>, ClientError> { 100 118 let key = self.key.read().await.clone().ok_or(ClientError::Auth( 101 119 jacquard_common::error::AuthError::NotAuthenticated, ··· 134 152 /// 135 153 /// - `identifier`: handle (preferred), DID, or `https://` PDS base URL. 136 154 /// - `session_id`: optional session label; defaults to "session". 155 + /// - Persists and activates the session, and updates the base endpoint to the user's PDS. 137 156 pub async fn login( 138 157 &self, 139 158 identifier: CowStr<'_>, ··· 298 317 Ok(()) 299 318 } 300 319 301 - /// Switch to a different stored session (and refresh endpoint from DID). 320 + /// Switch to a different stored session (and refresh endpoint/PDS). 302 321 pub async fn switch_session( 303 322 &self, 304 323 did: Did<'_>, ··· 380 399 T: HttpClient + XrpcExt + Send + Sync + 'static, 381 400 { 382 401 fn base_uri(&self) -> Url { 383 - self.endpoint.blocking_read().clone().unwrap_or( 384 - Url::parse("https://public.bsky.app").expect("public appview should be valid url"), 385 - ) 402 + // base_uri is a synchronous trait method; avoid `.await` here. 403 + // Under Tokio, use `block_in_place` to make a blocking RwLock read safe. 404 + if tokio::runtime::Handle::try_current().is_ok() { 405 + tokio::task::block_in_place(|| { 406 + self.endpoint 407 + .blocking_read() 408 + .clone() 409 + .unwrap_or( 410 + Url::parse("https://public.bsky.app") 411 + .expect("public appview should be valid url"), 412 + ) 413 + }) 414 + } else { 415 + self.endpoint 416 + .blocking_read() 417 + .clone() 418 + .unwrap_or( 419 + Url::parse("https://public.bsky.app") 420 + .expect("public appview should be valid url"), 421 + ) 422 + } 386 423 } 387 424 async fn send<R: jacquard_common::types::xrpc::XrpcRequest + Send>( 388 425 self,
+95 -25
crates/jacquard/src/client/token.rs
··· 10 10 use serde_json::Value; 11 11 use url::Url; 12 12 13 + /// On-disk session records for app-password and OAuth flows, sharing a single JSON map. 13 14 #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] 14 15 pub enum StoredSession { 16 + /// App-password session 15 17 Atp(StoredAtSession), 18 + /// OAuth client session 16 19 OAuth(OAuthSession), 20 + /// OAuth authorization request state 17 21 OAuthState(OAuthState), 18 22 } 19 23 24 + /// Minimal persisted representation of an app‑password session. 20 25 #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] 21 26 pub struct StoredAtSession { 27 + /// Access token (JWT) 22 28 access_jwt: String, 29 + /// Refresh token (JWT) 23 30 refresh_jwt: String, 31 + /// Account DID 24 32 did: String, 33 + /// Optional PDS endpoint for faster resume 25 34 #[serde(skip_serializing_if = "std::option::Option::is_none")] 26 35 pds: Option<String>, 36 + /// Session id label (e.g., "session") 27 37 session_id: String, 38 + /// Last known handle 28 39 handle: String, 29 40 } 30 41 42 + /// Persisted OAuth client session (on-disk format). 31 43 #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] 32 44 pub struct OAuthSession { 45 + /// Account DID 33 46 account_did: String, 47 + /// Client-generated session id (usually auth `state`) 34 48 session_id: String, 35 49 36 - // Base URL of the "resource server" (eg, PDS). Should include scheme, hostname, port; no path or auth info. 50 + /// Base URL of the resource server (PDS) 37 51 host_url: Url, 38 52 39 - // Base URL of the "auth server" (eg, PDS or entryway). Should include scheme, hostname, port; no path or auth info. 53 + /// Base URL of the authorization server (PDS or entryway) 40 54 authserver_url: Url, 41 55 42 - // Full token endpoint 56 + /// Full token endpoint URL 43 57 authserver_token_endpoint: String, 44 58 45 - // Full revocation endpoint, if it exists 59 + /// Full revocation endpoint URL, if available 46 60 #[serde(skip_serializing_if = "std::option::Option::is_none")] 47 61 authserver_revocation_endpoint: Option<String>, 48 62 49 - // The set of scopes approved for this session (returned in the initial token request) 63 + /// Granted scopes 50 64 scopes: Vec<String>, 51 65 66 + /// Client DPoP key material 52 67 pub dpop_key: Key, 53 - // Current auth server DPoP nonce 68 + /// Current auth server DPoP nonce 54 69 pub dpop_authserver_nonce: String, 55 - // Current host ("resource server", eg PDS) DPoP nonce 70 + /// Current resource server (PDS) DPoP nonce 56 71 pub dpop_host_nonce: String, 57 72 73 + /// Token response issuer 58 74 pub iss: String, 75 + /// Token subject (DID) 59 76 pub sub: String, 77 + /// Token audience (verified PDS URL) 60 78 pub aud: String, 79 + /// Token scopes (raw) if provided 61 80 pub scope: Option<String>, 62 81 82 + /// Refresh token 63 83 pub refresh_token: Option<String>, 84 + /// Access token 64 85 pub access_token: String, 86 + /// Token type (e.g., DPoP) 65 87 pub token_type: OAuthTokenType, 66 88 89 + /// Expiration timestamp 67 90 pub expires_at: Option<Datetime>, 68 91 } 69 92 ··· 130 153 } 131 154 } 132 155 156 + /// Persisted OAuth authorization request state. 133 157 #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] 134 158 pub struct OAuthState { 135 - // The random identifier generated by the client for the auth request flow. Can be used as "primary key" for storing and retrieving this information. 159 + /// Random identifier generated for the authorization flow (`state`) 136 160 pub state: String, 137 161 138 - // URL of the auth server (eg, PDS or entryway) 162 + /// Base URL of the authorization server (PDS or entryway) 139 163 pub authserver_url: Url, 140 164 141 - // If the flow started with an account identifier (DID or handle), it should be persisted, to verify against the initial token response. 165 + /// Optional pre-known account DID 142 166 #[serde(skip_serializing_if = "std::option::Option::is_none")] 143 167 pub account_did: Option<String>, 144 168 145 - // OAuth scope strings 169 + /// Requested scopes 146 170 pub scopes: Vec<String>, 147 171 148 - // unique token in URI format, which will be used by the client in the auth flow redirect 172 + /// Request URI for the authorization step 149 173 pub request_uri: String, 150 174 151 - // Full token endpoint URL 175 + /// Full token endpoint URL 152 176 pub authserver_token_endpoint: String, 153 177 154 - // Full revocation endpoint, if it exists 178 + /// Full revocation endpoint URL, if available 155 179 #[serde(skip_serializing_if = "std::option::Option::is_none")] 156 180 pub authserver_revocation_endpoint: Option<String>, 157 181 158 - // The secret token/nonce which a code challenge was generated from 182 + /// PKCE verifier 159 183 pub pkce_verifier: String, 160 184 185 + /// Client DPoP key material 161 186 pub dpop_key: Key, 162 - // Current auth server DPoP nonce 187 + /// Auth server DPoP nonce at PAR time 163 188 #[serde(skip_serializing_if = "std::option::Option::is_none")] 164 189 pub dpop_authserver_nonce: Option<String>, 165 190 } ··· 211 236 } 212 237 } 213 238 239 + /// Convenience wrapper over `FileTokenStore` offering unified storage across auth modes. 214 240 pub struct FileAuthStore(FileTokenStore); 215 241 216 242 impl FileAuthStore { ··· 326 352 let mut store: Value = serde_json::from_str(&file)?; 327 353 if let Some(map) = store.as_object_mut() { 328 354 if let Some(value) = map.get_mut(&key_str) { 329 - if let Some(obj) = value.as_object_mut() { 330 - obj.insert( 331 - "pds".to_string(), 332 - serde_json::Value::String(pds.to_string()), 333 - ); 334 - std::fs::write(&self.0.path, serde_json::to_string_pretty(&store)?)?; 335 - return Ok(()); 355 + if let Some(outer) = value.as_object_mut() { 356 + if let Some(inner) = outer.get_mut("Atp").and_then(|v| v.as_object_mut()) { 357 + inner.insert( 358 + "pds".to_string(), 359 + serde_json::Value::String(pds.to_string()), 360 + ); 361 + std::fs::write(&self.0.path, serde_json::to_string_pretty(&store)?)?; 362 + return Ok(()); 363 + } 336 364 } 337 365 } 338 366 } ··· 349 377 let store: Value = serde_json::from_str(&file)?; 350 378 if let Some(value) = store.get(&key_str) { 351 379 if let Some(obj) = value.as_object() { 352 - if let Some(serde_json::Value::String(pds)) = obj.get("pds") { 353 - return Ok(Url::parse(pds).ok()); 380 + if let Some(serde_json::Value::Object(inner)) = obj.get("Atp") { 381 + if let Some(serde_json::Value::String(pds)) = inner.get("pds") { 382 + return Ok(Url::parse(pds).ok()); 383 + } 354 384 } 355 385 } 356 386 } ··· 418 448 } 419 449 } 420 450 } 451 + 452 + #[cfg(test)] 453 + mod tests { 454 + use super::*; 455 + use crate::client::credential_session::SessionKey; 456 + use crate::client::AtpSession; 457 + use jacquard_common::types::string::{Did, Handle}; 458 + use std::fs; 459 + use std::path::PathBuf; 460 + 461 + fn temp_file() -> PathBuf { 462 + let mut p = std::env::temp_dir(); 463 + p.push(format!("jacquard-test-{}.json", std::process::id())); 464 + p 465 + } 466 + 467 + #[tokio::test] 468 + async fn file_auth_store_roundtrip_atp() { 469 + let path = temp_file(); 470 + // initialize empty store file 471 + fs::write(&path, "{}").unwrap(); 472 + let store = FileAuthStore::new(&path); 473 + let session = AtpSession { 474 + access_jwt: "a".into(), 475 + refresh_jwt: "r".into(), 476 + did: Did::new_static("did:plc:alice").unwrap(), 477 + handle: Handle::new_static("alice.bsky.social").unwrap(), 478 + }; 479 + let key: SessionKey = (session.did.clone(), "session".into()); 480 + jacquard_common::session::SessionStore::set(&store, key.clone(), session.clone()) 481 + .await 482 + .unwrap(); 483 + let restored = jacquard_common::session::SessionStore::get(&store, &key) 484 + .await 485 + .unwrap(); 486 + assert_eq!(restored.access_jwt.as_ref(), "a"); 487 + // clean up 488 + let _ = fs::remove_file(&path); 489 + } 490 + }
+4 -4
crates/jacquard/src/lib.rs
··· 27 27 //! use jacquard::client::credential_session::{CredentialSession, SessionKey}; 28 28 //! use jacquard::client::{AtpSession, FileAuthStore, MemorySessionStore}; 29 29 //! use jacquard::identity::PublicResolver as JacquardResolver; 30 + //! use jacquard::types::xrpc::XrpcClient; 30 31 //! # use miette::IntoDiagnostic; 31 32 //! 32 33 //! # #[derive(Parser, Debug)] ··· 57 58 //! .into_diagnostic()?; 58 59 //! // Fetch timeline 59 60 //! let timeline = session 60 - //! .clone() 61 - //! .send(GetTimeline::new().limit(5).build()) 61 + //! .send(&GetTimeline::new().limit(5).build()) 62 62 //! .await 63 63 //! .into_diagnostic()? 64 64 //! .into_output() ··· 90 90 //! let resp = http 91 91 //! .xrpc(base) 92 92 //! .send( 93 - //! GetAuthorFeed::new() 93 + //! &GetAuthorFeed::new() 94 94 //! .actor(AtIdentifier::new_static("pattern.atproto.systems").unwrap()) 95 95 //! .limit(5) 96 96 //! .build(), ··· 124 124 //! .accept_labelers(vec![CowStr::from("did:plc:labelerid")]) 125 125 //! .header(http::header::USER_AGENT, http::HeaderValue::from_static("jacquard-example")) 126 126 //! .send( 127 - //! GetAuthorFeed::new() 127 + //! &GetAuthorFeed::new() 128 128 //! .actor(AtIdentifier::new_static("pattern.atproto.systems").unwrap()) 129 129 //! .limit(5) 130 130 //! .build(),
+147
crates/jacquard/tests/agent.rs
··· 1 + use std::collections::VecDeque; 2 + use std::sync::Arc; 3 + 4 + use http::{HeaderValue, Response as HttpResponse, StatusCode}; 5 + use jacquard::client::credential_session::{CredentialSession, SessionKey}; 6 + use jacquard::client::{Agent, AtpSession}; 7 + use jacquard::identity::resolver::{DidDocResponse, IdentityResolver, ResolverOptions}; 8 + use jacquard::types::did::Did; 9 + use jacquard::types::string::Handle; 10 + use jacquard_common::http_client::HttpClient; 11 + use jacquard_common::session::MemorySessionStore; 12 + use tokio::sync::Mutex; 13 + 14 + #[derive(Clone, Default)] 15 + struct MockClient { 16 + queue: Arc<Mutex<VecDeque<http::Response<Vec<u8>>>>>, 17 + log: Arc<Mutex<Vec<http::Request<Vec<u8>>>>>, 18 + } 19 + 20 + impl MockClient { 21 + async fn push(&self, resp: http::Response<Vec<u8>>) { 22 + self.queue.lock().await.push_back(resp); 23 + } 24 + } 25 + 26 + impl HttpClient for MockClient { 27 + type Error = std::convert::Infallible; 28 + fn send_http( 29 + &self, 30 + request: http::Request<Vec<u8>>, 31 + ) -> impl core::future::Future< 32 + Output = core::result::Result<http::Response<Vec<u8>>, Self::Error>, 33 + > + Send { 34 + let log = self.log.clone(); 35 + let queue = self.queue.clone(); 36 + async move { 37 + log.lock().await.push(request); 38 + Ok(queue.lock().await.pop_front().expect("no queued response")) 39 + } 40 + } 41 + } 42 + 43 + #[async_trait::async_trait] 44 + impl IdentityResolver for MockClient { 45 + fn options(&self) -> &ResolverOptions { 46 + use std::sync::LazyLock; 47 + static OPTS: LazyLock<ResolverOptions> = LazyLock::new(ResolverOptions::default); 48 + &OPTS 49 + } 50 + async fn resolve_handle( 51 + &self, 52 + _handle: &Handle<'_>, 53 + ) -> std::result::Result<Did<'static>, jacquard::identity::resolver::IdentityError> { 54 + Ok(Did::new_static("did:plc:alice").unwrap()) 55 + } 56 + async fn resolve_did_doc( 57 + &self, 58 + _did: &Did<'_>, 59 + ) -> std::result::Result<DidDocResponse, jacquard::identity::resolver::IdentityError> { 60 + let doc = serde_json::json!({ 61 + "id": "did:plc:alice", 62 + "service": [{ 63 + "id": "#pds", 64 + "type": "AtprotoPersonalDataServer", 65 + "serviceEndpoint": "https://pds" 66 + }] 67 + }); 68 + Ok(DidDocResponse { 69 + buffer: bytes::Bytes::from(serde_json::to_vec(&doc).unwrap()), 70 + status: StatusCode::OK, 71 + requested: None, 72 + }) 73 + } 74 + } 75 + 76 + // XrpcExt blanket impl applies via HttpClient 77 + 78 + fn refresh_session_body(access: &str, refresh: &str) -> Vec<u8> { 79 + serde_json::to_vec(&serde_json::json!({ 80 + "accessJwt": access, 81 + "refreshJwt": refresh, 82 + "did": "did:plc:alice", 83 + "handle": "alice.bsky.social" 84 + })) 85 + .unwrap() 86 + } 87 + 88 + #[tokio::test] 89 + async fn agent_delegates_to_session_and_refreshes() { 90 + let client = Arc::new(MockClient::default()); 91 + let store: Arc<MemorySessionStore<SessionKey, AtpSession>> = Arc::new(Default::default()); 92 + let session = CredentialSession::new(store.clone(), client.clone()); 93 + 94 + // Seed a session in the store and activate it via restore (sets endpoint to PDS) 95 + let atp = AtpSession { 96 + access_jwt: "acc1".into(), 97 + refresh_jwt: "ref1".into(), 98 + did: Did::new_static("did:plc:alice").unwrap(), 99 + handle: Handle::new_static("alice.bsky.social").unwrap(), 100 + }; 101 + let key: SessionKey = (atp.did.clone(), "session".into()); 102 + jacquard_common::session::SessionStore::set(store.as_ref(), key.clone(), atp) 103 + .await 104 + .unwrap(); 105 + session 106 + .restore(Did::new_static("did:plc:alice").unwrap(), "session".into()) 107 + .await 108 + .unwrap(); 109 + 110 + let agent: Agent<_> = Agent::from(session); 111 + assert_eq!(agent.kind(), jacquard::client::AgentKind::AppPassword); 112 + let info = agent.info().await.expect("session info"); 113 + assert_eq!(info.0.as_str(), "did:plc:alice"); 114 + assert_eq!(info.1.as_ref().unwrap().as_str(), "session"); 115 + assert_eq!(agent.endpoint().await.as_str(), "https://pds/"); 116 + 117 + // Queue a refresh response and call agent.refresh(); Authorization header must use refresh token 118 + client 119 + .push( 120 + HttpResponse::builder() 121 + .status(StatusCode::OK) 122 + .header(http::header::CONTENT_TYPE, "application/json") 123 + .body(refresh_session_body("acc2", "ref2")) 124 + .unwrap(), 125 + ) 126 + .await; 127 + 128 + let token = agent.refresh().await.expect("refresh ok"); 129 + match token { 130 + jacquard::AuthorizationToken::Bearer(s) => assert_eq!(s.as_ref(), "acc2"), 131 + _ => panic!("expected Bearer token"), 132 + } 133 + 134 + // Validate the refreshSession call used the refresh token header 135 + let log = client.log.lock().await; 136 + assert_eq!(log.len(), 1); 137 + assert!( 138 + log[0] 139 + .uri() 140 + .to_string() 141 + .ends_with("/xrpc/com.atproto.server.refreshSession") 142 + ); 143 + assert_eq!( 144 + log[0].headers().get(http::header::AUTHORIZATION), 145 + Some(&HeaderValue::from_static("Bearer ref1")) 146 + ); 147 + }
+261
crates/jacquard/tests/credential_session.rs
··· 1 + use std::collections::VecDeque; 2 + use std::sync::Arc; 3 + 4 + use bytes::Bytes; 5 + use http::{HeaderValue, Method, Response as HttpResponse, StatusCode}; 6 + use jacquard::client::AtpSession; 7 + use jacquard::client::credential_session::{CredentialSession, SessionKey}; 8 + use jacquard::identity::resolver::{DidDocResponse, IdentityResolver, ResolverOptions}; 9 + use jacquard::types::did::Did; 10 + use jacquard::types::string::Handle; 11 + use jacquard::types::xrpc::XrpcClient; 12 + use jacquard_common::http_client::HttpClient; 13 + use jacquard_common::session::{MemorySessionStore, SessionStore}; 14 + use tokio::sync::{Mutex, RwLock}; 15 + 16 + #[derive(Clone, Default)] 17 + struct MockClient { 18 + // Queue of HTTP responses to pop for each send_http call 19 + queue: Arc<Mutex<VecDeque<HttpResponse<Vec<u8>>>>>, 20 + // Capture requests for assertions 21 + log: Arc<Mutex<Vec<http::Request<Vec<u8>>>>>, 22 + // Count calls to identity resolver helpers 23 + did_doc_calls: Arc<RwLock<usize>>, 24 + } 25 + 26 + impl MockClient { 27 + async fn push(&self, resp: HttpResponse<Vec<u8>>) { 28 + self.queue.lock().await.push_back(resp); 29 + } 30 + async fn take_log(&self) -> Vec<http::Request<Vec<u8>>> { 31 + let mut log = self.log.lock().await; 32 + let out = log.clone(); 33 + log.clear(); 34 + out 35 + } 36 + } 37 + 38 + impl HttpClient for MockClient { 39 + type Error = std::convert::Infallible; 40 + 41 + fn send_http( 42 + &self, 43 + request: http::Request<Vec<u8>>, 44 + ) -> impl core::future::Future< 45 + Output = core::result::Result<http::Response<Vec<u8>>, Self::Error>, 46 + > + Send { 47 + let log = self.log.clone(); 48 + let queue = self.queue.clone(); 49 + async move { 50 + log.lock().await.push(request); 51 + Ok(queue.lock().await.pop_front().expect("no queued response")) 52 + } 53 + } 54 + } 55 + 56 + #[async_trait::async_trait] 57 + impl IdentityResolver for MockClient { 58 + fn options(&self) -> &ResolverOptions { 59 + use std::sync::LazyLock; 60 + static OPTS: LazyLock<ResolverOptions> = LazyLock::new(ResolverOptions::default); 61 + &OPTS 62 + } 63 + 64 + async fn resolve_handle( 65 + &self, 66 + handle: &Handle<'_>, 67 + ) -> std::result::Result<Did<'static>, jacquard::identity::resolver::IdentityError> { 68 + // Return a fixed DID for any handle 69 + assert!(handle.as_str().contains('.')); 70 + Ok(Did::new_static("did:plc:alice").unwrap()) 71 + } 72 + 73 + async fn resolve_did_doc( 74 + &self, 75 + did: &Did<'_>, 76 + ) -> std::result::Result<DidDocResponse, jacquard::identity::resolver::IdentityError> { 77 + // Track calls and return a minimal DID doc with a PDS endpoint 78 + *self.did_doc_calls.write().await += 1; 79 + assert_eq!(did.as_str(), "did:plc:alice"); 80 + let doc = serde_json::json!({ 81 + "id": "did:plc:alice", 82 + "service": [{ 83 + "id": "#pds", 84 + "type": "AtprotoPersonalDataServer", 85 + "serviceEndpoint": "https://pds" 86 + }] 87 + }); 88 + Ok(DidDocResponse { 89 + buffer: Bytes::from(serde_json::to_vec(&doc).unwrap()), 90 + status: StatusCode::OK, 91 + requested: None, 92 + }) 93 + } 94 + } 95 + 96 + // XrpcExt blanket impl applies via HttpClient 97 + 98 + fn create_session_body() -> Vec<u8> { 99 + serde_json::to_vec(&serde_json::json!({ 100 + "accessJwt": "acc1", 101 + "refreshJwt": "ref1", 102 + "did": "did:plc:alice", 103 + "handle": "alice.bsky.social" 104 + })) 105 + .unwrap() 106 + } 107 + 108 + fn refresh_session_body(access: &str, refresh: &str) -> Vec<u8> { 109 + serde_json::to_vec(&serde_json::json!({ 110 + "accessJwt": access, 111 + "refreshJwt": refresh, 112 + "did": "did:plc:alice", 113 + "handle": "alice.bsky.social" 114 + })) 115 + .unwrap() 116 + } 117 + 118 + fn get_session_ok_body() -> Vec<u8> { 119 + serde_json::to_vec(&serde_json::json!({ 120 + "did": "did:plc:alice", 121 + "handle": "alice.bsky.social", 122 + "active": true 123 + })) 124 + .unwrap() 125 + } 126 + 127 + #[tokio::test(flavor = "multi_thread")] 128 + async fn credential_login_and_auto_refresh() { 129 + let client = Arc::new(MockClient::default()); 130 + 131 + // Queue responses in order: createSession 200 → getSession 401 → refreshSession 200 → getSession 200 132 + client 133 + .push( 134 + HttpResponse::builder() 135 + .status(StatusCode::OK) 136 + .header(http::header::CONTENT_TYPE, "application/json") 137 + .body(create_session_body()) 138 + .unwrap(), 139 + ) 140 + .await; 141 + client 142 + .push( 143 + HttpResponse::builder() 144 + .status(StatusCode::UNAUTHORIZED) 145 + .header(http::header::CONTENT_TYPE, "application/json") 146 + .body(serde_json::to_vec(&serde_json::json!({"error":"ExpiredToken"})).unwrap()) 147 + .unwrap(), 148 + ) 149 + .await; 150 + client 151 + .push( 152 + HttpResponse::builder() 153 + .status(StatusCode::OK) 154 + .header(http::header::CONTENT_TYPE, "application/json") 155 + .body(refresh_session_body("acc2", "ref2")) 156 + .unwrap(), 157 + ) 158 + .await; 159 + client 160 + .push( 161 + HttpResponse::builder() 162 + .status(StatusCode::OK) 163 + .header(http::header::CONTENT_TYPE, "application/json") 164 + .body(get_session_ok_body()) 165 + .unwrap(), 166 + ) 167 + .await; 168 + 169 + let store: Arc<MemorySessionStore<SessionKey, AtpSession>> = Arc::new(Default::default()); 170 + let session = CredentialSession::new(store.clone(), client.clone()); 171 + 172 + // Before login, default endpoint should be public appview 173 + assert_eq!( 174 + session.endpoint().await.as_str(), 175 + "https://public.bsky.app/" 176 + ); 177 + 178 + // Login using handle; resolves to PDS and persists session 179 + session 180 + .login( 181 + jacquard::CowStr::from("alice.bsky.social"), 182 + jacquard::CowStr::from("apppass"), 183 + Some(jacquard::CowStr::from("session")), 184 + None, 185 + None, 186 + ) 187 + .await 188 + .expect("login ok"); 189 + 190 + // Endpoint switches to PDS 191 + assert_eq!(session.endpoint().await.as_str(), "https://pds/"); 192 + 193 + // Send a request that will first 401 (ExpiredToken), then refresh, then succeed 194 + let resp = session 195 + .send(&jacquard::api::com_atproto::server::get_session::GetSession) 196 + .await 197 + .expect("xrpc send ok"); 198 + assert_eq!(resp.status(), StatusCode::OK); 199 + let out = resp 200 + .parse() 201 + .expect("parse ok after refresh (GetSession output)"); 202 + assert_eq!(out.handle.as_str(), "alice.bsky.social"); 203 + 204 + // Verify request sequence and Authorization headers used 205 + let log = client.take_log().await; 206 + assert_eq!(log.len(), 4, "expected four HTTP calls"); 207 + // 0: createSession (no auth) 208 + assert_eq!(log[0].method(), Method::POST); 209 + assert!( 210 + log[0] 211 + .uri() 212 + .to_string() 213 + .ends_with("/xrpc/com.atproto.server.createSession") 214 + ); 215 + assert!(log[0].headers().get(http::header::AUTHORIZATION).is_none()); 216 + // 1: getSession (uses access token acc1) 217 + assert_eq!(log[1].method(), Method::GET); 218 + assert!( 219 + log[1] 220 + .uri() 221 + .to_string() 222 + .ends_with("/xrpc/com.atproto.server.getSession") 223 + ); 224 + assert_eq!( 225 + log[1].headers().get(http::header::AUTHORIZATION), 226 + Some(&HeaderValue::from_static("Bearer acc1")) 227 + ); 228 + // 2: refreshSession (uses refresh token ref1) 229 + assert_eq!(log[2].method(), Method::POST); 230 + assert!( 231 + log[2] 232 + .uri() 233 + .to_string() 234 + .ends_with("/xrpc/com.atproto.server.refreshSession") 235 + ); 236 + assert_eq!( 237 + log[2].headers().get(http::header::AUTHORIZATION), 238 + Some(&HeaderValue::from_static("Bearer ref1")) 239 + ); 240 + // 3: getSession (re-sent with new access token acc2) 241 + assert_eq!(log[3].method(), Method::GET); 242 + assert!( 243 + log[3] 244 + .uri() 245 + .to_string() 246 + .ends_with("/xrpc/com.atproto.server.getSession") 247 + ); 248 + assert_eq!( 249 + log[3].headers().get(http::header::AUTHORIZATION), 250 + Some(&HeaderValue::from_static("Bearer acc2")) 251 + ); 252 + 253 + // Verify store updated with refreshed tokens 254 + let key: SessionKey = ( 255 + Did::new_static("did:plc:alice").unwrap(), 256 + jacquard::CowStr::from("session"), 257 + ); 258 + let updated = store.get(&key).await.expect("session present"); 259 + assert_eq!(updated.access_jwt.as_ref(), "acc2"); 260 + assert_eq!(updated.refresh_jwt.as_ref(), "ref2"); 261 + }
+374
crates/jacquard/tests/oauth_auto_refresh.rs
··· 1 + use std::collections::VecDeque; 2 + use std::sync::Arc; 3 + 4 + use bytes::Bytes; 5 + use http::{HeaderValue, Method, Response as HttpResponse, StatusCode}; 6 + use jacquard::client::Agent; 7 + use jacquard::IntoStatic; 8 + use jacquard::types::did::Did; 9 + use jacquard::types::xrpc::XrpcClient; 10 + use jacquard_common::http_client::HttpClient; 11 + use jacquard_oauth::atproto::AtprotoClientMetadata; 12 + use jacquard_oauth::client::OAuthSession; 13 + use jacquard_oauth::session::SessionRegistry; 14 + use jacquard_oauth::resolver::OAuthResolver; 15 + use jacquard_oauth::scopes::Scope; 16 + use jacquard_oauth::session::{ClientData, ClientSessionData, DpopClientData}; 17 + use jacquard_oauth::types::{OAuthAuthorizationServerMetadata, OAuthTokenType, TokenSet}; 18 + use tokio::sync::Mutex; 19 + 20 + #[derive(Clone, Default)] 21 + struct MockClient { 22 + queue: Arc<Mutex<VecDeque<http::Response<Vec<u8>>>>>, 23 + log: Arc<Mutex<Vec<http::Request<Vec<u8>>>>>, 24 + } 25 + 26 + impl MockClient { 27 + async fn push(&self, resp: http::Response<Vec<u8>>) { 28 + self.queue.lock().await.push_back(resp); 29 + } 30 + } 31 + 32 + impl HttpClient for MockClient { 33 + type Error = std::convert::Infallible; 34 + fn send_http( 35 + &self, 36 + request: http::Request<Vec<u8>>, 37 + ) -> impl core::future::Future< 38 + Output = core::result::Result<http::Response<Vec<u8>>, Self::Error>, 39 + > + Send { 40 + let log = self.log.clone(); 41 + let queue = self.queue.clone(); 42 + async move { 43 + log.lock().await.push(request); 44 + Ok(queue 45 + .lock() 46 + .await 47 + .pop_front() 48 + .expect("no queued response")) 49 + } 50 + } 51 + } 52 + 53 + #[async_trait::async_trait] 54 + impl jacquard::identity::resolver::IdentityResolver for MockClient { 55 + fn options(&self) -> &jacquard::identity::resolver::ResolverOptions { 56 + use std::sync::LazyLock; 57 + static OPTS: LazyLock<jacquard::identity::resolver::ResolverOptions> = 58 + LazyLock::new(jacquard::identity::resolver::ResolverOptions::default); 59 + &OPTS 60 + } 61 + async fn resolve_handle( 62 + &self, 63 + _handle: &jacquard::types::string::Handle<'_>, 64 + ) -> std::result::Result<Did<'static>, jacquard::identity::resolver::IdentityError> { 65 + Ok(Did::new_static("did:plc:alice").unwrap()) 66 + } 67 + async fn resolve_did_doc( 68 + &self, 69 + _did: &Did<'_>, 70 + ) -> std::result::Result<jacquard::identity::resolver::DidDocResponse, jacquard::identity::resolver::IdentityError> { 71 + let doc = serde_json::json!({ 72 + "id": "did:plc:alice", 73 + "service": [{ 74 + "id": "#pds", 75 + "type": "AtprotoPersonalDataServer", 76 + "serviceEndpoint": "https://pds" 77 + }] 78 + }); 79 + Ok(jacquard::identity::resolver::DidDocResponse { 80 + buffer: Bytes::from(serde_json::to_vec(&doc).unwrap()), 81 + status: StatusCode::OK, 82 + requested: None, 83 + }) 84 + } 85 + } 86 + 87 + #[async_trait::async_trait] 88 + impl OAuthResolver for MockClient { 89 + async fn get_authorization_server_metadata( 90 + &self, 91 + issuer: &url::Url, 92 + ) -> Result<OAuthAuthorizationServerMetadata<'static>, jacquard_oauth::resolver::ResolverError> { 93 + // Return minimal metadata with supported auth method "none" and DPoP support 94 + let mut md = OAuthAuthorizationServerMetadata::default(); 95 + md.issuer = jacquard::CowStr::from(issuer.as_str()); 96 + md.token_endpoint = jacquard::CowStr::from(format!("{}/token", issuer)); 97 + md.authorization_endpoint = jacquard::CowStr::from(format!("{}/authorize", issuer)); 98 + md.require_pushed_authorization_requests = Some(true); 99 + md.pushed_authorization_request_endpoint = 100 + Some(jacquard::CowStr::from(format!("{}/par", issuer))); 101 + md.token_endpoint_auth_methods_supported = Some(vec![jacquard::CowStr::from("none")]); 102 + md.dpop_signing_alg_values_supported = Some(vec![jacquard::CowStr::from("ES256")]); 103 + use jacquard::IntoStatic; 104 + Ok(md.into_static()) 105 + } 106 + 107 + async fn get_resource_server_metadata( 108 + &self, 109 + _pds: &url::Url, 110 + ) -> Result<OAuthAuthorizationServerMetadata<'static>, jacquard_oauth::resolver::ResolverError> { 111 + // Return metadata pointing to the same issuer as above 112 + let mut md = OAuthAuthorizationServerMetadata::default(); 113 + md.issuer = jacquard::CowStr::from("https://issuer"); 114 + md.token_endpoint = jacquard::CowStr::from("https://issuer/token"); 115 + md.authorization_endpoint = jacquard::CowStr::from("https://issuer/authorize"); 116 + md.require_pushed_authorization_requests = Some(true); 117 + md.pushed_authorization_request_endpoint = Some(jacquard::CowStr::from("https://issuer/par")); 118 + md.token_endpoint_auth_methods_supported = Some(vec![jacquard::CowStr::from("none")]); 119 + md.dpop_signing_alg_values_supported = Some(vec![jacquard::CowStr::from("ES256")]); 120 + Ok(md.into_static()) 121 + } 122 + 123 + async fn verify_issuer( 124 + &self, 125 + _server_metadata: &OAuthAuthorizationServerMetadata<'_>, 126 + _sub: &Did<'_>, 127 + ) -> Result<url::Url, jacquard_oauth::resolver::ResolverError> { 128 + Ok(url::Url::parse("https://pds").unwrap()) 129 + } 130 + } 131 + 132 + fn get_session_unauthorized() -> http::Response<Vec<u8>> { 133 + HttpResponse::builder() 134 + .status(StatusCode::UNAUTHORIZED) 135 + .header( 136 + http::header::WWW_AUTHENTICATE, 137 + HeaderValue::from_static("DPoP realm=\"pds\", error=\"invalid_token\""), 138 + ) 139 + .body(Vec::new()) 140 + .unwrap() 141 + } 142 + 143 + fn get_session_unauthorized_body() -> http::Response<Vec<u8>> { 144 + HttpResponse::builder() 145 + .status(StatusCode::UNAUTHORIZED) 146 + .header(http::header::CONTENT_TYPE, "application/json") 147 + .body( 148 + serde_json::to_vec(&serde_json::json!({ 149 + "error":"InvalidToken" 150 + })) 151 + .unwrap(), 152 + ) 153 + .unwrap() 154 + } 155 + 156 + fn token_use_dpop_nonce() -> http::Response<Vec<u8>> { 157 + HttpResponse::builder() 158 + .status(StatusCode::BAD_REQUEST) 159 + .header(http::header::CONTENT_TYPE, "application/json") 160 + .header("DPoP-Nonce", HeaderValue::from_static("n1")) 161 + .body(serde_json::to_vec(&serde_json::json!({"error":"use_dpop_nonce"})).unwrap()) 162 + .unwrap() 163 + } 164 + 165 + fn token_refresh_ok() -> http::Response<Vec<u8>> { 166 + HttpResponse::builder() 167 + .status(StatusCode::OK) 168 + .header(http::header::CONTENT_TYPE, "application/json") 169 + .body( 170 + serde_json::to_vec(&serde_json::json!({ 171 + "access_token":"newacc", 172 + "token_type":"DPoP", 173 + "refresh_token":"newref", 174 + "expires_in": 3600 175 + })) 176 + .unwrap(), 177 + ) 178 + .unwrap() 179 + } 180 + 181 + fn get_session_ok() -> http::Response<Vec<u8>> { 182 + HttpResponse::builder() 183 + .status(StatusCode::OK) 184 + .header(http::header::CONTENT_TYPE, "application/json") 185 + .body( 186 + serde_json::to_vec(&serde_json::json!({ 187 + "did":"did:plc:alice", 188 + "handle":"alice.bsky.social", 189 + "active":true 190 + })) 191 + .unwrap(), 192 + ) 193 + .unwrap() 194 + } 195 + 196 + impl jacquard_oauth::dpop::DpopExt for MockClient {} 197 + 198 + #[tokio::test(flavor = "multi_thread")] 199 + async fn oauth_xrpc_invalid_token_triggers_refresh_and_retries() { 200 + // (reopen test body since we inserted a trait impl) 201 + let client = Arc::new(MockClient::default()); 202 + 203 + client.push(get_session_unauthorized()).await; 204 + client.push(token_use_dpop_nonce()).await; 205 + client.push(token_refresh_ok()).await; 206 + client.push(get_session_ok()).await; 207 + 208 + let mut path = std::env::temp_dir(); 209 + path.push(format!("jacquard-oauth-test-{}.json", std::process::id())); 210 + std::fs::write(&path, "{}").unwrap(); 211 + let store = jacquard::client::FileAuthStore::new(&path); 212 + 213 + let client_data = ClientData { 214 + keyset: None, 215 + config: AtprotoClientMetadata::new_localhost(None, Some(vec![Scope::Atproto])), 216 + }; 217 + use jacquard::IntoStatic; 218 + let session_data = ClientSessionData { 219 + account_did: Did::new_static("did:plc:alice").unwrap(), 220 + session_id: jacquard::CowStr::from("state"), 221 + host_url: url::Url::parse("https://pds").unwrap(), 222 + authserver_url: url::Url::parse("https://issuer").unwrap(), 223 + authserver_token_endpoint: jacquard::CowStr::from("https://issuer/token"), 224 + authserver_revocation_endpoint: None, 225 + scopes: vec![Scope::Atproto], 226 + dpop_data: DpopClientData { 227 + dpop_key: jacquard_oauth::utils::generate_key(&[jacquard::CowStr::from("ES256")]) 228 + .unwrap(), 229 + dpop_authserver_nonce: jacquard::CowStr::from(""), 230 + dpop_host_nonce: jacquard::CowStr::from(""), 231 + }, 232 + token_set: TokenSet { 233 + iss: jacquard::CowStr::from("https://issuer"), 234 + sub: Did::new_static("did:plc:alice").unwrap(), 235 + aud: jacquard::CowStr::from("https://pds"), 236 + scope: None, 237 + refresh_token: Some(jacquard::CowStr::from("rt1")), 238 + access_token: jacquard::CowStr::from("atk1"), 239 + token_type: OAuthTokenType::DPoP, 240 + expires_at: None, 241 + }, 242 + } 243 + .into_static(); 244 + let client_arc = client.clone(); 245 + let registry = Arc::new(SessionRegistry::new(store, client_arc.clone(), client_data)); 246 + // Seed the store so refresh can load the session 247 + let data_store = ClientSessionData { 248 + account_did: Did::new_static("did:plc:alice").unwrap(), 249 + session_id: jacquard::CowStr::from("state"), 250 + host_url: url::Url::parse("https://pds").unwrap(), 251 + authserver_url: url::Url::parse("https://issuer").unwrap(), 252 + authserver_token_endpoint: jacquard::CowStr::from("https://issuer/token"), 253 + authserver_revocation_endpoint: None, 254 + scopes: vec![Scope::Atproto], 255 + dpop_data: DpopClientData { 256 + dpop_key: jacquard_oauth::utils::generate_key(&[jacquard::CowStr::from("ES256")]) 257 + .unwrap(), 258 + dpop_authserver_nonce: jacquard::CowStr::from(""), 259 + dpop_host_nonce: jacquard::CowStr::from(""), 260 + }, 261 + token_set: TokenSet { 262 + iss: jacquard::CowStr::from("https://issuer"), 263 + sub: Did::new_static("did:plc:alice").unwrap(), 264 + aud: jacquard::CowStr::from("https://pds"), 265 + scope: None, 266 + refresh_token: Some(jacquard::CowStr::from("rt1")), 267 + access_token: jacquard::CowStr::from("atk1"), 268 + token_type: OAuthTokenType::DPoP, 269 + expires_at: None, 270 + }, 271 + } 272 + .into_static(); 273 + registry.set(data_store).await.unwrap(); 274 + let session = OAuthSession::new(registry, client_arc, session_data); 275 + 276 + let agent: Agent<_> = Agent::from(session); 277 + let resp = agent 278 + .send(&jacquard::api::com_atproto::server::get_session::GetSession) 279 + .await 280 + .expect("xrpc send ok after auto-refresh"); 281 + assert_eq!(resp.status(), StatusCode::OK); 282 + 283 + // Inspect the request log 284 + let log = client.log.lock().await; 285 + assert_eq!(log.len(), 4, "expected 4 HTTP calls"); 286 + // 0: getSession with old token 287 + assert_eq!(log[0].method(), Method::GET); 288 + assert!(log[0].headers().get(http::header::AUTHORIZATION).unwrap().to_str().unwrap().starts_with("DPoP ")); 289 + assert!(log[0] 290 + .uri() 291 + .to_string() 292 + .ends_with("/xrpc/com.atproto.server.getSession")); 293 + // 1 and 2: token refresh attempts 294 + assert_eq!(log[1].method(), Method::POST); 295 + assert!(log[1].uri().to_string().ends_with("/token")); 296 + assert!(log[1].headers().contains_key("DPoP")); 297 + assert_eq!(log[2].method(), Method::POST); 298 + assert!(log[2].uri().to_string().ends_with("/token")); 299 + assert!(log[2].headers().contains_key("DPoP")); 300 + // 3: retried getSession with new access token 301 + assert_eq!(log[3].method(), Method::GET); 302 + assert!(log[3] 303 + .headers() 304 + .get(http::header::AUTHORIZATION) 305 + .unwrap() 306 + .to_str() 307 + .unwrap() 308 + .starts_with("DPoP newacc")); 309 + 310 + // Cleanup temp file 311 + let _ = std::fs::remove_file(&path); 312 + } 313 + 314 + #[tokio::test(flavor = "multi_thread")] 315 + async fn oauth_xrpc_invalid_token_body_triggers_refresh_and_retries() { 316 + let client = Arc::new(MockClient::default()); 317 + 318 + // Queue responses: initial 401 with JSON body; token refresh 400(use_dpop_nonce); token refresh 200; retry getSession 200 319 + client.push(get_session_unauthorized_body()).await; 320 + client.push(token_use_dpop_nonce()).await; 321 + client.push(token_refresh_ok()).await; 322 + client.push(get_session_ok()).await; 323 + 324 + let mut path = std::env::temp_dir(); 325 + path.push(format!("jacquard-oauth-test-body-{}.json", std::process::id())); 326 + std::fs::write(&path, "{}").unwrap(); 327 + let store = jacquard::client::FileAuthStore::new(&path); 328 + 329 + let client_data = ClientData { 330 + keyset: None, 331 + config: AtprotoClientMetadata::new_localhost(None, Some(vec![Scope::Atproto])), 332 + }; 333 + use jacquard::IntoStatic; 334 + let session_data = ClientSessionData { 335 + account_did: Did::new_static("did:plc:alice").unwrap(), 336 + session_id: jacquard::CowStr::from("state"), 337 + host_url: url::Url::parse("https://pds").unwrap(), 338 + authserver_url: url::Url::parse("https://issuer").unwrap(), 339 + authserver_token_endpoint: jacquard::CowStr::from("https://issuer/token"), 340 + authserver_revocation_endpoint: None, 341 + scopes: vec![Scope::Atproto], 342 + dpop_data: DpopClientData { 343 + dpop_key: jacquard_oauth::utils::generate_key(&[jacquard::CowStr::from("ES256")]) 344 + .unwrap(), 345 + dpop_authserver_nonce: jacquard::CowStr::from(""), 346 + dpop_host_nonce: jacquard::CowStr::from(""), 347 + }, 348 + token_set: TokenSet { 349 + iss: jacquard::CowStr::from("https://issuer"), 350 + sub: Did::new_static("did:plc:alice").unwrap(), 351 + aud: jacquard::CowStr::from("https://pds"), 352 + scope: None, 353 + refresh_token: Some(jacquard::CowStr::from("rt1")), 354 + access_token: jacquard::CowStr::from("atk1"), 355 + token_type: OAuthTokenType::DPoP, 356 + expires_at: None, 357 + }, 358 + } 359 + .into_static(); 360 + let client_arc = client.clone(); 361 + let registry = Arc::new(SessionRegistry::new(store, client_arc.clone(), client_data)); 362 + registry.set(session_data.clone()).await.unwrap(); 363 + let session = OAuthSession::new(registry, client_arc, session_data); 364 + 365 + let agent: Agent<_> = Agent::from(session); 366 + let resp = agent 367 + .send(&jacquard::api::com_atproto::server::get_session::GetSession) 368 + .await 369 + .expect("xrpc send ok after auto-refresh"); 370 + assert_eq!(resp.status(), StatusCode::OK); 371 + 372 + // Cleanup temp file 373 + let _ = std::fs::remove_file(&path); 374 + }
+293
crates/jacquard/tests/oauth_flow.rs
··· 1 + use std::collections::VecDeque; 2 + use std::sync::Arc; 3 + 4 + use bytes::Bytes; 5 + use http::{Response as HttpResponse, StatusCode}; 6 + use jacquard::IntoStatic; 7 + use jacquard::client::Agent; 8 + use jacquard::types::xrpc::XrpcClient; 9 + use jacquard_common::http_client::HttpClient; 10 + use jacquard_oauth::atproto::AtprotoClientMetadata; 11 + use jacquard_oauth::authstore::ClientAuthStore; 12 + use jacquard_oauth::client::OAuthClient; 13 + use jacquard_oauth::resolver::OAuthResolver; 14 + use jacquard_oauth::scopes::Scope; 15 + use jacquard_oauth::session::ClientData; 16 + 17 + #[derive(Clone, Default)] 18 + struct MockClient { 19 + queue: Arc<tokio::sync::Mutex<VecDeque<http::Response<Vec<u8>>>>>, 20 + } 21 + 22 + impl MockClient { 23 + async fn push(&self, resp: http::Response<Vec<u8>>) { 24 + self.queue.lock().await.push_back(resp); 25 + } 26 + } 27 + 28 + impl HttpClient for MockClient { 29 + type Error = std::convert::Infallible; 30 + fn send_http( 31 + &self, 32 + _request: http::Request<Vec<u8>>, 33 + ) -> impl core::future::Future< 34 + Output = core::result::Result<http::Response<Vec<u8>>, Self::Error>, 35 + > + Send { 36 + let queue = self.queue.clone(); 37 + async move { Ok(queue.lock().await.pop_front().expect("no queued response")) } 38 + } 39 + } 40 + 41 + #[async_trait::async_trait] 42 + impl jacquard::identity::resolver::IdentityResolver for MockClient { 43 + fn options(&self) -> &jacquard::identity::resolver::ResolverOptions { 44 + use std::sync::LazyLock; 45 + static OPTS: LazyLock<jacquard::identity::resolver::ResolverOptions> = 46 + LazyLock::new(jacquard::identity::resolver::ResolverOptions::default); 47 + &OPTS 48 + } 49 + async fn resolve_handle( 50 + &self, 51 + _handle: &jacquard::types::string::Handle<'_>, 52 + ) -> std::result::Result< 53 + jacquard::types::did::Did<'static>, 54 + jacquard::identity::resolver::IdentityError, 55 + > { 56 + Ok(jacquard::types::did::Did::new_static("did:plc:alice").unwrap()) 57 + } 58 + async fn resolve_did_doc( 59 + &self, 60 + _did: &jacquard::types::did::Did<'_>, 61 + ) -> std::result::Result< 62 + jacquard::identity::resolver::DidDocResponse, 63 + jacquard::identity::resolver::IdentityError, 64 + > { 65 + let doc = serde_json::json!({ 66 + "id": "did:plc:alice", 67 + "service": [{ 68 + "id": "#pds", 69 + "type": "AtprotoPersonalDataServer", 70 + "serviceEndpoint": "https://pds" 71 + }] 72 + }); 73 + Ok(jacquard::identity::resolver::DidDocResponse { 74 + buffer: Bytes::from(serde_json::to_vec(&doc).unwrap()), 75 + status: StatusCode::OK, 76 + requested: None, 77 + }) 78 + } 79 + } 80 + 81 + #[async_trait::async_trait] 82 + impl OAuthResolver for MockClient { 83 + async fn resolve_oauth( 84 + &self, 85 + _input: &str, 86 + ) -> Result< 87 + ( 88 + jacquard_oauth::types::OAuthAuthorizationServerMetadata<'static>, 89 + Option<jacquard_common::types::did_doc::DidDocument<'static>>, 90 + ), 91 + jacquard_oauth::resolver::ResolverError, 92 + > { 93 + let mut md = jacquard_oauth::types::OAuthAuthorizationServerMetadata::default(); 94 + md.issuer = jacquard::CowStr::from("https://issuer"); 95 + md.authorization_endpoint = jacquard::CowStr::from("https://issuer/authorize"); 96 + md.token_endpoint = jacquard::CowStr::from("https://issuer/token"); 97 + md.require_pushed_authorization_requests = Some(true); 98 + md.pushed_authorization_request_endpoint = 99 + Some(jacquard::CowStr::from("https://issuer/par")); 100 + md.token_endpoint_auth_methods_supported = Some(vec![jacquard::CowStr::from("none")]); 101 + md.dpop_signing_alg_values_supported = Some(vec![jacquard::CowStr::from("ES256")]); 102 + 103 + // Simple DID doc pointing to https://pds 104 + let doc = serde_json::json!({ 105 + "id": "did:plc:alice", 106 + "service": [{ 107 + "id": "#pds", 108 + "type": "AtprotoPersonalDataServer", 109 + "serviceEndpoint": "https://pds" 110 + }] 111 + }); 112 + let buf = Bytes::from(serde_json::to_vec(&doc).unwrap()); 113 + let did_doc_b: jacquard_common::types::did_doc::DidDocument<'_> = 114 + serde_json::from_slice(&buf).unwrap(); 115 + let did_doc = did_doc_b.into_static(); 116 + Ok((md.into_static(), Some(did_doc))) 117 + } 118 + async fn get_authorization_server_metadata( 119 + &self, 120 + issuer: &url::Url, 121 + ) -> Result< 122 + jacquard_oauth::types::OAuthAuthorizationServerMetadata<'static>, 123 + jacquard_oauth::resolver::ResolverError, 124 + > { 125 + let mut md = jacquard_oauth::types::OAuthAuthorizationServerMetadata::default(); 126 + md.issuer = jacquard::CowStr::from(issuer.as_str()); 127 + md.authorization_endpoint = jacquard::CowStr::from(format!("{}/authorize", issuer)); 128 + md.token_endpoint = jacquard::CowStr::from(format!("{}/token", issuer)); 129 + md.require_pushed_authorization_requests = Some(true); 130 + md.pushed_authorization_request_endpoint = 131 + Some(jacquard::CowStr::from(format!("{}/par", issuer))); 132 + md.token_endpoint_auth_methods_supported = Some(vec![jacquard::CowStr::from("none")]); 133 + md.dpop_signing_alg_values_supported = Some(vec![jacquard::CowStr::from("ES256")]); 134 + Ok(md.into_static()) 135 + } 136 + 137 + async fn get_resource_server_metadata( 138 + &self, 139 + _pds: &url::Url, 140 + ) -> Result< 141 + jacquard_oauth::types::OAuthAuthorizationServerMetadata<'static>, 142 + jacquard_oauth::resolver::ResolverError, 143 + > { 144 + let mut md = jacquard_oauth::types::OAuthAuthorizationServerMetadata::default(); 145 + md.issuer = jacquard::CowStr::from("https://issuer/"); 146 + md.authorization_endpoint = jacquard::CowStr::from("https://issuer/authorize"); 147 + md.token_endpoint = jacquard::CowStr::from("https://issuer/token"); 148 + md.require_pushed_authorization_requests = Some(true); 149 + md.pushed_authorization_request_endpoint = 150 + Some(jacquard::CowStr::from("https://issuer/par")); 151 + md.token_endpoint_auth_methods_supported = Some(vec![jacquard::CowStr::from("none")]); 152 + md.dpop_signing_alg_values_supported = Some(vec![jacquard::CowStr::from("ES256")]); 153 + Ok(md.into_static()) 154 + } 155 + } 156 + 157 + impl jacquard_oauth::dpop::DpopExt for MockClient {} 158 + 159 + #[tokio::test(flavor = "multi_thread")] 160 + async fn oauth_end_to_end_mock_flow() { 161 + let client = Arc::new(MockClient::default()); 162 + // Queue responses: PAR 201, token 200, XRPC getSession 200 163 + client 164 + .push( 165 + HttpResponse::builder() 166 + .status(StatusCode::CREATED) 167 + .header(http::header::CONTENT_TYPE, "application/json") 168 + .body( 169 + serde_json::to_vec(&serde_json::json!({ 170 + "request_uri": "urn:par:abc", 171 + "expires_in": 60 172 + })) 173 + .unwrap(), 174 + ) 175 + .unwrap(), 176 + ) 177 + .await; 178 + client 179 + .push( 180 + HttpResponse::builder() 181 + .status(StatusCode::OK) 182 + .header(http::header::CONTENT_TYPE, "application/json") 183 + .header("DPoP-Nonce", http::HeaderValue::from_static("n1")) 184 + .body( 185 + serde_json::to_vec(&serde_json::json!({ 186 + "access_token": "atk1", 187 + "token_type": "DPoP", 188 + "refresh_token": "rt1", 189 + "sub": "did:plc:alice", 190 + "iss": "https://issuer", 191 + "aud": "https://pds", 192 + "expires_in": 3600 193 + })) 194 + .unwrap(), 195 + ) 196 + .unwrap(), 197 + ) 198 + .await; 199 + client 200 + .push( 201 + HttpResponse::builder() 202 + .status(StatusCode::OK) 203 + .header(http::header::CONTENT_TYPE, "application/json") 204 + .body( 205 + serde_json::to_vec(&serde_json::json!({ 206 + "did": "did:plc:alice", 207 + "handle": "alice.bsky.social", 208 + "active": true 209 + })) 210 + .unwrap(), 211 + ) 212 + .unwrap(), 213 + ) 214 + .await; 215 + 216 + // File-backed store for auth state/session 217 + let mut path = std::env::temp_dir(); 218 + path.push(format!("jacquard-oauth-flow-{}.json", std::process::id())); 219 + std::fs::write(&path, "{}").unwrap(); 220 + let store = jacquard::client::FileAuthStore::new(&path); 221 + 222 + let client_data: ClientData<'static> = ClientData { 223 + keyset: None, 224 + config: AtprotoClientMetadata::new_localhost(None, Some(vec![Scope::Atproto])), 225 + }; 226 + let client_arc = client.clone(); 227 + let oauth = OAuthClient::new_from_resolver(store, (*client_arc).clone(), client_data); 228 + 229 + // Build metadata and call PAR to get an AuthRequestData, then save in store 230 + let (server_metadata, identity) = client.resolve_oauth("alice.bsky.social").await.unwrap(); 231 + let metadata = jacquard_oauth::request::OAuthMetadata { 232 + server_metadata, 233 + client_metadata: jacquard_oauth::atproto::atproto_client_metadata( 234 + AtprotoClientMetadata::new_localhost(None, Some(vec![Scope::Atproto])), 235 + &None, 236 + ) 237 + .unwrap() 238 + .into_static(), 239 + keyset: None, 240 + }; 241 + let login_hint = identity.map(|_| jacquard::CowStr::from("alice.bsky.social")); 242 + let mut auth_req = jacquard_oauth::request::par(client.as_ref(), login_hint, None, &metadata) 243 + .await 244 + .unwrap(); 245 + // Construct authorization URL as OAuthClient::start_auth would do 246 + #[derive(serde::Serialize)] 247 + struct Parameters<'s> { 248 + client_id: url::Url, 249 + request_uri: jacquard::CowStr<'s>, 250 + } 251 + let auth_url = format!( 252 + "{}?{}", 253 + metadata.server_metadata.authorization_endpoint, 254 + serde_html_form::to_string(Parameters { 255 + client_id: metadata.client_metadata.client_id.clone(), 256 + request_uri: auth_req.request_uri.clone(), 257 + }) 258 + .unwrap() 259 + ); 260 + assert!(auth_url.contains("/authorize?")); 261 + assert!(auth_url.contains("request_uri")); 262 + // keep state for the callback 263 + let state = auth_req.state.clone(); 264 + oauth 265 + .registry 266 + .store 267 + .save_auth_req_info(&auth_req) 268 + .await 269 + .unwrap(); 270 + 271 + // callback: exchange code, create session 272 + use jacquard_oauth::types::CallbackParams; 273 + let session = oauth 274 + .callback(CallbackParams { 275 + code: jacquard::CowStr::from("code123"), 276 + state: Some(state.clone()), 277 + // Callback compares exact string with metadata.issuer (which is a URL string 278 + // including trailing slash). Use normalized form to match. 279 + iss: Some(jacquard::CowStr::from("https://issuer/")), 280 + }) 281 + .await 282 + .unwrap(); 283 + 284 + // Wrap in Agent and send a resource XRPC call to verify Authorization works 285 + let agent: Agent<_> = Agent::from(session); 286 + let resp = agent 287 + .send(&jacquard::api::com_atproto::server::get_session::GetSession) 288 + .await 289 + .unwrap(); 290 + assert_eq!(resp.status(), StatusCode::OK); 291 + 292 + let _ = std::fs::remove_file(&path); 293 + }
+125
crates/jacquard/tests/restore_pds_cache.rs
··· 1 + use std::sync::Arc; 2 + 3 + use bytes::Bytes; 4 + use http::{Response as HttpResponse, StatusCode}; 5 + use jacquard::client::credential_session::{CredentialSession, SessionKey}; 6 + use jacquard::client::{AtpSession, FileAuthStore}; 7 + use jacquard::identity::resolver::{DidDocResponse, IdentityResolver, ResolverOptions}; 8 + use jacquard::types::did::Did; 9 + use jacquard::types::string::Handle; 10 + use jacquard_common::http_client::HttpClient; 11 + use jacquard_common::session::SessionStore; 12 + use std::fs; 13 + use std::path::PathBuf; 14 + use tokio::sync::RwLock; 15 + use url::Url; 16 + 17 + #[derive(Clone, Default)] 18 + struct MockResolver { 19 + // Count calls to DID doc resolution 20 + did_doc_calls: Arc<RwLock<usize>>, 21 + } 22 + 23 + impl HttpClient for MockResolver { 24 + type Error = std::convert::Infallible; 25 + fn send_http( 26 + &self, 27 + _request: http::Request<Vec<u8>>, 28 + ) -> impl core::future::Future< 29 + Output = core::result::Result<http::Response<Vec<u8>>, Self::Error>, 30 + > + Send { 31 + async { 32 + // Not used in this test 33 + Ok(HttpResponse::builder() 34 + .status(StatusCode::OK) 35 + .body(Vec::new()) 36 + .unwrap()) 37 + } 38 + } 39 + } 40 + 41 + #[async_trait::async_trait] 42 + impl IdentityResolver for MockResolver { 43 + fn options(&self) -> &ResolverOptions { 44 + use std::sync::LazyLock; 45 + static OPTS: LazyLock<ResolverOptions> = LazyLock::new(ResolverOptions::default); 46 + &OPTS 47 + } 48 + async fn resolve_handle( 49 + &self, 50 + _handle: &Handle<'_>, 51 + ) -> std::result::Result<Did<'static>, jacquard::identity::resolver::IdentityError> { 52 + Ok(Did::new_static("did:plc:alice").unwrap()) 53 + } 54 + async fn resolve_did_doc( 55 + &self, 56 + _did: &Did<'_>, 57 + ) -> std::result::Result<DidDocResponse, jacquard::identity::resolver::IdentityError> { 58 + *self.did_doc_calls.write().await += 1; 59 + let doc = serde_json::json!({ 60 + "id": "did:plc:alice", 61 + "service": [{ 62 + "id": "#pds", 63 + "type": "AtprotoPersonalDataServer", 64 + "serviceEndpoint": "https://pds-resolved" 65 + }] 66 + }); 67 + Ok(DidDocResponse { 68 + buffer: Bytes::from(serde_json::to_vec(&doc).unwrap()), 69 + status: StatusCode::OK, 70 + requested: None, 71 + }) 72 + } 73 + } 74 + 75 + fn temp_file() -> PathBuf { 76 + let mut p = std::env::temp_dir(); 77 + p.push(format!("jacquard-test-restore-{}.json", std::process::id())); 78 + p 79 + } 80 + 81 + #[tokio::test] 82 + async fn restore_uses_cached_pds_when_present() { 83 + let path = temp_file(); 84 + fs::write(&path, "{}").unwrap(); 85 + let store = Arc::new(FileAuthStore::new(&path)); 86 + let resolver = Arc::new(MockResolver::default()); 87 + 88 + // Seed an app-password session in the file store 89 + let session = AtpSession { 90 + access_jwt: "acc".into(), 91 + refresh_jwt: "ref".into(), 92 + did: Did::new_static("did:plc:alice").unwrap(), 93 + handle: Handle::new_static("alice.bsky.social").unwrap(), 94 + }; 95 + let key: SessionKey = (session.did.clone(), "session".into()); 96 + jacquard_common::session::SessionStore::set(store.as_ref(), key.clone(), session) 97 + .await 98 + .unwrap(); 99 + // Verify it is persisted 100 + assert!(SessionStore::get(store.as_ref(), &key).await.is_some()); 101 + // Persist PDS endpoint cache to avoid DID resolution on restore 102 + store 103 + .set_atp_pds(&key, &Url::parse("https://pds-cached").unwrap()) 104 + .unwrap(); 105 + assert_eq!( 106 + store 107 + .get_atp_pds(&key) 108 + .ok() 109 + .flatten() 110 + .expect("pds cached") 111 + .as_str(), 112 + "https://pds-cached/" 113 + ); 114 + 115 + let session = CredentialSession::new(store.clone(), resolver.clone()); 116 + // Restore should pick cached PDS and NOT call resolve_did_doc 117 + session 118 + .restore(Did::new_static("did:plc:alice").unwrap(), "session".into()) 119 + .await 120 + .expect("restore ok"); 121 + assert_eq!(session.endpoint().await.as_str(), "https://pds-cached/"); 122 + 123 + // Cleanup 124 + let _ = fs::remove_file(&path); 125 + }