Alternative ATProto PDS implementation
1//! Testing utilities for the PDS.
2#![expect(clippy::arbitrary_source_item_ordering)]
3use std::{
4 net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener},
5 path::PathBuf,
6 time::{Duration, Instant},
7};
8
9use anyhow::Result;
10use atrium_api::{
11 com::atproto::server,
12 types::string::{AtIdentifier, Did, Handle, Nsid, RecordKey},
13};
14use figment::{Figment, providers::Format as _};
15use futures::future::join_all;
16use serde::{Deserialize, Serialize};
17use tokio::sync::OnceCell;
18use uuid::Uuid;
19
20use crate::config::AppConfig;
21
22/// Global test state, created once for all tests.
23pub(crate) static TEST_STATE: OnceCell<TestState> = OnceCell::const_new();
24
25/// A temporary test directory that will be cleaned up when the struct is dropped.
26struct TempDir {
27 /// The path to the directory.
28 path: PathBuf,
29}
30
31impl TempDir {
32 /// Create a new temporary directory.
33 fn new() -> Result<Self> {
34 let path = std::env::temp_dir().join(format!("bluepds-test-{}", Uuid::new_v4()));
35 std::fs::create_dir_all(&path)?;
36 Ok(Self { path })
37 }
38
39 /// Get the path to the directory.
40 fn path(&self) -> &PathBuf {
41 &self.path
42 }
43}
44
45impl Drop for TempDir {
46 fn drop(&mut self) {
47 drop(std::fs::remove_dir_all(&self.path));
48 }
49}
50
51/// Test state for the application.
52pub(crate) struct TestState {
53 /// The address the test server is listening on.
54 address: SocketAddr,
55 /// The HTTP client.
56 client: reqwest::Client,
57 /// The application configuration.
58 config: AppConfig,
59 /// The temporary directory for test data.
60 #[expect(dead_code)]
61 temp_dir: TempDir,
62}
63
64impl TestState {
65 /// Get a base URL for the test server.
66 pub(crate) fn base_url(&self) -> String {
67 format!("http://{}", self.address)
68 }
69
70 /// Create a test account.
71 pub(crate) async fn create_test_account(&self) -> Result<TestAccount> {
72 // Create the account
73 let handle = "test.handle";
74 let response = self
75 .client
76 .post(format!(
77 "http://{}/xrpc/com.atproto.server.createAccount",
78 self.address
79 ))
80 .json(&server::create_account::InputData {
81 did: None,
82 verification_code: None,
83 verification_phone: None,
84 email: Some(format!("{}@example.com", &handle)),
85 handle: Handle::new(handle.to_owned()).expect("should be able to create handle"),
86 password: Some("password123".to_owned()),
87 invite_code: None,
88 recovery_key: None,
89 plc_op: None,
90 })
91 .send()
92 .await?;
93
94 let account: server::create_account::Output = response.json().await?;
95
96 Ok(TestAccount {
97 handle: handle.to_owned(),
98 did: account.did.to_string(),
99 access_token: account.access_jwt.clone(),
100 refresh_token: account.refresh_jwt.clone(),
101 })
102 }
103
104 /// Create a new test state.
105 #[expect(clippy::unused_async)]
106 async fn new() -> Result<Self> {
107 // Configure the test app
108 #[derive(Serialize, Deserialize)]
109 struct TestConfigInput {
110 db: Option<String>,
111 host_name: Option<String>,
112 key: Option<PathBuf>,
113 listen_address: Option<SocketAddr>,
114 test: Option<bool>,
115 }
116 // Create a temporary directory for test data
117 let temp_dir = TempDir::new()?;
118
119 // Find a free port
120 let listener = TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))?;
121 let address = listener.local_addr()?;
122 drop(listener);
123
124 let test_config = TestConfigInput {
125 db: Some(format!("sqlite://{}/test.db", temp_dir.path().display())),
126 host_name: Some(format!("localhost:{}", address.port())),
127 key: Some(temp_dir.path().join("test.key")),
128 listen_address: Some(address),
129 test: Some(true),
130 };
131
132 let config: AppConfig = Figment::new()
133 .admerge(figment::providers::Toml::file("default.toml"))
134 .admerge(figment::providers::Env::prefixed("BLUEPDS_"))
135 .merge(figment::providers::Serialized::defaults(test_config))
136 .merge(
137 figment::providers::Toml::string(
138 r#"
139 [firehose]
140 relays = []
141
142 [repo]
143 path = "repo"
144
145 [plc]
146 path = "plc"
147
148 [blob]
149 path = "blob"
150 limit = 10485760 # 10 MB
151 "#,
152 )
153 .nested(),
154 )
155 .extract()?;
156
157 // Create directories
158 std::fs::create_dir_all(temp_dir.path().join("repo"))?;
159 std::fs::create_dir_all(temp_dir.path().join("plc"))?;
160 std::fs::create_dir_all(temp_dir.path().join("blob"))?;
161
162 // Create client
163 let client = reqwest::Client::builder()
164 .timeout(Duration::from_secs(30))
165 .build()?;
166
167 Ok(Self {
168 address,
169 client,
170 config,
171 temp_dir,
172 })
173 }
174
175 /// Start the application in a background task.
176 async fn start_app(&self) -> Result<()> {
177 // Get a reference to the config that can be moved into the task
178 let config = self.config.clone();
179 let address = self.address;
180
181 // Start the application in a background task
182 let _handle = tokio::spawn(async move {
183 // Set up the application
184 use crate::*;
185
186 // Initialize metrics (noop in test mode)
187 drop(metrics::setup(None));
188
189 // Create client
190 let simple_client = reqwest::Client::builder()
191 .user_agent(APP_USER_AGENT)
192 .build()
193 .context("failed to build requester client")?;
194 let client = reqwest_middleware::ClientBuilder::new(simple_client.clone())
195 .with(http_cache_reqwest::Cache(http_cache_reqwest::HttpCache {
196 mode: CacheMode::Default,
197 manager: MokaManager::default(),
198 options: HttpCacheOptions::default(),
199 }))
200 .build();
201
202 // Create a test keypair
203 std::fs::create_dir_all(config.key.parent().context("should have parent")?)?;
204 let (skey, rkey) = {
205 let skey = Secp256k1Keypair::create(&mut rand::thread_rng());
206 let rkey = Secp256k1Keypair::create(&mut rand::thread_rng());
207
208 let keys = KeyData {
209 skey: skey.export(),
210 rkey: rkey.export(),
211 };
212
213 let mut f =
214 std::fs::File::create(&config.key).context("failed to create key file")?;
215 serde_ipld_dagcbor::to_writer(&mut f, &keys)
216 .context("failed to serialize crypto keys")?;
217
218 (SigningKey(Arc::new(skey)), RotationKey(Arc::new(rkey)))
219 };
220
221 // Set up database
222 let opts = SqliteConnectOptions::from_str(&config.db)
223 .context("failed to parse database options")?
224 .create_if_missing(true);
225 let db = SqlitePool::connect_with(opts).await?;
226
227 sqlx::migrate!()
228 .run(&db)
229 .await
230 .context("failed to apply migrations")?;
231
232 // Create firehose
233 let (_fh, fhp) = firehose::spawn(client.clone(), config.clone());
234
235 // Create the application state
236 let app_state = AppState {
237 cred: azure_identity::DefaultAzureCredential::new()?,
238 config: config.clone(),
239 db: db.clone(),
240 client: client.clone(),
241 simple_client,
242 firehose: fhp,
243 signing_key: skey,
244 rotation_key: rkey,
245 };
246
247 // Create the router
248 let app = Router::new()
249 .route("/", get(index))
250 .merge(oauth::routes())
251 .nest(
252 "/xrpc",
253 endpoints::routes()
254 .merge(actor_endpoints::routes())
255 .fallback(service_proxy),
256 )
257 .layer(CorsLayer::permissive())
258 .layer(TraceLayer::new_for_http())
259 .with_state(app_state);
260
261 // Listen for connections
262 let listener = TcpListener::bind(&address)
263 .await
264 .context("failed to bind address")?;
265
266 axum::serve(listener, app.into_make_service())
267 .await
268 .context("failed to serve app")
269 });
270
271 // Give the server a moment to start
272 tokio::time::sleep(Duration::from_millis(500)).await;
273
274 Ok(())
275 }
276}
277
278/// A test account that can be used for testing.
279pub(crate) struct TestAccount {
280 /// The access token for the account.
281 pub(crate) access_token: String,
282 /// The account DID.
283 pub(crate) did: String,
284 /// The account handle.
285 pub(crate) handle: String,
286 /// The refresh token for the account.
287 #[expect(dead_code)]
288 pub(crate) refresh_token: String,
289}
290
291/// Initialize the test state.
292pub(crate) async fn init_test_state() -> Result<&'static TestState> {
293 async fn init_test_state() -> std::result::Result<TestState, anyhow::Error> {
294 let state = TestState::new().await?;
295 state.start_app().await?;
296 Ok(state)
297 }
298 TEST_STATE.get_or_try_init(init_test_state).await
299}
300
301/// Create a record benchmark that creates records and measures the time it takes.
302#[expect(
303 clippy::arithmetic_side_effects,
304 clippy::integer_division,
305 clippy::integer_division_remainder_used,
306 clippy::use_debug,
307 clippy::print_stdout
308)]
309pub(crate) async fn create_record_benchmark(count: usize, concurrent: usize) -> Result<Duration> {
310 // Initialize the test state
311 let state = init_test_state().await?;
312
313 // Create a test account
314 let account = state.create_test_account().await?;
315
316 // Create the client with authorization
317 let client = reqwest::Client::builder()
318 .timeout(Duration::from_secs(30))
319 .build()?;
320
321 let start = Instant::now();
322
323 // Split the work into batches
324 let mut handles = Vec::new();
325 for batch_idx in 0..concurrent {
326 let batch_size = count / concurrent;
327 let client = client.clone();
328 let base_url = state.base_url();
329 let account_did = account.did.clone();
330 let account_handle = account.handle.clone();
331 let access_token = account.access_token.clone();
332
333 let handle = tokio::spawn(async move {
334 let mut results = Vec::new();
335
336 for i in 0..batch_size {
337 let request_start = Instant::now();
338 let record_idx = batch_idx * batch_size + i;
339
340 let result = client
341 .post(format!("{base_url}/xrpc/com.atproto.repo.createRecord"))
342 .header("Authorization", format!("Bearer {access_token}"))
343 .json(&atrium_api::com::atproto::repo::create_record::InputData {
344 repo: AtIdentifier::Did(Did::new(account_did.clone()).expect("valid DID")),
345 collection: Nsid::new("app.bsky.feed.post".to_owned()).expect("valid NSID"),
346 rkey: Some(
347 RecordKey::new(format!("test-{record_idx}")).expect("valid record key"),
348 ),
349 validate: None,
350 record: serde_json::from_str(
351 &serde_json::json!({
352 "$type": "app.bsky.feed.post",
353 "text": format!("Test post {record_idx} from {account_handle}"),
354 "createdAt": chrono::Utc::now().to_rfc3339(),
355 })
356 .to_string(),
357 )
358 .expect("valid JSON record"),
359 swap_commit: None,
360 })
361 .send()
362 .await;
363
364 // Fetch the record we just created
365 let get_response = client
366 .get(format!(
367 "{base_url}/xrpc/com.atproto.sync.getRecord?did={account_did}&collection=app.bsky.feed.post&rkey={record_idx}"
368 ))
369 .header("Authorization", format!("Bearer {access_token}"))
370 .send()
371 .await;
372 if get_response.is_err() {
373 println!("Failed to fetch record {record_idx}: {get_response:?}");
374 results.push(get_response);
375 continue;
376 }
377
378 let request_duration = request_start.elapsed();
379 if record_idx % 10 == 0 {
380 println!("Created record {record_idx} in {request_duration:?}");
381 }
382 results.push(result);
383 }
384
385 results
386 });
387
388 handles.push(handle);
389 }
390
391 // Wait for all batches to complete
392 let results = join_all(handles).await;
393
394 // Check for errors
395 for batch_result in results {
396 let batch_responses = batch_result?;
397 for response_result in batch_responses {
398 match response_result {
399 Ok(response) => {
400 if !response.status().is_success() {
401 return Err(anyhow::anyhow!(
402 "Failed to create record: {}",
403 response.status()
404 ));
405 }
406 }
407 Err(err) => {
408 return Err(anyhow::anyhow!("Failed to create record: {}", err));
409 }
410 }
411 }
412 }
413
414 let duration = start.elapsed();
415 Ok(duration)
416}
417
418#[cfg(test)]
419#[expect(clippy::module_inception, clippy::use_debug, clippy::print_stdout)]
420mod tests {
421 use super::*;
422 use anyhow::anyhow;
423
424 #[tokio::test]
425 async fn test_create_account() -> Result<()> {
426 return Ok(());
427 #[expect(unreachable_code, reason = "Disabled")]
428 let state = init_test_state().await?;
429 let account = state.create_test_account().await?;
430
431 println!("Created test account: {}", account.handle);
432 if account.handle.is_empty() {
433 return Err(anyhow::anyhow!("Account handle is empty"));
434 }
435 if account.did.is_empty() {
436 return Err(anyhow::anyhow!("Account DID is empty"));
437 }
438 if account.access_token.is_empty() {
439 return Err(anyhow::anyhow!("Account access token is empty"));
440 }
441
442 Ok(())
443 }
444
445 #[tokio::test]
446 async fn test_create_record_benchmark() -> Result<()> {
447 return Ok(());
448 #[expect(unreachable_code, reason = "Disabled")]
449 let duration = create_record_benchmark(100, 1).await?;
450
451 println!("Created 100 records in {duration:?}");
452
453 if duration.as_secs() >= 10 {
454 return Err(anyhow!("Benchmark took too long"));
455 }
456
457 Ok(())
458 }
459}