A better Rust ATProto crate

deserialization fixes and service auth update

Orual 1864232b aae802b5

+15
Cargo.lock
··· 1817 1817 "axum", 1818 1818 "axum-macros", 1819 1819 "axum-test", 1820 + "base64 0.22.1", 1820 1821 "bytes", 1822 + "chrono", 1821 1823 "jacquard", 1822 1824 "jacquard-common 0.5.1", 1825 + "jacquard-derive 0.5.1", 1826 + "jacquard-identity 0.5.1", 1827 + "k256", 1823 1828 "miette", 1829 + "multibase", 1830 + "rand 0.8.5", 1831 + "reqwest", 1824 1832 "serde", 1825 1833 "serde_html_form", 1826 1834 "serde_ipld_dagcbor", ··· 1828 1836 "thiserror 2.0.17", 1829 1837 "tokio", 1830 1838 "tokio-test", 1839 + "tower", 1831 1840 "tower-http", 1832 1841 "tracing", 1833 1842 "tracing-subscriber", 1843 + "urlencoding", 1834 1844 ] 1835 1845 1836 1846 [[package]] ··· 1861 1871 "serde_ipld_dagcbor", 1862 1872 "serde_json", 1863 1873 "serde_with", 1874 + "signature", 1864 1875 "smol_str", 1865 1876 "thiserror 2.0.17", 1866 1877 "tokio", ··· 2120 2131 checksum = "f6e3919bbaa2945715f0bb6d3934a173d1e9a59ac23767fbaaef277265a7411b" 2121 2132 dependencies = [ 2122 2133 "cfg-if", 2134 + "ecdsa", 2123 2135 "elliptic-curve", 2136 + "once_cell", 2137 + "sha2", 2138 + "signature", 2124 2139 ] 2125 2140 2126 2141 [[package]]
+15
crates/jacquard-axum/Cargo.toml
··· 26 26 bytes.workspace = true 27 27 jacquard = { version = "0.5", path = "../jacquard" } 28 28 jacquard-common = { version = "0.5", path = "../jacquard-common", features = ["reqwest-client"] } 29 + jacquard-derive = { version = "0.5.1", path = "../jacquard-derive" } 30 + jacquard-identity = { version = "0.5", path = "../jacquard-identity", optional = true } 29 31 miette.workspace = true 32 + multibase = { version = "0.9.1", optional = true } 30 33 serde.workspace = true 31 34 serde_html_form.workspace = true 32 35 serde_ipld_dagcbor.workspace = true ··· 36 39 tower-http = { version = "0.6.6", features = ["trace", "tracing"] } 37 40 tracing = "0.1.41" 38 41 tracing-subscriber = { version = "0.3.20", features = ["env-filter", "time"] } 42 + urlencoding.workspace = true 43 + 44 + [features] 45 + default = ["service-auth"] 46 + service-auth = ["jacquard-common/service-auth", "dep:jacquard-identity", "dep:multibase"] 39 47 40 48 [dev-dependencies] 41 49 axum-test = "18.1.0" 50 + base64.workspace = true 51 + chrono.workspace = true 52 + k256 = { version = "0.13", features = ["ecdsa"] } 53 + rand = "0.8" 54 + reqwest.workspace = true 55 + serde_json.workspace = true 42 56 tokio-test = "0.4.4" 57 + tower = { version = "0.5", features = ["util"] }
+76
crates/jacquard-axum/src/did_web.rs
··· 1 + //! Helper for serving did:web DID documents 2 + //! 3 + //! did:web DIDs resolve to HTTPS endpoints serving DID documents. This module 4 + //! provides a router that serves your service's DID document at `/.well-known/did.json`. 5 + //! 6 + //! # Example 7 + //! 8 + //! ```no_run 9 + //! use axum::Router; 10 + //! use jacquard_axum::did_web::did_web_router; 11 + //! use jacquard_common::types::did_doc::DidDocument; 12 + //! 13 + //! #[tokio::main] 14 + //! async fn main() { 15 + //! // Your DID document (typically loaded from config or generated) 16 + //! let did_doc: DidDocument = serde_json::from_str(r#"{ 17 + //! "id": "did:web:feedgen.example.com", 18 + //! "verificationMethod": [{ 19 + //! "id": "did:web:feedgen.example.com#atproto", 20 + //! "type": "Multikey", 21 + //! "controller": "did:web:feedgen.example.com", 22 + //! "publicKeyMultibase": "zQ3sh..." 23 + //! }] 24 + //! }"#).unwrap(); 25 + //! 26 + //! let app = Router::new() 27 + //! .merge(did_web_router(did_doc)); 28 + //! 29 + //! let listener = tokio::net::TcpListener::bind("0.0.0.0:443") 30 + //! .await 31 + //! .unwrap(); 32 + //! axum::serve(listener, app).await.unwrap(); 33 + //! } 34 + //! ``` 35 + 36 + use axum::{ 37 + Json, Router, 38 + http::{HeaderValue, StatusCode, header}, 39 + response::IntoResponse, 40 + routing::get, 41 + }; 42 + use jacquard_common::types::did_doc::DidDocument; 43 + 44 + /// Create a router that serves a DID document at `/.well-known/did.json` 45 + /// 46 + /// Returns a Router that can be merged into your main application router. 47 + /// The DID document is cloned on each request. 48 + /// 49 + /// # Example 50 + /// 51 + /// ```no_run 52 + /// use axum::Router; 53 + /// use jacquard_axum::did_web::did_web_router; 54 + /// use jacquard_common::types::did_doc::DidDocument; 55 + /// 56 + /// # async fn example(did_doc: DidDocument<'static>) { 57 + /// let app = Router::new() 58 + /// .merge(did_web_router(did_doc)); 59 + /// # } 60 + /// ``` 61 + pub fn did_web_router(did_doc: DidDocument<'static>) -> Router { 62 + Router::new().route( 63 + "/.well-known/did.json", 64 + get(move || async move { 65 + ( 66 + StatusCode::OK, 67 + [( 68 + header::CONTENT_TYPE, 69 + HeaderValue::from_static("application/did+json"), 70 + )], 71 + Json(did_doc.clone()), 72 + ) 73 + .into_response() 74 + }), 75 + ) 76 + }
+24 -3
crates/jacquard-axum/src/lib.rs
··· 45 45 //! The extractor deserializes to borrowed types first, then converts to `'static` via 46 46 //! [`IntoStatic`], avoiding the DeserializeOwned requirement of the Json axum extractor and similar. 47 47 48 + pub mod did_web; 49 + #[cfg(feature = "service-auth")] 50 + pub mod service_auth; 51 + 48 52 use axum::{ 49 53 Json, Router, 50 54 body::Bytes, ··· 102 106 } 103 107 XrpcMethod::Query => { 104 108 if let Some(path_query) = req.uri().path_and_query() { 105 - let query = path_query.query().unwrap_or(""); 106 - let value: R::Request<'_> = 107 - serde_html_form::from_str::<R::Request<'_>>(query).map_err(|e| { 109 + // TODO: see if we can eliminate this now that we've fixed the deserialize impls for string types 110 + let query = 111 + urlencoding::decode(path_query.query().unwrap_or("")).map_err(|e| { 108 112 ( 109 113 StatusCode::BAD_REQUEST, 110 114 [( ··· 118 122 ) 119 123 .into_response() 120 124 })?; 125 + let value: R::Request<'_> = serde_html_form::from_str::<R::Request<'_>>( 126 + query.as_ref(), 127 + ) 128 + .map_err(|e| { 129 + ( 130 + StatusCode::BAD_REQUEST, 131 + [( 132 + header::CONTENT_TYPE, 133 + HeaderValue::from_static("application/json"), 134 + )], 135 + Json(json!({ 136 + "error": "InvalidRequest", 137 + "message": format!("failed to decode request: {}", e) 138 + })), 139 + ) 140 + .into_response() 141 + })?; 121 142 Ok(ExtractXrpc(value.into_static())) 122 143 } else { 123 144 Err((
+516
crates/jacquard-axum/src/service_auth.rs
··· 1 + //! Service authentication extractor and middleware 2 + //! 3 + //! # Example 4 + //! 5 + //! ```no_run 6 + //! use axum::{Router, routing::get}; 7 + //! use jacquard_axum::service_auth::{ServiceAuthConfig, ExtractServiceAuth}; 8 + //! use jacquard_identity::JacquardResolver; 9 + //! use jacquard_identity::resolver::ResolverOptions; 10 + //! use jacquard_common::types::string::Did; 11 + //! 12 + //! async fn handler( 13 + //! ExtractServiceAuth(auth): ExtractServiceAuth, 14 + //! ) -> String { 15 + //! format!("Authenticated as {}", auth.did()) 16 + //! } 17 + //! 18 + //! #[tokio::main] 19 + //! async fn main() { 20 + //! let resolver = JacquardResolver::new( 21 + //! reqwest::Client::new(), 22 + //! ResolverOptions::default(), 23 + //! ); 24 + //! let config = ServiceAuthConfig::new( 25 + //! Did::new_static("did:web:feedgen.example.com").unwrap(), 26 + //! resolver, 27 + //! ); 28 + //! 29 + //! let app = Router::new() 30 + //! .route("/xrpc/app.bsky.feed.getFeedSkeleton", get(handler)) 31 + //! .with_state(config); 32 + //! 33 + //! let listener = tokio::net::TcpListener::bind("0.0.0.0:3000") 34 + //! .await 35 + //! .unwrap(); 36 + //! axum::serve(listener, app).await.unwrap(); 37 + //! } 38 + //! ``` 39 + 40 + use axum::{ 41 + Json, 42 + extract::FromRequestParts, 43 + http::{HeaderValue, StatusCode, header, request::Parts}, 44 + middleware::Next, 45 + response::{IntoResponse, Response}, 46 + }; 47 + use jacquard_common::{ 48 + CowStr, IntoStatic, 49 + service_auth::{self, PublicKey}, 50 + types::{ 51 + did_doc::VerificationMethod, 52 + string::{Did, Nsid}, 53 + }, 54 + }; 55 + use jacquard_identity::resolver::IdentityResolver; 56 + use serde_json::json; 57 + use std::sync::Arc; 58 + use thiserror::Error; 59 + 60 + /// Trait for providing service authentication configuration. 61 + /// 62 + /// This trait allows custom state types to provide service auth configuration 63 + /// without requiring `ServiceAuthConfig<R>` directly. 64 + pub trait ServiceAuth { 65 + /// The identity resolver type 66 + type Resolver: IdentityResolver; 67 + 68 + /// Get the service DID (expected audience) 69 + fn service_did(&self) -> &Did<'_>; 70 + 71 + /// Get a reference to the identity resolver 72 + fn resolver(&self) -> &Self::Resolver; 73 + 74 + /// Whether to require the `lxm` (method binding) field 75 + fn require_lxm(&self) -> bool; 76 + } 77 + 78 + /// Configuration for service auth verification. 79 + /// 80 + /// This should be stored in your Axum app state and will be extracted 81 + /// by the `ExtractServiceAuth` extractor. 82 + pub struct ServiceAuthConfig<R> { 83 + /// The DID of your service (the expected audience) 84 + service_did: Did<'static>, 85 + /// Identity resolver for fetching DID documents 86 + resolver: Arc<R>, 87 + /// Whether to require the `lxm` (method binding) field 88 + require_lxm: bool, 89 + } 90 + 91 + impl<R> Clone for ServiceAuthConfig<R> { 92 + fn clone(&self) -> Self { 93 + Self { 94 + service_did: self.service_did.clone(), 95 + resolver: Arc::clone(&self.resolver), 96 + require_lxm: self.require_lxm, 97 + } 98 + } 99 + } 100 + 101 + impl<R: IdentityResolver> ServiceAuthConfig<R> { 102 + /// Create a new service auth config. 103 + /// 104 + /// This enables `lxm` (method binding). If you need backward compatibility, 105 + /// use `ServiceAuthConfig::new_legacy()` 106 + pub fn new(service_did: Did<'static>, resolver: R) -> Self { 107 + Self { 108 + service_did, 109 + resolver: Arc::new(resolver), 110 + require_lxm: true, 111 + } 112 + } 113 + 114 + /// Create a new service auth config. 115 + /// 116 + /// `lxm` (method binding) is disabled for backwards compatibility 117 + pub fn new_legacy(service_did: Did<'static>, resolver: R) -> Self { 118 + Self { 119 + service_did, 120 + resolver: Arc::new(resolver), 121 + require_lxm: false, 122 + } 123 + } 124 + 125 + /// Set whether to require the `lxm` field (method binding). 126 + /// 127 + /// When enabled, the JWT must contain an `lxm` field matching the requested endpoint. 128 + /// This prevents token reuse across different methods. 129 + pub fn require_lxm(mut self, require: bool) -> Self { 130 + self.require_lxm = require; 131 + self 132 + } 133 + 134 + /// Get the service DID. 135 + pub fn service_did(&self) -> &Did<'static> { 136 + &self.service_did 137 + } 138 + 139 + /// Get a reference to the identity resolver. 140 + pub fn resolver(&self) -> &R { 141 + &self.resolver 142 + } 143 + } 144 + 145 + impl<R: IdentityResolver> ServiceAuth for ServiceAuthConfig<R> { 146 + type Resolver = R; 147 + 148 + fn service_did(&self) -> &Did<'_> { 149 + &self.service_did 150 + } 151 + 152 + fn resolver(&self) -> &Self::Resolver { 153 + &self.resolver 154 + } 155 + 156 + fn require_lxm(&self) -> bool { 157 + self.require_lxm 158 + } 159 + } 160 + 161 + /// Verified service authentication information. 162 + /// 163 + /// This is the result of successfully verifying a service auth JWT. 164 + /// This type is extracted by the `ExtractServiceAuth` extractor. 165 + #[derive(Debug, Clone, jacquard_derive::IntoStatic)] 166 + pub struct VerifiedServiceAuth<'a> { 167 + /// The authenticated user's DID (from `iss` claim) 168 + did: Did<'a>, 169 + /// The audience (should match your service DID) 170 + aud: Did<'a>, 171 + /// The lexicon method NSID, if present 172 + lxm: Option<Nsid<'a>>, 173 + /// JWT ID (nonce), if present 174 + jti: Option<CowStr<'a>>, 175 + } 176 + 177 + impl<'a> VerifiedServiceAuth<'a> { 178 + /// Get the authenticated user's DID. 179 + pub fn did(&self) -> &Did<'a> { 180 + &self.did 181 + } 182 + 183 + /// Get the audience (your service DID). 184 + pub fn aud(&self) -> &Did<'a> { 185 + &self.aud 186 + } 187 + 188 + /// Get the lexicon method NSID, if present. 189 + pub fn lxm(&self) -> Option<&Nsid<'a>> { 190 + self.lxm.as_ref() 191 + } 192 + 193 + /// Get the JWT ID (nonce), if present. 194 + /// 195 + /// You can use this for replay protection by tracking seen JTIs 196 + /// until their expiration time. 197 + pub fn jti(&self) -> Option<&str> { 198 + self.jti.as_ref().map(|j| j.as_ref()) 199 + } 200 + } 201 + 202 + /// Axum extractor for service authentication. 203 + /// 204 + /// This extracts and verifies a service auth JWT from the Authorization header, 205 + /// resolving the issuer's DID to verify the signature. 206 + /// 207 + /// # Example 208 + /// 209 + /// ```no_run 210 + /// use axum::{Router, routing::get}; 211 + /// use jacquard_axum::service_auth::{ServiceAuthConfig, ExtractServiceAuth}; 212 + /// use jacquard_identity::JacquardResolver; 213 + /// use jacquard_identity::resolver::ResolverOptions; 214 + /// use jacquard_common::types::string::Did; 215 + /// 216 + /// async fn handler( 217 + /// ExtractServiceAuth(auth): ExtractServiceAuth, 218 + /// ) -> String { 219 + /// format!("Authenticated as {}", auth.did()) 220 + /// } 221 + /// 222 + /// #[tokio::main] 223 + /// async fn main() { 224 + /// let resolver = JacquardResolver::new( 225 + /// reqwest::Client::new(), 226 + /// ResolverOptions::default(), 227 + /// ); 228 + /// let config = ServiceAuthConfig::new( 229 + /// Did::new_static("did:web:feedgen.example.com").unwrap(), 230 + /// resolver, 231 + /// ); 232 + /// 233 + /// let app = Router::new() 234 + /// .route("/xrpc/app.bsky.feed.getFeedSkeleton", get(handler)) 235 + /// .with_state(config); 236 + /// 237 + /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000") 238 + /// .await 239 + /// .unwrap(); 240 + /// axum::serve(listener, app).await.unwrap(); 241 + /// } 242 + /// ``` 243 + pub struct ExtractServiceAuth(pub VerifiedServiceAuth<'static>); 244 + 245 + /// Errors that can occur during service auth verification. 246 + #[derive(Debug, Error, miette::Diagnostic)] 247 + pub enum ServiceAuthError { 248 + /// Authorization header is missing 249 + #[error("missing Authorization header")] 250 + MissingAuthHeader, 251 + 252 + /// Authorization header is malformed (not "Bearer <token>") 253 + #[error("invalid Authorization header format")] 254 + InvalidAuthHeader, 255 + 256 + /// JWT parsing or verification failed 257 + #[error("JWT verification failed: {0}")] 258 + JwtError(#[from] service_auth::ServiceAuthError), 259 + 260 + /// DID resolution failed 261 + #[error("failed to resolve DID {did}: {source}")] 262 + DidResolutionFailed { 263 + did: Did<'static>, 264 + #[source] 265 + source: Box<dyn std::error::Error + Send + Sync>, 266 + }, 267 + 268 + /// No valid signing key found in DID document 269 + #[error("no valid signing key found in DID document for {0}")] 270 + NoSigningKey(Did<'static>), 271 + 272 + /// Method binding required but missing 273 + #[error("lxm (method binding) is required but missing from token")] 274 + MethodBindingRequired, 275 + 276 + /// Invalid key format 277 + #[error("invalid key format: {0}")] 278 + InvalidKey(String), 279 + } 280 + 281 + impl IntoResponse for ServiceAuthError { 282 + fn into_response(self) -> Response { 283 + let (status, error_code, message) = match &self { 284 + ServiceAuthError::MissingAuthHeader => { 285 + (StatusCode::UNAUTHORIZED, "AuthMissing", self.to_string()) 286 + } 287 + ServiceAuthError::InvalidAuthHeader => { 288 + (StatusCode::UNAUTHORIZED, "AuthMissing", self.to_string()) 289 + } 290 + ServiceAuthError::JwtError(_) => ( 291 + StatusCode::UNAUTHORIZED, 292 + "AuthenticationRequired", 293 + self.to_string(), 294 + ), 295 + ServiceAuthError::DidResolutionFailed { .. } => ( 296 + StatusCode::UNAUTHORIZED, 297 + "AuthenticationRequired", 298 + self.to_string(), 299 + ), 300 + ServiceAuthError::NoSigningKey(_) => ( 301 + StatusCode::UNAUTHORIZED, 302 + "AuthenticationRequired", 303 + self.to_string(), 304 + ), 305 + ServiceAuthError::MethodBindingRequired => ( 306 + StatusCode::UNAUTHORIZED, 307 + "AuthenticationRequired", 308 + self.to_string(), 309 + ), 310 + ServiceAuthError::InvalidKey(_) => ( 311 + StatusCode::UNAUTHORIZED, 312 + "AuthenticationRequired", 313 + self.to_string(), 314 + ), 315 + }; 316 + 317 + tracing::warn!("Service auth failed: {}", message); 318 + 319 + ( 320 + status, 321 + [( 322 + header::CONTENT_TYPE, 323 + HeaderValue::from_static("application/json"), 324 + )], 325 + Json(json!({ 326 + "error": error_code, 327 + "message": message, 328 + })), 329 + ) 330 + .into_response() 331 + } 332 + } 333 + 334 + impl<S> FromRequestParts<S> for ExtractServiceAuth 335 + where 336 + S: ServiceAuth + Send + Sync, 337 + S::Resolver: Send + Sync, 338 + { 339 + type Rejection = ServiceAuthError; 340 + 341 + fn from_request_parts( 342 + parts: &mut Parts, 343 + state: &S, 344 + ) -> impl std::future::Future<Output = Result<Self, Self::Rejection>> + Send { 345 + async move { 346 + // Extract Authorization header 347 + let auth_header = parts 348 + .headers 349 + .get(header::AUTHORIZATION) 350 + .ok_or(ServiceAuthError::MissingAuthHeader)?; 351 + 352 + // Parse Bearer token 353 + let auth_str = auth_header 354 + .to_str() 355 + .map_err(|_| ServiceAuthError::InvalidAuthHeader)?; 356 + 357 + let token = auth_str 358 + .strip_prefix("Bearer ") 359 + .ok_or(ServiceAuthError::InvalidAuthHeader)?; 360 + 361 + // Parse JWT 362 + let parsed = service_auth::parse_jwt(token)?; 363 + 364 + // Get claims for DID resolution 365 + let claims = parsed.claims(); 366 + 367 + // Resolve DID to get signing key (do this before checking claims) 368 + let did_doc = state 369 + .resolver() 370 + .resolve_did_doc(&claims.iss) 371 + .await 372 + .map_err(|e| ServiceAuthError::DidResolutionFailed { 373 + did: claims.iss.clone().into_static(), 374 + source: Box::new(e), 375 + })?; 376 + 377 + // Parse the DID document response to get verification methods 378 + let doc = did_doc 379 + .parse() 380 + .map_err(|e| ServiceAuthError::DidResolutionFailed { 381 + did: claims.iss.clone().into_static(), 382 + source: Box::new(e), 383 + })?; 384 + 385 + // Extract signing key from DID document 386 + let verification_methods = doc 387 + .verification_method 388 + .as_deref() 389 + .ok_or_else(|| ServiceAuthError::NoSigningKey(claims.iss.clone().into_static()))?; 390 + 391 + let signing_key = extract_signing_key(verification_methods) 392 + .ok_or_else(|| ServiceAuthError::NoSigningKey(claims.iss.clone().into_static()))?; 393 + 394 + // Verify signature FIRST - if this fails, nothing else matters 395 + service_auth::verify_signature(&parsed, &signing_key)?; 396 + 397 + // Now validate claims (audience, expiration, etc.) 398 + claims.validate(state.service_did())?; 399 + 400 + // Check method binding if required 401 + if state.require_lxm() && claims.lxm.is_none() { 402 + return Err(ServiceAuthError::MethodBindingRequired); 403 + } 404 + 405 + // All checks passed - return verified auth 406 + Ok(ExtractServiceAuth(VerifiedServiceAuth { 407 + did: claims.iss.clone().into_static(), 408 + aud: claims.aud.clone().into_static(), 409 + lxm: claims.lxm.as_ref().map(|l| l.clone().into_static()), 410 + jti: claims.jti.as_ref().map(|j| j.clone().into_static()), 411 + })) 412 + } 413 + } 414 + } 415 + 416 + /// Extract the signing key from a DID document's verification methods. 417 + /// 418 + /// This looks for a key with type "atproto" or the first available key 419 + /// if no atproto-specific key is found. 420 + fn extract_signing_key(methods: &[VerificationMethod]) -> Option<PublicKey> { 421 + // First try to find an atproto-specific key 422 + let atproto_method = methods 423 + .iter() 424 + .find(|m| m.r#type.as_ref() == "Multikey" || m.r#type.as_ref() == "atproto"); 425 + 426 + let method = atproto_method.or_else(|| methods.first())?; 427 + 428 + // Parse the multikey 429 + let public_key_multibase = method.public_key_multibase.as_ref()?; 430 + 431 + // Decode multibase 432 + let (_, key_bytes) = multibase::decode(public_key_multibase.as_ref()).ok()?; 433 + 434 + // First two bytes are the multicodec prefix 435 + if key_bytes.len() < 2 { 436 + return None; 437 + } 438 + 439 + let codec = &key_bytes[..2]; 440 + let key_material = &key_bytes[2..]; 441 + 442 + match codec { 443 + // p256-pub (0x1200) 444 + [0x80, 0x24] => PublicKey::from_p256_bytes(key_material).ok(), 445 + // secp256k1-pub (0xe7) 446 + [0xe7, 0x01] => PublicKey::from_k256_bytes(key_material).ok(), 447 + _ => None, 448 + } 449 + } 450 + 451 + /// Middleware for verifying service authentication on all requests. 452 + /// 453 + /// This middleware extracts and verifies the service auth JWT, then adds the 454 + /// `VerifiedServiceAuth` to request extensions for downstream handlers to access. 455 + /// 456 + /// # Example 457 + /// 458 + /// ```no_run 459 + /// use axum::{Router, routing::get, middleware, Extension}; 460 + /// use jacquard_axum::service_auth::{ServiceAuthConfig, service_auth_middleware}; 461 + /// use jacquard_identity::JacquardResolver; 462 + /// use jacquard_identity::resolver::ResolverOptions; 463 + /// use jacquard_common::types::string::Did; 464 + /// 465 + /// async fn handler( 466 + /// Extension(auth): Extension<jacquard_axum::service_auth::VerifiedServiceAuth<'static>>, 467 + /// ) -> String { 468 + /// format!("Authenticated as {}", auth.did()) 469 + /// } 470 + /// 471 + /// #[tokio::main] 472 + /// async fn main() { 473 + /// let resolver = JacquardResolver::new( 474 + /// reqwest::Client::new(), 475 + /// ResolverOptions::default(), 476 + /// ); 477 + /// let config = ServiceAuthConfig::new( 478 + /// Did::new_static("did:web:feedgen.example.com").unwrap(), 479 + /// resolver, 480 + /// ); 481 + /// 482 + /// let app = Router::new() 483 + /// .route("/xrpc/app.bsky.feed.getFeedSkeleton", get(handler)) 484 + /// .layer(middleware::from_fn_with_state( 485 + /// config.clone(), 486 + /// service_auth_middleware::<ServiceAuthConfig<JacquardResolver>>, 487 + /// )) 488 + /// .with_state(config); 489 + /// 490 + /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000") 491 + /// .await 492 + /// .unwrap(); 493 + /// axum::serve(listener, app).await.unwrap(); 494 + /// } 495 + /// ``` 496 + pub async fn service_auth_middleware<S>( 497 + state: axum::extract::State<S>, 498 + mut req: axum::extract::Request, 499 + next: Next, 500 + ) -> Result<Response, ServiceAuthError> 501 + where 502 + S: ServiceAuth + Send + Sync + Clone, 503 + S::Resolver: Send + Sync, 504 + { 505 + // Extract auth from request parts 506 + let (mut parts, body) = req.into_parts(); 507 + let ExtractServiceAuth(auth) = 508 + ExtractServiceAuth::from_request_parts(&mut parts, &state.0).await?; 509 + 510 + // Add auth to extensions 511 + parts.extensions.insert(auth); 512 + 513 + // Reconstruct request and continue 514 + req = axum::extract::Request::from_parts(parts, body); 515 + Ok(next.run(req).await) 516 + }
+543
crates/jacquard-axum/tests/service_auth_tests.rs
··· 1 + use axum::{ 2 + Extension, Router, 3 + body::Body, 4 + extract::Request, 5 + http::{StatusCode, header}, 6 + middleware, 7 + routing::get, 8 + }; 9 + use base64::Engine; 10 + use base64::engine::general_purpose::URL_SAFE_NO_PAD; 11 + use bytes::Bytes; 12 + use jacquard_axum::service_auth::{ 13 + ExtractServiceAuth, ServiceAuthConfig, VerifiedServiceAuth, service_auth_middleware, 14 + }; 15 + use jacquard_common::{ 16 + CowStr, IntoStatic, 17 + service_auth::JwtHeader, 18 + types::{ 19 + did::Did, 20 + did_doc::{DidDocument, VerificationMethod}, 21 + }, 22 + }; 23 + use jacquard_identity::resolver::{ 24 + DidDocResponse, IdentityError, IdentityResolver, ResolverOptions, 25 + }; 26 + use reqwest::StatusCode as ReqwestStatusCode; 27 + use serde_json::json; 28 + use std::future::Future; 29 + use tower::ServiceExt; 30 + 31 + // Test helper: create a signed JWT 32 + fn create_test_jwt( 33 + iss: &str, 34 + aud: &str, 35 + exp: i64, 36 + lxm: Option<&str>, 37 + signing_key: &k256::ecdsa::SigningKey, 38 + ) -> String { 39 + use k256::ecdsa::signature::Signer; 40 + 41 + let header = JwtHeader { 42 + alg: CowStr::new_static("ES256K"), 43 + typ: CowStr::new_static("JWT"), 44 + }; 45 + 46 + let mut claims_json = json!({ 47 + "iss": iss, 48 + "aud": aud, 49 + "exp": exp, 50 + "iat": chrono::Utc::now().timestamp(), 51 + }); 52 + 53 + if let Some(lxm_val) = lxm { 54 + claims_json["lxm"] = json!(lxm_val); 55 + } 56 + 57 + let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 58 + let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&claims_json).unwrap()); 59 + 60 + let signing_input = format!("{}.{}", header_b64, payload_b64); 61 + 62 + let signature: k256::ecdsa::Signature = signing_key.sign(signing_input.as_bytes()); 63 + let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes()); 64 + 65 + format!("{}.{}", signing_input, signature_b64) 66 + } 67 + 68 + // Test helper: create DID document with k256 key 69 + fn create_test_did_doc(did: &str, public_key: &k256::ecdsa::VerifyingKey) -> DidDocument<'static> { 70 + use std::collections::BTreeMap; 71 + 72 + // Encode as compressed SEC1 73 + let encoded_point = public_key.to_encoded_point(true); 74 + let key_bytes = encoded_point.as_bytes(); 75 + 76 + // Multicodec prefix for secp256k1-pub (0xe701) 77 + let mut multicodec_bytes = vec![0xe7, 0x01]; 78 + multicodec_bytes.extend_from_slice(key_bytes); 79 + 80 + // Multibase encode (base58btc = 'z') 81 + let multibase_key = multibase::encode(multibase::Base::Base58Btc, &multicodec_bytes); 82 + 83 + DidDocument { 84 + id: Did::new_owned(did).unwrap().into_static(), 85 + also_known_as: None, 86 + verification_method: Some(vec![VerificationMethod { 87 + id: CowStr::Owned(format!("{}#atproto", did).into()), 88 + r#type: CowStr::new_static("Multikey"), 89 + controller: Some(CowStr::Owned(did.into())), 90 + public_key_multibase: Some(CowStr::Owned(multibase_key.into())), 91 + extra_data: BTreeMap::new(), 92 + }]), 93 + service: None, 94 + extra_data: BTreeMap::new(), 95 + } 96 + } 97 + 98 + // Mock resolver for tests 99 + #[derive(Clone)] 100 + struct MockResolver { 101 + did_doc: DidDocument<'static>, 102 + options: ResolverOptions, 103 + } 104 + 105 + impl MockResolver { 106 + fn new(did_doc: DidDocument<'static>) -> Self { 107 + Self { 108 + did_doc, 109 + options: ResolverOptions::default(), 110 + } 111 + } 112 + } 113 + 114 + impl IdentityResolver for MockResolver { 115 + fn options(&self) -> &ResolverOptions { 116 + &self.options 117 + } 118 + 119 + fn resolve_handle( 120 + &self, 121 + _handle: &jacquard_common::types::string::Handle<'_>, 122 + ) -> impl Future<Output = Result<Did<'static>, IdentityError>> + Send { 123 + async { Err(IdentityError::InvalidWellKnown) } 124 + } 125 + 126 + fn resolve_did_doc( 127 + &self, 128 + _did: &Did<'_>, 129 + ) -> impl Future<Output = Result<DidDocResponse, IdentityError>> + Send { 130 + let doc = self.did_doc.clone(); 131 + async move { 132 + let json = serde_json::to_vec(&doc).unwrap(); 133 + Ok(DidDocResponse { 134 + buffer: Bytes::from(json), 135 + status: ReqwestStatusCode::OK, 136 + requested: Some(doc.id.clone()), 137 + }) 138 + } 139 + } 140 + } 141 + 142 + #[tokio::test] 143 + async fn test_extractor_with_valid_jwt() { 144 + // Generate keypair 145 + let signing_key = k256::ecdsa::SigningKey::random(&mut rand::thread_rng()); 146 + let verifying_key = signing_key.verifying_key(); 147 + 148 + // Create test DID and JWT 149 + let user_did = "did:plc:test123"; 150 + let service_did = "did:web:feedgen.example.com"; 151 + let exp = chrono::Utc::now().timestamp() + 300; 152 + 153 + // JWT with lxm 154 + let jwt = create_test_jwt( 155 + user_did, 156 + service_did, 157 + exp, 158 + Some("app.bsky.feed.getFeedSkeleton"), 159 + &signing_key, 160 + ); 161 + 162 + // Create mock resolver 163 + let did_doc = create_test_did_doc(user_did, verifying_key); 164 + let resolver = MockResolver::new(did_doc); 165 + 166 + // Create config (default: require_lxm = true) 167 + let config = ServiceAuthConfig::new(Did::new_static(service_did).unwrap(), resolver); 168 + 169 + // Create handler 170 + async fn handler(ExtractServiceAuth(auth): ExtractServiceAuth) -> String { 171 + format!("Authenticated as {}", auth.did()) 172 + } 173 + 174 + let app = Router::new() 175 + .route("/test", get(handler)) 176 + .with_state(config); 177 + 178 + // Create request with JWT 179 + let request = Request::builder() 180 + .uri("/test") 181 + .header(header::AUTHORIZATION, format!("Bearer {}", jwt)) 182 + .body(Body::empty()) 183 + .unwrap(); 184 + 185 + // Send request 186 + let response = app.oneshot(request).await.unwrap(); 187 + 188 + assert_eq!(response.status(), StatusCode::OK); 189 + 190 + let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX) 191 + .await 192 + .unwrap(); 193 + let body = String::from_utf8(body_bytes.to_vec()).unwrap(); 194 + 195 + assert_eq!(body, format!("Authenticated as {}", user_did)); 196 + } 197 + 198 + #[tokio::test] 199 + async fn test_extractor_with_expired_jwt() { 200 + let signing_key = k256::ecdsa::SigningKey::random(&mut rand::thread_rng()); 201 + let verifying_key = signing_key.verifying_key(); 202 + 203 + let user_did = "did:plc:test123"; 204 + let service_did = "did:web:feedgen.example.com"; 205 + let exp = chrono::Utc::now().timestamp() - 300; // Expired 206 + 207 + let jwt = create_test_jwt(user_did, service_did, exp, None, &signing_key); 208 + 209 + let did_doc = create_test_did_doc(user_did, verifying_key); 210 + let resolver = MockResolver::new(did_doc); 211 + 212 + let config = ServiceAuthConfig::new(Did::new_static(service_did).unwrap(), resolver); 213 + 214 + async fn handler(ExtractServiceAuth(auth): ExtractServiceAuth) -> String { 215 + format!("Authenticated as {}", auth.did()) 216 + } 217 + 218 + let app = Router::new() 219 + .route("/test", get(handler)) 220 + .with_state(config); 221 + 222 + let request = Request::builder() 223 + .uri("/test") 224 + .header(header::AUTHORIZATION, format!("Bearer {}", jwt)) 225 + .body(Body::empty()) 226 + .unwrap(); 227 + 228 + let response = app.oneshot(request).await.unwrap(); 229 + 230 + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); 231 + } 232 + 233 + #[tokio::test] 234 + async fn test_extractor_with_wrong_audience() { 235 + let signing_key = k256::ecdsa::SigningKey::random(&mut rand::thread_rng()); 236 + let verifying_key = signing_key.verifying_key(); 237 + 238 + let user_did = "did:plc:test123"; 239 + let service_did = "did:web:feedgen.example.com"; 240 + let wrong_aud = "did:web:other.example.com"; 241 + let exp = chrono::Utc::now().timestamp() + 300; 242 + 243 + let jwt = create_test_jwt(user_did, wrong_aud, exp, None, &signing_key); 244 + 245 + let did_doc = create_test_did_doc(user_did, verifying_key); 246 + let resolver = MockResolver::new(did_doc); 247 + 248 + let config = ServiceAuthConfig::new(Did::new_static(service_did).unwrap(), resolver); 249 + 250 + async fn handler(ExtractServiceAuth(auth): ExtractServiceAuth) -> String { 251 + format!("Authenticated as {}", auth.did()) 252 + } 253 + 254 + let app = Router::new() 255 + .route("/test", get(handler)) 256 + .with_state(config); 257 + 258 + let request = Request::builder() 259 + .uri("/test") 260 + .header(header::AUTHORIZATION, format!("Bearer {}", jwt)) 261 + .body(Body::empty()) 262 + .unwrap(); 263 + 264 + let response = app.oneshot(request).await.unwrap(); 265 + 266 + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); 267 + } 268 + 269 + #[tokio::test] 270 + async fn test_extractor_missing_auth_header() { 271 + let signing_key = k256::ecdsa::SigningKey::random(&mut rand::thread_rng()); 272 + let verifying_key = signing_key.verifying_key(); 273 + 274 + let user_did = "did:plc:test123"; 275 + let service_did = "did:web:feedgen.example.com"; 276 + 277 + let did_doc = create_test_did_doc(user_did, verifying_key); 278 + let resolver = MockResolver::new(did_doc); 279 + 280 + let config = ServiceAuthConfig::new(Did::new_static(service_did).unwrap(), resolver); 281 + 282 + async fn handler(ExtractServiceAuth(auth): ExtractServiceAuth) -> String { 283 + format!("Authenticated as {}", auth.did()) 284 + } 285 + 286 + let app = Router::new() 287 + .route("/test", get(handler)) 288 + .with_state(config); 289 + 290 + let request = Request::builder().uri("/test").body(Body::empty()).unwrap(); 291 + 292 + let response = app.oneshot(request).await.unwrap(); 293 + 294 + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); 295 + } 296 + 297 + #[tokio::test] 298 + async fn test_middleware_with_valid_jwt() { 299 + let signing_key = k256::ecdsa::SigningKey::random(&mut rand::thread_rng()); 300 + let verifying_key = signing_key.verifying_key(); 301 + 302 + let user_did = "did:plc:test123"; 303 + let service_did = "did:web:feedgen.example.com"; 304 + let exp = chrono::Utc::now().timestamp() + 300; 305 + 306 + // JWT with lxm 307 + let jwt = create_test_jwt( 308 + user_did, 309 + service_did, 310 + exp, 311 + Some("app.bsky.feed.getFeedSkeleton"), 312 + &signing_key, 313 + ); 314 + 315 + let did_doc = create_test_did_doc(user_did, verifying_key); 316 + let resolver = MockResolver::new(did_doc); 317 + 318 + // Create config (default: require_lxm = true) 319 + let config = ServiceAuthConfig::new(Did::new_static(service_did).unwrap(), resolver); 320 + 321 + async fn handler(Extension(auth): Extension<VerifiedServiceAuth<'static>>) -> String { 322 + format!("Authenticated as {}", auth.did()) 323 + } 324 + 325 + let app = Router::new() 326 + .route("/test", get(handler)) 327 + .layer(middleware::from_fn_with_state( 328 + config.clone(), 329 + service_auth_middleware::<ServiceAuthConfig<MockResolver>>, 330 + )) 331 + .with_state(config); 332 + 333 + let request = Request::builder() 334 + .uri("/test") 335 + .header(header::AUTHORIZATION, format!("Bearer {}", jwt)) 336 + .body(Body::empty()) 337 + .unwrap(); 338 + 339 + let response = app.oneshot(request).await.unwrap(); 340 + 341 + assert_eq!(response.status(), StatusCode::OK); 342 + 343 + let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX) 344 + .await 345 + .unwrap(); 346 + let body = String::from_utf8(body_bytes.to_vec()).unwrap(); 347 + 348 + assert_eq!(body, format!("Authenticated as {}", user_did)); 349 + } 350 + 351 + #[tokio::test] 352 + async fn test_require_lxm() { 353 + let signing_key = k256::ecdsa::SigningKey::random(&mut rand::thread_rng()); 354 + let verifying_key = signing_key.verifying_key(); 355 + 356 + let user_did = "did:plc:test123"; 357 + let service_did = "did:web:feedgen.example.com"; 358 + let exp = chrono::Utc::now().timestamp() + 300; 359 + 360 + // JWT without lxm 361 + let jwt = create_test_jwt(user_did, service_did, exp, None, &signing_key); 362 + 363 + let did_doc = create_test_did_doc(user_did, verifying_key); 364 + let resolver = MockResolver::new(did_doc); 365 + 366 + let config = 367 + ServiceAuthConfig::new(Did::new_static(service_did).unwrap(), resolver).require_lxm(true); 368 + 369 + async fn handler(ExtractServiceAuth(auth): ExtractServiceAuth) -> String { 370 + format!("Authenticated as {}", auth.did()) 371 + } 372 + 373 + let app = Router::new() 374 + .route("/test", get(handler)) 375 + .with_state(config); 376 + 377 + let request = Request::builder() 378 + .uri("/test") 379 + .header(header::AUTHORIZATION, format!("Bearer {}", jwt)) 380 + .body(Body::empty()) 381 + .unwrap(); 382 + 383 + let response = app.oneshot(request).await.unwrap(); 384 + 385 + // Should fail because lxm is required but missing 386 + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); 387 + } 388 + 389 + #[tokio::test] 390 + async fn test_with_lxm_present() { 391 + let signing_key = k256::ecdsa::SigningKey::random(&mut rand::thread_rng()); 392 + let verifying_key = signing_key.verifying_key(); 393 + 394 + let user_did = "did:plc:test123"; 395 + let service_did = "did:web:feedgen.example.com"; 396 + let exp = chrono::Utc::now().timestamp() + 300; 397 + 398 + // JWT with lxm 399 + let jwt = create_test_jwt( 400 + user_did, 401 + service_did, 402 + exp, 403 + Some("app.bsky.feed.getFeedSkeleton"), 404 + &signing_key, 405 + ); 406 + 407 + let did_doc = create_test_did_doc(user_did, verifying_key); 408 + let resolver = MockResolver::new(did_doc); 409 + 410 + let config = 411 + ServiceAuthConfig::new(Did::new_static(service_did).unwrap(), resolver).require_lxm(true); 412 + 413 + async fn handler(ExtractServiceAuth(auth): ExtractServiceAuth) -> String { 414 + format!( 415 + "Authenticated as {} for {}", 416 + auth.did(), 417 + auth.lxm().unwrap() 418 + ) 419 + } 420 + 421 + let app = Router::new() 422 + .route("/test", get(handler)) 423 + .with_state(config); 424 + 425 + let request = Request::builder() 426 + .uri("/test") 427 + .header(header::AUTHORIZATION, format!("Bearer {}", jwt)) 428 + .body(Body::empty()) 429 + .unwrap(); 430 + 431 + let response = app.oneshot(request).await.unwrap(); 432 + 433 + assert_eq!(response.status(), StatusCode::OK); 434 + 435 + let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX) 436 + .await 437 + .unwrap(); 438 + let body = String::from_utf8(body_bytes.to_vec()).unwrap(); 439 + 440 + assert_eq!( 441 + body, 442 + format!( 443 + "Authenticated as {} for app.bsky.feed.getFeedSkeleton", 444 + user_did 445 + ) 446 + ); 447 + } 448 + 449 + #[tokio::test] 450 + async fn test_legacy_without_lxm() { 451 + let signing_key = k256::ecdsa::SigningKey::random(&mut rand::thread_rng()); 452 + let verifying_key = signing_key.verifying_key(); 453 + 454 + let user_did = "did:plc:test123"; 455 + let service_did = "did:web:feedgen.example.com"; 456 + let exp = chrono::Utc::now().timestamp() + 300; 457 + 458 + // JWT without lxm 459 + let jwt = create_test_jwt(user_did, service_did, exp, None, &signing_key); 460 + 461 + let did_doc = create_test_did_doc(user_did, verifying_key); 462 + let resolver = MockResolver::new(did_doc); 463 + 464 + // Legacy config: lxm not required 465 + let config = 466 + ServiceAuthConfig::new(Did::new_static(service_did).unwrap(), resolver).require_lxm(false); 467 + 468 + async fn handler(ExtractServiceAuth(auth): ExtractServiceAuth) -> String { 469 + format!("Authenticated as {}", auth.did()) 470 + } 471 + 472 + let app = Router::new() 473 + .route("/test", get(handler)) 474 + .with_state(config); 475 + 476 + let request = Request::builder() 477 + .uri("/test") 478 + .header(header::AUTHORIZATION, format!("Bearer {}", jwt)) 479 + .body(Body::empty()) 480 + .unwrap(); 481 + 482 + let response = app.oneshot(request).await.unwrap(); 483 + 484 + // Should succeed because lxm is not required 485 + assert_eq!(response.status(), StatusCode::OK); 486 + 487 + let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX) 488 + .await 489 + .unwrap(); 490 + let body = String::from_utf8(body_bytes.to_vec()).unwrap(); 491 + 492 + assert_eq!(body, format!("Authenticated as {}", user_did)); 493 + } 494 + 495 + #[tokio::test] 496 + async fn test_invalid_signature() { 497 + // Real JWT token from did:plc:uc7pehijmk5jrllip4cglxdd with bogus signature 498 + let token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NksifQ.eyJpYXQiOjE3NjAzOTMyMzUsImlzcyI6ImRpZDpwbGM6dWM3cGVoaWptazVqcmxsaXA0Y2dseGRkIiwiYXVkIjoiZGlkOndlYjpkZXYucGRzbW9vdmVyLmNvbSIsImV4cCI6MTc2MDM5MzI5NSwibHhtIjoiY29tLnBkc21vb3Zlci5iYWNrdXAuc2lnblVwIiwianRpIjoiMTk0MDQzMzQyNmMyNTNlZjhmNmYxZDJjZWE1YzI0NGMifQ.h5BrgYE"; 499 + 500 + // Real DID document for did:plc:uc7pehijmk5jrllip4cglxdd 501 + let did_doc_json = r##"{ 502 + "id": "did:plc:uc7pehijmk5jrllip4cglxdd", 503 + "alsoKnownAs": ["at://bailey.skeetcentral.com"], 504 + "verificationMethod": [{ 505 + "controller": "did:plc:uc7pehijmk5jrllip4cglxdd", 506 + "id": "did:plc:uc7pehijmk5jrllip4cglxdd#atproto", 507 + "publicKeyMultibase": "zQ3shNBS3N4EB3vX5G1HoxFkS8tDLFXUHaV85rHQZgVM88rM5", 508 + "type": "Multikey" 509 + }], 510 + "service": [{ 511 + "id": "#atproto_pds", 512 + "serviceEndpoint": "https://skeetcentral.com", 513 + "type": "AtprotoPersonalDataServer" 514 + }] 515 + }"##; 516 + 517 + let did_doc: DidDocument = serde_json::from_str(did_doc_json).unwrap(); 518 + let resolver = MockResolver::new(did_doc); 519 + 520 + let config = ServiceAuthConfig::new( 521 + Did::new_static("did:web:dev.pdsmoover.com").unwrap(), 522 + resolver, 523 + ); 524 + 525 + async fn handler(ExtractServiceAuth(auth): ExtractServiceAuth) -> String { 526 + format!("Authenticated as {}", auth.did()) 527 + } 528 + 529 + let app = Router::new() 530 + .route("/test", get(handler)) 531 + .with_state(config); 532 + 533 + let request = Request::builder() 534 + .uri("/test") 535 + .header(header::AUTHORIZATION, format!("Bearer {}", token)) 536 + .body(Body::empty()) 537 + .unwrap(); 538 + 539 + let response = app.oneshot(request).await.unwrap(); 540 + 541 + // Should fail due to invalid signature 542 + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); 543 + }
+5 -3
crates/jacquard-common/Cargo.toml
··· 39 39 reqwest = { workspace = true, optional = true, features = ["charset", "http2", "json", "system-proxy", "gzip", "rustls-tls"] } 40 40 serde_ipld_dagcbor.workspace = true 41 41 trait-variant.workspace = true 42 + signature = { version = "2", optional = true } 42 43 43 44 [features] 44 - default = [] 45 + default = ["service-auth", "reqwest-client", "crypto"] 45 46 crypto = [] 46 47 crypto-ed25519 = ["crypto", "dep:ed25519-dalek"] 47 - crypto-k256 = ["crypto", "dep:k256"] 48 - crypto-p256 = ["crypto", "dep:p256"] 48 + crypto-k256 = ["crypto", "dep:k256", "k256/ecdsa"] 49 + crypto-p256 = ["crypto", "dep:p256", "p256/ecdsa"] 50 + service-auth = ["crypto-k256", "crypto-p256", "dep:signature"] 49 51 reqwest-client = ["dep:reqwest"] 50 52 51 53 [dependencies.ed25519-dalek]
+47 -36
crates/jacquard-common/src/cowstr.rs
··· 1 - use serde::{Deserialize, Serialize}; 1 + use serde::{Deserialize, Deserializer, Serialize}; 2 2 use smol_str::SmolStr; 3 3 use std::{ 4 4 borrow::Cow, ··· 283 283 } 284 284 } 285 285 286 + /// Deserialization helper for things that wrap a CowStr 287 + pub struct CowStrVisitor; 288 + 289 + impl<'de> serde::de::Visitor<'de> for CowStrVisitor { 290 + type Value = CowStr<'de>; 291 + 292 + #[inline] 293 + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { 294 + write!(formatter, "a string") 295 + } 296 + 297 + #[inline] 298 + fn visit_str<E>(self, v: &str) -> Result<Self::Value, E> 299 + where 300 + E: serde::de::Error, 301 + { 302 + Ok(CowStr::copy_from_str(v)) 303 + } 304 + 305 + #[inline] 306 + fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E> 307 + where 308 + E: serde::de::Error, 309 + { 310 + Ok(CowStr::Borrowed(v)) 311 + } 312 + 313 + #[inline] 314 + fn visit_string<E>(self, v: String) -> Result<Self::Value, E> 315 + where 316 + E: serde::de::Error, 317 + { 318 + Ok(v.into()) 319 + } 320 + } 321 + 286 322 impl<'de, 'a> Deserialize<'de> for CowStr<'a> 287 323 where 288 324 'de: 'a, ··· 292 328 where 293 329 D: serde::Deserializer<'de>, 294 330 { 295 - struct CowStrVisitor; 296 - 297 - impl<'de> serde::de::Visitor<'de> for CowStrVisitor { 298 - type Value = CowStr<'de>; 299 - 300 - #[inline] 301 - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { 302 - write!(formatter, "a string") 303 - } 304 - 305 - #[inline] 306 - fn visit_str<E>(self, v: &str) -> Result<Self::Value, E> 307 - where 308 - E: serde::de::Error, 309 - { 310 - Ok(CowStr::copy_from_str(v)) 311 - } 312 - 313 - #[inline] 314 - fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E> 315 - where 316 - E: serde::de::Error, 317 - { 318 - Ok(CowStr::Borrowed(v)) 319 - } 320 - 321 - #[inline] 322 - fn visit_string<E>(self, v: String) -> Result<Self::Value, E> 323 - where 324 - E: serde::de::Error, 325 - { 326 - Ok(v.into()) 327 - } 328 - } 329 - 330 331 deserializer.deserialize_str(CowStrVisitor) 331 332 } 333 + } 334 + 335 + /// Serde helper for deserializing stuff when you want an owned version 336 + pub fn deserialize_owned<'de, T, D>(deserializer: D) -> Result<<T as IntoStatic>::Output, D::Error> 337 + where 338 + T: Deserialize<'de> + IntoStatic, 339 + D: Deserializer<'de>, 340 + { 341 + let value = T::deserialize(deserializer)?; 342 + Ok(value.into_static()) 332 343 } 333 344 334 345 /// Convert to a CowStr.
+3
crates/jacquard-common/src/lib.rs
··· 213 213 pub mod macros; 214 214 /// Generic session storage traits and utilities. 215 215 pub mod session; 216 + /// Service authentication JWT parsing and verification. 217 + #[cfg(feature = "service-auth")] 218 + pub mod service_auth; 216 219 /// Baseline fundamental AT Protocol data types. 217 220 pub mod types; 218 221 // XRPC protocol types and traits
+480
crates/jacquard-common/src/service_auth.rs
··· 1 + //! Service authentication JWT parsing and verification for AT Protocol. 2 + //! 3 + //! Service auth is atproto's inter-service authentication mechanism. When a backend 4 + //! service (feed generator, labeler, etc.) receives requests, the PDS signs a 5 + //! short-lived JWT with the user's signing key and includes it as a Bearer token. 6 + //! 7 + //! # JWT Structure 8 + //! 9 + //! - Header: `alg` (ES256K for k256, ES256 for p256), `typ` ("JWT") 10 + //! - Payload: 11 + //! - `iss`: user's DID (issuer) 12 + //! - `aud`: target service DID (audience) 13 + //! - `exp`: expiration unix timestamp 14 + //! - `iat`: issued at unix timestamp 15 + //! - `jti`: random nonce (128-bit hex) for replay protection 16 + //! - `lxm`: lexicon method NSID (method binding) 17 + //! - Signature: signed with user's signing key from DID doc (ES256 or ES256K) 18 + 19 + use crate::CowStr; 20 + use crate::IntoStatic; 21 + use crate::types::string::{Did, Nsid}; 22 + use base64::Engine; 23 + use base64::engine::general_purpose::URL_SAFE_NO_PAD; 24 + use ouroboros::self_referencing; 25 + use serde::{Deserialize, Serialize}; 26 + use signature::Verifier; 27 + use smol_str::SmolStr; 28 + use smol_str::format_smolstr; 29 + use thiserror::Error; 30 + 31 + #[cfg(feature = "crypto-p256")] 32 + use p256::ecdsa::{Signature as P256Signature, VerifyingKey as P256VerifyingKey}; 33 + 34 + #[cfg(feature = "crypto-k256")] 35 + use k256::ecdsa::{Signature as K256Signature, VerifyingKey as K256VerifyingKey}; 36 + 37 + /// Errors that can occur during JWT parsing and verification. 38 + #[derive(Debug, Error, miette::Diagnostic)] 39 + pub enum ServiceAuthError { 40 + /// JWT format is invalid (not three base64-encoded parts separated by dots) 41 + #[error("malformed JWT: {0}")] 42 + MalformedToken(CowStr<'static>), 43 + 44 + /// Base64 decoding failed 45 + #[error("base64 decode error: {0}")] 46 + Base64Decode(#[from] base64::DecodeError), 47 + 48 + /// JSON parsing failed 49 + #[error("JSON parsing error: {0}")] 50 + JsonParse(#[from] serde_json::Error), 51 + 52 + /// Signature verification failed 53 + #[error("invalid signature")] 54 + InvalidSignature, 55 + 56 + /// Unsupported algorithm 57 + #[error("unsupported algorithm: {alg}")] 58 + UnsupportedAlgorithm { 59 + /// Algorithm name from JWT header 60 + alg: SmolStr, 61 + }, 62 + 63 + /// Token has expired 64 + #[error("token expired at {exp} (current time: {now})")] 65 + Expired { 66 + /// Expiration timestamp from token 67 + exp: i64, 68 + /// Current timestamp 69 + now: i64, 70 + }, 71 + 72 + /// Audience mismatch 73 + #[error("audience mismatch: expected {expected}, got {actual}")] 74 + AudienceMismatch { 75 + /// Expected audience DID 76 + expected: Did<'static>, 77 + /// Actual audience DID in token 78 + actual: Did<'static>, 79 + }, 80 + 81 + /// Method mismatch (lxm field) 82 + #[error("method mismatch: expected {expected}, got {actual:?}")] 83 + MethodMismatch { 84 + /// Expected method NSID 85 + expected: Nsid<'static>, 86 + /// Actual method NSID in token (if any) 87 + actual: Option<Nsid<'static>>, 88 + }, 89 + 90 + /// Missing required field 91 + #[error("missing required field: {0}")] 92 + MissingField(&'static str), 93 + 94 + /// Crypto error 95 + #[error("crypto error: {0}")] 96 + Crypto(CowStr<'static>), 97 + } 98 + 99 + /// JWT header for service auth tokens. 100 + #[derive(Debug, Clone, Serialize, Deserialize)] 101 + pub struct JwtHeader<'a> { 102 + /// Algorithm used for signing 103 + #[serde(borrow)] 104 + pub alg: CowStr<'a>, 105 + /// Type (always "JWT") 106 + #[serde(borrow)] 107 + pub typ: CowStr<'a>, 108 + } 109 + 110 + impl IntoStatic for JwtHeader<'_> { 111 + type Output = JwtHeader<'static>; 112 + 113 + fn into_static(self) -> Self::Output { 114 + JwtHeader { 115 + alg: self.alg.into_static(), 116 + typ: self.typ.into_static(), 117 + } 118 + } 119 + } 120 + 121 + /// Service authentication claims. 122 + /// 123 + /// These are the payload fields in a service auth JWT. 124 + #[derive(Debug, Clone, Serialize, Deserialize)] 125 + pub struct ServiceAuthClaims<'a> { 126 + /// Issuer (user's DID) 127 + #[serde(borrow)] 128 + pub iss: Did<'a>, 129 + 130 + /// Audience (target service DID) 131 + #[serde(borrow)] 132 + pub aud: Did<'a>, 133 + 134 + /// Expiration time (unix timestamp) 135 + pub exp: i64, 136 + 137 + /// Issued at (unix timestamp) 138 + pub iat: i64, 139 + 140 + /// JWT ID (nonce for replay protection) 141 + #[serde(borrow, skip_serializing_if = "Option::is_none")] 142 + pub jti: Option<CowStr<'a>>, 143 + 144 + /// Lexicon method NSID (method binding) 145 + #[serde(borrow, skip_serializing_if = "Option::is_none")] 146 + pub lxm: Option<Nsid<'a>>, 147 + } 148 + 149 + impl<'a> IntoStatic for ServiceAuthClaims<'a> { 150 + type Output = ServiceAuthClaims<'static>; 151 + 152 + fn into_static(self) -> Self::Output { 153 + ServiceAuthClaims { 154 + iss: self.iss.into_static(), 155 + aud: self.aud.into_static(), 156 + exp: self.exp, 157 + iat: self.iat, 158 + jti: self.jti.map(|j| j.into_static()), 159 + lxm: self.lxm.map(|l| l.into_static()), 160 + } 161 + } 162 + } 163 + 164 + impl<'a> ServiceAuthClaims<'a> { 165 + /// Validate the claims against expected values. 166 + /// 167 + /// Checks: 168 + /// - Audience matches expected DID 169 + /// - Token is not expired 170 + pub fn validate(&self, expected_aud: &Did) -> Result<(), ServiceAuthError> { 171 + // Check audience 172 + if self.aud.as_str() != expected_aud.as_str() { 173 + return Err(ServiceAuthError::AudienceMismatch { 174 + expected: expected_aud.clone().into_static(), 175 + actual: self.aud.clone().into_static(), 176 + }); 177 + } 178 + 179 + // Check expiration 180 + if self.is_expired() { 181 + let now = chrono::Utc::now().timestamp(); 182 + return Err(ServiceAuthError::Expired { exp: self.exp, now }); 183 + } 184 + 185 + Ok(()) 186 + } 187 + 188 + /// Check if the token has expired. 189 + pub fn is_expired(&self) -> bool { 190 + let now = chrono::Utc::now().timestamp(); 191 + self.exp <= now 192 + } 193 + 194 + /// Check if the method (lxm) matches the expected NSID. 195 + pub fn check_method(&self, nsid: &Nsid) -> bool { 196 + self.lxm 197 + .as_ref() 198 + .map(|lxm| lxm.as_str() == nsid.as_str()) 199 + .unwrap_or(false) 200 + } 201 + 202 + /// Require that the method (lxm) matches the expected NSID. 203 + pub fn require_method(&self, nsid: &Nsid) -> Result<(), ServiceAuthError> { 204 + if !self.check_method(nsid) { 205 + return Err(ServiceAuthError::MethodMismatch { 206 + expected: nsid.clone().into_static(), 207 + actual: self.lxm.as_ref().map(|l| l.clone().into_static()), 208 + }); 209 + } 210 + Ok(()) 211 + } 212 + } 213 + 214 + /// Parsed JWT components. 215 + /// 216 + /// This struct owns the decoded buffers and parsed components using ouroboros 217 + /// self-referencing. The header and claims borrow from their respective buffers. 218 + #[self_referencing] 219 + pub struct ParsedJwt { 220 + /// Decoded header buffer (owned) 221 + header_buf: Vec<u8>, 222 + /// Decoded payload buffer (owned) 223 + payload_buf: Vec<u8>, 224 + /// Original token string for signing_input 225 + token: String, 226 + /// Signature bytes 227 + signature: Vec<u8>, 228 + /// Parsed header borrowing from header_buf 229 + #[borrows(header_buf)] 230 + #[covariant] 231 + header: JwtHeader<'this>, 232 + /// Parsed claims borrowing from payload_buf 233 + #[borrows(payload_buf)] 234 + #[covariant] 235 + claims: ServiceAuthClaims<'this>, 236 + } 237 + 238 + impl ParsedJwt { 239 + /// Get the signing input (header.payload) for signature verification. 240 + pub fn signing_input(&self) -> &[u8] { 241 + self.with_token(|token| { 242 + let dot_pos = token.find('.').unwrap(); 243 + let second_dot_pos = token[dot_pos + 1..].find('.').unwrap() + dot_pos + 1; 244 + token[..second_dot_pos].as_bytes() 245 + }) 246 + } 247 + 248 + /// Get a reference to the header. 249 + pub fn header(&self) -> &JwtHeader<'_> { 250 + self.borrow_header() 251 + } 252 + 253 + /// Get a reference to the claims. 254 + pub fn claims(&self) -> &ServiceAuthClaims<'_> { 255 + self.borrow_claims() 256 + } 257 + 258 + /// Get a reference to the signature. 259 + pub fn signature(&self) -> &[u8] { 260 + self.borrow_signature() 261 + } 262 + 263 + /// Get owned header with 'static lifetime. 264 + pub fn into_header(self) -> JwtHeader<'static> { 265 + self.with_header(|header| header.clone().into_static()) 266 + } 267 + 268 + /// Get owned claims with 'static lifetime. 269 + pub fn into_claims(self) -> ServiceAuthClaims<'static> { 270 + self.with_claims(|claims| claims.clone().into_static()) 271 + } 272 + } 273 + 274 + /// Parse a JWT token into its components without verifying the signature. 275 + /// 276 + /// This extracts and decodes all JWT components. The header and claims are parsed 277 + /// and borrow from their respective owned buffers using ouroboros self-referencing. 278 + pub fn parse_jwt(token: &str) -> Result<ParsedJwt, ServiceAuthError> { 279 + let parts: Vec<&str> = token.split('.').collect(); 280 + if parts.len() != 3 { 281 + return Err(ServiceAuthError::MalformedToken(CowStr::new_static( 282 + "JWT must have exactly 3 parts separated by dots", 283 + ))); 284 + } 285 + 286 + let header_b64 = parts[0]; 287 + let payload_b64 = parts[1]; 288 + let signature_b64 = parts[2]; 289 + 290 + // Decode all components 291 + let header_buf = URL_SAFE_NO_PAD.decode(header_b64)?; 292 + let payload_buf = URL_SAFE_NO_PAD.decode(payload_b64)?; 293 + let signature = URL_SAFE_NO_PAD.decode(signature_b64)?; 294 + 295 + // Validate that buffers contain valid JSON for their types 296 + // We parse once here to validate, then again in the builder (unavoidable with ouroboros) 297 + let _header: JwtHeader = serde_json::from_slice(&header_buf)?; 298 + let _claims: ServiceAuthClaims = serde_json::from_slice(&payload_buf)?; 299 + 300 + Ok(ParsedJwtBuilder { 301 + header_buf, 302 + payload_buf, 303 + token: token.to_string(), 304 + signature, 305 + header_builder: |buf| { 306 + // Safe: we validated this succeeds above 307 + serde_json::from_slice(buf).expect("header was validated") 308 + }, 309 + claims_builder: |buf| { 310 + // Safe: we validated this succeeds above 311 + serde_json::from_slice(buf).expect("claims were validated") 312 + }, 313 + } 314 + .build()) 315 + } 316 + 317 + /// Public key types for signature verification. 318 + #[derive(Debug, Clone)] 319 + pub enum PublicKey { 320 + /// P-256 (ES256) public key 321 + #[cfg(feature = "crypto-p256")] 322 + P256(P256VerifyingKey), 323 + 324 + /// secp256k1 (ES256K) public key 325 + #[cfg(feature = "crypto-k256")] 326 + K256(K256VerifyingKey), 327 + } 328 + 329 + impl PublicKey { 330 + /// Create a P-256 public key from compressed or uncompressed bytes. 331 + #[cfg(feature = "crypto-p256")] 332 + pub fn from_p256_bytes(bytes: &[u8]) -> Result<Self, ServiceAuthError> { 333 + let key = P256VerifyingKey::from_sec1_bytes(bytes).map_err(|e| { 334 + ServiceAuthError::Crypto(CowStr::Owned(format_smolstr!("invalid P-256 key: {}", e))) 335 + })?; 336 + Ok(PublicKey::P256(key)) 337 + } 338 + 339 + /// Create a secp256k1 public key from compressed or uncompressed bytes. 340 + #[cfg(feature = "crypto-k256")] 341 + pub fn from_k256_bytes(bytes: &[u8]) -> Result<Self, ServiceAuthError> { 342 + let key = K256VerifyingKey::from_sec1_bytes(bytes).map_err(|e| { 343 + ServiceAuthError::Crypto(CowStr::Owned(format_smolstr!("invalid K-256 key: {}", e))) 344 + })?; 345 + Ok(PublicKey::K256(key)) 346 + } 347 + } 348 + 349 + /// Verify a JWT signature using the provided public key. 350 + /// 351 + /// The algorithm is determined by the JWT header and must match the public key type. 352 + pub fn verify_signature( 353 + parsed: &ParsedJwt, 354 + public_key: &PublicKey, 355 + ) -> Result<(), ServiceAuthError> { 356 + let alg = parsed.header().alg.as_str(); 357 + let signing_input = parsed.signing_input(); 358 + let signature = parsed.signature(); 359 + 360 + match (alg, public_key) { 361 + #[cfg(feature = "crypto-p256")] 362 + ("ES256", PublicKey::P256(key)) => { 363 + let sig = P256Signature::from_slice(signature).map_err(|e| { 364 + ServiceAuthError::Crypto(CowStr::Owned(format_smolstr!( 365 + "invalid ES256 signature: {}", 366 + e 367 + ))) 368 + })?; 369 + key.verify(signing_input, &sig) 370 + .map_err(|_| ServiceAuthError::InvalidSignature)?; 371 + Ok(()) 372 + } 373 + 374 + #[cfg(feature = "crypto-k256")] 375 + ("ES256K", PublicKey::K256(key)) => { 376 + let sig = K256Signature::from_slice(signature).map_err(|e| { 377 + ServiceAuthError::Crypto(CowStr::Owned(format_smolstr!( 378 + "invalid ES256K signature: {}", 379 + e 380 + ))) 381 + })?; 382 + key.verify(signing_input, &sig) 383 + .map_err(|_| ServiceAuthError::InvalidSignature)?; 384 + Ok(()) 385 + } 386 + 387 + _ => Err(ServiceAuthError::UnsupportedAlgorithm { 388 + alg: SmolStr::new(alg), 389 + }), 390 + } 391 + } 392 + 393 + /// Parse and verify a service auth JWT in one step, returning owned claims. 394 + /// 395 + /// This is a convenience function that combines parsing and signature verification. 396 + pub fn verify_service_jwt( 397 + token: &str, 398 + public_key: &PublicKey, 399 + ) -> Result<ServiceAuthClaims<'static>, ServiceAuthError> { 400 + let parsed = parse_jwt(token)?; 401 + verify_signature(&parsed, public_key)?; 402 + Ok(parsed.into_claims()) 403 + } 404 + 405 + #[cfg(test)] 406 + mod tests { 407 + use super::*; 408 + 409 + #[test] 410 + fn test_parse_jwt_invalid_format() { 411 + let result = parse_jwt("not.a.valid.jwt.with.too.many.parts"); 412 + assert!(matches!(result, Err(ServiceAuthError::MalformedToken(_)))); 413 + } 414 + 415 + #[test] 416 + fn test_claims_expiration() { 417 + let now = chrono::Utc::now().timestamp(); 418 + let expired_claims = ServiceAuthClaims { 419 + iss: Did::new("did:plc:test").unwrap(), 420 + aud: Did::new("did:web:example.com").unwrap(), 421 + exp: now - 100, 422 + iat: now - 200, 423 + jti: None, 424 + lxm: None, 425 + }; 426 + 427 + assert!(expired_claims.is_expired()); 428 + 429 + let valid_claims = ServiceAuthClaims { 430 + iss: Did::new("did:plc:test").unwrap(), 431 + aud: Did::new("did:web:example.com").unwrap(), 432 + exp: now + 100, 433 + iat: now, 434 + jti: None, 435 + lxm: None, 436 + }; 437 + 438 + assert!(!valid_claims.is_expired()); 439 + } 440 + 441 + #[test] 442 + fn test_audience_validation() { 443 + let now = chrono::Utc::now().timestamp(); 444 + let claims = ServiceAuthClaims { 445 + iss: Did::new("did:plc:test").unwrap(), 446 + aud: Did::new("did:web:example.com").unwrap(), 447 + exp: now + 100, 448 + iat: now, 449 + jti: None, 450 + lxm: None, 451 + }; 452 + 453 + let expected_aud = Did::new("did:web:example.com").unwrap(); 454 + assert!(claims.validate(&expected_aud).is_ok()); 455 + 456 + let wrong_aud = Did::new("did:web:wrong.com").unwrap(); 457 + assert!(matches!( 458 + claims.validate(&wrong_aud), 459 + Err(ServiceAuthError::AudienceMismatch { .. }) 460 + )); 461 + } 462 + 463 + #[test] 464 + fn test_method_check() { 465 + let claims = ServiceAuthClaims { 466 + iss: Did::new("did:plc:test").unwrap(), 467 + aud: Did::new("did:web:example.com").unwrap(), 468 + exp: chrono::Utc::now().timestamp() + 100, 469 + iat: chrono::Utc::now().timestamp(), 470 + jti: None, 471 + lxm: Some(Nsid::new("app.bsky.feed.getFeedSkeleton").unwrap()), 472 + }; 473 + 474 + let expected = Nsid::new("app.bsky.feed.getFeedSkeleton").unwrap(); 475 + assert!(claims.check_method(&expected)); 476 + 477 + let wrong = Nsid::new("app.bsky.feed.getTimeline").unwrap(); 478 + assert!(!claims.check_method(&wrong)); 479 + } 480 + }
-2
crates/jacquard-common/src/types.rs
··· 24 24 pub mod integer; 25 25 /// Language tag types per BCP 47 26 26 pub mod language; 27 - /// CID link wrapper for JSON serialization 28 - pub mod link; 29 27 /// Namespaced Identifier (NSID) types and validation 30 28 pub mod nsid; 31 29 /// Record key types and validation
+6 -1
crates/jacquard-common/src/types/aturi.rs
··· 330 330 .as_ref() 331 331 .and_then(|p| p.rkey.as_ref()) 332 332 } 333 + 334 + /// Fallible constructor, validates, borrows from input if possible 335 + pub fn new_cow(uri: CowStr<'u>) -> Result<Self, AtStrError> { 336 + Self::try_from(uri) 337 + } 333 338 } 334 339 335 340 impl AtUri<'static> { ··· 615 620 D: Deserializer<'de>, 616 621 { 617 622 let value = Deserialize::deserialize(deserializer)?; 618 - Self::new(value).map_err(D::Error::custom) 623 + Self::new_cow(value).map_err(D::Error::custom) 619 624 } 620 625 } 621 626
+6 -1
crates/jacquard-common/src/types/blob.rs
··· 150 150 Ok(Self(mime_type)) 151 151 } 152 152 153 + /// Fallible constructor, validates, borrows from input if possible 154 + pub fn new_cow(mime_type: CowStr<'m>) -> Result<MimeType<'m>, &'static str> { 155 + Self::from_cowstr(mime_type) 156 + } 157 + 153 158 /// Infallible constructor for trusted MIME type strings 154 159 pub fn raw(mime_type: &'m str) -> Self { 155 160 Self(CowStr::Borrowed(mime_type)) ··· 190 195 D: Deserializer<'de>, 191 196 { 192 197 let value = Deserialize::deserialize(deserializer)?; 193 - Self::new(value).map_err(D::Error::custom) 198 + Self::new_cow(value).map_err(D::Error::custom) 194 199 } 195 200 } 196 201
+32 -10
crates/jacquard-common/src/types/cid.rs
··· 2 2 pub use cid::Cid as IpldCid; 3 3 use serde::{Deserialize, Deserializer, Serialize, Serializer, de::Visitor}; 4 4 use smol_str::ToSmolStr; 5 - use std::{convert::Infallible, fmt, marker::PhantomData, ops::Deref, str::FromStr}; 5 + use std::{convert::Infallible, fmt, ops::Deref, str::FromStr}; 6 6 7 7 /// CID codec for AT Protocol (raw) 8 8 pub const ATP_CID_CODEC: u64 = 0x55; ··· 19 19 /// This type supports both string and parsed IPLD forms, with string caching 20 20 /// for the parsed form to optimize serialization. 21 21 /// 22 - /// Deserialization automatically detects the format (bytes trigger IPLD parsing). 22 + /// # Validation 23 + /// 24 + /// String deserialization does NOT validate CIDs. This is intentional for performance: 25 + /// CID strings from AT Protocol endpoints are generally trustworthy, so validation 26 + /// is deferred until needed. Use `to_ipld()` to parse and validate, or `is_valid()` 27 + /// to check without parsing. 28 + /// 29 + /// Byte deserialization (CBOR) parses immediately since the data is already in binary form. 23 30 #[derive(Debug, Clone, PartialEq, Eq, Hash)] 24 31 pub enum Cid<'c> { 25 32 /// Parsed IPLD CID with cached string representation ··· 100 107 Cid::Str(cow_str) => cow_str.as_ref(), 101 108 } 102 109 } 110 + 111 + /// Check if the CID string is valid without parsing 112 + /// 113 + /// Returns `true` if the CID is already parsed (`Ipld` variant) or if 114 + /// the string can be successfully parsed as an IPLD CID. 115 + pub fn is_valid(&self) -> bool { 116 + match self { 117 + Cid::Ipld { .. } => true, 118 + Cid::Str(s) => IpldCid::try_from(s.as_ref()).is_ok(), 119 + } 120 + } 103 121 } 104 122 105 123 impl std::fmt::Display for Cid<'_> { ··· 155 173 where 156 174 D: Deserializer<'de>, 157 175 { 158 - struct StringOrBytes<T>(PhantomData<fn() -> T>); 176 + struct CidVisitor; 159 177 160 - impl<'de, T> Visitor<'de> for StringOrBytes<T> 161 - where 162 - T: Deserialize<'de> + FromStr<Err = Infallible> + From<IpldCid>, 163 - { 164 - type Value = T; 178 + impl<'de> Visitor<'de> for CidVisitor { 179 + type Value = Cid<'de>; 165 180 166 181 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { 167 182 formatter.write_str("either valid IPLD CID bytes or a str") 168 183 } 169 184 185 + fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E> 186 + where 187 + E: serde::de::Error, 188 + { 189 + Ok(Cid::str(v)) 190 + } 191 + 170 192 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E> 171 193 where 172 194 E: serde::de::Error, ··· 179 201 E: serde::de::Error, 180 202 { 181 203 let hash = cid::multihash::Multihash::from_bytes(v).map_err(|e| E::custom(e))?; 182 - Ok(T::from(IpldCid::new_v1(ATP_CID_CODEC, hash))) 204 + Ok(Cid::ipld(IpldCid::new_v1(ATP_CID_CODEC, hash))) 183 205 } 184 206 } 185 207 186 - deserializer.deserialize_any(StringOrBytes(PhantomData)) 208 + deserializer.deserialize_any(CidVisitor) 187 209 } 188 210 } 189 211
+23 -1
crates/jacquard-common/src/types/did.rs
··· 54 54 } 55 55 } 56 56 57 + /// Fallible constructor, validates, borrows from input if possible 58 + /// 59 + /// May allocate for a long DID with an at:// prefix, otherwise borrows. 60 + pub fn new_cow(did: CowStr<'d>) -> Result<Self, AtStrError> { 61 + let did = if let Some(did) = did.strip_prefix("at://") { 62 + CowStr::copy_from_str(did) 63 + } else { 64 + did 65 + }; 66 + if did.len() > 2048 { 67 + Err(AtStrError::too_long("did", &did, 2048, did.len())) 68 + } else if !DID_REGEX.is_match(&did) { 69 + Err(AtStrError::regex( 70 + "did", 71 + &did, 72 + SmolStr::new_static("invalid"), 73 + )) 74 + } else { 75 + Ok(Self(did)) 76 + } 77 + } 78 + 57 79 /// Fallible constructor, validates, takes ownership 58 80 pub fn new_owned(did: impl AsRef<str>) -> Result<Self, AtStrError> { 59 81 let did = did.as_ref(); ··· 144 166 D: Deserializer<'de>, 145 167 { 146 168 let value = Deserialize::deserialize(deserializer)?; 147 - Self::new(value).map_err(D::Error::custom) 169 + Self::new_cow(value).map_err(D::Error::custom) 148 170 } 149 171 } 150 172
+29 -1
crates/jacquard-common/src/types/handle.rs
··· 108 108 Ok(Self(CowStr::new_static(handle))) 109 109 } 110 110 } 111 + 112 + /// Fallible constructor, validates, borrows from input if possible 113 + /// 114 + /// May allocate for a long handle with an at:// or @ prefix, otherwise borrows. 115 + /// Accepts (and strips) preceding '@' or 'at://' if present 116 + pub fn new_cow(handle: CowStr<'h>) -> Result<Self, AtStrError> { 117 + let handle = if let Some(stripped) = handle.strip_prefix("at://") { 118 + CowStr::copy_from_str(stripped) 119 + } else if let Some(stripped) = handle.strip_prefix('@') { 120 + CowStr::copy_from_str(stripped) 121 + } else { 122 + handle 123 + }; 124 + if handle.len() > 253 { 125 + Err(AtStrError::too_long("handle", &handle, 253, handle.len())) 126 + } else if !HANDLE_REGEX.is_match(&handle) { 127 + Err(AtStrError::regex( 128 + "handle", 129 + &handle, 130 + SmolStr::new_static("invalid"), 131 + )) 132 + } else if ends_with(&handle, DISALLOWED_TLDS) { 133 + Err(AtStrError::disallowed("handle", &handle, DISALLOWED_TLDS)) 134 + } else { 135 + Ok(Self(handle)) 136 + } 137 + } 138 + 111 139 /// Infallible constructor for when you *know* the string is a valid handle. 112 140 /// Will panic on invalid handles. If you're manually decoding atproto records 113 141 /// or API values you know are valid (rather than using serde), this is the one to use. ··· 179 207 D: Deserializer<'de>, 180 208 { 181 209 let value = Deserialize::deserialize(deserializer)?; 182 - Self::new(value).map_err(D::Error::custom) 210 + Self::new_cow(value).map_err(D::Error::custom) 183 211 } 184 212 } 185 213
+9
crates/jacquard-common/src/types/ident.rs
··· 54 54 } 55 55 } 56 56 57 + /// Fallible constructor, validates, borrows from input if possible 58 + pub fn new_cow(ident: CowStr<'i>) -> Result<Self, AtStrError> { 59 + if let Ok(did) = Did::new_cow(ident.clone()) { 60 + Ok(AtIdentifier::Did(did)) 61 + } else { 62 + Ok(AtIdentifier::Handle(Handle::new_cow(ident)?)) 63 + } 64 + } 65 + 57 66 /// Infallible constructor for when you *know* the string is a valid identifier. 58 67 /// Will panic on invalid identifiers. If you're manually decoding atproto records 59 68 /// or API values you know are valid (rather than using serde), this is the one to use.
-1
crates/jacquard-common/src/types/link.rs
··· 1 - // strongref, blobref(s), cid links
+17 -2
crates/jacquard-common/src/types/nsid.rs
··· 81 81 } 82 82 } 83 83 84 + /// Fallible constructor, validates, borrows from input if possible 85 + pub fn new_cow(nsid: CowStr<'n>) -> Result<Self, AtStrError> { 86 + if nsid.len() > 317 { 87 + Err(AtStrError::too_long("nsid", &nsid, 317, nsid.len())) 88 + } else if !NSID_REGEX.is_match(&nsid) { 89 + Err(AtStrError::regex( 90 + "nsid", 91 + &nsid, 92 + SmolStr::new_static("invalid"), 93 + )) 94 + } else { 95 + Ok(Self(nsid)) 96 + } 97 + } 98 + 84 99 /// Infallible constructor for when you *know* the string is a valid NSID. 85 100 /// Will panic on invalid NSIDs. If you're manually decoding atproto records 86 101 /// or API values you know are valid (rather than using serde), this is the one to use. ··· 148 163 where 149 164 D: Deserializer<'de>, 150 165 { 151 - let value: &str = Deserialize::deserialize(deserializer)?; 152 - Self::new(value).map_err(D::Error::custom) 166 + let value = Deserialize::deserialize(deserializer)?; 167 + Self::new_cow(value).map_err(D::Error::custom) 153 168 } 154 169 } 155 170
+17 -2
crates/jacquard-common/src/types/recordkey.rs
··· 137 137 } 138 138 } 139 139 140 + /// Fallible constructor, validates, borrows from input if possible 141 + pub fn new_cow(rkey: CowStr<'r>) -> Result<Self, AtStrError> { 142 + if [".", ".."].contains(&rkey.as_ref()) { 143 + Err(AtStrError::disallowed("record-key", &rkey, &[".", ".."])) 144 + } else if !RKEY_REGEX.is_match(&rkey) { 145 + Err(AtStrError::regex( 146 + "record-key", 147 + &rkey, 148 + SmolStr::new_static("doesn't match 'any' schema"), 149 + )) 150 + } else { 151 + Ok(Self(rkey)) 152 + } 153 + } 154 + 140 155 /// Infallible constructor for when you *know* the string is a valid rkey. 141 156 /// Will panic on invalid rkeys. If you're manually decoding atproto records 142 157 /// or API values you know are valid (rather than using serde), this is the one to use. ··· 200 215 where 201 216 D: Deserializer<'de>, 202 217 { 203 - let value: &str = Deserialize::deserialize(deserializer)?; 204 - Self::new(value).map_err(D::Error::custom) 218 + let value = Deserialize::deserialize(deserializer)?; 219 + Self::new_cow(value).map_err(D::Error::custom) 205 220 } 206 221 } 207 222
+30 -7
crates/jacquard-common/src/types/uri.rs
··· 1 - use serde::{Deserialize, Deserializer, Serialize, Serializer}; 2 - use smol_str::ToSmolStr; 3 - use url::Url; 4 - 5 1 use crate::{ 6 2 CowStr, IntoStatic, 7 3 types::{aturi::AtUri, cid::Cid, did::Did, string::AtStrError}, 8 4 }; 5 + use serde::{Deserialize, Deserializer, Serialize, Serializer}; 6 + use smol_str::ToSmolStr; 7 + use std::str::FromStr; 8 + use url::Url; 9 9 10 10 /// Generic URI with type-specific parsing 11 11 /// ··· 55 55 } else if uri.starts_with("wss://") { 56 56 Ok(Uri::Https(Url::parse(uri)?)) 57 57 } else if uri.starts_with("ipld://") { 58 - Ok(Uri::Cid(Cid::new(uri.as_bytes())?)) 58 + Ok(Uri::Cid( 59 + Cid::from_str(uri.strip_prefix("ipld://").unwrap_or(uri.as_ref())).unwrap(), 60 + )) 59 61 } else { 60 62 Ok(Uri::Any(CowStr::Borrowed(uri))) 61 63 } ··· 73 75 } else if uri.starts_with("wss://") { 74 76 Ok(Uri::Https(Url::parse(uri)?)) 75 77 } else if uri.starts_with("ipld://") { 76 - Ok(Uri::Cid(Cid::new_owned(uri.as_bytes())?)) 78 + Ok(Uri::Cid( 79 + Cid::from_str(uri.strip_prefix("ipld://").unwrap_or(uri.as_ref())).unwrap(), 80 + )) 77 81 } else { 78 82 Ok(Uri::Any(CowStr::Owned(uri.to_smolstr()))) 79 83 } 80 84 } 81 85 86 + /// Parse a URI from a CowStr, borrowing where possible 87 + pub fn new_cow(uri: CowStr<'u>) -> Result<Self, UriParseError> { 88 + if uri.starts_with("did:") { 89 + Ok(Uri::Did(Did::new_cow(uri)?)) 90 + } else if uri.starts_with("at://") { 91 + Ok(Uri::At(AtUri::new_cow(uri)?)) 92 + } else if uri.starts_with("https://") { 93 + Ok(Uri::Https(Url::parse(uri.as_ref())?)) 94 + } else if uri.starts_with("wss://") { 95 + Ok(Uri::Https(Url::parse(uri.as_ref())?)) 96 + } else if uri.starts_with("ipld://") { 97 + Ok(Uri::Cid( 98 + Cid::from_str(uri.strip_prefix("ipld://").unwrap_or(uri.as_str())).unwrap(), 99 + )) 100 + } else { 101 + Ok(Uri::Any(uri)) 102 + } 103 + } 104 + 82 105 /// Get the URI as a string slice 83 106 pub fn as_str(&self) -> &str { 84 107 match self { ··· 111 134 { 112 135 use serde::de::Error; 113 136 let value = Deserialize::deserialize(deserializer)?; 114 - Self::new(value).map_err(D::Error::custom) 137 + Self::new_cow(value).map_err(D::Error::custom) 115 138 } 116 139 } 117 140
+1 -1
crates/jacquard-common/src/xrpc.rs
··· 238 238 fn send<R>( 239 239 &self, 240 240 request: R, 241 - ) -> impl Future<Output = XrpcResult<Response<<R as XrpcRequest>::Response>>> 241 + ) -> impl Future<Output = XrpcResult<Response<<R as XrpcRequest>::Response>>> + Send 242 242 where 243 243 R: XrpcRequest + Send + Sync, 244 244 <R as XrpcRequest>::Response: Send + Sync;
+7 -5
crates/jacquard-identity/src/lib.rs
··· 84 84 use jacquard_common::{IntoStatic, types::string::Handle}; 85 85 use percent_encoding::percent_decode_str; 86 86 use reqwest::StatusCode; 87 + use std::sync::Arc; 87 88 use url::{ParseError, Url}; 88 89 89 90 #[cfg(feature = "dns")] 90 91 use hickory_resolver::{TokioAsyncResolver, config::ResolverConfig}; 91 92 92 93 /// Default resolver implementation with configurable fallback order. 94 + #[derive(Clone)] 93 95 pub struct JacquardResolver { 94 96 http: reqwest::Client, 95 97 opts: ResolverOptions, 96 98 #[cfg(feature = "dns")] 97 - dns: Option<TokioAsyncResolver>, 99 + dns: Option<Arc<TokioAsyncResolver>>, 98 100 } 99 101 100 102 impl JacquardResolver { ··· 114 116 Self { 115 117 http, 116 118 opts, 117 - dns: Some(TokioAsyncResolver::tokio( 119 + dns: Some(Arc::new(TokioAsyncResolver::tokio( 118 120 ResolverConfig::default(), 119 121 Default::default(), 120 - )), 122 + ))), 121 123 } 122 124 } 123 125 124 126 #[cfg(feature = "dns")] 125 127 /// Add default DNS resolution to the resolver 126 128 pub fn with_system_dns(mut self) -> Self { 127 - self.dns = Some(TokioAsyncResolver::tokio( 129 + self.dns = Some(Arc::new(TokioAsyncResolver::tokio( 128 130 ResolverConfig::default(), 129 131 Default::default(), 130 - )); 132 + ))); 131 133 self 132 134 } 133 135
+28 -8
crates/jacquard-identity/src/resolver.rs
··· 10 10 //! and optionally validate the document `id` against the requested DID. 11 11 12 12 use std::collections::BTreeMap; 13 + use std::marker::Sync; 13 14 use std::str::FromStr; 14 15 15 16 use bon::Builder; ··· 333 334 fn resolve_handle( 334 335 &self, 335 336 handle: &Handle<'_>, 336 - ) -> impl Future<Output = Result<Did<'static>, IdentityError>>; 337 + ) -> impl Future<Output = Result<Did<'static>, IdentityError>> + Send 338 + where 339 + Self: Sync; 337 340 338 341 /// Resolve DID document 339 342 fn resolve_did_doc( 340 343 &self, 341 344 did: &Did<'_>, 342 - ) -> impl Future<Output = Result<DidDocResponse, IdentityError>>; 345 + ) -> impl Future<Output = Result<DidDocResponse, IdentityError>> + Send 346 + where 347 + Self: Sync; 343 348 344 349 /// Resolve DID doc from an identifier 345 350 fn resolve_ident( 346 351 &self, 347 352 actor: &AtIdentifier<'_>, 348 - ) -> impl Future<Output = Result<DidDocResponse, IdentityError>> { 353 + ) -> impl Future<Output = Result<DidDocResponse, IdentityError>> + Send 354 + where 355 + Self: Sync, 356 + { 349 357 async move { 350 358 match actor { 351 359 AtIdentifier::Did(did) => self.resolve_did_doc(&did).await, ··· 361 369 fn resolve_ident_owned( 362 370 &self, 363 371 actor: &AtIdentifier<'_>, 364 - ) -> impl Future<Output = Result<DidDocument<'static>, IdentityError>> { 372 + ) -> impl Future<Output = Result<DidDocument<'static>, IdentityError>> + Send 373 + where 374 + Self: Sync, 375 + { 365 376 async move { 366 377 match actor { 367 378 AtIdentifier::Did(did) => self.resolve_did_doc_owned(&did).await, ··· 377 388 fn resolve_did_doc_owned( 378 389 &self, 379 390 did: &Did<'_>, 380 - ) -> impl Future<Output = Result<DidDocument<'static>, IdentityError>> { 391 + ) -> impl Future<Output = Result<DidDocument<'static>, IdentityError>> + Send 392 + where 393 + Self: Sync, 394 + { 381 395 async { self.resolve_did_doc(did).await?.into_owned() } 382 396 } 383 397 /// Return the PDS url for a DID 384 - fn pds_for_did(&self, did: &Did<'_>) -> impl Future<Output = Result<Url, IdentityError>> { 398 + fn pds_for_did(&self, did: &Did<'_>) -> impl Future<Output = Result<Url, IdentityError>> + Send 399 + where 400 + Self: Sync, 401 + { 385 402 async { 386 403 let resp = self.resolve_did_doc(did).await?; 387 404 let doc = resp.parse()?; ··· 401 418 fn pds_for_handle( 402 419 &self, 403 420 handle: &Handle<'_>, 404 - ) -> impl Future<Output = Result<(Did<'static>, Url), IdentityError>> { 421 + ) -> impl Future<Output = Result<(Did<'static>, Url), IdentityError>> + Send 422 + where 423 + Self: Sync, 424 + { 405 425 async { 406 426 let did = self.resolve_handle(handle).await?; 407 427 let pds = self.pds_for_did(&did).await?; ··· 410 430 } 411 431 } 412 432 413 - impl<T: IdentityResolver> IdentityResolver for std::sync::Arc<T> { 433 + impl<T: IdentityResolver + Sync> IdentityResolver for std::sync::Arc<T> { 414 434 fn options(&self) -> &ResolverOptions { 415 435 self.as_ref().options() 416 436 }
+21 -6
crates/jacquard-oauth/src/resolver.rs
··· 120 120 &self, 121 121 server_metadata: &OAuthAuthorizationServerMetadata<'_>, 122 122 sub: &Did<'_>, 123 - ) -> impl std::future::Future<Output = Result<Url, ResolverError>> { 123 + ) -> impl std::future::Future<Output = Result<Url, ResolverError>> + Send 124 + where 125 + Self: Sync, 126 + { 124 127 async { 125 128 let (metadata, identity) = self.resolve_from_identity(sub).await?; 126 129 if !issuer_equivalent(&metadata.issuer, &server_metadata.issuer) { ··· 144 147 ), 145 148 ResolverError, 146 149 >, 147 - > { 150 + > + Send 151 + where 152 + Self: Sync, 153 + { 148 154 // Allow using an entryway, or PDS url, directly as login input (e.g. 149 155 // when the user forgot their handle, or when the handle does not 150 156 // resolve to a DID) ··· 161 167 fn resolve_from_service( 162 168 &self, 163 169 input: &Url, 164 - ) -> impl Future<Output = Result<OAuthAuthorizationServerMetadata<'static>, ResolverError>> 170 + ) -> impl Future<Output = Result<OAuthAuthorizationServerMetadata<'static>, ResolverError>> + Send 171 + where 172 + Self: Sync, 165 173 { 166 174 async { 167 175 // Assume first that input is a PDS URL (as required by ATPROTO) ··· 183 191 ), 184 192 ResolverError, 185 193 >, 186 - > { 194 + > + Send 195 + where 196 + Self: Sync, 197 + { 187 198 async { 188 199 let actor = AtIdentifier::new(input) 189 200 .map_err(|e| ResolverError::AtIdentifier(format!("{:?}", e)))?; ··· 199 210 fn get_authorization_server_metadata( 200 211 &self, 201 212 issuer: &Url, 202 - ) -> impl Future<Output = Result<OAuthAuthorizationServerMetadata<'static>, ResolverError>> 213 + ) -> impl Future<Output = Result<OAuthAuthorizationServerMetadata<'static>, ResolverError>> + Send 214 + where 215 + Self: Sync, 203 216 { 204 217 async { 205 218 let mut md = resolve_authorization_server(self, issuer).await?; ··· 211 224 fn get_resource_server_metadata( 212 225 &self, 213 226 pds: &Url, 214 - ) -> impl Future<Output = Result<OAuthAuthorizationServerMetadata<'static>, ResolverError>> 227 + ) -> impl Future<Output = Result<OAuthAuthorizationServerMetadata<'static>, ResolverError>> + Send 228 + where 229 + Self: Sync, 215 230 { 216 231 async move { 217 232 let rs_metadata = resolve_protected_resource_info(self, pds).await?;
+1 -1
crates/jacquard/src/client/credential_session.rs
··· 149 149 impl<S, T> CredentialSession<S, T> 150 150 where 151 151 S: SessionStore<SessionKey, AtpSession>, 152 - T: HttpClient + IdentityResolver + XrpcExt, 152 + T: HttpClient + IdentityResolver + XrpcExt + Sync + Send, 153 153 { 154 154 /// Resolve the user's PDS and create an app-password session. 155 155 ///