Alternative ATProto PDS implementation
1//! Testing utilities for the PDS. 2#![expect(clippy::arbitrary_source_item_ordering)] 3use std::{ 4 net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener}, 5 path::PathBuf, 6 time::{Duration, Instant}, 7}; 8 9use anyhow::Result; 10use atrium_api::{ 11 com::atproto::server, 12 types::string::{AtIdentifier, Did, Handle, Nsid, RecordKey}, 13}; 14use figment::{Figment, providers::Format as _}; 15use futures::future::join_all; 16use serde::{Deserialize, Serialize}; 17use tokio::sync::OnceCell; 18use uuid::Uuid; 19 20use crate::config::AppConfig; 21 22/// Global test state, created once for all tests. 23pub(crate) static TEST_STATE: OnceCell<TestState> = OnceCell::const_new(); 24 25/// A temporary test directory that will be cleaned up when the struct is dropped. 26struct TempDir { 27 /// The path to the directory. 28 path: PathBuf, 29} 30 31impl TempDir { 32 /// Create a new temporary directory. 33 fn new() -> Result<Self> { 34 let path = std::env::temp_dir().join(format!("bluepds-test-{}", Uuid::new_v4())); 35 std::fs::create_dir_all(&path)?; 36 Ok(Self { path }) 37 } 38 39 /// Get the path to the directory. 40 fn path(&self) -> &PathBuf { 41 &self.path 42 } 43} 44 45impl Drop for TempDir { 46 fn drop(&mut self) { 47 drop(std::fs::remove_dir_all(&self.path)); 48 } 49} 50 51/// Test state for the application. 52pub(crate) struct TestState { 53 /// The address the test server is listening on. 54 address: SocketAddr, 55 /// The HTTP client. 56 client: reqwest::Client, 57 /// The application configuration. 58 config: AppConfig, 59 /// The temporary directory for test data. 60 #[expect(dead_code)] 61 temp_dir: TempDir, 62} 63 64impl TestState { 65 /// Get a base URL for the test server. 66 pub(crate) fn base_url(&self) -> String { 67 format!("http://{}", self.address) 68 } 69 70 /// Create a test account. 71 pub(crate) async fn create_test_account(&self) -> Result<TestAccount> { 72 // Create the account 73 let handle = "test.handle"; 74 let response = self 75 .client 76 .post(format!( 77 "http://{}/xrpc/com.atproto.server.createAccount", 78 self.address 79 )) 80 .json(&server::create_account::InputData { 81 did: None, 82 verification_code: None, 83 verification_phone: None, 84 email: Some(format!("{}@example.com", &handle)), 85 handle: Handle::new(handle.to_owned()).expect("should be able to create handle"), 86 password: Some("password123".to_owned()), 87 invite_code: None, 88 recovery_key: None, 89 plc_op: None, 90 }) 91 .send() 92 .await?; 93 94 let account: server::create_account::Output = response.json().await?; 95 96 Ok(TestAccount { 97 handle: handle.to_owned(), 98 did: account.did.to_string(), 99 access_token: account.access_jwt.clone(), 100 refresh_token: account.refresh_jwt.clone(), 101 }) 102 } 103 104 /// Create a new test state. 105 #[expect(clippy::unused_async)] 106 async fn new() -> Result<Self> { 107 // Configure the test app 108 #[derive(Serialize, Deserialize)] 109 struct TestConfigInput { 110 db: Option<String>, 111 host_name: Option<String>, 112 key: Option<PathBuf>, 113 listen_address: Option<SocketAddr>, 114 test: Option<bool>, 115 } 116 // Create a temporary directory for test data 117 let temp_dir = TempDir::new()?; 118 119 // Find a free port 120 let listener = TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))?; 121 let address = listener.local_addr()?; 122 drop(listener); 123 124 let test_config = TestConfigInput { 125 db: Some(format!("sqlite://{}/test.db", temp_dir.path().display())), 126 host_name: Some(format!("localhost:{}", address.port())), 127 key: Some(temp_dir.path().join("test.key")), 128 listen_address: Some(address), 129 test: Some(true), 130 }; 131 132 let config: AppConfig = Figment::new() 133 .admerge(figment::providers::Toml::file("default.toml")) 134 .admerge(figment::providers::Env::prefixed("BLUEPDS_")) 135 .merge(figment::providers::Serialized::defaults(test_config)) 136 .merge( 137 figment::providers::Toml::string( 138 r#" 139 [firehose] 140 relays = [] 141 142 [repo] 143 path = "repo" 144 145 [plc] 146 path = "plc" 147 148 [blob] 149 path = "blob" 150 limit = 10485760 # 10 MB 151 "#, 152 ) 153 .nested(), 154 ) 155 .extract()?; 156 157 // Create directories 158 std::fs::create_dir_all(temp_dir.path().join("repo"))?; 159 std::fs::create_dir_all(temp_dir.path().join("plc"))?; 160 std::fs::create_dir_all(temp_dir.path().join("blob"))?; 161 162 // Create client 163 let client = reqwest::Client::builder() 164 .timeout(Duration::from_secs(30)) 165 .build()?; 166 167 Ok(Self { 168 address, 169 client, 170 config, 171 temp_dir, 172 }) 173 } 174 175 /// Start the application in a background task. 176 async fn start_app(&self) -> Result<()> { 177 // Get a reference to the config that can be moved into the task 178 let config = self.config.clone(); 179 let address = self.address; 180 181 // Start the application in a background task 182 let _handle = tokio::spawn(async move { 183 // Set up the application 184 use crate::*; 185 186 // Initialize metrics (noop in test mode) 187 drop(metrics::setup(None)); 188 189 // Create client 190 let simple_client = reqwest::Client::builder() 191 .user_agent(APP_USER_AGENT) 192 .build() 193 .context("failed to build requester client")?; 194 let client = reqwest_middleware::ClientBuilder::new(simple_client.clone()) 195 .with(http_cache_reqwest::Cache(http_cache_reqwest::HttpCache { 196 mode: CacheMode::Default, 197 manager: MokaManager::default(), 198 options: HttpCacheOptions::default(), 199 })) 200 .build(); 201 202 // Create a test keypair 203 std::fs::create_dir_all(config.key.parent().context("should have parent")?)?; 204 let (skey, rkey) = { 205 let skey = Secp256k1Keypair::create(&mut rand::thread_rng()); 206 let rkey = Secp256k1Keypair::create(&mut rand::thread_rng()); 207 208 let keys = KeyData { 209 skey: skey.export(), 210 rkey: rkey.export(), 211 }; 212 213 let mut f = 214 std::fs::File::create(&config.key).context("failed to create key file")?; 215 serde_ipld_dagcbor::to_writer(&mut f, &keys) 216 .context("failed to serialize crypto keys")?; 217 218 (SigningKey(Arc::new(skey)), RotationKey(Arc::new(rkey))) 219 }; 220 221 // Set up database 222 let opts = SqliteConnectOptions::from_str(&config.db) 223 .context("failed to parse database options")? 224 .create_if_missing(true); 225 let db = SqlitePool::connect_with(opts).await?; 226 227 sqlx::migrate!() 228 .run(&db) 229 .await 230 .context("failed to apply migrations")?; 231 232 // Create firehose 233 let (_fh, fhp) = firehose::spawn(client.clone(), config.clone()); 234 235 // Create the application state 236 let app_state = AppState { 237 cred: azure_identity::DefaultAzureCredential::new()?, 238 config: config.clone(), 239 db: db.clone(), 240 client: client.clone(), 241 simple_client, 242 firehose: fhp, 243 signing_key: skey, 244 rotation_key: rkey, 245 }; 246 247 // Create the router 248 let app = Router::new() 249 .route("/", get(index)) 250 .merge(oauth::routes()) 251 .nest( 252 "/xrpc", 253 endpoints::routes() 254 .merge(actor_endpoints::routes()) 255 .fallback(service_proxy), 256 ) 257 .layer(CorsLayer::permissive()) 258 .layer(TraceLayer::new_for_http()) 259 .with_state(app_state); 260 261 // Listen for connections 262 let listener = TcpListener::bind(&address) 263 .await 264 .context("failed to bind address")?; 265 266 axum::serve(listener, app.into_make_service()) 267 .await 268 .context("failed to serve app") 269 }); 270 271 // Give the server a moment to start 272 tokio::time::sleep(Duration::from_millis(500)).await; 273 274 Ok(()) 275 } 276} 277 278/// A test account that can be used for testing. 279pub(crate) struct TestAccount { 280 /// The access token for the account. 281 pub(crate) access_token: String, 282 /// The account DID. 283 pub(crate) did: String, 284 /// The account handle. 285 pub(crate) handle: String, 286 /// The refresh token for the account. 287 #[expect(dead_code)] 288 pub(crate) refresh_token: String, 289} 290 291/// Initialize the test state. 292pub(crate) async fn init_test_state() -> Result<&'static TestState> { 293 async fn init_test_state() -> std::result::Result<TestState, anyhow::Error> { 294 let state = TestState::new().await?; 295 state.start_app().await?; 296 Ok(state) 297 } 298 TEST_STATE.get_or_try_init(init_test_state).await 299} 300 301/// Create a record benchmark that creates records and measures the time it takes. 302#[expect( 303 clippy::arithmetic_side_effects, 304 clippy::integer_division, 305 clippy::integer_division_remainder_used, 306 clippy::use_debug, 307 clippy::print_stdout 308)] 309pub(crate) async fn create_record_benchmark(count: usize, concurrent: usize) -> Result<Duration> { 310 // Initialize the test state 311 let state = init_test_state().await?; 312 313 // Create a test account 314 let account = state.create_test_account().await?; 315 316 // Create the client with authorization 317 let client = reqwest::Client::builder() 318 .timeout(Duration::from_secs(30)) 319 .build()?; 320 321 let start = Instant::now(); 322 323 // Split the work into batches 324 let mut handles = Vec::new(); 325 for batch_idx in 0..concurrent { 326 let batch_size = count / concurrent; 327 let client = client.clone(); 328 let base_url = state.base_url(); 329 let account_did = account.did.clone(); 330 let account_handle = account.handle.clone(); 331 let access_token = account.access_token.clone(); 332 333 let handle = tokio::spawn(async move { 334 let mut results = Vec::new(); 335 336 for i in 0..batch_size { 337 let request_start = Instant::now(); 338 let record_idx = batch_idx * batch_size + i; 339 340 let result = client 341 .post(format!("{base_url}/xrpc/com.atproto.repo.createRecord")) 342 .header("Authorization", format!("Bearer {access_token}")) 343 .json(&atrium_api::com::atproto::repo::create_record::InputData { 344 repo: AtIdentifier::Did(Did::new(account_did.clone()).expect("valid DID")), 345 collection: Nsid::new("app.bsky.feed.post".to_owned()).expect("valid NSID"), 346 rkey: Some( 347 RecordKey::new(format!("test-{record_idx}")).expect("valid record key"), 348 ), 349 validate: None, 350 record: serde_json::from_str( 351 &serde_json::json!({ 352 "$type": "app.bsky.feed.post", 353 "text": format!("Test post {record_idx} from {account_handle}"), 354 "createdAt": chrono::Utc::now().to_rfc3339(), 355 }) 356 .to_string(), 357 ) 358 .expect("valid JSON record"), 359 swap_commit: None, 360 }) 361 .send() 362 .await; 363 364 // Fetch the record we just created 365 let get_response = client 366 .get(format!( 367 "{base_url}/xrpc/com.atproto.sync.getRecord?did={account_did}&collection=app.bsky.feed.post&rkey={record_idx}" 368 )) 369 .header("Authorization", format!("Bearer {access_token}")) 370 .send() 371 .await; 372 if get_response.is_err() { 373 println!("Failed to fetch record {record_idx}: {get_response:?}"); 374 results.push(get_response); 375 continue; 376 } 377 378 let request_duration = request_start.elapsed(); 379 if record_idx % 10 == 0 { 380 println!("Created record {record_idx} in {request_duration:?}"); 381 } 382 results.push(result); 383 } 384 385 results 386 }); 387 388 handles.push(handle); 389 } 390 391 // Wait for all batches to complete 392 let results = join_all(handles).await; 393 394 // Check for errors 395 for batch_result in results { 396 let batch_responses = batch_result?; 397 for response_result in batch_responses { 398 match response_result { 399 Ok(response) => { 400 if !response.status().is_success() { 401 return Err(anyhow::anyhow!( 402 "Failed to create record: {}", 403 response.status() 404 )); 405 } 406 } 407 Err(err) => { 408 return Err(anyhow::anyhow!("Failed to create record: {}", err)); 409 } 410 } 411 } 412 } 413 414 let duration = start.elapsed(); 415 Ok(duration) 416} 417 418#[cfg(test)] 419#[expect(clippy::module_inception, clippy::use_debug, clippy::print_stdout)] 420mod tests { 421 use super::*; 422 use anyhow::anyhow; 423 424 #[tokio::test] 425 async fn test_create_account() -> Result<()> { 426 return Ok(()); 427 #[expect(unreachable_code, reason = "Disabled")] 428 let state = init_test_state().await?; 429 let account = state.create_test_account().await?; 430 431 println!("Created test account: {}", account.handle); 432 if account.handle.is_empty() { 433 return Err(anyhow::anyhow!("Account handle is empty")); 434 } 435 if account.did.is_empty() { 436 return Err(anyhow::anyhow!("Account DID is empty")); 437 } 438 if account.access_token.is_empty() { 439 return Err(anyhow::anyhow!("Account access token is empty")); 440 } 441 442 Ok(()) 443 } 444 445 #[tokio::test] 446 async fn test_create_record_benchmark() -> Result<()> { 447 return Ok(()); 448 #[expect(unreachable_code, reason = "Disabled")] 449 let duration = create_record_benchmark(100, 1).await?; 450 451 println!("Created 100 records in {duration:?}"); 452 453 if duration.as_secs() >= 10 { 454 return Err(anyhow!("Benchmark took too long")); 455 } 456 457 Ok(()) 458 } 459}