Alternative ATProto PDS implementation

lint

Changed files
+276 -246
src
+28 -20
src/auth.rs
··· 5 5 }; 6 6 use axum::{extract::FromRequestParts, http::StatusCode}; 7 7 use base64::Engine as _; 8 - use sha2::{Digest, Sha256}; 8 + use sha2::{Digest as _, Sha256}; 9 9 10 10 use crate::{AppState, Error, error::ErrorMessage}; 11 11 ··· 39 39 Error::with_status(StatusCode::UNAUTHORIZED, anyhow!("no authorization header")) 40 40 })?; 41 41 42 - let auth_str = auth_header.to_str().map_err(|_| { 42 + let auth_str = auth_header.to_str().map_err(|e| { 43 43 Error::with_status( 44 44 StatusCode::UNAUTHORIZED, 45 - anyhow!("authorization header should be valid utf-8"), 45 + anyhow!("authorization header should be valid utf-8").context(e), 46 46 ) 47 47 })?; 48 48 ··· 52 52 53 53 // Handle different token types 54 54 if auth_str.starts_with("Bearer ") || auth_str.starts_with("DPoP ") { 55 - let token = auth_str.splitn(2, ' ').nth(1).unwrap(); 55 + let token = auth_str 56 + .split_once(' ') 57 + .expect("Auth string should have a space") 58 + .1; 59 + 56 60 if has_dpop { 57 61 // Process DPoP token - the Authorization header contains the access token 58 62 // and the DPoP header contains the proof 59 - let dpop_token = dpop_header.unwrap().to_str().map_err(|_| { 60 - Error::with_status( 61 - StatusCode::UNAUTHORIZED, 62 - anyhow!("DPoP header should be valid utf-8"), 63 - ) 64 - })?; 63 + let dpop_token = dpop_header 64 + .expect("DPoP header should exist") 65 + .to_str() 66 + .map_err(|e| { 67 + Error::with_status( 68 + StatusCode::UNAUTHORIZED, 69 + anyhow!("DPoP header should be valid utf-8").context(e), 70 + ) 71 + })?; 65 72 66 73 return validate_dpop_token(token, dpop_token, parts, state).await; 67 - } else { 68 - // Standard Bearer token 69 - return validate_bearer_token(token, state).await; 70 74 } 75 + 76 + // Standard Bearer token 77 + return validate_bearer_token(token, state).await; 71 78 } 72 79 73 80 // If we reach here, no valid authorization method was found ··· 139 146 } 140 147 } 141 148 149 + #[expect(clippy::too_many_lines, reason = "validating dpop has many loc")] 142 150 /// Validate a DPoP token and proof 143 151 async fn validate_dpop_token( 144 152 access_token: &str, ··· 165 173 } 166 174 167 175 // Check token expiration 168 - if let Some(exp) = claims.get("exp").and_then(|v| v.as_i64()) { 176 + if let Some(exp) = claims.get("exp").and_then(serde_json::Value::as_i64) { 169 177 let now = chrono::Utc::now().timestamp(); 170 178 if now >= exp { 171 179 return Err(Error::with_message( ··· 187 195 188 196 // Decode header 189 197 let dpop_header_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD 190 - .decode(dpop_parts[0]) 198 + .decode(dpop_parts.first().context("header part missing")?) 191 199 .context("failed to decode DPoP header")?; 192 200 193 201 let dpop_header: serde_json::Value = ··· 203 211 204 212 // Decode claims 205 213 let dpop_claims_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD 206 - .decode(dpop_parts[1]) 214 + .decode(dpop_parts.get(1).context("claims part missing")?) 207 215 .context("failed to decode DPoP claims")?; 208 216 209 217 let dpop_claims: serde_json::Value = ··· 280 288 for prop in required_props { 281 289 let value = jwk 282 290 .get(prop) 283 - .context(format!("JWK missing required property: {}", prop))?; 284 - drop(canonical_jwk.insert(prop.to_string(), value.clone())); 291 + .context(format!("JWK missing required property: {prop}"))?; 292 + drop(canonical_jwk.insert((*prop).to_owned(), value.clone())); 285 293 } 286 294 287 295 // Serialize with no whitespace ··· 336 344 // Get expiry from token or default to 60 seconds 337 345 let exp = dpop_claims 338 346 .get("exp") 339 - .and_then(|e| e.as_i64()) 340 - .unwrap_or_else(|| timestamp + 60); 347 + .and_then(serde_json::Value::as_i64) 348 + .unwrap_or_else(|| timestamp.checked_add(60).unwrap_or(timestamp)); 341 349 342 350 _ = sqlx::query!( 343 351 r#"
+5 -2
src/endpoints/sync.rs
··· 234 234 active, 235 235 status, 236 236 did: input.did.clone(), 237 - rev: Some(atrium_api::types::string::Tid::new(r.rev).unwrap()), 237 + rev: Some( 238 + atrium_api::types::string::Tid::new(r.rev).expect("should be able to convert Tid"), 239 + ), 238 240 } 239 241 .into(), 240 242 )) ··· 371 373 head: atrium_api::types::string::Cid::new( 372 374 Cid::from_str(&r.root).expect("should be a valid CID"), 373 375 ), 374 - rev: atrium_api::types::string::Tid::new(r.rev).unwrap(), 376 + rev: atrium_api::types::string::Tid::new(r.rev) 377 + .expect("should be able to convert Tid"), 375 378 status: None, 376 379 } 377 380 .into()
+4 -4
src/firehose.rs
··· 81 81 }, 82 82 } 83 83 84 - impl Into<sync::subscribe_repos::RepoOp> for RepoOp { 85 - fn into(self) -> sync::subscribe_repos::RepoOp { 86 - let (action, cid, prev, path) = match self { 84 + impl From<RepoOp> for sync::subscribe_repos::RepoOp { 85 + fn from(val: RepoOp) -> Self { 86 + let (action, cid, prev, path) = match val { 87 87 RepoOp::Create { cid, path } => ("create", Some(cid), None, path), 88 88 RepoOp::Update { cid, path, prev } => ("update", Some(cid), Some(prev), path), 89 89 RepoOp::Delete { path, prev } => ("delete", None, Some(prev), path), ··· 131 131 prev_data: val.pcid.map(atrium_api::types::CidLink), 132 132 rebase: false, 133 133 repo: val.did, 134 - rev: Tid::new(val.rev).unwrap(), 134 + rev: Tid::new(val.rev).expect("should be valid revision"), 135 135 seq: 0, 136 136 since: None, 137 137 time: Datetime::now(),
+104 -95
src/oauth.rs
··· 3 3 use crate::metrics::AUTH_FAILED; 4 4 use crate::{AppConfig, AppState, Client, Db, Error, Result, SigningKey}; 5 5 use anyhow::{Context as _, anyhow}; 6 - use argon2::{Argon2, PasswordHash, PasswordVerifier}; 7 - use atrium_crypto::keypair::Did; 6 + use argon2::{Argon2, PasswordHash, PasswordVerifier as _}; 7 + use atrium_crypto::keypair::Did as _; 8 8 use axum::response::Redirect; 9 9 use axum::{ 10 10 Json, Router, extract, ··· 13 13 response::IntoResponse, 14 14 routing::{get, post}, 15 15 }; 16 - use base64::Engine; 16 + use base64::Engine as _; 17 17 use metrics::counter; 18 18 use rand::distributions::Alphanumeric; 19 - use rand::{Rng, thread_rng}; 19 + use rand::{Rng as _, thread_rng}; 20 20 use serde::{Deserialize, Serialize}; 21 21 use serde_json::{Value, json}; 22 - use sha2::Digest; 22 + use sha2::Digest as _; 23 23 use std::collections::{HashMap, HashSet}; 24 24 25 25 /// JWK thumbprint required properties for each key type (RFC7638) ··· 31 31 ("RSA", &["e", "kty", "n"]), 32 32 ]; 33 33 34 + /// JWT ID used record for tracking used JTIs to prevent replay attacks 35 + #[derive(Debug, Serialize, Deserialize)] 36 + struct JtiRecord { 37 + expires_at: i64, 38 + issuer: String, 39 + jti: String, 40 + } 41 + 34 42 /// Parses a JWT without validation and returns header and claims 35 43 fn parse_jwt(token: &str) -> Result<(Value, Value)> { 36 44 let parts: Vec<&str> = token.split('.').collect(); ··· 42 50 } 43 51 44 52 let header_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD 45 - .decode(parts[0]) 53 + .decode(parts.first().expect("should have JWT header")) 46 54 .context("Failed to decode JWT header")?; 47 55 48 56 let header: Value = 49 57 serde_json::from_slice(&header_bytes).context("Failed to parse JWT header as JSON")?; 50 58 51 59 let claims_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD 52 - .decode(parts[1]) 60 + .decode(parts.get(1).expect("should have JWT claims")) 53 61 .context("Failed to decode JWT claims")?; 54 62 55 63 let claims: Value = ··· 75 83 // Find required properties for this key type 76 84 let required_props = JWK_REQUIRED_PROPS 77 85 .iter() 78 - .find(|(kty, _)| *kty == key_type) 79 - .map(|(_, props)| *props) 80 - .context(format!("Unsupported key type: {}", key_type))?; 86 + .find(|&&(kty, _)| kty == key_type) 87 + .map(|&(_, props)| props) 88 + .context(anyhow!("Unsupported key type: {key_type}"))?; 81 89 82 90 // Build a new JWK with only the required properties 83 91 let mut canonical_jwk = serde_json::Map::new(); 84 92 85 - for prop in required_props { 93 + for &prop in required_props { 86 94 let value = jwk 87 95 .get(prop) 88 - .context(format!("JWK missing required property: {}", prop))?; 89 - drop(canonical_jwk.insert(prop.to_string(), value.clone())); 96 + .context(anyhow!("JWK missing required property: {prop}"))?; 97 + drop(canonical_jwk.insert((*prop).to_string(), value.clone())); 90 98 } 91 99 92 100 // Serialize with no whitespace ··· 149 157 }))) 150 158 } 151 159 152 - /// Fetch and validate client metadata from client_id URL 160 + /// Fetch and validate client metadata from `client_id` URL 153 161 async fn fetch_client_metadata(client: &Client, client_id: &str) -> Result<Value> { 154 162 // Handle localhost development 155 163 if client_id.starts_with("http://localhost") { ··· 167 175 }); 168 176 169 177 // Extract redirect_uri from query params if available 170 - let redirect_uris = if let Some(query) = client_url.query() { 171 - let pairs: HashMap<_, _> = url::form_urlencoded::parse(query.as_bytes()).collect(); 172 - if let Some(uri) = pairs.get("redirect_uri") { 173 - vec![json!(uri)] 174 - } else { 178 + let redirect_uris = client_url.query().map_or_else( 179 + || { 175 180 vec![ 176 181 json!("http://127.0.0.1/callback"), 177 182 json!("http://[::1]/callback"), 178 183 ] 179 - } 180 - } else { 181 - vec![ 182 - json!("http://127.0.0.1/callback"), 183 - json!("http://[::1]/callback"), 184 - ] 185 - }; 184 + }, 185 + |query| { 186 + let pairs: HashMap<_, _> = url::form_urlencoded::parse(query.as_bytes()).collect(); 187 + pairs.get("redirect_uri").map_or_else( 188 + || { 189 + vec![ 190 + json!("http://127.0.0.1/callback"), 191 + json!("http://[::1]/callback"), 192 + ] 193 + }, 194 + |uri| vec![json!(uri)], 195 + ) 196 + }, 197 + ); 186 198 187 - metadata["redirect_uris"] = json!(redirect_uris); 199 + if let Some(redirect_uris_value) = metadata.as_object_mut() { 200 + drop(redirect_uris_value.insert("redirect_uris".to_owned(), json!(redirect_uris))); 201 + } 202 + 188 203 return Ok(metadata); 189 204 } 190 205 ··· 221 236 // Validate DPoP tokens requirement 222 237 if !metadata 223 238 .get("dpop_bound_access_tokens") 224 - .and_then(|v| v.as_bool()) 239 + .and_then(Value::as_bool) 225 240 .unwrap_or(false) 226 241 { 227 242 return Err(Error::with_status( ··· 233 248 Ok(metadata) 234 249 } 235 250 236 - /// JWT ID used record for tracking used JTIs to prevent replay attacks 237 - #[derive(Debug, Serialize, Deserialize)] 238 - struct JtiRecord { 239 - jti: String, 240 - issuer: String, 241 - expires_at: i64, 242 - } 243 - 244 251 /// Pushed Authorization Request endpoint 245 252 /// POST `/oauth/par` 253 + #[expect(clippy::too_many_lines)] 246 254 async fn par( 247 255 State(db): State<Db>, 248 256 State(client): State<Client>, ··· 296 304 .and_then(|uris| uris.as_array()) 297 305 .context("client metadata missing redirect_uris")?; 298 306 299 - let uri_valid = allowed_uris.iter().any(|uri| { 300 - uri.as_str() 301 - .map_or(false, |uri_str| uri_str == provided_uri) 302 - }); 307 + let uri_valid = allowed_uris 308 + .iter() 309 + .any(|uri| uri.as_str().is_some_and(|uri_str| uri_str == provided_uri)); 303 310 304 311 if !uri_valid { 305 312 return Err(Error::with_status( ··· 310 317 } else if client_metadata 311 318 .get("redirect_uris") 312 319 .and_then(|uris| uris.as_array()) 313 - .map_or(0, |uris| uris.len()) 320 + .map_or(0, Vec::len) 314 321 != 1 315 322 { 316 323 return Err(Error::with_status( ··· 340 347 .take(32) 341 348 .map(char::from) 342 349 .collect::<String>(); 343 - let request_uri = format!("urn:ietf:params:oauth:request_uri:req-{}", request_id); 350 + let request_uri = format!("urn:ietf:params:oauth:request_uri:req-{request_id}"); 344 351 345 352 // Store request data in the database 346 353 let now = chrono::Utc::now(); ··· 378 385 379 386 Ok(Json(json!({ 380 387 "request_uri": request_uri, 381 - "expires_in": 300 // 5 minutes 388 + "expires_in": 300_i32 // 5 minutes 382 389 }))) 383 390 } 384 391 ··· 482 489 483 490 /// OAuth Authorization Sign-in endpoint 484 491 /// POST `/oauth/authorize/sign-in` 492 + #[expect(clippy::too_many_lines)] 485 493 async fn authorize_signin( 486 494 State(db): State<Db>, 487 495 State(config): State<AppConfig>, 488 496 State(client): State<Client>, 489 497 extract::Form(form_data): extract::Form<HashMap<String, String>>, 490 498 ) -> Result<impl IntoResponse> { 499 + use std::fmt::Write as _; 500 + 491 501 // Extract form data 492 502 let username = form_data.get("username").context("username is required")?; 493 503 let password = form_data.get("password").context("password is required")?; ··· 539 549 .context("failed to query database")? 540 550 .context("user not found")?; 541 551 542 - // Verify password 543 - match Argon2::default().verify_password( 552 + // Verify password - fixed to use equality check instead of pattern matching 553 + if Argon2::default().verify_password( 544 554 password.as_bytes(), 545 555 &PasswordHash::new(account.password.as_str()).context("invalid password hash in db")?, 546 - ) { 547 - Ok(()) => {} 548 - Err(_) => { 549 - counter!(AUTH_FAILED).increment(1); 550 - return Err(Error::with_status( 551 - StatusCode::UNAUTHORIZED, 552 - anyhow!("invalid credentials"), 553 - )); 554 - } 556 + ) == Ok(()) 557 + { 558 + } else { 559 + counter!(AUTH_FAILED).increment(1); 560 + return Err(Error::with_status( 561 + StatusCode::UNAUTHORIZED, 562 + anyhow!("invalid credentials"), 563 + )); 555 564 } 556 565 557 566 // Generate authorization code ··· 562 571 .collect::<String>(); 563 572 564 573 // Determine redirect URI to use 565 - let redirect_uri = if let Some(uri) = &par_request.redirect_uri { 574 + let redirect_uri = if let Some(uri) = par_request.redirect_uri.as_ref() { 566 575 uri.clone() 567 576 } else { 568 577 let client_metadata = fetch_client_metadata(&client, client_id).await?; ··· 572 581 .and_then(|uris| uris.first()) 573 582 .and_then(|uri| uri.as_str()) 574 583 .context("No redirect_uri available")? 575 - .to_string() 584 + .to_owned() 576 585 }; 577 586 578 587 // Store the authorization code ··· 615 624 }); 616 625 617 626 // Build redirect URL 618 - let mut redirect_url = redirect_uri; 619 - match par_request.response_mode { 620 - None => redirect_url.push_str("?"), // Default to query 621 - Some(response_mode) => match response_mode.as_str() { 622 - "query" => redirect_url.push_str("?"), 623 - "fragment" => redirect_url.push_str("#"), 624 - _ => redirect_url.push_str("?"), // Default to query 625 - }, 626 - }; 627 - redirect_url.push_str(&format!("state={}", urlencoding::encode(&state))); 627 + let mut redirect_target = redirect_uri; 628 + match par_request.response_mode.as_deref() { 629 + Some("fragment") => redirect_target.push('#'), 630 + _ => redirect_target.push('?'), 631 + } 632 + 633 + write!(redirect_target, "state={}", urlencoding::encode(&state)).unwrap(); 628 634 let host_name = format!("https://{}", &config.host_name); 629 - redirect_url.push_str(&format!("&iss={}", urlencoding::encode(&host_name))); 630 - redirect_url.push_str(&format!("&code={}", urlencoding::encode(&code))); 631 - Ok(Redirect::to(&redirect_url)) 635 + write!(redirect_target, "&iss={}", urlencoding::encode(&host_name)).unwrap(); 636 + write!(redirect_target, "&code={}", urlencoding::encode(&code)).unwrap(); 637 + Ok(Redirect::to(&redirect_target)) 632 638 } 633 639 634 640 /// Verify a DPoP proof and return the JWK thumbprint ··· 650 656 let (header, claims) = parse_jwt(dpop_token)?; 651 657 652 658 // Verify "typ" is "dpop+jwt" 653 - if header.get("typ").and_then(|t| t.as_str()) != Some("dpop+jwt") { 659 + if header.get("typ").and_then(Value::as_str) != Some("dpop+jwt") { 654 660 return Err(Error::with_status( 655 661 StatusCode::BAD_REQUEST, 656 662 anyhow!("Invalid DPoP token type"), ··· 660 666 // Verify required claims 661 667 let jti = claims 662 668 .get("jti") 663 - .and_then(|j| j.as_str()) 669 + .and_then(Value::as_str) 664 670 .context("Missing jti claim in DPoP token")?; 665 671 666 672 // Check for token expiration 673 + #[expect(clippy::arithmetic_side_effects)] 667 674 let exp = claims 668 675 .get("exp") 669 - .and_then(|e| e.as_i64()) 676 + .and_then(Value::as_i64) 670 677 .unwrap_or_else(|| chrono::Utc::now().timestamp() + 60); // Default 60s if not specified 671 678 672 679 let now = chrono::Utc::now().timestamp(); ··· 678 685 } 679 686 680 687 // Check htm (HTTP method) claim 681 - if claims.get("htm").and_then(|m| m.as_str()) != Some(http_method) { 688 + if claims.get("htm").and_then(Value::as_str) != Some(http_method) { 682 689 return Err(Error::with_status( 683 690 StatusCode::BAD_REQUEST, 684 691 anyhow!("Invalid htm claim in DPoP token"), ··· 686 693 } 687 694 688 695 // Check htu (HTTP URI) claim 689 - if claims.get("htu").and_then(|u| u.as_str()) != Some(http_uri) { 696 + if claims.get("htu").and_then(Value::as_str) != Some(http_uri) { 690 697 return Err(Error::with_status( 691 698 StatusCode::BAD_REQUEST, 692 699 anyhow!( 693 700 "Invalid htu claim in DPoP token: expected {}, got {}", 694 701 http_uri, 695 - claims.get("htu").and_then(|u| u.as_str()).unwrap_or("None") 702 + claims.get("htu").and_then(Value::as_str).unwrap_or("None") 696 703 ), 697 704 )); 698 705 } ··· 739 746 .context("failed to store JTI")?; 740 747 741 748 // Cleanup expired JTIs periodically (1% chance on each request) 742 - if thread_rng().gen_range(0..100) == 0 { 749 + if thread_rng().gen_range(0_i32..100_i32) == 0_i32 { 743 750 _ = sqlx::query!(r#"DELETE FROM oauth_used_jtis WHERE expires_at < ?"#, now) 744 751 .execute(db) 745 752 .await ··· 749 756 Ok(thumbprint) 750 757 } 751 758 752 - /// Verify a code_verifier against stored code_challenge 759 + /// Verify a `code_verifier` against stored `code_challenge` 753 760 fn verify_pkce(code_verifier: &str, stored_challenge: &str, method: &str) -> Result<()> { 754 761 // Only S256 is supported currently 755 762 if method != "S256" { ··· 778 785 /// OAuth token endpoint 779 786 /// - POST `/oauth/token` 780 787 /// 781 - /// Handles both authorization_code and refresh_token grants 788 + /// Handles both `authorization_code` and `refresh_token` grants 789 + #[expect(clippy::too_many_lines)] 782 790 async fn token( 783 791 State(db): State<Db>, 784 792 State(skey): State<SigningKey>, ··· 857 865 // Generate tokens 858 866 let now = chrono::Utc::now().timestamp(); 859 867 let access_token_expires_in = 3600; // 1 hour 868 + #[expect(clippy::arithmetic_side_effects)] 860 869 let access_token_expires_at = now + access_token_expires_in; 861 - let refresh_token_expires_at = now + 2592000; // 30 days 870 + #[expect(clippy::arithmetic_side_effects)] 871 + let refresh_token_expires_at = now + 2_592_000; // 30 days 862 872 863 873 // Create access token 864 874 let access_token_claims = json!({ ··· 961 971 // Generate new tokens 962 972 let now = chrono::Utc::now().timestamp(); 963 973 let access_token_expires_in = 3600; // 1 hour 974 + #[expect(clippy::arithmetic_side_effects)] 964 975 let access_token_expires_at = now + access_token_expires_in; 965 - let refresh_token_expires_at = now + 2592000; // 30 days 976 + #[expect(clippy::arithmetic_side_effects)] 977 + let refresh_token_expires_at = now + 2_592_000; // 30 days 966 978 967 979 // Create access token 968 980 let access_token_claims = json!({ ··· 1043 1055 // For a real implementation, you would construct a proper JWK 1044 1056 // with all the required fields based on the key type 1045 1057 1046 - let key_did = skey.did(); 1058 + let did_string = skey.did(); 1047 1059 1048 1060 // Extract the key ID from the DID string 1049 1061 // did:key:z... format, where z... is the multibase-encoded public key 1050 - let key_id = key_did.strip_prefix("did:key:").unwrap_or(&key_did); 1062 + let key_id = did_string.strip_prefix("did:key:").unwrap_or(&did_string); 1051 1063 1052 1064 let jwk = json!({ 1053 1065 "kty": "EC", ··· 1077 1089 ) -> Result<Json<Value>> { 1078 1090 // Extract required parameters 1079 1091 let token = form_data.get("token").context("token is required")?; 1080 - let refresh_token_string = "refresh_token".to_string(); 1092 + let refresh_token_string = "refresh_token".to_owned(); 1081 1093 let token_type_hint = form_data 1082 1094 .get("token_type_hint") 1083 1095 .unwrap_or(&refresh_token_string); ··· 1117 1129 let token_type_hint = form_data.get("token_type_hint"); 1118 1130 1119 1131 // Parse the token 1120 - let (typ, claims) = match crate::auth::verify(&skey.did(), token) { 1121 - Ok(result) => result, 1122 - Err(_) => { 1123 - // Per RFC7662, invalid tokens return { "active": false } 1124 - return Ok(Json(json!({"active": false}))); 1125 - } 1132 + let Ok((typ, claims)) = crate::auth::verify(&skey.did(), token) else { 1133 + // Per RFC7662, invalid tokens return { "active": false } 1134 + return Ok(Json(json!({"active": false}))); 1126 1135 }; 1127 1136 1128 1137 // Check token type ··· 1143 1152 } 1144 1153 1145 1154 // Check expiration 1146 - if let Some(exp) = claims.get("exp").and_then(|v| v.as_i64()) { 1155 + if let Some(exp) = claims.get("exp").and_then(Value::as_i64) { 1147 1156 let now = chrono::Utc::now().timestamp(); 1148 1157 if now >= exp { 1149 1158 return Ok(Json(json!({"active": false}))); ··· 1169 1178 } 1170 1179 1171 1180 // Token is valid, return introspection info 1172 - let subject = claims.get("sub").and_then(|v| v.as_str()); 1173 - let client_id = claims.get("aud").and_then(|v| v.as_str()); 1174 - let scope = claims.get("scope").and_then(|v| v.as_str()); 1175 - let expiration = claims.get("exp").and_then(|v| v.as_i64()); 1176 - let issued_at = claims.get("iat").and_then(|v| v.as_i64()); 1181 + let subject = claims.get("sub").and_then(Value::as_str); 1182 + let client_id = claims.get("aud").and_then(Value::as_str); 1183 + let scope = claims.get("scope").and_then(Value::as_str); 1184 + let expiration = claims.get("exp").and_then(Value::as_i64); 1185 + let issued_at = claims.get("iat").and_then(Value::as_i64); 1177 1186 1178 1187 Ok(Json(json!({ 1179 1188 "active": true,
+135 -125
src/tests.rs
··· 1 1 //! Testing utilities for the PDS. 2 - 2 + #![expect(clippy::arbitrary_source_item_ordering)] 3 3 use std::{ 4 4 net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener}, 5 5 path::PathBuf, 6 - sync::Arc, 7 6 time::{Duration, Instant}, 8 7 }; 9 8 10 - use anyhow::{Context as _, Result}; 9 + use anyhow::Result; 11 10 use atrium_api::{ 12 11 com::atproto::server, 13 - types::{ 14 - Unknown, 15 - string::{AtIdentifier, Did, Handle, Nsid, RecordKey}, 16 - }, 12 + types::string::{AtIdentifier, Did, Handle, Nsid, RecordKey}, 17 13 }; 18 14 use figment::{Figment, providers::Format as _}; 19 15 use futures::future::join_all; 20 - use rand::{Rng, thread_rng}; 21 16 use serde::{Deserialize, Serialize}; 22 17 use tokio::sync::OnceCell; 23 18 use uuid::Uuid; 24 19 25 - use crate::{AppState, auth::AuthenticatedUser, config::AppConfig}; 20 + use crate::config::AppConfig; 26 21 27 22 /// Global test state, created once for all tests. 28 23 pub(crate) static TEST_STATE: OnceCell<TestState> = OnceCell::const_new(); ··· 49 44 50 45 impl Drop for TempDir { 51 46 fn drop(&mut self) { 52 - let _ = std::fs::remove_dir_all(&self.path); 47 + drop(std::fs::remove_dir_all(&self.path)); 53 48 } 54 49 } 55 50 56 51 /// Test state for the application. 57 52 pub(crate) struct TestState { 58 - /// The temporary directory for test data. 59 - temp_dir: TempDir, 60 53 /// The address the test server is listening on. 61 54 address: SocketAddr, 55 + /// The HTTP client. 56 + client: reqwest::Client, 62 57 /// The application configuration. 63 58 config: AppConfig, 64 - /// The HTTP client. 65 - client: reqwest::Client, 59 + /// The temporary directory for test data. 60 + #[expect(dead_code)] 61 + temp_dir: TempDir, 66 62 } 67 63 68 64 impl TestState { 69 - /// Create a new test state. 70 - async fn new() -> Result<Self> { 71 - // Create a temporary directory for test data 72 - let temp_dir = TempDir::new()?; 65 + /// Get a base URL for the test server. 66 + pub(crate) fn base_url(&self) -> String { 67 + format!("http://{}", self.address) 68 + } 69 + 70 + /// Create a test account. 71 + pub(crate) async fn create_test_account(&self) -> Result<TestAccount> { 72 + // Create the account 73 + let handle = "test.handle"; 74 + let response = self 75 + .client 76 + .post(format!( 77 + "http://{}/xrpc/com.atproto.server.createAccount", 78 + self.address 79 + )) 80 + .json(&server::create_account::InputData { 81 + did: None, 82 + verification_code: None, 83 + verification_phone: None, 84 + email: Some(format!("{}@example.com", &handle)), 85 + handle: Handle::new(handle.to_owned()).expect("should be able to create handle"), 86 + password: Some("password123".to_owned()), 87 + invite_code: None, 88 + recovery_key: None, 89 + plc_op: None, 90 + }) 91 + .send() 92 + .await?; 73 93 74 - // Find a free port 75 - let listener = TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))?; 76 - let address = listener.local_addr()?; 77 - drop(listener); 94 + let account: server::create_account::Output = response.json().await?; 78 95 96 + Ok(TestAccount { 97 + handle: handle.to_owned(), 98 + did: account.did.to_string(), 99 + access_token: account.access_jwt.clone(), 100 + refresh_token: account.refresh_jwt.clone(), 101 + }) 102 + } 103 + 104 + /// Create a new test state. 105 + #[expect(clippy::unused_async)] 106 + async fn new() -> Result<Self> { 79 107 // Configure the test app 80 108 #[derive(Serialize, Deserialize)] 81 109 struct TestConfigInput { 110 + db: Option<String>, 82 111 host_name: Option<String>, 83 - db: Option<String>, 112 + key: Option<PathBuf>, 84 113 listen_address: Option<SocketAddr>, 85 - key: Option<PathBuf>, 86 114 test: Option<bool>, 87 115 } 116 + // Create a temporary directory for test data 117 + let temp_dir = TempDir::new()?; 118 + 119 + // Find a free port 120 + let listener = TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))?; 121 + let address = listener.local_addr()?; 122 + drop(listener); 88 123 89 124 let test_config = TestConfigInput { 125 + db: Some(format!("sqlite://{}/test.db", temp_dir.path().display())), 90 126 host_name: Some(format!("localhost:{}", address.port())), 91 - db: Some(format!("sqlite://{}/test.db", temp_dir.path().display())), 127 + key: Some(temp_dir.path().join("test.key")), 92 128 listen_address: Some(address), 93 - key: Some(temp_dir.path().join("test.key")), 94 129 test: Some(true), 95 130 }; 96 131 ··· 130 165 .build()?; 131 166 132 167 Ok(Self { 133 - temp_dir, 134 168 address, 135 - config, 136 169 client, 170 + config, 171 + temp_dir, 137 172 }) 138 173 } 139 174 ··· 144 179 let address = self.address; 145 180 146 181 // Start the application in a background task 147 - tokio::spawn(async move { 182 + let _handle = tokio::spawn(async move { 148 183 // Set up the application 149 184 use crate::*; 150 185 151 186 // Initialize metrics (noop in test mode) 152 - let _ = metrics::setup(None); 187 + drop(metrics::setup(None)); 153 188 154 189 // Create client 155 190 let simple_client = reqwest::Client::builder() ··· 158 193 .context("failed to build requester client")?; 159 194 let client = reqwest_middleware::ClientBuilder::new(simple_client.clone()) 160 195 .with(http_cache_reqwest::Cache(http_cache_reqwest::HttpCache { 161 - mode: http_cache_reqwest::CacheMode::Default, 162 - manager: http_cache_reqwest::MokaManager::default(), 163 - options: http_cache_reqwest::HttpCacheOptions::default(), 196 + mode: CacheMode::Default, 197 + manager: MokaManager::default(), 198 + options: HttpCacheOptions::default(), 164 199 })) 165 200 .build(); 166 201 167 202 // Create a test keypair 168 - std::fs::create_dir_all(&config.key.parent().context("should have parent")?)?; 203 + std::fs::create_dir_all(config.key.parent().context("should have parent")?)?; 169 204 let (skey, rkey) = { 170 - let skey = 171 - atrium_crypto::keypair::Secp256k1Keypair::create(&mut rand::thread_rng()); 172 - let rkey = 173 - atrium_crypto::keypair::Secp256k1Keypair::create(&mut rand::thread_rng()); 205 + let skey = Secp256k1Keypair::create(&mut rand::thread_rng()); 206 + let rkey = Secp256k1Keypair::create(&mut rand::thread_rng()); 174 207 175 208 let keys = KeyData { 176 209 skey: skey.export(), ··· 186 219 }; 187 220 188 221 // Set up database 189 - let opts = sqlx::sqlite::SqliteConnectOptions::from_str(&config.db) 222 + let opts = SqliteConnectOptions::from_str(&config.db) 190 223 .context("failed to parse database options")? 191 224 .create_if_missing(true); 192 - let db = sqlx::SqlitePool::connect_with(opts).await?; 225 + let db = SqlitePool::connect_with(opts).await?; 193 226 194 227 sqlx::migrate!() 195 228 .run(&db) ··· 212 245 }; 213 246 214 247 // Create the router 215 - let app = axum::Router::new() 216 - .route("/", axum::routing::get(crate::index)) 217 - .merge(crate::oauth::routes()) 248 + let app = Router::new() 249 + .route("/", get(index)) 250 + .merge(oauth::routes()) 218 251 .nest( 219 252 "/xrpc", 220 - crate::endpoints::routes() 221 - .merge(crate::actor_endpoints::routes()) 222 - .fallback(crate::service_proxy), 253 + endpoints::routes() 254 + .merge(actor_endpoints::routes()) 255 + .fallback(service_proxy), 223 256 ) 224 - .layer(tower_http::cors::CorsLayer::permissive()) 225 - .layer(tower_http::trace::TraceLayer::new_for_http()) 257 + .layer(CorsLayer::permissive()) 258 + .layer(TraceLayer::new_for_http()) 226 259 .with_state(app_state); 227 260 228 - println!("Test server listening on {address}"); 229 - 230 261 // Listen for connections 231 - let listener = tokio::net::TcpListener::bind(&address) 262 + let listener = TcpListener::bind(&address) 232 263 .await 233 264 .context("failed to bind address")?; 234 265 ··· 242 273 243 274 Ok(()) 244 275 } 245 - 246 - /// Create a test account. 247 - pub async fn create_test_account(&self) -> Result<TestAccount> { 248 - let handle = "test.handle"; 249 - println!("Creating test account with handle: {}", handle); 250 - 251 - // Create the account 252 - let response = self 253 - .client 254 - .post(&format!( 255 - "http://{}/xrpc/com.atproto.server.createAccount", 256 - self.address 257 - )) 258 - .json(&server::create_account::InputData { 259 - did: None, 260 - verification_code: None, 261 - verification_phone: None, 262 - email: Some(format!("{}@example.com", &handle)), 263 - handle: Handle::new(handle.to_owned()).expect("should be able to create handle"), 264 - password: Some("password123".to_string()), 265 - invite_code: None, 266 - recovery_key: None, 267 - plc_op: None, 268 - }) 269 - .send() 270 - .await?; 271 - 272 - let account: server::create_account::Output = response.json().await?; 273 - 274 - Ok(TestAccount { 275 - handle: handle.to_owned(), 276 - did: account.did.to_string(), 277 - access_token: account.access_jwt.clone(), 278 - refresh_token: account.refresh_jwt.clone(), 279 - }) 280 - } 281 - 282 - /// Get a base URL for the test server. 283 - pub fn base_url(&self) -> String { 284 - format!("http://{}", self.address) 285 - } 286 276 } 287 277 288 278 /// A test account that can be used for testing. 289 - pub struct TestAccount { 290 - /// The account handle. 291 - pub handle: String, 292 - /// The account DID. 293 - pub did: String, 279 + pub(crate) struct TestAccount { 294 280 /// The access token for the account. 295 - pub access_token: String, 281 + pub(crate) access_token: String, 282 + /// The account DID. 283 + pub(crate) did: String, 284 + /// The account handle. 285 + pub(crate) handle: String, 296 286 /// The refresh token for the account. 297 - pub refresh_token: String, 287 + #[expect(dead_code)] 288 + pub(crate) refresh_token: String, 298 289 } 299 290 300 291 /// Initialize the test state. 301 - pub async fn init_test_state() -> Result<&'static TestState> { 302 - TEST_STATE 303 - .get_or_try_init(|| async { 304 - let state = TestState::new().await?; 305 - state.start_app().await?; 306 - Ok(state) 307 - }) 308 - .await 292 + pub(crate) async fn init_test_state() -> Result<&'static TestState> { 293 + async fn init_test_state() -> std::result::Result<TestState, anyhow::Error> { 294 + let state = TestState::new().await?; 295 + state.start_app().await?; 296 + Ok(state) 297 + } 298 + TEST_STATE.get_or_try_init(init_test_state).await 309 299 } 310 300 311 301 /// Create a record benchmark that creates records and measures the time it takes. 312 - pub async fn create_record_benchmark(count: usize, concurrent: usize) -> Result<Duration> { 302 + #[expect( 303 + clippy::arithmetic_side_effects, 304 + clippy::integer_division, 305 + clippy::integer_division_remainder_used, 306 + clippy::use_debug, 307 + clippy::print_stdout 308 + )] 309 + pub(crate) async fn create_record_benchmark(count: usize, concurrent: usize) -> Result<Duration> { 313 310 // Initialize the test state 314 311 let state = init_test_state().await?; 315 312 ··· 341 338 let record_idx = batch_idx * batch_size + i; 342 339 343 340 let result = client 344 - .post(&format!("{}/xrpc/com.atproto.repo.createRecord", base_url)) 345 - .header("Authorization", format!("Bearer {}", access_token)) 341 + .post(format!("{base_url}/xrpc/com.atproto.repo.createRecord")) 342 + .header("Authorization", format!("Bearer {access_token}")) 346 343 .json(&atrium_api::com::atproto::repo::create_record::InputData { 347 - repo: AtIdentifier::Did(Did::new(account_did.clone()).unwrap()), 348 - collection: Nsid::new("app.bsky.feed.post".to_string()).unwrap(), 349 - rkey: Some(RecordKey::new(format!("test-{}", record_idx)).unwrap()), 344 + repo: AtIdentifier::Did(Did::new(account_did.clone()).expect("valid DID")), 345 + collection: Nsid::new("app.bsky.feed.post".to_owned()).expect("valid NSID"), 346 + rkey: Some( 347 + RecordKey::new(format!("test-{record_idx}")).expect("valid record key"), 348 + ), 350 349 validate: None, 351 350 record: serde_json::from_str( 352 351 &serde_json::json!({ 353 352 "$type": "app.bsky.feed.post", 354 - "text": format!("Test post {} from {}", record_idx, account_handle), 353 + "text": format!("Test post {record_idx} from {account_handle}"), 355 354 "createdAt": chrono::Utc::now().to_rfc3339(), 356 355 }) 357 356 .to_string(), 358 357 ) 359 - .unwrap(), 358 + .expect("valid JSON record"), 360 359 swap_commit: None, 361 360 }) 362 361 .send() ··· 364 363 365 364 let request_duration = request_start.elapsed(); 366 365 if record_idx % 10 == 0 { 367 - println!("Created record {} in {:?}", record_idx, request_duration); 366 + println!("Created record {record_idx} in {request_duration:?}"); 368 367 } 369 368 results.push(result); 370 369 } ··· 379 378 let results = join_all(handles).await; 380 379 381 380 // Check for errors 382 - for result in results { 383 - let batch_results = result?; 384 - for result in batch_results { 385 - match result { 381 + for batch_result in results { 382 + let batch_responses = batch_result?; 383 + for response_result in batch_responses { 384 + match response_result { 386 385 Ok(response) => { 387 386 if !response.status().is_success() { 388 387 return Err(anyhow::anyhow!( ··· 403 402 } 404 403 405 404 #[cfg(test)] 405 + #[expect(clippy::module_inception, clippy::use_debug, clippy::print_stdout)] 406 406 mod tests { 407 407 use super::*; 408 + use anyhow::anyhow; 408 409 409 410 #[tokio::test] 410 411 async fn test_create_account() -> Result<()> { ··· 412 413 let account = state.create_test_account().await?; 413 414 414 415 println!("Created test account: {}", account.handle); 415 - assert!(!account.handle.is_empty()); 416 - assert!(!account.did.is_empty()); 417 - assert!(!account.access_token.is_empty()); 416 + if account.handle.is_empty() { 417 + return Err(anyhow::anyhow!("Account handle is empty")); 418 + } 419 + if account.did.is_empty() { 420 + return Err(anyhow::anyhow!("Account DID is empty")); 421 + } 422 + if account.access_token.is_empty() { 423 + return Err(anyhow::anyhow!("Account access token is empty")); 424 + } 418 425 419 426 Ok(()) 420 427 } ··· 423 430 async fn test_create_record_benchmark() -> Result<()> { 424 431 let duration = create_record_benchmark(100, 1).await?; 425 432 426 - println!("Created 100 records in {:?}", duration); 427 - assert!(duration.as_secs() < 100, "Benchmark took too long"); 433 + println!("Created 100 records in {duration:?}"); 434 + 435 + if duration.as_secs() >= 100 { 436 + return Err(anyhow!("Benchmark took too long")); 437 + } 428 438 429 439 Ok(()) 430 440 }