Alternative ATProto PDS implementation
6
fork

Configure Feed

Select the types of activity you want to include in your feed.

at 8f63824a1b85560fa30b225b60b901689fb84d23 481 lines 16 kB view raw
1use anyhow::{Context as _, anyhow}; 2use atrium_crypto::{ 3 keypair::{Did as _, Secp256k1Keypair}, 4 verify::Verifier, 5}; 6use axum::{extract::FromRequestParts, http::StatusCode}; 7use base64::Engine as _; 8use diesel::prelude::*; 9use sha2::{Digest as _, Sha256}; 10 11use crate::{AppState, Error, error::ErrorMessage}; 12 13/// Request extractor for authenticated users. 14/// If specified in an API endpoint, this guarantees the API can only be called 15/// by an authenticated user. 16pub(crate) struct AuthenticatedUser { 17 /// The DID of the authenticated user. 18 did: String, 19} 20 21impl AuthenticatedUser { 22 /// Get the DID of the authenticated user. 23 pub(crate) fn did(&self) -> String { 24 self.did.clone() 25 } 26} 27 28impl FromRequestParts<AppState> for AuthenticatedUser { 29 type Rejection = Error; 30 31 async fn from_request_parts( 32 parts: &mut axum::http::request::Parts, 33 state: &AppState, 34 ) -> Result<Self, Self::Rejection> { 35 // Check for authorization header (either for Bearer or DPoP tokens) 36 let auth_header = parts 37 .headers 38 .get(axum::http::header::AUTHORIZATION) 39 .ok_or_else(|| { 40 Error::with_status(StatusCode::UNAUTHORIZED, anyhow!("no authorization header")) 41 })?; 42 43 let auth_str = auth_header.to_str().map_err(|e| { 44 Error::with_status( 45 StatusCode::UNAUTHORIZED, 46 anyhow!("authorization header should be valid utf-8").context(e), 47 ) 48 })?; 49 50 // Check for DPoP header 51 let dpop_header = parts.headers.get("dpop"); 52 let has_dpop = dpop_header.is_some(); 53 54 // Handle different token types 55 if auth_str.starts_with("Bearer ") || auth_str.starts_with("DPoP ") { 56 let token = auth_str 57 .split_once(' ') 58 .expect("Auth string should have a space") 59 .1; 60 61 if has_dpop { 62 // Process DPoP token - the Authorization header contains the access token 63 // and the DPoP header contains the proof 64 let dpop_token = dpop_header 65 .expect("DPoP header should exist") 66 .to_str() 67 .map_err(|e| { 68 Error::with_status( 69 StatusCode::UNAUTHORIZED, 70 anyhow!("DPoP header should be valid utf-8").context(e), 71 ) 72 })?; 73 74 return validate_dpop_token(token, dpop_token, parts, state).await; 75 } 76 77 // Standard Bearer token 78 return validate_bearer_token(token, state).await; 79 } 80 81 // If we reach here, no valid authorization method was found 82 Err(Error::with_status( 83 StatusCode::UNAUTHORIZED, 84 anyhow!("unsupported authorization method"), 85 )) 86 } 87} 88 89/// Validate a standard Bearer token 90async fn validate_bearer_token(token: &str, state: &AppState) -> Result<AuthenticatedUser, Error> { 91 // Validate JWT token 92 let (typ, claims) = verify(&state.signing_key.did(), token) 93 .map_err(|e| { 94 Error::with_status( 95 StatusCode::UNAUTHORIZED, 96 e.context("failed to verify bearer token"), 97 ) 98 }) 99 .context("token auth should verify")?; 100 101 // Ensure this is an authentication token. 102 if typ != "at+jwt" { 103 return Err(Error::with_status( 104 StatusCode::UNAUTHORIZED, 105 anyhow!("invalid token type: {typ}"), 106 )); 107 } 108 109 // Ensure we are in the audience field. 110 if let Some(aud) = claims.get("aud").and_then(serde_json::Value::as_str) { 111 if aud != format!("did:web:{}", state.config.host_name) { 112 return Err(Error::with_status( 113 StatusCode::UNAUTHORIZED, 114 anyhow!("invalid audience: {aud}"), 115 )); 116 } 117 } 118 119 // Check token expiration 120 if let Some(exp) = claims.get("exp").and_then(serde_json::Value::as_i64) { 121 let now = chrono::Utc::now().timestamp(); 122 if now >= exp { 123 return Err(Error::with_message( 124 StatusCode::BAD_REQUEST, 125 anyhow!("token has expired"), 126 ErrorMessage::new("ExpiredToken", "Token has expired"), 127 )); 128 } 129 } 130 131 // Extract subject (DID) 132 if let Some(did) = claims.get("sub").and_then(serde_json::Value::as_str) { 133 use crate::schema::pds::account::dsl as AccountSchema; 134 let did_clone = did.to_owned(); 135 136 let _did = state 137 .db 138 .get() 139 .await 140 .expect("failed to get db connection") 141 .interact(move |conn| { 142 AccountSchema::account 143 .filter(AccountSchema::did.eq(did_clone)) 144 .select(AccountSchema::did) 145 .first::<String>(conn) 146 }) 147 .await 148 .expect("failed to query account"); 149 150 Ok(AuthenticatedUser { 151 did: did.to_owned(), 152 }) 153 } else { 154 Err(Error::with_status( 155 StatusCode::UNAUTHORIZED, 156 anyhow!("invalid authorization token: missing subject"), 157 )) 158 } 159} 160 161#[expect(clippy::too_many_lines, reason = "validating dpop has many loc")] 162/// Validate a DPoP token and proof 163async fn validate_dpop_token( 164 access_token: &str, 165 dpop_token: &str, 166 parts: &axum::http::request::Parts, 167 state: &AppState, 168) -> Result<AuthenticatedUser, Error> { 169 // Step 1: Parse and validate the access token 170 let (typ, claims) = verify(&state.signing_key.did(), access_token) 171 .map_err(|e| { 172 Error::with_status( 173 StatusCode::UNAUTHORIZED, 174 e.context("failed to verify access token"), 175 ) 176 }) 177 .context("access token auth should verify")?; 178 179 // Ensure this is an access token JWT 180 if typ != "at+jwt" { 181 return Err(Error::with_status( 182 StatusCode::UNAUTHORIZED, 183 anyhow!("invalid token type: {typ}"), 184 )); 185 } 186 187 // Check token expiration 188 if let Some(exp) = claims.get("exp").and_then(serde_json::Value::as_i64) { 189 let now = chrono::Utc::now().timestamp(); 190 if now >= exp { 191 return Err(Error::with_message( 192 StatusCode::BAD_REQUEST, 193 anyhow!("token has expired"), 194 ErrorMessage::new("ExpiredToken", "Token has expired"), 195 )); 196 } 197 } 198 199 // Step 2: Parse and validate the DPoP proof 200 let dpop_parts: Vec<&str> = dpop_token.split('.').collect(); 201 if dpop_parts.len() != 3 { 202 return Err(Error::with_status( 203 StatusCode::UNAUTHORIZED, 204 anyhow!("invalid DPoP token format"), 205 )); 206 } 207 208 // Decode header 209 let dpop_header_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD 210 .decode(dpop_parts.first().context("header part missing")?) 211 .context("failed to decode DPoP header")?; 212 213 let dpop_header: serde_json::Value = 214 serde_json::from_slice(&dpop_header_bytes).context("failed to parse DPoP header")?; 215 216 // Check typ is "dpop+jwt" 217 if dpop_header.get("typ").and_then(|v| v.as_str()) != Some("dpop+jwt") { 218 return Err(Error::with_status( 219 StatusCode::UNAUTHORIZED, 220 anyhow!("invalid DPoP token type"), 221 )); 222 } 223 224 // Decode claims 225 let dpop_claims_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD 226 .decode(dpop_parts.get(1).context("claims part missing")?) 227 .context("failed to decode DPoP claims")?; 228 229 let dpop_claims: serde_json::Value = 230 serde_json::from_slice(&dpop_claims_bytes).context("failed to parse DPoP claims")?; 231 232 // Check HTTP method 233 if dpop_claims.get("htm").and_then(|v| v.as_str()) != Some(parts.method.as_str()) { 234 return Err(Error::with_status( 235 StatusCode::UNAUTHORIZED, 236 anyhow!("DPoP token HTTP method mismatch"), 237 )); 238 } 239 240 // Check HTTP URI 241 let expected_uri = format!( 242 "https://{}/xrpc{}", 243 state.config.host_name, 244 parts.uri.path_and_query().expect("path and query to exist") 245 ); 246 if dpop_claims.get("htu").and_then(|v| v.as_str()) != Some(&expected_uri) { 247 return Err(Error::with_status( 248 StatusCode::UNAUTHORIZED, 249 anyhow!( 250 "DPoP token HTTP URI mismatch: expected {}, got {}", 251 expected_uri, 252 dpop_claims 253 .get("htu") 254 .and_then(|v| v.as_str()) 255 .unwrap_or("None") 256 ), 257 )); 258 } 259 260 // Verify access token hash (ath) 261 if let Some(ath) = dpop_claims.get("ath").and_then(|v| v.as_str()) { 262 let mut hasher = Sha256::new(); 263 hasher.update(access_token.as_bytes()); 264 let token_hash = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(hasher.finalize()); 265 266 if ath != token_hash { 267 return Err(Error::with_status( 268 StatusCode::UNAUTHORIZED, 269 anyhow!("DPoP access token hash mismatch"), 270 )); 271 } 272 } 273 274 // Extract JWK from DPoP header 275 let jwk = dpop_header 276 .get("jwk") 277 .context("missing jwk in DPoP header")?; 278 279 // Calculate JWK thumbprint - RFC7638 compliant 280 let key_type = jwk 281 .get("kty") 282 .and_then(|v| v.as_str()) 283 .context("JWK missing kty property")?; 284 285 // Define required properties for each key type 286 let required_props = match key_type { 287 "EC" => &["crv", "kty", "x", "y"][..], 288 "RSA" => &["e", "kty", "n"][..], 289 _ => { 290 return Err(Error::with_status( 291 StatusCode::UNAUTHORIZED, 292 anyhow!("Unsupported JWK key type: {}", key_type), 293 )); 294 } 295 }; 296 297 // Build a new JWK with only the required properties 298 let mut canonical_jwk = serde_json::Map::new(); 299 300 for prop in required_props { 301 let value = jwk 302 .get(prop) 303 .context(format!("JWK missing required property: {prop}"))?; 304 drop(canonical_jwk.insert((*prop).to_owned(), value.clone())); 305 } 306 307 // Serialize with no whitespace 308 let canonical_json = serde_json::to_string(&serde_json::Value::Object(canonical_jwk)) 309 .context("Failed to serialize canonical JWK")?; 310 311 // Hash the canonical representation 312 let mut hasher = Sha256::new(); 313 hasher.update(canonical_json.as_bytes()); 314 let calculated_thumbprint = 315 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(hasher.finalize()); 316 317 // Get JWK thumbprint from the access token "cnf" claim 318 let jkt = claims 319 .get("cnf") 320 .and_then(|v| v.as_object()) 321 .and_then(|o| o.get("jkt")) 322 .and_then(|v| v.as_str()) 323 .context("missing or invalid 'cnf.jkt' in access token")?; 324 325 // Verify JWK thumbprint 326 if calculated_thumbprint != jkt { 327 return Err(Error::with_status( 328 StatusCode::UNAUTHORIZED, 329 anyhow!("JWK thumbprint mismatch"), 330 )); 331 } 332 333 // Check JTI replay using the database 334 let jti = dpop_claims 335 .get("jti") 336 .and_then(|j| j.as_str()) 337 .context("Missing jti claim in DPoP token")?; 338 339 let timestamp = chrono::Utc::now().timestamp(); 340 341 use crate::schema::pds::oauth_used_jtis::dsl as JtiSchema; 342 343 // Check if JTI has been used before 344 let jti_string = jti.to_owned(); 345 let jti_used = state 346 .db 347 .get() 348 .await 349 .expect("failed to get db connection") 350 .interact(move |conn| { 351 JtiSchema::oauth_used_jtis 352 .filter(JtiSchema::jti.eq(jti_string)) 353 .count() 354 .get_result::<i64>(conn) 355 }) 356 .await 357 .expect("failed to query JTI") 358 .expect("failed to get JTI count"); 359 360 if jti_used > 0 { 361 return Err(Error::with_status( 362 StatusCode::UNAUTHORIZED, 363 anyhow!("DPoP token has been replayed"), 364 )); 365 } 366 367 // Store the JTI to prevent replay attacks 368 // Get expiry from token or default to 60 seconds 369 let exp = dpop_claims 370 .get("exp") 371 .and_then(serde_json::Value::as_i64) 372 .unwrap_or_else(|| timestamp.checked_add(60).unwrap_or(timestamp)); 373 374 // Convert SQLx INSERT to Diesel 375 let jti_str = jti.to_owned(); 376 let thumbprint_str = calculated_thumbprint.to_string(); 377 let _ = state 378 .db 379 .get() 380 .await 381 .expect("failed to get db connection") 382 .interact(move |conn| { 383 diesel::insert_into(JtiSchema::oauth_used_jtis) 384 .values(( 385 JtiSchema::jti.eq(jti_str), 386 JtiSchema::issuer.eq(thumbprint_str), 387 JtiSchema::created_at.eq(timestamp), 388 JtiSchema::expires_at.eq(exp), 389 )) 390 .execute(conn) 391 }) 392 .await 393 .expect("failed to insert JTI") 394 .expect("failed to insert JTI"); 395 396 // Extract subject (DID) from access token 397 if let Some(did) = claims.get("sub").and_then(|v| v.as_str()) { 398 use crate::schema::pds::account::dsl as AccountSchema; 399 400 let did_clone = did.to_owned(); 401 402 let _did = state 403 .db 404 .get() 405 .await 406 .expect("failed to get db connection") 407 .interact(move |conn| { 408 AccountSchema::account 409 .filter(AccountSchema::did.eq(did_clone)) 410 .select(AccountSchema::did) 411 .first::<String>(conn) 412 }) 413 .await 414 .expect("failed to query account") 415 .expect("failed to get account"); 416 417 Ok(AuthenticatedUser { 418 did: did.to_owned(), 419 }) 420 } else { 421 Err(Error::with_status( 422 StatusCode::UNAUTHORIZED, 423 anyhow!("invalid access token: missing subject"), 424 )) 425 } 426} 427 428/// Cryptographically sign a JSON web token with the specified key. 429pub(crate) fn sign( 430 key: &Secp256k1Keypair, 431 typ: &str, 432 claims: &serde_json::Value, 433) -> anyhow::Result<String> { 434 // RFC 9068 435 let hdr = serde_json::json!({ 436 "typ": typ, 437 "alg": "ES256K", // Secp256k1Keypair 438 }); 439 let hdr = base64::prelude::BASE64_URL_SAFE_NO_PAD 440 .encode(serde_json::to_vec(&hdr).context("failed to encode claims")?); 441 let claims = base64::prelude::BASE64_URL_SAFE_NO_PAD 442 .encode(serde_json::to_vec(&claims).context("failed to encode claims")?); 443 let sig = key 444 .sign(format!("{hdr}.{claims}").as_bytes()) 445 .context("failed to sign jwt")?; 446 let sig = base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(sig); 447 448 Ok(format!("{hdr}.{claims}.{sig}")) 449} 450 451/// Cryptographically verify a JSON web token's validity using the specified public key. 452pub(crate) fn verify(key: &str, token: &str) -> anyhow::Result<(String, serde_json::Value)> { 453 let mut parts = token.splitn(3, '.'); 454 let hdr = parts.next().context("no header")?; 455 let claims = parts.next().context("no claims")?; 456 let sig = base64::prelude::BASE64_URL_SAFE_NO_PAD 457 .decode(parts.next().context("no sig")?) 458 .context("failed to decode signature")?; 459 460 let (alg, key) = atrium_crypto::did::parse_did_key(key).context("failed to decode key")?; 461 Verifier::default() 462 .verify(alg, &key, format!("{hdr}.{claims}").as_bytes(), &sig) 463 .context("failed to verify jwt")?; 464 465 let hdr = base64::prelude::BASE64_URL_SAFE_NO_PAD 466 .decode(hdr) 467 .context("failed to decode hdr")?; 468 let hdr = 469 serde_json::from_slice::<serde_json::Value>(&hdr).context("failed to parse hdr as json")?; 470 let typ = hdr 471 .get("typ") 472 .and_then(serde_json::Value::as_str) 473 .context("hdr is invalid")?; 474 475 let claims = base64::prelude::BASE64_URL_SAFE_NO_PAD 476 .decode(claims) 477 .context("failed to decode claims")?; 478 let claims = serde_json::from_slice(&claims).context("failed to parse claims as json")?; 479 480 Ok((typ.to_owned(), claims)) 481}