Alternative ATProto PDS implementation
at oauth 16 kB view raw
1use anyhow::{Context as _, anyhow}; 2use atrium_crypto::{ 3 keypair::{Did as _, Secp256k1Keypair}, 4 verify::Verifier, 5}; 6use axum::{extract::FromRequestParts, http::StatusCode}; 7use base64::Engine as _; 8use diesel::prelude::*; 9use sha2::{Digest as _, Sha256}; 10 11use crate::{ 12 error::{Error, ErrorMessage}, 13 serve::AppState, 14}; 15 16/// Request extractor for authenticated users. 17/// If specified in an API endpoint, this guarantees the API can only be called 18/// by an authenticated user. 19pub(crate) struct AuthenticatedUser { 20 /// The DID of the authenticated user. 21 did: String, 22} 23 24impl AuthenticatedUser { 25 /// Get the DID of the authenticated user. 26 pub(crate) fn did(&self) -> String { 27 self.did.clone() 28 } 29} 30 31impl FromRequestParts<AppState> for AuthenticatedUser { 32 type Rejection = Error; 33 34 async fn from_request_parts( 35 parts: &mut axum::http::request::Parts, 36 state: &AppState, 37 ) -> Result<Self, Self::Rejection> { 38 // Check for authorization header (either for Bearer or DPoP tokens) 39 let auth_header = parts 40 .headers 41 .get(axum::http::header::AUTHORIZATION) 42 .ok_or_else(|| { 43 Error::with_status(StatusCode::UNAUTHORIZED, anyhow!("no authorization header")) 44 })?; 45 46 let auth_str = auth_header.to_str().map_err(|e| { 47 Error::with_status( 48 StatusCode::UNAUTHORIZED, 49 anyhow!("authorization header should be valid utf-8").context(e), 50 ) 51 })?; 52 53 // Check for DPoP header 54 let dpop_header = parts.headers.get("dpop"); 55 let has_dpop = dpop_header.is_some(); 56 57 // Handle different token types 58 if auth_str.starts_with("Bearer ") || auth_str.starts_with("DPoP ") { 59 let token = auth_str 60 .split_once(' ') 61 .expect("Auth string should have a space") 62 .1; 63 64 if has_dpop { 65 // Process DPoP token - the Authorization header contains the access token 66 // and the DPoP header contains the proof 67 let dpop_token = dpop_header 68 .expect("DPoP header should exist") 69 .to_str() 70 .map_err(|e| { 71 Error::with_status( 72 StatusCode::UNAUTHORIZED, 73 anyhow!("DPoP header should be valid utf-8").context(e), 74 ) 75 })?; 76 77 return validate_dpop_token(token, dpop_token, parts, state).await; 78 } 79 80 // Standard Bearer token 81 return validate_bearer_token(token, state).await; 82 } 83 84 // If we reach here, no valid authorization method was found 85 Err(Error::with_status( 86 StatusCode::UNAUTHORIZED, 87 anyhow!("unsupported authorization method"), 88 )) 89 } 90} 91 92/// Validate a standard Bearer token 93async fn validate_bearer_token(token: &str, state: &AppState) -> Result<AuthenticatedUser, Error> { 94 // Validate JWT token 95 let (typ, claims) = verify(&state.signing_key.did(), token) 96 .map_err(|e| { 97 Error::with_status( 98 StatusCode::UNAUTHORIZED, 99 e.context("failed to verify bearer token"), 100 ) 101 }) 102 .context("token auth should verify")?; 103 104 // Ensure this is an authentication token. 105 if typ != "at+jwt" { 106 return Err(Error::with_status( 107 StatusCode::UNAUTHORIZED, 108 anyhow!("invalid token type: {typ}"), 109 )); 110 } 111 112 // Ensure we are in the audience field. 113 if let Some(aud) = claims.get("aud").and_then(serde_json::Value::as_str) { 114 if aud != format!("did:web:{}", state.config.host_name) { 115 return Err(Error::with_status( 116 StatusCode::UNAUTHORIZED, 117 anyhow!("invalid audience: {aud}"), 118 )); 119 } 120 } 121 122 // Check token expiration 123 if let Some(exp) = claims.get("exp").and_then(serde_json::Value::as_i64) { 124 let now = chrono::Utc::now().timestamp(); 125 if now >= exp { 126 return Err(Error::with_message( 127 StatusCode::BAD_REQUEST, 128 anyhow!("token has expired"), 129 ErrorMessage::new("ExpiredToken", "Token has expired"), 130 )); 131 } 132 } 133 134 // Extract subject (DID) 135 if let Some(did) = claims.get("sub").and_then(serde_json::Value::as_str) { 136 use crate::schema::pds::account::dsl as AccountSchema; 137 let did_clone = did.to_owned(); 138 139 let _did = state 140 .db 141 .get() 142 .await 143 .expect("failed to get db connection") 144 .interact(move |conn| { 145 AccountSchema::account 146 .filter(AccountSchema::did.eq(did_clone)) 147 .select(AccountSchema::did) 148 .first::<String>(conn) 149 }) 150 .await 151 .expect("failed to query account"); 152 153 Ok(AuthenticatedUser { 154 did: did.to_owned(), 155 }) 156 } else { 157 Err(Error::with_status( 158 StatusCode::UNAUTHORIZED, 159 anyhow!("invalid authorization token: missing subject"), 160 )) 161 } 162} 163 164#[expect(clippy::too_many_lines, reason = "validating dpop has many loc")] 165/// Validate a DPoP token and proof 166async fn validate_dpop_token( 167 access_token: &str, 168 dpop_token: &str, 169 parts: &axum::http::request::Parts, 170 state: &AppState, 171) -> Result<AuthenticatedUser, Error> { 172 // Step 1: Parse and validate the access token 173 let (typ, claims) = verify(&state.signing_key.did(), access_token) 174 .map_err(|e| { 175 Error::with_status( 176 StatusCode::UNAUTHORIZED, 177 e.context("failed to verify access token"), 178 ) 179 }) 180 .context("access token auth should verify")?; 181 182 // Ensure this is an access token JWT 183 if typ != "at+jwt" { 184 return Err(Error::with_status( 185 StatusCode::UNAUTHORIZED, 186 anyhow!("invalid token type: {typ}"), 187 )); 188 } 189 190 // Check token expiration 191 if let Some(exp) = claims.get("exp").and_then(serde_json::Value::as_i64) { 192 let now = chrono::Utc::now().timestamp(); 193 if now >= exp { 194 return Err(Error::with_message( 195 StatusCode::BAD_REQUEST, 196 anyhow!("token has expired"), 197 ErrorMessage::new("ExpiredToken", "Token has expired"), 198 )); 199 } 200 } 201 202 // Step 2: Parse and validate the DPoP proof 203 let dpop_parts: Vec<&str> = dpop_token.split('.').collect(); 204 if dpop_parts.len() != 3 { 205 return Err(Error::with_status( 206 StatusCode::UNAUTHORIZED, 207 anyhow!("invalid DPoP token format"), 208 )); 209 } 210 211 // Decode header 212 let dpop_header_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD 213 .decode(dpop_parts.first().context("header part missing")?) 214 .context("failed to decode DPoP header")?; 215 216 let dpop_header: serde_json::Value = 217 serde_json::from_slice(&dpop_header_bytes).context("failed to parse DPoP header")?; 218 219 // Check typ is "dpop+jwt" 220 if dpop_header.get("typ").and_then(|v| v.as_str()) != Some("dpop+jwt") { 221 return Err(Error::with_status( 222 StatusCode::UNAUTHORIZED, 223 anyhow!("invalid DPoP token type"), 224 )); 225 } 226 227 // Decode claims 228 let dpop_claims_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD 229 .decode(dpop_parts.get(1).context("claims part missing")?) 230 .context("failed to decode DPoP claims")?; 231 232 let dpop_claims: serde_json::Value = 233 serde_json::from_slice(&dpop_claims_bytes).context("failed to parse DPoP claims")?; 234 235 // Check HTTP method 236 if dpop_claims.get("htm").and_then(|v| v.as_str()) != Some(parts.method.as_str()) { 237 return Err(Error::with_status( 238 StatusCode::UNAUTHORIZED, 239 anyhow!("DPoP token HTTP method mismatch"), 240 )); 241 } 242 243 // Check HTTP URI 244 let expected_uri = format!( 245 "https://{}/xrpc{}", 246 state.config.host_name, 247 parts.uri.path_and_query().expect("path and query to exist") 248 ); 249 if dpop_claims.get("htu").and_then(|v| v.as_str()) != Some(&expected_uri) { 250 return Err(Error::with_status( 251 StatusCode::UNAUTHORIZED, 252 anyhow!( 253 "DPoP token HTTP URI mismatch: expected {}, got {}", 254 expected_uri, 255 dpop_claims 256 .get("htu") 257 .and_then(|v| v.as_str()) 258 .unwrap_or("None") 259 ), 260 )); 261 } 262 263 // Verify access token hash (ath) 264 if let Some(ath) = dpop_claims.get("ath").and_then(|v| v.as_str()) { 265 let mut hasher = Sha256::new(); 266 hasher.update(access_token.as_bytes()); 267 let token_hash = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(hasher.finalize()); 268 269 if ath != token_hash { 270 return Err(Error::with_status( 271 StatusCode::UNAUTHORIZED, 272 anyhow!("DPoP access token hash mismatch"), 273 )); 274 } 275 } 276 277 // Extract JWK from DPoP header 278 let jwk = dpop_header 279 .get("jwk") 280 .context("missing jwk in DPoP header")?; 281 282 // Calculate JWK thumbprint - RFC7638 compliant 283 let key_type = jwk 284 .get("kty") 285 .and_then(|v| v.as_str()) 286 .context("JWK missing kty property")?; 287 288 // Define required properties for each key type 289 let required_props = match key_type { 290 "EC" => &["crv", "kty", "x", "y"][..], 291 "RSA" => &["e", "kty", "n"][..], 292 _ => { 293 return Err(Error::with_status( 294 StatusCode::UNAUTHORIZED, 295 anyhow!("Unsupported JWK key type: {}", key_type), 296 )); 297 } 298 }; 299 300 // Build a new JWK with only the required properties 301 let mut canonical_jwk = serde_json::Map::new(); 302 303 for prop in required_props { 304 let value = jwk 305 .get(prop) 306 .context(format!("JWK missing required property: {prop}"))?; 307 drop(canonical_jwk.insert((*prop).to_owned(), value.clone())); 308 } 309 310 // Serialize with no whitespace 311 let canonical_json = serde_json::to_string(&serde_json::Value::Object(canonical_jwk)) 312 .context("Failed to serialize canonical JWK")?; 313 314 // Hash the canonical representation 315 let mut hasher = Sha256::new(); 316 hasher.update(canonical_json.as_bytes()); 317 let calculated_thumbprint = 318 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(hasher.finalize()); 319 320 // Get JWK thumbprint from the access token "cnf" claim 321 let jkt = claims 322 .get("cnf") 323 .and_then(|v| v.as_object()) 324 .and_then(|o| o.get("jkt")) 325 .and_then(|v| v.as_str()) 326 .context("missing or invalid 'cnf.jkt' in access token")?; 327 328 // Verify JWK thumbprint 329 if calculated_thumbprint != jkt { 330 return Err(Error::with_status( 331 StatusCode::UNAUTHORIZED, 332 anyhow!("JWK thumbprint mismatch"), 333 )); 334 } 335 336 // Check JTI replay using the database 337 let jti = dpop_claims 338 .get("jti") 339 .and_then(|j| j.as_str()) 340 .context("Missing jti claim in DPoP token")?; 341 342 let timestamp = chrono::Utc::now().timestamp(); 343 344 use crate::schema::pds::oauth_used_jtis::dsl as JtiSchema; 345 346 // Check if JTI has been used before 347 let jti_string = jti.to_owned(); 348 let jti_used = state 349 .db 350 .get() 351 .await 352 .expect("failed to get db connection") 353 .interact(move |conn| { 354 JtiSchema::oauth_used_jtis 355 .filter(JtiSchema::jti.eq(jti_string)) 356 .count() 357 .get_result::<i64>(conn) 358 }) 359 .await 360 .expect("failed to query JTI") 361 .expect("failed to get JTI count"); 362 363 if jti_used > 0 { 364 return Err(Error::with_status( 365 StatusCode::UNAUTHORIZED, 366 anyhow!("DPoP token has been replayed"), 367 )); 368 } 369 370 // Store the JTI to prevent replay attacks 371 // Get expiry from token or default to 60 seconds 372 let exp = dpop_claims 373 .get("exp") 374 .and_then(serde_json::Value::as_i64) 375 .unwrap_or_else(|| timestamp.checked_add(60).unwrap_or(timestamp)); 376 377 // Convert SQLx INSERT to Diesel 378 let jti_str = jti.to_owned(); 379 let thumbprint_str = calculated_thumbprint.to_string(); 380 let _ = state 381 .db 382 .get() 383 .await 384 .expect("failed to get db connection") 385 .interact(move |conn| { 386 diesel::insert_into(JtiSchema::oauth_used_jtis) 387 .values(( 388 JtiSchema::jti.eq(jti_str), 389 JtiSchema::issuer.eq(thumbprint_str), 390 JtiSchema::created_at.eq(timestamp), 391 JtiSchema::expires_at.eq(exp), 392 )) 393 .execute(conn) 394 }) 395 .await 396 .expect("failed to insert JTI") 397 .expect("failed to insert JTI"); 398 399 // Extract subject (DID) from access token 400 if let Some(did) = claims.get("sub").and_then(|v| v.as_str()) { 401 use crate::schema::pds::account::dsl as AccountSchema; 402 403 let did_clone = did.to_owned(); 404 405 let _did = state 406 .db 407 .get() 408 .await 409 .expect("failed to get db connection") 410 .interact(move |conn| { 411 AccountSchema::account 412 .filter(AccountSchema::did.eq(did_clone)) 413 .select(AccountSchema::did) 414 .first::<String>(conn) 415 }) 416 .await 417 .expect("failed to query account") 418 .expect("failed to get account"); 419 420 Ok(AuthenticatedUser { 421 did: did.to_owned(), 422 }) 423 } else { 424 Err(Error::with_status( 425 StatusCode::UNAUTHORIZED, 426 anyhow!("invalid access token: missing subject"), 427 )) 428 } 429} 430 431/// Cryptographically sign a JSON web token with the specified key. 432pub(crate) fn sign( 433 key: &Secp256k1Keypair, 434 typ: &str, 435 claims: &serde_json::Value, 436) -> anyhow::Result<String> { 437 // RFC 9068 438 let hdr = serde_json::json!({ 439 "typ": typ, 440 "alg": "ES256K", // Secp256k1Keypair 441 }); 442 let hdr = base64::prelude::BASE64_URL_SAFE_NO_PAD 443 .encode(serde_json::to_vec(&hdr).context("failed to encode claims")?); 444 let claims = base64::prelude::BASE64_URL_SAFE_NO_PAD 445 .encode(serde_json::to_vec(&claims).context("failed to encode claims")?); 446 let sig = key 447 .sign(format!("{hdr}.{claims}").as_bytes()) 448 .context("failed to sign jwt")?; 449 let sig = base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(sig); 450 451 Ok(format!("{hdr}.{claims}.{sig}")) 452} 453 454/// Cryptographically verify a JSON web token's validity using the specified public key. 455pub(crate) fn verify(key: &str, token: &str) -> anyhow::Result<(String, serde_json::Value)> { 456 let mut parts = token.splitn(3, '.'); 457 let hdr = parts.next().context("no header")?; 458 let claims = parts.next().context("no claims")?; 459 let sig = base64::prelude::BASE64_URL_SAFE_NO_PAD 460 .decode(parts.next().context("no sig")?) 461 .context("failed to decode signature")?; 462 463 let (alg, key) = atrium_crypto::did::parse_did_key(key).context("failed to decode key")?; 464 Verifier::default() 465 .verify(alg, &key, format!("{hdr}.{claims}").as_bytes(), &sig) 466 .context("failed to verify jwt")?; 467 468 let hdr = base64::prelude::BASE64_URL_SAFE_NO_PAD 469 .decode(hdr) 470 .context("failed to decode hdr")?; 471 let hdr = 472 serde_json::from_slice::<serde_json::Value>(&hdr).context("failed to parse hdr as json")?; 473 let typ = hdr 474 .get("typ") 475 .and_then(serde_json::Value::as_str) 476 .context("hdr is invalid")?; 477 478 let claims = base64::prelude::BASE64_URL_SAFE_NO_PAD 479 .decode(claims) 480 .context("failed to decode claims")?; 481 let claims = serde_json::from_slice(&claims).context("failed to parse claims as json")?; 482 483 Ok((typ.to_owned(), claims)) 484}