Alternative ATProto PDS implementation

cleanup

+1 -1
Cargo.toml
··· 79 79 # unstable-features = "allow" 80 80 # # Temporary Allows 81 81 dead_code = "allow" 82 - unused_imports = "allow" 82 + # unused_imports = "allow" 83 83 84 84 [lints.clippy] 85 85 # Groups
+17 -6
src/account_manager/helpers/account.rs
··· 23 23 use thiserror::Error; 24 24 25 25 use diesel::dsl::{LeftJoinOn, exists, not}; 26 - use diesel::helper_types::{Eq, IntoBoxed}; 26 + use diesel::helper_types::Eq; 27 27 28 28 #[derive(Error, Debug)] 29 29 pub enum AccountHelperError { ··· 277 277 }) 278 278 .await 279 279 .expect("Failed to delete actor")?; 280 - let did = did.to_owned(); 280 + let did_clone = did.to_owned(); 281 281 _ = db 282 282 .get() 283 283 .await? 284 284 .interact(move |conn| { 285 285 _ = delete(EmailTokenSchema::email_token) 286 - .filter(EmailTokenSchema::did.eq(&did)) 286 + .filter(EmailTokenSchema::did.eq(&did_clone)) 287 287 .execute(conn)?; 288 288 _ = delete(RefreshTokenSchema::refresh_token) 289 - .filter(RefreshTokenSchema::did.eq(&did)) 289 + .filter(RefreshTokenSchema::did.eq(&did_clone)) 290 290 .execute(conn)?; 291 291 _ = delete(AccountSchema::account) 292 - .filter(AccountSchema::did.eq(&did)) 292 + .filter(AccountSchema::did.eq(&did_clone)) 293 293 .execute(conn)?; 294 294 delete(ActorSchema::actor) 295 - .filter(ActorSchema::did.eq(&did)) 295 + .filter(ActorSchema::did.eq(&did_clone)) 296 296 .execute(conn) 297 297 }) 298 298 .await 299 299 .expect("Failed to delete account")?; 300 + 301 + let data_repo_file = format!("data/repo/{}.db", did.to_owned()); 302 + let data_blob_path = format!("data/blob/{}", did); 303 + let data_blob_path = std::path::Path::new(&data_blob_path); 304 + let data_repo_file = std::path::Path::new(&data_repo_file); 305 + if data_repo_file.exists() { 306 + std::fs::remove_file(data_repo_file)?; 307 + }; 308 + if data_blob_path.exists() { 309 + std::fs::remove_dir_all(data_blob_path)?; 310 + }; 300 311 Ok(()) 301 312 } 302 313
+18 -15
src/account_manager/mod.rs
··· 2 2 //! blacksky-algorithms/rsky is licensed under the Apache License 2.0 3 3 //! 4 4 //! Modified for SQLite backend 5 - use crate::ActorPools; 6 5 use crate::account_manager::helpers::account::{ 7 6 AccountStatus, ActorAccount, AvailabilityFlags, GetAccountAdminStatusOutput, 8 7 }; ··· 12 11 use crate::account_manager::helpers::invite::CodeDetail; 13 12 use crate::account_manager::helpers::password::UpdateUserPasswordOpts; 14 13 use crate::models::pds::EmailTokenPurpose; 14 + use crate::serve::ActorStorage; 15 15 use anyhow::Result; 16 - use axum::extract::FromRef; 17 16 use chrono::DateTime; 18 17 use chrono::offset::Utc as UtcOffset; 19 18 use cidv10::Cid; 20 - use deadpool_diesel::sqlite::Pool; 21 19 use diesel::*; 22 20 use futures::try_join; 23 21 use helpers::{account, auth, email_token, invite, password, repo}; ··· 136 134 pub async fn create_account( 137 135 &self, 138 136 opts: CreateAccountOpts, 139 - actor_pools: &mut std::collections::HashMap<String, ActorPools>, 137 + actor_pools: &mut std::collections::HashMap<String, ActorStorage>, 140 138 ) -> Result<(String, String)> { 141 139 let CreateAccountOpts { 142 140 did, ··· 182 180 let did_path = did 183 181 .strip_prefix("did:plc:") 184 182 .ok_or_else(|| anyhow::anyhow!("Invalid DID"))?; 185 - let path_repo = format!("sqlite://{}", did_path); 183 + let repo_path = format!("sqlite://data/repo/{}.db", did_path); 186 184 let actor_repo_pool = 187 - crate::establish_pool(path_repo.as_str()).expect("Failed to establish pool"); 188 - let path_blob = path_repo.replace("repo", "blob"); 189 - let actor_blob_pool = crate::establish_pool(&path_blob).expect("Failed to establish pool"); 190 - let actor_pool = ActorPools { 185 + crate::db::establish_pool(repo_path.as_str()).expect("Failed to establish pool"); 186 + let blob_path = std::path::Path::new("data/blob").to_path_buf(); 187 + let actor_pool = ActorStorage { 191 188 repo: actor_repo_pool, 192 - blob: actor_blob_pool, 189 + blob: blob_path.clone(), 193 190 }; 194 - actor_pools 195 - .insert(did.clone(), actor_pool) 196 - .expect("Failed to insert actor pools"); 191 + let blob_path = blob_path.join(did_path); 192 + tokio::fs::create_dir_all(&blob_path) 193 + .await 194 + .map_err(|_| anyhow::anyhow!("Failed to create blob path"))?; 195 + drop( 196 + actor_pools 197 + .insert(did.clone(), actor_pool) 198 + .expect("Failed to insert actor pools"), 199 + ); 197 200 let db = actor_pools 198 201 .get(&did) 199 202 .ok_or_else(|| anyhow::anyhow!("Actor not found"))? ··· 215 218 did: String, 216 219 cid: Cid, 217 220 rev: String, 218 - actor_pools: &std::collections::HashMap<String, ActorPools>, 221 + actor_pools: &std::collections::HashMap<String, ActorStorage>, 219 222 ) -> Result<()> { 220 223 let db = actor_pools 221 224 .get(&did) ··· 228 231 pub async fn delete_account( 229 232 &self, 230 233 did: &str, 231 - actor_pools: &std::collections::HashMap<String, ActorPools>, 234 + actor_pools: &std::collections::HashMap<String, ActorStorage>, 232 235 ) -> Result<()> { 233 236 let db = actor_pools 234 237 .get(did)
+20 -7
src/actor_endpoints.rs
··· 3 3 /// We shouldn't have to know about any bsky endpoints to store private user data. 4 4 /// This will _very likely_ be changed in the future. 5 5 use atrium_api::app::bsky::actor; 6 - use axum::{Json, routing::post}; 6 + use axum::{ 7 + Json, Router, 8 + extract::State, 9 + routing::{get, post}, 10 + }; 7 11 use constcat::concat; 8 12 use diesel::prelude::*; 9 13 10 - use crate::actor_store::ActorStore; 14 + use crate::{actor_store::ActorStore, auth::AuthenticatedUser}; 11 15 12 - use super::*; 16 + use super::serve::*; 13 17 14 18 async fn put_preferences( 15 19 user: AuthenticatedUser, 16 - State(actor_pools): State<std::collections::HashMap<String, ActorPools>>, 20 + State(actor_pools): State<std::collections::HashMap<String, ActorStorage>>, 17 21 Json(input): Json<actor::put_preferences::Input>, 18 22 ) -> Result<()> { 19 23 let did = user.did(); 20 - let json_string = 21 - serde_json::to_string(&input.preferences).context("failed to serialize preferences")?; 24 + // let json_string = 25 + // serde_json::to_string(&input.preferences).context("failed to serialize preferences")?; 22 26 23 27 // let conn = &mut actor_pools 24 28 // .get(&did) ··· 35 39 // .context("failed to update user preferences") 36 40 // }); 37 41 todo!("Use actor_store's preferences writer instead"); 42 + // let mut actor_store = ActorStore::from_actor_pools(&did, &actor_pools).await; 43 + // let values = actor::defs::Preferences { 44 + // private_prefs: Some(json_string), 45 + // ..Default::default() 46 + // }; 47 + // let namespace = actor::defs::PreferencesNamespace::Private; 48 + // let scope = actor::defs::PreferencesScope::User; 49 + // actor_store.pref.put_preferences(values, namespace, scope); 50 + 38 51 Ok(()) 39 52 } 40 53 41 54 async fn get_preferences( 42 55 user: AuthenticatedUser, 43 - State(actor_pools): State<std::collections::HashMap<String, ActorPools>>, 56 + State(actor_pools): State<std::collections::HashMap<String, ActorStorage>>, 44 57 ) -> Result<Json<actor::get_preferences::Output>> { 45 58 let did = user.did(); 46 59 // let conn = &mut actor_pools
+3 -3
src/actor_store/blob.rs
··· 28 28 use rsky_repo::types::{PreparedBlobRef, PreparedWrite}; 29 29 use std::str::FromStr as _; 30 30 31 - use super::sql_blob::{BlobStoreSql, ByteStream}; 31 + use super::blob_fs::{BlobStoreFs, ByteStream}; 32 32 33 33 pub struct GetBlobOutput { 34 34 pub size: i32, ··· 39 39 /// Handles blob operations for an actor store 40 40 pub struct BlobReader { 41 41 /// SQL-based blob storage 42 - pub blobstore: BlobStoreSql, 42 + pub blobstore: BlobStoreFs, 43 43 /// DID of the actor 44 44 pub did: String, 45 45 /// Database connection ··· 52 52 impl BlobReader { 53 53 /// Create a new blob reader 54 54 pub fn new( 55 - blobstore: BlobStoreSql, 55 + blobstore: BlobStoreFs, 56 56 db: deadpool_diesel::Pool< 57 57 deadpool_diesel::Manager<SqliteConnection>, 58 58 deadpool_diesel::sqlite::Object,
+4 -5
src/actor_store/blob_fs.rs
··· 72 72 let first_level = if cid_str.len() >= 10 { 73 73 &cid_str[0..10] 74 74 } else { 75 - &cid_str 75 + "short" 76 76 }; 77 77 78 78 let second_level = if cid_str.len() >= 20 { 79 79 &cid_str[10..20] 80 80 } else { 81 - "default" 81 + "short" 82 82 }; 83 83 84 84 self.base_dir ··· 277 277 async_fs::create_dir_all(parent).await?; 278 278 } 279 279 280 - // Copy first, then delete source after success 281 - _ = async_fs::copy(&mov.from, &mov.to).await?; 282 - async_fs::remove_file(&mov.from).await?; 280 + // Move the file 281 + async_fs::rename(&mov.from, &mov.to).await?; 283 282 284 283 debug!("Moved blob: {:?} -> {:?}", mov.from, mov.to); 285 284 Ok(())
+6 -6
src/actor_store/mod.rs
··· 34 34 use tokio::sync::RwLock; 35 35 36 36 use blob::BlobReader; 37 + use blob_fs::BlobStoreFs; 37 38 use preference::PreferenceReader; 38 39 use record::RecordReader; 39 - use sql_blob::BlobStoreSql; 40 40 use sql_repo::SqlRepoReader; 41 41 42 - use crate::ActorPools; 42 + use crate::serve::ActorStorage; 43 43 44 44 #[derive(Debug)] 45 45 enum FormatCommitError { ··· 74 74 75 75 // Combination of RepoReader/Transactor, BlobReader/Transactor, SqlRepoReader/Transactor 76 76 impl ActorStore { 77 - /// Concrete reader of an individual repo (hence BlobStoreSql which takes `did` param) 77 + /// Concrete reader of an individual repo (hence BlobStoreFs which takes `did` param) 78 78 pub fn new( 79 79 did: String, 80 - blobstore: BlobStoreSql, 80 + blobstore: BlobStoreFs, 81 81 db: deadpool_diesel::Pool< 82 82 deadpool_diesel::Manager<SqliteConnection>, 83 83 deadpool_diesel::sqlite::Object, ··· 96 96 /// Create a new ActorStore taking ActorPools HashMap as input 97 97 pub async fn from_actor_pools( 98 98 did: &String, 99 - hashmap_actor_pools: &std::collections::HashMap<String, ActorPools>, 99 + hashmap_actor_pools: &std::collections::HashMap<String, ActorStorage>, 100 100 ) -> Self { 101 101 let actor_pool = hashmap_actor_pools 102 102 .get(did) 103 103 .expect("Actor pool not found") 104 104 .clone(); 105 - let blobstore = BlobStoreSql::new(did.clone(), actor_pool.blob); 105 + let blobstore = BlobStoreFs::new(did.clone(), actor_pool.blob); 106 106 let conn = actor_pool 107 107 .repo 108 108 .clone()
+6 -17
src/apis/com/atproto/repo/apply_writes.rs
··· 1 1 //! Apply a batch transaction of repository creates, updates, and deletes. Requires auth, implemented by PDS. 2 - use crate::SharedSequencer; 2 + use crate::account_manager::AccountManager; 3 3 use crate::account_manager::helpers::account::AvailabilityFlags; 4 - use crate::account_manager::{AccountManager, AccountManagerCreator, SharedAccountManager}; 5 4 use crate::{ 6 - ActorPools, AppState, SigningKey, 7 - actor_store::{ActorStore, sql_blob::BlobStoreSql}, 5 + actor_store::ActorStore, 8 6 auth::AuthenticatedUser, 9 - config::AppConfig, 10 - error::{ApiError, ErrorMessage}, 7 + error::ApiError, 8 + serve::{ActorStorage, AppState}, 11 9 }; 12 10 use anyhow::{Result, bail}; 13 - use axum::{ 14 - Json, Router, 15 - body::Body, 16 - extract::{Query, Request, State}, 17 - http::{self, StatusCode}, 18 - routing::{get, post}, 19 - }; 11 + use axum::{Json, extract::State}; 20 12 use cidv10::Cid; 21 - use deadpool_diesel::sqlite::Pool; 22 13 use futures::stream::{self, StreamExt}; 23 14 use rsky_lexicon::com::atproto::repo::{ApplyWritesInput, ApplyWritesInputRefWrite}; 24 - use rsky_pds::auth_verifier::AccessStandardIncludeChecks; 25 15 use rsky_pds::repo::prepare::{ 26 16 PrepareCreateOpts, PrepareDeleteOpts, PrepareUpdateOpts, prepare_create, prepare_delete, 27 17 prepare_update, ··· 29 19 use rsky_pds::sequencer::Sequencer; 30 20 use rsky_repo::types::PreparedWrite; 31 21 use std::str::FromStr; 32 - use std::sync::Arc; 33 22 use tokio::sync::RwLock; 34 23 35 24 async fn inner_apply_writes( 36 25 body: ApplyWritesInput, 37 26 user: AuthenticatedUser, 38 27 sequencer: &RwLock<Sequencer>, 39 - actor_pools: std::collections::HashMap<String, ActorPools>, 28 + actor_pools: std::collections::HashMap<String, ActorStorage>, 40 29 account_manager: &RwLock<AccountManager>, 41 30 ) -> Result<()> { 42 31 let tx: ApplyWritesInput = body;
+1 -1
src/apis/com/atproto/repo/mod.rs
··· 2 2 use axum::{Router, routing::post}; 3 3 use constcat::concat; 4 4 5 - use crate::AppState; 5 + use crate::serve::AppState; 6 6 7 7 pub mod apply_writes; 8 8 // pub mod create_record;
+1 -1
src/apis/mod.rs
··· 7 7 use axum::{Json, Router, routing::get}; 8 8 use serde_json::json; 9 9 10 - use crate::{AppState, Result}; 10 + use crate::serve::{AppState, Result}; 11 11 12 12 /// Health check endpoint. Returns name and version of the service. 13 13 pub(crate) async fn health() -> Result<Json<serde_json::Value>> {
+4 -1
src/auth.rs
··· 8 8 use diesel::prelude::*; 9 9 use sha2::{Digest as _, Sha256}; 10 10 11 - use crate::{AppState, Error, error::ErrorMessage}; 11 + use crate::{ 12 + error::{Error, ErrorMessage}, 13 + serve::AppState, 14 + }; 12 15 13 16 /// Request extractor for authenticated users. 14 17 /// If specified in an API endpoint, this guarantees the API can only be called
+1 -1
src/did.rs
··· 5 5 use serde::{Deserialize, Serialize}; 6 6 use url::Url; 7 7 8 - use crate::Client; 8 + use crate::serve::Client; 9 9 10 10 /// URL whitelist for DID document resolution. 11 11 const ALLOWED_URLS: &[&str] = &["bsky.app", "bsky.chat"];
+12 -11
src/error.rs
··· 148 148 149 149 impl ApiError { 150 150 /// Get the appropriate HTTP status code for this error 151 - fn status_code(&self) -> StatusCode { 151 + const fn status_code(&self) -> StatusCode { 152 152 match self { 153 153 Self::RuntimeError => StatusCode::INTERNAL_SERVER_ERROR, 154 154 Self::InvalidLogin ··· 190 190 Self::BadRequest(error, _) => error, 191 191 Self::AuthRequiredError(_) => "AuthRequiredError", 192 192 } 193 - .to_string() 193 + .to_owned() 194 194 } 195 195 196 196 /// Get the user-facing error message ··· 218 218 Self::BadRequest(_, msg) => msg, 219 219 Self::AuthRequiredError(msg) => msg, 220 220 } 221 - .to_string() 221 + .to_owned() 222 222 } 223 223 } 224 224 225 225 impl From<Error> for ApiError { 226 226 fn from(_value: Error) -> Self { 227 - ApiError::RuntimeError 227 + Self::RuntimeError 228 228 } 229 229 } 230 230 231 231 impl From<handle::errors::Error> for ApiError { 232 232 fn from(value: handle::errors::Error) -> Self { 233 233 match value.kind { 234 - ErrorKind::InvalidHandle => ApiError::InvalidHandle, 235 - ErrorKind::HandleNotAvailable => ApiError::HandleNotAvailable, 236 - ErrorKind::UnsupportedDomain => ApiError::UnsupportedDomain, 237 - ErrorKind::InternalError => ApiError::RuntimeError, 234 + ErrorKind::InvalidHandle => Self::InvalidHandle, 235 + ErrorKind::HandleNotAvailable => Self::HandleNotAvailable, 236 + ErrorKind::UnsupportedDomain => Self::UnsupportedDomain, 237 + ErrorKind::InternalError => Self::RuntimeError, 238 238 } 239 239 } 240 240 } ··· 245 245 let error_type = self.error_type(); 246 246 let message = self.message(); 247 247 248 - // Log the error for debugging 249 - error!("API Error: {}: {}", error_type, message); 248 + if cfg!(debug_assertions) { 249 + error!("API Error: {}: {}", error_type, message); 250 + } 250 251 251 252 // Create the error message and serialize to JSON 252 253 let error_message = ErrorMessage::new(error_type, message); 253 254 let body = serde_json::to_string(&error_message).unwrap_or_else(|_| { 254 - r#"{"error":"InternalServerError","message":"Error serializing response"}"#.to_string() 255 + r#"{"error":"InternalServerError","message":"Error serializing response"}"#.to_owned() 255 256 }); 256 257 257 258 // Build the response
-426
src/firehose.rs
··· 1 - //! The firehose module. 2 - use std::{collections::VecDeque, time::Duration}; 3 - 4 - use anyhow::{Result, bail}; 5 - use atrium_api::{ 6 - com::atproto::sync::{self}, 7 - types::string::{Datetime, Did, Tid}, 8 - }; 9 - use atrium_repo::Cid; 10 - use axum::extract::ws::{Message, WebSocket}; 11 - use metrics::{counter, gauge}; 12 - use rand::Rng as _; 13 - use serde::{Serialize, ser::SerializeMap as _}; 14 - use tracing::{debug, error, info, warn}; 15 - 16 - use crate::{ 17 - Client, 18 - config::AppConfig, 19 - metrics::{FIREHOSE_HISTORY, FIREHOSE_LISTENERS, FIREHOSE_MESSAGES, FIREHOSE_SEQUENCE}, 20 - }; 21 - 22 - enum FirehoseMessage { 23 - Broadcast(sync::subscribe_repos::Message), 24 - Connect(Box<(WebSocket, Option<i64>)>), 25 - } 26 - 27 - enum FrameHeader { 28 - Error, 29 - Message(String), 30 - } 31 - 32 - impl Serialize for FrameHeader { 33 - #[expect(clippy::question_mark_used, reason = "returns a Result")] 34 - fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> 35 - where 36 - S: serde::Serializer, 37 - { 38 - let mut map = serializer.serialize_map(None)?; 39 - 40 - match *self { 41 - Self::Message(ref s) => { 42 - map.serialize_key("op")?; 43 - map.serialize_value(&1_i32)?; 44 - map.serialize_key("t")?; 45 - map.serialize_value(s.as_str())?; 46 - } 47 - Self::Error => { 48 - map.serialize_key("op")?; 49 - map.serialize_value(&-1_i32)?; 50 - } 51 - } 52 - 53 - map.end() 54 - } 55 - } 56 - 57 - /// A repository operation. 58 - pub(crate) enum RepoOp { 59 - /// Create a new record. 60 - Create { 61 - /// The CID of the record. 62 - cid: Cid, 63 - /// The path of the record. 64 - path: String, 65 - }, 66 - /// Delete an existing record. 67 - Delete { 68 - /// The path of the record. 69 - path: String, 70 - /// The previous CID of the record. 71 - prev: Cid, 72 - }, 73 - /// Update an existing record. 74 - Update { 75 - /// The CID of the record. 76 - cid: Cid, 77 - /// The path of the record. 78 - path: String, 79 - /// The previous CID of the record. 80 - prev: Cid, 81 - }, 82 - } 83 - 84 - impl From<RepoOp> for sync::subscribe_repos::RepoOp { 85 - fn from(val: RepoOp) -> Self { 86 - let (action, cid, prev, path) = match val { 87 - RepoOp::Create { cid, path } => ("create", Some(cid), None, path), 88 - RepoOp::Update { cid, path, prev } => ("update", Some(cid), Some(prev), path), 89 - RepoOp::Delete { path, prev } => ("delete", None, Some(prev), path), 90 - }; 91 - 92 - sync::subscribe_repos::RepoOpData { 93 - action: action.to_owned(), 94 - cid: cid.map(atrium_api::types::CidLink), 95 - prev: prev.map(atrium_api::types::CidLink), 96 - path, 97 - } 98 - .into() 99 - } 100 - } 101 - 102 - /// A commit to the repository. 103 - pub(crate) struct Commit { 104 - /// Blobs that were created in this commit. 105 - pub blobs: Vec<Cid>, 106 - /// The car file containing the commit blocks. 107 - pub car: Vec<u8>, 108 - /// The CID of the commit. 109 - pub cid: Cid, 110 - /// The DID of the repository changed. 111 - pub did: Did, 112 - /// The operations performed in this commit. 113 - pub ops: Vec<RepoOp>, 114 - /// The previous commit's CID (if applicable). 115 - pub pcid: Option<Cid>, 116 - /// The revision of the commit. 117 - pub rev: String, 118 - } 119 - 120 - impl From<Commit> for sync::subscribe_repos::Commit { 121 - fn from(val: Commit) -> Self { 122 - sync::subscribe_repos::CommitData { 123 - blobs: val 124 - .blobs 125 - .into_iter() 126 - .map(atrium_api::types::CidLink) 127 - .collect::<Vec<_>>(), 128 - blocks: val.car, 129 - commit: atrium_api::types::CidLink(val.cid), 130 - ops: val.ops.into_iter().map(Into::into).collect::<Vec<_>>(), 131 - prev_data: val.pcid.map(atrium_api::types::CidLink), 132 - rebase: false, 133 - repo: val.did, 134 - rev: Tid::new(val.rev).expect("should be valid revision"), 135 - seq: 0, 136 - since: None, 137 - time: Datetime::now(), 138 - too_big: false, 139 - } 140 - .into() 141 - } 142 - } 143 - 144 - /// A firehose producer. This is used to transmit messages to the firehose for broadcast. 145 - #[derive(Clone, Debug)] 146 - pub(crate) struct FirehoseProducer { 147 - /// The channel to send messages to the firehose. 148 - tx: tokio::sync::mpsc::Sender<FirehoseMessage>, 149 - } 150 - 151 - impl FirehoseProducer { 152 - /// Broadcast an `#account` event. 153 - pub(crate) async fn account(&self, account: impl Into<sync::subscribe_repos::Account>) { 154 - drop( 155 - self.tx 156 - .send(FirehoseMessage::Broadcast( 157 - sync::subscribe_repos::Message::Account(Box::new(account.into())), 158 - )) 159 - .await, 160 - ); 161 - } 162 - /// Handle client connection. 163 - pub(crate) async fn client_connection(&self, ws: WebSocket, cursor: Option<i64>) { 164 - drop( 165 - self.tx 166 - .send(FirehoseMessage::Connect(Box::new((ws, cursor)))) 167 - .await, 168 - ); 169 - } 170 - /// Broadcast a `#commit` event. 171 - pub(crate) async fn commit(&self, commit: impl Into<sync::subscribe_repos::Commit>) { 172 - drop( 173 - self.tx 174 - .send(FirehoseMessage::Broadcast( 175 - sync::subscribe_repos::Message::Commit(Box::new(commit.into())), 176 - )) 177 - .await, 178 - ); 179 - } 180 - /// Broadcast an `#identity` event. 181 - pub(crate) async fn identity(&self, identity: impl Into<sync::subscribe_repos::Identity>) { 182 - drop( 183 - self.tx 184 - .send(FirehoseMessage::Broadcast( 185 - sync::subscribe_repos::Message::Identity(Box::new(identity.into())), 186 - )) 187 - .await, 188 - ); 189 - } 190 - } 191 - 192 - #[expect( 193 - clippy::as_conversions, 194 - clippy::cast_possible_truncation, 195 - clippy::cast_sign_loss, 196 - clippy::cast_precision_loss, 197 - clippy::arithmetic_side_effects 198 - )] 199 - /// Convert a `usize` to a `f64`. 200 - const fn convert_usize_f64(x: usize) -> Result<f64, &'static str> { 201 - let result = x as f64; 202 - if result as usize - x > 0 { 203 - return Err("cannot convert"); 204 - } 205 - Ok(result) 206 - } 207 - 208 - /// Serialize a message. 209 - fn serialize_message(seq: u64, mut msg: sync::subscribe_repos::Message) -> (&'static str, Vec<u8>) { 210 - let mut dummy_seq = 0_i64; 211 - #[expect(clippy::pattern_type_mismatch)] 212 - let (ty, nseq) = match &mut msg { 213 - sync::subscribe_repos::Message::Account(m) => ("#account", &mut m.seq), 214 - sync::subscribe_repos::Message::Commit(m) => ("#commit", &mut m.seq), 215 - sync::subscribe_repos::Message::Identity(m) => ("#identity", &mut m.seq), 216 - sync::subscribe_repos::Message::Sync(m) => ("#sync", &mut m.seq), 217 - sync::subscribe_repos::Message::Info(_m) => ("#info", &mut dummy_seq), 218 - }; 219 - // Set the sequence number. 220 - *nseq = i64::try_from(seq).expect("should find seq"); 221 - 222 - let hdr = FrameHeader::Message(ty.to_owned()); 223 - 224 - let mut frame = Vec::new(); 225 - serde_ipld_dagcbor::to_writer(&mut frame, &hdr).expect("should serialize header"); 226 - serde_ipld_dagcbor::to_writer(&mut frame, &msg).expect("should serialize message"); 227 - 228 - (ty, frame) 229 - } 230 - 231 - /// Broadcast a message out to all clients. 232 - async fn broadcast_message(clients: &mut Vec<WebSocket>, msg: Message) -> Result<()> { 233 - counter!(FIREHOSE_MESSAGES).increment(1); 234 - 235 - for i in (0..clients.len()).rev() { 236 - let client = clients.get_mut(i).expect("should find client"); 237 - if let Err(e) = client.send(msg.clone()).await { 238 - debug!("Firehose client disconnected: {e}"); 239 - drop(clients.remove(i)); 240 - } 241 - } 242 - 243 - gauge!(FIREHOSE_LISTENERS) 244 - .set(convert_usize_f64(clients.len()).expect("should find clients length")); 245 - Ok(()) 246 - } 247 - 248 - /// Handle a new connection from a websocket client created by subscribeRepos. 249 - async fn handle_connect( 250 - mut ws: WebSocket, 251 - seq: u64, 252 - history: &VecDeque<(u64, &str, sync::subscribe_repos::Message)>, 253 - cursor: Option<i64>, 254 - ) -> Result<WebSocket> { 255 - if let Some(cursor) = cursor { 256 - let mut frame = Vec::new(); 257 - let cursor = u64::try_from(cursor); 258 - if cursor.is_err() { 259 - tracing::warn!("cursor is not a valid u64"); 260 - return Ok(ws); 261 - } 262 - let cursor = cursor.expect("should be valid u64"); 263 - // Cursor specified; attempt to backfill the consumer. 264 - if cursor > seq { 265 - let hdr = FrameHeader::Error; 266 - let msg = sync::subscribe_repos::Error::FutureCursor(Some(format!( 267 - "cursor {cursor} is greater than the current sequence number {seq}" 268 - ))); 269 - serde_ipld_dagcbor::to_writer(&mut frame, &hdr).expect("should serialize header"); 270 - serde_ipld_dagcbor::to_writer(&mut frame, &msg).expect("should serialize message"); 271 - // Drop the connection. 272 - drop(ws.send(Message::binary(frame)).await); 273 - bail!( 274 - "connection dropped: cursor {cursor} is greater than the current sequence number {seq}" 275 - ); 276 - } 277 - 278 - for &(historical_seq, ty, ref msg) in history { 279 - if cursor > historical_seq { 280 - continue; 281 - } 282 - let hdr = FrameHeader::Message(ty.to_owned()); 283 - serde_ipld_dagcbor::to_writer(&mut frame, &hdr).expect("should serialize header"); 284 - serde_ipld_dagcbor::to_writer(&mut frame, msg).expect("should serialize message"); 285 - if let Err(e) = ws.send(Message::binary(frame.clone())).await { 286 - debug!("Firehose client disconnected during backfill: {e}"); 287 - break; 288 - } 289 - // Clear out the frame to begin a new one. 290 - frame.clear(); 291 - } 292 - } 293 - 294 - Ok(ws) 295 - } 296 - 297 - /// Reconnect to upstream relays. 298 - pub(crate) async fn reconnect_relays(client: &Client, config: &AppConfig) { 299 - // Avoid connecting to upstream relays in test mode. 300 - if config.test { 301 - return; 302 - } 303 - 304 - info!("attempting to reconnect to upstream relays"); 305 - for relay in &config.firehose.relays { 306 - let Some(host) = relay.host_str() else { 307 - warn!("relay {} has no host specified", relay); 308 - continue; 309 - }; 310 - 311 - let r = client 312 - .post(format!("https://{host}/xrpc/com.atproto.sync.requestCrawl")) 313 - .json(&serde_json::json!({ 314 - "hostname": format!("https://{}", config.host_name) 315 - })) 316 - .send() 317 - .await; 318 - 319 - let r = match r { 320 - Ok(r) => r, 321 - Err(e) => { 322 - error!("failed to hit upstream relay {host}: {e}"); 323 - continue; 324 - } 325 - }; 326 - 327 - let s = r.status(); 328 - if let Err(e) = r.error_for_status_ref() { 329 - error!("failed to hit upstream relay {host}: {e}"); 330 - } 331 - 332 - let b = r.json::<serde_json::Value>().await; 333 - if let Ok(b) = b { 334 - info!("relay {host}: {} {}", s, b); 335 - } else { 336 - info!("relay {host}: {}", s); 337 - } 338 - } 339 - } 340 - 341 - /// The main entrypoint for the firehose. 342 - /// 343 - /// This will broadcast all updates in this PDS out to anyone who is listening. 344 - /// 345 - /// Reference: <https://atproto.com/specs/sync> 346 - pub(crate) fn spawn( 347 - client: Client, 348 - config: AppConfig, 349 - ) -> (tokio::task::JoinHandle<()>, FirehoseProducer) { 350 - let (tx, mut rx) = tokio::sync::mpsc::channel(1000); 351 - let handle = tokio::spawn(async move { 352 - fn time_since_inception() -> u64 { 353 - chrono::Utc::now() 354 - .timestamp_micros() 355 - .checked_sub(1_743_442_000_000_000) 356 - .expect("should not wrap") 357 - .unsigned_abs() 358 - } 359 - let mut clients: Vec<WebSocket> = Vec::new(); 360 - let mut history = VecDeque::with_capacity(1000); 361 - let mut seq = time_since_inception(); 362 - 363 - loop { 364 - if let Ok(msg) = tokio::time::timeout(Duration::from_secs(30), rx.recv()).await { 365 - match msg { 366 - Some(FirehoseMessage::Broadcast(msg)) => { 367 - let (ty, by) = serialize_message(seq, msg.clone()); 368 - 369 - history.push_back((seq, ty, msg)); 370 - gauge!(FIREHOSE_HISTORY).set( 371 - convert_usize_f64(history.len()).expect("should find history length"), 372 - ); 373 - 374 - info!( 375 - "Broadcasting message {} {} to {} clients", 376 - seq, 377 - ty, 378 - clients.len() 379 - ); 380 - 381 - counter!(FIREHOSE_SEQUENCE).absolute(seq); 382 - let now = time_since_inception(); 383 - if now > seq { 384 - seq = now; 385 - } else { 386 - seq = seq.checked_add(1).expect("should not wrap"); 387 - } 388 - 389 - drop(broadcast_message(&mut clients, Message::binary(by)).await); 390 - } 391 - Some(FirehoseMessage::Connect(ws_cursor)) => { 392 - let (ws, cursor) = *ws_cursor; 393 - match handle_connect(ws, seq, &history, cursor).await { 394 - Ok(r) => { 395 - gauge!(FIREHOSE_LISTENERS).increment(1_i32); 396 - clients.push(r); 397 - } 398 - Err(e) => { 399 - error!("failed to connect new client: {e}"); 400 - } 401 - } 402 - } 403 - // All producers have been destroyed. 404 - None => break, 405 - } 406 - } else { 407 - if clients.is_empty() { 408 - reconnect_relays(&client, &config).await; 409 - } 410 - 411 - let contents = rand::thread_rng() 412 - .sample_iter(rand::distributions::Alphanumeric) 413 - .take(15) 414 - .map(char::from) 415 - .collect::<String>(); 416 - 417 - // Send a websocket ping message. 418 - // Reference: https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API/Writing_WebSocket_servers#pings_and_pongs_the_heartbeat_of_websockets 419 - let message = Message::Ping(axum::body::Bytes::from_owner(contents)); 420 - drop(broadcast_message(&mut clients, message).await); 421 - } 422 - } 423 - }); 424 - 425 - (handle, FirehoseProducer { tx }) 426 - }
+3 -438
src/lib.rs
··· 8 8 mod db; 9 9 mod did; 10 10 pub mod error; 11 - mod firehose; 12 11 mod metrics; 13 - mod mmap; 14 12 mod models; 15 13 mod oauth; 16 - mod plc; 17 14 mod schema; 15 + mod serve; 18 16 mod service_proxy; 19 - #[cfg(test)] 20 - mod tests; 21 17 22 - use account_manager::{AccountManager, SharedAccountManager}; 23 - use anyhow::{Context as _, anyhow}; 24 - use atrium_api::types::string::Did; 25 - use atrium_crypto::keypair::{Export as _, Secp256k1Keypair}; 26 - use auth::AuthenticatedUser; 27 - use axum::{ 28 - Router, 29 - body::Body, 30 - extract::{FromRef, Request, State}, 31 - http::{self, HeaderMap, Response, StatusCode, Uri}, 32 - response::IntoResponse, 33 - routing::get, 34 - }; 35 - use azure_core::credentials::TokenCredential; 36 - use clap::Parser; 37 - use clap_verbosity_flag::{InfoLevel, Verbosity, log::LevelFilter}; 38 - use config::AppConfig; 39 - use db::establish_pool; 40 - use deadpool_diesel::sqlite::Pool; 41 - use diesel::prelude::*; 42 - use diesel_migrations::{EmbeddedMigrations, embed_migrations}; 43 - pub use error::Error; 44 - use figment::{Figment, providers::Format as _}; 45 - use firehose::FirehoseProducer; 46 - use http_cache_reqwest::{CacheMode, HttpCacheOptions, MokaManager}; 47 - use rand::Rng as _; 48 - use rsky_pds::{crawlers::Crawlers, sequencer::Sequencer}; 49 - use serde::{Deserialize, Serialize}; 50 - use service_proxy::service_proxy; 51 - use std::{ 52 - net::{IpAddr, Ipv4Addr, SocketAddr}, 53 - path::PathBuf, 54 - str::FromStr as _, 55 - sync::Arc, 56 - }; 57 - use tokio::{net::TcpListener, sync::RwLock}; 58 - use tower_http::{cors::CorsLayer, trace::TraceLayer}; 59 - use tracing::{info, warn}; 60 - use uuid::Uuid; 61 - 62 - /// The application user agent. Concatenates the package name and version. e.g. `bluepds/0.0.0`. 63 - pub const APP_USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),); 64 - 65 - /// Embedded migrations 66 - pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("./migrations"); 67 - 68 - /// The application-wide result type. 69 - pub type Result<T> = std::result::Result<T, Error>; 70 - /// The reqwest client type with middleware. 71 - pub type Client = reqwest_middleware::ClientWithMiddleware; 72 - 73 - /// The Shared Sequencer which requests crawls from upstream relays and emits events to the firehose. 74 - pub struct SharedSequencer { 75 - /// The sequencer instance. 76 - pub sequencer: RwLock<Sequencer>, 77 - } 78 - 79 - #[expect( 80 - clippy::arbitrary_source_item_ordering, 81 - reason = "serialized data might be structured" 82 - )] 83 - #[derive(Serialize, Deserialize, Debug, Clone)] 84 - /// The key data structure. 85 - struct KeyData { 86 - /// Primary signing key for all repo operations. 87 - skey: Vec<u8>, 88 - /// Primary signing (rotation) key for all PLC operations. 89 - rkey: Vec<u8>, 90 - } 91 - 92 - // FIXME: We should use P256Keypair instead. SecP256K1 is primarily used for cryptocurrencies, 93 - // and the implementations of this algorithm are much more limited as compared to P256. 94 - // 95 - // Reference: https://soatok.blog/2022/05/19/guidance-for-choosing-an-elliptic-curve-signature-algorithm-in-2022/ 96 - #[derive(Clone)] 97 - /// The signing key for PLC/DID operations. 98 - pub struct SigningKey(Arc<Secp256k1Keypair>); 99 - #[derive(Clone)] 100 - /// The rotation key for PLC operations. 101 - pub struct RotationKey(Arc<Secp256k1Keypair>); 102 - 103 - impl std::ops::Deref for SigningKey { 104 - type Target = Secp256k1Keypair; 105 - 106 - fn deref(&self) -> &Self::Target { 107 - &self.0 108 - } 109 - } 110 - 111 - impl SigningKey { 112 - /// Import from a private key. 113 - pub fn import(key: &[u8]) -> Result<Self> { 114 - let key = Secp256k1Keypair::import(key).context("failed to import signing key")?; 115 - Ok(Self(Arc::new(key))) 116 - } 117 - } 118 - 119 - impl std::ops::Deref for RotationKey { 120 - type Target = Secp256k1Keypair; 121 - 122 - fn deref(&self) -> &Self::Target { 123 - &self.0 124 - } 125 - } 126 - 127 - #[derive(Parser, Debug, Clone)] 128 - /// Command line arguments. 129 - pub struct Args { 130 - /// Path to the configuration file 131 - #[arg(short, long, default_value = "default.toml")] 132 - pub config: PathBuf, 133 - /// The verbosity level. 134 - #[command(flatten)] 135 - pub verbosity: Verbosity<InfoLevel>, 136 - } 137 - 138 - /// The actor pools for the database connections. 139 - pub struct ActorPools { 140 - /// The database connection pool for the actor's repository. 141 - pub repo: Pool, 142 - /// The database connection pool for the actor's blobs. 143 - pub blob: Pool, 144 - } 145 - 146 - impl Clone for ActorPools { 147 - fn clone(&self) -> Self { 148 - Self { 149 - repo: self.repo.clone(), 150 - blob: self.blob.clone(), 151 - } 152 - } 153 - } 154 - 155 - #[expect(clippy::arbitrary_source_item_ordering, reason = "arbitrary")] 156 - #[derive(Clone, FromRef)] 157 - pub struct AppState { 158 - /// The application configuration. 159 - pub config: AppConfig, 160 - /// The main database connection pool. Used for common PDS data, like invite codes. 161 - pub db: Pool, 162 - /// Actor-specific database connection pools. Hashed by DID. 163 - pub db_actors: std::collections::HashMap<String, ActorPools>, 164 - 165 - /// The HTTP client with middleware. 166 - pub client: Client, 167 - /// The simple HTTP client. 168 - pub simple_client: reqwest::Client, 169 - /// The firehose producer. 170 - pub sequencer: Arc<SharedSequencer>, 171 - /// The account manager. 172 - pub account_manager: Arc<SharedAccountManager>, 173 - 174 - /// The signing key. 175 - pub signing_key: SigningKey, 176 - /// The rotation key. 177 - pub rotation_key: RotationKey, 178 - } 18 + pub use serve::run; 179 19 180 20 /// The index (/) route. 181 - async fn index() -> impl IntoResponse { 21 + async fn index() -> impl axum::response::IntoResponse { 182 22 r" 183 23 __ __ 184 24 /\ \__ /\ \__ ··· 199 39 Protocol: https://atproto.com 200 40 " 201 41 } 202 - 203 - /// The main application entry point. 204 - #[expect( 205 - clippy::cognitive_complexity, 206 - clippy::too_many_lines, 207 - unused_qualifications, 208 - reason = "main function has high complexity" 209 - )] 210 - pub async fn run() -> anyhow::Result<()> { 211 - let args = Args::parse(); 212 - 213 - // Set up trace logging to console and account for the user-provided verbosity flag. 214 - if args.verbosity.log_level_filter() != LevelFilter::Off { 215 - let lvl = match args.verbosity.log_level_filter() { 216 - LevelFilter::Error => tracing::Level::ERROR, 217 - LevelFilter::Warn => tracing::Level::WARN, 218 - LevelFilter::Info | LevelFilter::Off => tracing::Level::INFO, 219 - LevelFilter::Debug => tracing::Level::DEBUG, 220 - LevelFilter::Trace => tracing::Level::TRACE, 221 - }; 222 - tracing_subscriber::fmt().with_max_level(lvl).init(); 223 - } 224 - 225 - if !args.config.exists() { 226 - // Throw up a warning if the config file does not exist. 227 - // 228 - // This is not fatal because users can specify all configuration settings via 229 - // the environment, but the most likely scenario here is that a user accidentally 230 - // omitted the config file for some reason (e.g. forgot to mount it into Docker). 231 - warn!( 232 - "configuration file {} does not exist", 233 - args.config.display() 234 - ); 235 - } 236 - 237 - // Read and parse the user-provided configuration. 238 - let config: AppConfig = Figment::new() 239 - .admerge(figment::providers::Toml::file(args.config)) 240 - .admerge(figment::providers::Env::prefixed("BLUEPDS_")) 241 - .extract() 242 - .context("failed to load configuration")?; 243 - 244 - if config.test { 245 - warn!("BluePDS starting up in TEST mode."); 246 - warn!("This means the application will not federate with the rest of the network."); 247 - warn!( 248 - "If you want to turn this off, either set `test` to false in the config or define `BLUEPDS_TEST = false`" 249 - ); 250 - } 251 - 252 - // Initialize metrics reporting. 253 - metrics::setup(config.metrics.as_ref()).context("failed to set up metrics exporter")?; 254 - 255 - // Create a reqwest client that will be used for all outbound requests. 256 - let simple_client = reqwest::Client::builder() 257 - .user_agent(APP_USER_AGENT) 258 - .build() 259 - .context("failed to build requester client")?; 260 - let client = reqwest_middleware::ClientBuilder::new(simple_client.clone()) 261 - .with(http_cache_reqwest::Cache(http_cache_reqwest::HttpCache { 262 - mode: CacheMode::Default, 263 - manager: MokaManager::default(), 264 - options: HttpCacheOptions::default(), 265 - })) 266 - .build(); 267 - 268 - tokio::fs::create_dir_all(&config.key.parent().context("should have parent")?) 269 - .await 270 - .context("failed to create key directory")?; 271 - 272 - // Check if crypto keys exist. If not, create new ones. 273 - let (skey, rkey) = if let Ok(f) = std::fs::File::open(&config.key) { 274 - let keys: KeyData = serde_ipld_dagcbor::from_reader(std::io::BufReader::new(f)) 275 - .context("failed to deserialize crypto keys")?; 276 - 277 - let skey = Secp256k1Keypair::import(&keys.skey).context("failed to import signing key")?; 278 - let rkey = Secp256k1Keypair::import(&keys.rkey).context("failed to import rotation key")?; 279 - 280 - (SigningKey(Arc::new(skey)), RotationKey(Arc::new(rkey))) 281 - } else { 282 - info!("signing keys not found, generating new ones"); 283 - 284 - let skey = Secp256k1Keypair::create(&mut rand::thread_rng()); 285 - let rkey = Secp256k1Keypair::create(&mut rand::thread_rng()); 286 - 287 - let keys = KeyData { 288 - skey: skey.export(), 289 - rkey: rkey.export(), 290 - }; 291 - 292 - let mut f = std::fs::File::create(&config.key).context("failed to create key file")?; 293 - serde_ipld_dagcbor::to_writer(&mut f, &keys).context("failed to serialize crypto keys")?; 294 - 295 - (SigningKey(Arc::new(skey)), RotationKey(Arc::new(rkey))) 296 - }; 297 - 298 - tokio::fs::create_dir_all(&config.repo.path).await?; 299 - tokio::fs::create_dir_all(&config.plc.path).await?; 300 - tokio::fs::create_dir_all(&config.blob.path).await?; 301 - 302 - // Create a database connection manager and pool for the main database. 303 - let pool = 304 - establish_pool(&config.db).context("failed to establish database connection pool")?; 305 - // Create a dictionary of database connection pools for each actor. 306 - let mut actor_pools = std::collections::HashMap::new(); 307 - // let mut actor_blob_pools = std::collections::HashMap::new(); 308 - // We'll determine actors by looking in the data/repo dir for .db files. 309 - let mut actor_dbs = tokio::fs::read_dir(&config.repo.path) 310 - .await 311 - .context("failed to read repo directory")?; 312 - while let Some(entry) = actor_dbs 313 - .next_entry() 314 - .await 315 - .context("failed to read repo dir")? 316 - { 317 - let path = entry.path(); 318 - if path.extension().and_then(|s| s.to_str()) == Some("db") { 319 - let did_path = path 320 - .file_stem() 321 - .and_then(|s| s.to_str()) 322 - .context("failed to get actor DID")?; 323 - let did = Did::from_str(&format!("did:plc:{}", did_path)) 324 - .expect("should be able to parse actor DID"); 325 - 326 - // Create a new database connection manager and pool for the actor. 327 - // The path for the SQLite connection needs to look like "sqlite://data/repo/<actor>.db" 328 - let path_repo = format!("sqlite://{}", did_path); 329 - let actor_repo_pool = 330 - establish_pool(&path_repo).context("failed to create database connection pool")?; 331 - // Create a new database connection manager and pool for the actor blobs. 332 - // The path for the SQLite connection needs to look like "sqlite://data/blob/<actor>.db" 333 - let path_blob = path_repo.replace("repo", "blob"); 334 - let actor_blob_pool = 335 - establish_pool(&path_blob).context("failed to create database connection pool")?; 336 - drop(actor_pools.insert( 337 - did.to_string(), 338 - ActorPools { 339 - repo: actor_repo_pool, 340 - blob: actor_blob_pool, 341 - }, 342 - )); 343 - } 344 - } 345 - // Apply pending migrations 346 - // let conn = pool.get().await?; 347 - // conn.run_pending_migrations(MIGRATIONS) 348 - // .expect("should be able to run migrations"); 349 - 350 - let hostname = config.host_name.clone(); 351 - let crawlers: Vec<String> = config 352 - .firehose 353 - .relays 354 - .iter() 355 - .map(|s| s.to_string()) 356 - .collect(); 357 - let sequencer = Arc::new(SharedSequencer { 358 - sequencer: RwLock::new(Sequencer::new( 359 - Crawlers::new(hostname, crawlers.clone()), 360 - None, 361 - )), 362 - }); 363 - let account_manager = SharedAccountManager { 364 - account_manager: RwLock::new(AccountManager::new(pool.clone())), 365 - }; 366 - 367 - let addr = config 368 - .listen_address 369 - .unwrap_or(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8000)); 370 - 371 - let app = Router::new() 372 - .route("/", get(index)) 373 - .merge(oauth::routes()) 374 - .nest( 375 - "/xrpc", 376 - apis::routes() 377 - .merge(actor_endpoints::routes()) 378 - .fallback(service_proxy), 379 - ) 380 - // .layer(RateLimitLayer::new(30, Duration::from_secs(30))) 381 - .layer(CorsLayer::permissive()) 382 - .layer(TraceLayer::new_for_http()) 383 - .with_state(AppState { 384 - config: config.clone(), 385 - db: pool.clone(), 386 - db_actors: actor_pools.clone(), 387 - client: client.clone(), 388 - simple_client, 389 - sequencer: sequencer.clone(), 390 - account_manager: Arc::new(account_manager), 391 - signing_key: skey, 392 - rotation_key: rkey, 393 - }); 394 - 395 - info!("listening on {addr}"); 396 - info!("connect to: http://127.0.0.1:{}", addr.port()); 397 - 398 - // Determine whether or not this was the first startup (i.e. no accounts exist and no invite codes were created). 399 - // If so, create an invite code and share it via the console. 400 - let conn = pool.get().await.context("failed to get db connection")?; 401 - 402 - #[derive(QueryableByName)] 403 - struct TotalCount { 404 - #[diesel(sql_type = diesel::sql_types::Integer)] 405 - total_count: i32, 406 - } 407 - 408 - let result = conn.interact(move |conn| { 409 - diesel::sql_query( 410 - "SELECT (SELECT COUNT(*) FROM account) + (SELECT COUNT(*) FROM invite_code) AS total_count", 411 - ) 412 - .get_result::<TotalCount>(conn) 413 - }) 414 - .await 415 - .expect("should be able to query database")?; 416 - 417 - let c = result.total_count; 418 - 419 - #[expect(clippy::print_stdout)] 420 - if c == 0 { 421 - let uuid = Uuid::new_v4().to_string(); 422 - 423 - use crate::models::pds as models; 424 - use crate::schema::pds::invite_code::dsl as InviteCode; 425 - let uuid_clone = uuid.clone(); 426 - drop( 427 - conn.interact(move |conn| { 428 - diesel::insert_into(InviteCode::invite_code) 429 - .values(models::InviteCode { 430 - code: uuid_clone, 431 - available_uses: 1, 432 - disabled: 0, 433 - for_account: "None".to_owned(), 434 - created_by: "None".to_owned(), 435 - created_at: "None".to_owned(), 436 - }) 437 - .execute(conn) 438 - .context("failed to create new invite code") 439 - }) 440 - .await 441 - .expect("should be able to create invite code"), 442 - ); 443 - 444 - // N.B: This is a sensitive message, so we're bypassing `tracing` here and 445 - // logging it directly to console. 446 - println!("====================================="); 447 - println!(" FIRST STARTUP "); 448 - println!("====================================="); 449 - println!("Use this code to create an account:"); 450 - println!("{uuid}"); 451 - println!("====================================="); 452 - } 453 - 454 - let listener = TcpListener::bind(&addr) 455 - .await 456 - .context("failed to bind address")?; 457 - 458 - // Serve the app, and request crawling from upstream relays. 459 - let serve = tokio::spawn(async move { 460 - axum::serve(listener, app.into_make_service()) 461 - .await 462 - .context("failed to serve app") 463 - }); 464 - 465 - // Now that the app is live, request a crawl from upstream relays. 466 - let mut background_sequencer = sequencer.sequencer.write().await.clone(); 467 - drop(tokio::spawn( 468 - async move { background_sequencer.start().await }, 469 - )); 470 - 471 - serve 472 - .await 473 - .map_err(Into::into) 474 - .and_then(|r| r) 475 - .context("failed to serve app") 476 - }
+1 -3
src/main.rs
··· 1 1 //! BluePDS binary entry point. 2 2 3 3 use anyhow::Context as _; 4 - use clap::Parser; 5 4 6 5 #[tokio::main(flavor = "multi_thread")] 7 6 async fn main() -> anyhow::Result<()> { 8 - // Parse command line arguments and call into the library's run function 9 7 bluepds::run().await.context("failed to run application") 10 - } 8 + }
-274
src/mmap.rs
··· 1 - #![allow(clippy::arbitrary_source_item_ordering)] 2 - use std::io::{ErrorKind, Read as _, Seek as _, Write as _}; 3 - 4 - #[cfg(unix)] 5 - use std::os::fd::AsRawFd as _; 6 - #[cfg(windows)] 7 - use std::os::windows::io::AsRawHandle; 8 - 9 - use memmap2::{MmapMut, MmapOptions}; 10 - 11 - pub(crate) struct MappedFile { 12 - /// The underlying file handle. 13 - file: std::fs::File, 14 - /// The length of the file. 15 - len: u64, 16 - /// The mapped memory region. 17 - map: MmapMut, 18 - /// Our current offset into the file. 19 - off: u64, 20 - } 21 - 22 - impl MappedFile { 23 - pub(crate) fn new(mut f: std::fs::File) -> std::io::Result<Self> { 24 - let len = f.seek(std::io::SeekFrom::End(0))?; 25 - 26 - #[cfg(windows)] 27 - let raw = f.as_raw_handle(); 28 - #[cfg(unix)] 29 - let raw = f.as_raw_fd(); 30 - 31 - #[expect(unsafe_code)] 32 - Ok(Self { 33 - // SAFETY: 34 - // All file-backed memory map constructors are marked \ 35 - // unsafe because of the potential for Undefined Behavior (UB) \ 36 - // using the map if the underlying file is subsequently modified, in or out of process. 37 - map: unsafe { MmapOptions::new().map_mut(raw)? }, 38 - file: f, 39 - len, 40 - off: 0, 41 - }) 42 - } 43 - 44 - /// Resize the memory-mapped file. This will reallocate the memory mapping. 45 - #[expect(unsafe_code)] 46 - fn resize(&mut self, len: u64) -> std::io::Result<()> { 47 - // Resize the file. 48 - self.file.set_len(len)?; 49 - 50 - #[cfg(windows)] 51 - let raw = self.file.as_raw_handle(); 52 - #[cfg(unix)] 53 - let raw = self.file.as_raw_fd(); 54 - 55 - // SAFETY: 56 - // All file-backed memory map constructors are marked \ 57 - // unsafe because of the potential for Undefined Behavior (UB) \ 58 - // using the map if the underlying file is subsequently modified, in or out of process. 59 - self.map = unsafe { MmapOptions::new().map_mut(raw)? }; 60 - self.len = len; 61 - 62 - Ok(()) 63 - } 64 - } 65 - 66 - impl std::io::Read for MappedFile { 67 - fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> { 68 - if self.off == self.len { 69 - // If we're at EOF, return an EOF error code. `Ok(0)` tends to trip up some implementations. 70 - return Err(std::io::Error::new(ErrorKind::UnexpectedEof, "eof")); 71 - } 72 - 73 - // Calculate the number of bytes we're going to read. 74 - let remaining_bytes = self.len.saturating_sub(self.off); 75 - let buf_len = u64::try_from(buf.len()).unwrap_or(u64::MAX); 76 - let len = usize::try_from(std::cmp::min(remaining_bytes, buf_len)).unwrap_or(usize::MAX); 77 - 78 - let off = usize::try_from(self.off).map_err(|e| { 79 - std::io::Error::new( 80 - ErrorKind::InvalidInput, 81 - format!("offset too large for this platform: {e}"), 82 - ) 83 - })?; 84 - 85 - if let (Some(dest), Some(src)) = ( 86 - buf.get_mut(..len), 87 - self.map.get(off..off.saturating_add(len)), 88 - ) { 89 - dest.copy_from_slice(src); 90 - self.off = self.off.saturating_add(u64::try_from(len).unwrap_or(0)); 91 - Ok(len) 92 - } else { 93 - Err(std::io::Error::new( 94 - ErrorKind::InvalidInput, 95 - "invalid buffer range", 96 - )) 97 - } 98 - } 99 - } 100 - 101 - impl std::io::Write for MappedFile { 102 - fn flush(&mut self) -> std::io::Result<()> { 103 - // This is done by the system. 104 - Ok(()) 105 - } 106 - fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> { 107 - // Determine if we need to resize the file. 108 - let buf_len = u64::try_from(buf.len()).map_err(|e| { 109 - std::io::Error::new( 110 - ErrorKind::InvalidInput, 111 - format!("buffer length too large for this platform: {e}"), 112 - ) 113 - })?; 114 - 115 - if self.off.saturating_add(buf_len) >= self.len { 116 - self.resize(self.off.saturating_add(buf_len))?; 117 - } 118 - 119 - let off = usize::try_from(self.off).map_err(|e| { 120 - std::io::Error::new( 121 - ErrorKind::InvalidInput, 122 - format!("offset too large for this platform: {e}"), 123 - ) 124 - })?; 125 - let len = buf.len(); 126 - 127 - if let Some(dest) = self.map.get_mut(off..off.saturating_add(len)) { 128 - dest.copy_from_slice(buf); 129 - self.off = self.off.saturating_add(buf_len); 130 - Ok(len) 131 - } else { 132 - Err(std::io::Error::new( 133 - ErrorKind::InvalidInput, 134 - "invalid buffer range", 135 - )) 136 - } 137 - } 138 - } 139 - 140 - impl std::io::Seek for MappedFile { 141 - fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result<u64> { 142 - let off = match pos { 143 - std::io::SeekFrom::Start(i) => i, 144 - std::io::SeekFrom::End(i) => { 145 - if i <= 0 { 146 - // If i is negative or zero, we're seeking backwards from the end 147 - // or exactly at the end 148 - self.len.saturating_sub(i.unsigned_abs()) 149 - } else { 150 - // If i is positive, we're seeking beyond the end, which is allowed 151 - // but requires extending the file 152 - self.len.saturating_add(i.unsigned_abs()) 153 - } 154 - } 155 - std::io::SeekFrom::Current(i) => { 156 - if i >= 0 { 157 - self.off.saturating_add(i.unsigned_abs()) 158 - } else { 159 - self.off.saturating_sub(i.unsigned_abs()) 160 - } 161 - } 162 - }; 163 - 164 - // If the offset is beyond EOF, extend the file to the new size. 165 - if off > self.len { 166 - self.resize(off)?; 167 - } 168 - 169 - self.off = off; 170 - Ok(off) 171 - } 172 - } 173 - 174 - impl tokio::io::AsyncRead for MappedFile { 175 - fn poll_read( 176 - mut self: std::pin::Pin<&mut Self>, 177 - _cx: &mut std::task::Context<'_>, 178 - buf: &mut tokio::io::ReadBuf<'_>, 179 - ) -> std::task::Poll<std::io::Result<()>> { 180 - let wbuf = buf.initialize_unfilled(); 181 - let len = wbuf.len(); 182 - 183 - std::task::Poll::Ready(match self.read(wbuf) { 184 - Ok(_) => { 185 - buf.advance(len); 186 - Ok(()) 187 - } 188 - Err(e) => Err(e), 189 - }) 190 - } 191 - } 192 - 193 - impl tokio::io::AsyncWrite for MappedFile { 194 - fn poll_flush( 195 - self: std::pin::Pin<&mut Self>, 196 - _cx: &mut std::task::Context<'_>, 197 - ) -> std::task::Poll<Result<(), std::io::Error>> { 198 - std::task::Poll::Ready(Ok(())) 199 - } 200 - 201 - fn poll_shutdown( 202 - self: std::pin::Pin<&mut Self>, 203 - _cx: &mut std::task::Context<'_>, 204 - ) -> std::task::Poll<Result<(), std::io::Error>> { 205 - std::task::Poll::Ready(Ok(())) 206 - } 207 - 208 - fn poll_write( 209 - mut self: std::pin::Pin<&mut Self>, 210 - _cx: &mut std::task::Context<'_>, 211 - buf: &[u8], 212 - ) -> std::task::Poll<Result<usize, std::io::Error>> { 213 - std::task::Poll::Ready(self.write(buf)) 214 - } 215 - } 216 - 217 - impl tokio::io::AsyncSeek for MappedFile { 218 - fn poll_complete( 219 - self: std::pin::Pin<&mut Self>, 220 - _cx: &mut std::task::Context<'_>, 221 - ) -> std::task::Poll<std::io::Result<u64>> { 222 - std::task::Poll::Ready(Ok(self.off)) 223 - } 224 - 225 - fn start_seek( 226 - mut self: std::pin::Pin<&mut Self>, 227 - position: std::io::SeekFrom, 228 - ) -> std::io::Result<()> { 229 - self.seek(position).map(|_p| ()) 230 - } 231 - } 232 - 233 - #[cfg(test)] 234 - mod test { 235 - use rand::Rng as _; 236 - use std::io::Write as _; 237 - 238 - use super::*; 239 - 240 - #[test] 241 - fn basic_rw() { 242 - let tmp = std::env::temp_dir().join( 243 - rand::thread_rng() 244 - .sample_iter(rand::distributions::Alphanumeric) 245 - .take(10) 246 - .map(char::from) 247 - .collect::<String>(), 248 - ); 249 - 250 - let mut m = MappedFile::new( 251 - std::fs::File::options() 252 - .create(true) 253 - .truncate(true) 254 - .read(true) 255 - .write(true) 256 - .open(&tmp) 257 - .expect("Failed to open temporary file"), 258 - ) 259 - .expect("Failed to create MappedFile"); 260 - 261 - m.write_all(b"abcd123").expect("Failed to write data"); 262 - let _: u64 = m 263 - .seek(std::io::SeekFrom::Start(0)) 264 - .expect("Failed to seek to start"); 265 - 266 - let mut buf = [0_u8; 7]; 267 - m.read_exact(&mut buf).expect("Failed to read data"); 268 - 269 - assert_eq!(&buf, b"abcd123"); 270 - 271 - drop(m); 272 - std::fs::remove_file(tmp).expect("Failed to remove temporary file"); 273 - } 274 - }
+3 -1
src/oauth.rs
··· 1 1 //! OAuth endpoints 2 2 #![allow(unnameable_types, unused_qualifications)] 3 + use crate::config::AppConfig; 4 + use crate::error::Error; 3 5 use crate::metrics::AUTH_FAILED; 4 - use crate::{AppConfig, AppState, Client, Error, Result, SigningKey}; 6 + use crate::serve::{AppState, Client, Result, SigningKey}; 5 7 use anyhow::{Context as _, anyhow}; 6 8 use argon2::{Argon2, PasswordHash, PasswordVerifier as _}; 7 9 use atrium_crypto::keypair::Did as _;
-114
src/plc.rs
··· 1 - //! PLC operations. 2 - use std::collections::HashMap; 3 - 4 - use anyhow::{Context as _, bail}; 5 - use base64::Engine as _; 6 - use serde::{Deserialize, Serialize}; 7 - use tracing::debug; 8 - 9 - use crate::{Client, RotationKey}; 10 - 11 - /// The URL of the public PLC directory. 12 - const PLC_DIRECTORY: &str = "https://plc.directory/"; 13 - 14 - #[derive(Debug, Deserialize, Serialize, Clone)] 15 - #[serde(rename_all = "camelCase", tag = "type")] 16 - /// A PLC service. 17 - pub(crate) enum PlcService { 18 - #[serde(rename = "AtprotoPersonalDataServer")] 19 - /// A personal data server. 20 - Pds { 21 - /// The URL of the PDS. 22 - endpoint: String, 23 - }, 24 - } 25 - 26 - #[expect( 27 - clippy::arbitrary_source_item_ordering, 28 - reason = "serialized data might be structured" 29 - )] 30 - #[derive(Debug, Deserialize, Serialize, Clone)] 31 - #[serde(rename_all = "camelCase")] 32 - pub(crate) struct PlcOperation { 33 - #[serde(rename = "type")] 34 - pub typ: String, 35 - pub rotation_keys: Vec<String>, 36 - pub verification_methods: HashMap<String, String>, 37 - pub also_known_as: Vec<String>, 38 - pub services: HashMap<String, PlcService>, 39 - pub prev: Option<String>, 40 - } 41 - 42 - impl PlcOperation { 43 - /// Sign an operation with the provided signature. 44 - pub(crate) fn sign(self, sig: Vec<u8>) -> SignedPlcOperation { 45 - SignedPlcOperation { 46 - typ: self.typ, 47 - rotation_keys: self.rotation_keys, 48 - verification_methods: self.verification_methods, 49 - also_known_as: self.also_known_as, 50 - services: self.services, 51 - prev: self.prev, 52 - sig: base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(sig), 53 - } 54 - } 55 - } 56 - 57 - #[expect( 58 - clippy::arbitrary_source_item_ordering, 59 - reason = "serialized data might be structured" 60 - )] 61 - #[derive(Debug, Deserialize, Serialize, Clone)] 62 - #[serde(rename_all = "camelCase")] 63 - /// A signed PLC operation. 64 - pub(crate) struct SignedPlcOperation { 65 - #[serde(rename = "type")] 66 - pub typ: String, 67 - pub rotation_keys: Vec<String>, 68 - pub verification_methods: HashMap<String, String>, 69 - pub also_known_as: Vec<String>, 70 - pub services: HashMap<String, PlcService>, 71 - pub prev: Option<String>, 72 - pub sig: String, 73 - } 74 - 75 - pub(crate) fn sign_op(rkey: &RotationKey, op: PlcOperation) -> anyhow::Result<SignedPlcOperation> { 76 - let bytes = serde_ipld_dagcbor::to_vec(&op).context("failed to encode op")?; 77 - let bytes = rkey.sign(&bytes).context("failed to sign op")?; 78 - 79 - Ok(op.sign(bytes)) 80 - } 81 - 82 - /// Submit a PLC operation to the public directory. 83 - pub(crate) async fn submit( 84 - client: &Client, 85 - did: &str, 86 - op: &SignedPlcOperation, 87 - ) -> anyhow::Result<()> { 88 - debug!( 89 - "submitting {} {}", 90 - did, 91 - serde_json::to_string(&op).context("should serialize")? 92 - ); 93 - 94 - let res = client 95 - .post(format!("{PLC_DIRECTORY}{did}")) 96 - .json(&op) 97 - .send() 98 - .await 99 - .context("failed to send directory request")?; 100 - 101 - if res.status().is_success() { 102 - Ok(()) 103 - } else { 104 - let e = res 105 - .json::<serde_json::Value>() 106 - .await 107 - .context("failed to read error response")?; 108 - 109 - bail!( 110 - "error from PLC directory: {}", 111 - serde_json::to_string(&e).context("should serialize")? 112 - ); 113 - } 114 - }
+415
src/serve.rs
··· 1 + use super::account_manager::{AccountManager, SharedAccountManager}; 2 + use super::config::AppConfig; 3 + use super::db::establish_pool; 4 + pub use super::error::Error; 5 + use super::service_proxy::service_proxy; 6 + use anyhow::Context as _; 7 + use atrium_api::types::string::Did; 8 + use atrium_crypto::keypair::{Export as _, Secp256k1Keypair}; 9 + use axum::{Router, extract::FromRef, routing::get}; 10 + use clap::Parser; 11 + use clap_verbosity_flag::{InfoLevel, Verbosity, log::LevelFilter}; 12 + use deadpool_diesel::sqlite::Pool; 13 + use diesel::prelude::*; 14 + use diesel_migrations::{EmbeddedMigrations, embed_migrations}; 15 + use figment::{Figment, providers::Format as _}; 16 + use http_cache_reqwest::{CacheMode, HttpCacheOptions, MokaManager}; 17 + use rsky_pds::{crawlers::Crawlers, sequencer::Sequencer}; 18 + use serde::{Deserialize, Serialize}; 19 + use std::{ 20 + net::{IpAddr, Ipv4Addr, SocketAddr}, 21 + path::PathBuf, 22 + str::FromStr as _, 23 + sync::Arc, 24 + }; 25 + use tokio::{net::TcpListener, sync::RwLock}; 26 + use tower_http::{cors::CorsLayer, trace::TraceLayer}; 27 + use tracing::{info, warn}; 28 + use uuid::Uuid; 29 + 30 + /// The application user agent. Concatenates the package name and version. e.g. `bluepds/0.0.0`. 31 + pub const APP_USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),); 32 + 33 + /// Embedded migrations 34 + pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("./migrations"); 35 + 36 + /// The application-wide result type. 37 + pub type Result<T> = std::result::Result<T, Error>; 38 + /// The reqwest client type with middleware. 39 + pub type Client = reqwest_middleware::ClientWithMiddleware; 40 + 41 + /// The Shared Sequencer which requests crawls from upstream relays and emits events to the firehose. 42 + pub struct SharedSequencer { 43 + /// The sequencer instance. 44 + pub sequencer: RwLock<Sequencer>, 45 + } 46 + 47 + #[expect( 48 + clippy::arbitrary_source_item_ordering, 49 + reason = "serialized data might be structured" 50 + )] 51 + #[derive(Serialize, Deserialize, Debug, Clone)] 52 + /// The key data structure. 53 + struct KeyData { 54 + /// Primary signing key for all repo operations. 55 + skey: Vec<u8>, 56 + /// Primary signing (rotation) key for all PLC operations. 57 + rkey: Vec<u8>, 58 + } 59 + 60 + // FIXME: We should use P256Keypair instead. SecP256K1 is primarily used for cryptocurrencies, 61 + // and the implementations of this algorithm are much more limited as compared to P256. 62 + // 63 + // Reference: https://soatok.blog/2022/05/19/guidance-for-choosing-an-elliptic-curve-signature-algorithm-in-2022/ 64 + #[derive(Clone)] 65 + /// The signing key for PLC/DID operations. 66 + pub struct SigningKey(Arc<Secp256k1Keypair>); 67 + #[derive(Clone)] 68 + /// The rotation key for PLC operations. 69 + pub struct RotationKey(Arc<Secp256k1Keypair>); 70 + 71 + impl std::ops::Deref for SigningKey { 72 + type Target = Secp256k1Keypair; 73 + 74 + fn deref(&self) -> &Self::Target { 75 + &self.0 76 + } 77 + } 78 + 79 + impl SigningKey { 80 + /// Import from a private key. 81 + pub fn import(key: &[u8]) -> Result<Self> { 82 + let key = Secp256k1Keypair::import(key).context("failed to import signing key")?; 83 + Ok(Self(Arc::new(key))) 84 + } 85 + } 86 + 87 + impl std::ops::Deref for RotationKey { 88 + type Target = Secp256k1Keypair; 89 + 90 + fn deref(&self) -> &Self::Target { 91 + &self.0 92 + } 93 + } 94 + 95 + #[derive(Parser, Debug, Clone)] 96 + /// Command line arguments. 97 + pub struct Args { 98 + /// Path to the configuration file 99 + #[arg(short, long, default_value = "default.toml")] 100 + pub config: PathBuf, 101 + /// The verbosity level. 102 + #[command(flatten)] 103 + pub verbosity: Verbosity<InfoLevel>, 104 + } 105 + 106 + /// The actor pools for the database connections. 107 + pub struct ActorStorage { 108 + /// The database connection pool for the actor's repository. 109 + pub repo: Pool, 110 + /// The file storage path for the actor's blobs. 111 + pub blob: PathBuf, 112 + } 113 + 114 + impl Clone for ActorStorage { 115 + fn clone(&self) -> Self { 116 + Self { 117 + repo: self.repo.clone(), 118 + blob: self.blob.clone(), 119 + } 120 + } 121 + } 122 + 123 + #[expect(clippy::arbitrary_source_item_ordering, reason = "arbitrary")] 124 + #[derive(Clone, FromRef)] 125 + /// The application state, shared across all routes. 126 + pub struct AppState { 127 + /// The application configuration. 128 + pub(crate) config: AppConfig, 129 + /// The main database connection pool. Used for common PDS data, like invite codes. 130 + pub db: Pool, 131 + /// Actor-specific database connection pools. Hashed by DID. 132 + pub db_actors: std::collections::HashMap<String, ActorStorage>, 133 + 134 + /// The HTTP client with middleware. 135 + pub client: Client, 136 + /// The simple HTTP client. 137 + pub simple_client: reqwest::Client, 138 + /// The firehose producer. 139 + pub sequencer: Arc<SharedSequencer>, 140 + /// The account manager. 141 + pub account_manager: Arc<SharedAccountManager>, 142 + 143 + /// The signing key. 144 + pub signing_key: SigningKey, 145 + /// The rotation key. 146 + pub rotation_key: RotationKey, 147 + } 148 + 149 + /// The main application entry point. 150 + #[expect( 151 + clippy::cognitive_complexity, 152 + clippy::too_many_lines, 153 + unused_qualifications, 154 + reason = "main function has high complexity" 155 + )] 156 + pub async fn run() -> anyhow::Result<()> { 157 + let args = Args::parse(); 158 + 159 + // Set up trace logging to console and account for the user-provided verbosity flag. 160 + if args.verbosity.log_level_filter() != LevelFilter::Off { 161 + let lvl = match args.verbosity.log_level_filter() { 162 + LevelFilter::Error => tracing::Level::ERROR, 163 + LevelFilter::Warn => tracing::Level::WARN, 164 + LevelFilter::Info | LevelFilter::Off => tracing::Level::INFO, 165 + LevelFilter::Debug => tracing::Level::DEBUG, 166 + LevelFilter::Trace => tracing::Level::TRACE, 167 + }; 168 + tracing_subscriber::fmt().with_max_level(lvl).init(); 169 + } 170 + 171 + if !args.config.exists() { 172 + // Throw up a warning if the config file does not exist. 173 + // 174 + // This is not fatal because users can specify all configuration settings via 175 + // the environment, but the most likely scenario here is that a user accidentally 176 + // omitted the config file for some reason (e.g. forgot to mount it into Docker). 177 + warn!( 178 + "configuration file {} does not exist", 179 + args.config.display() 180 + ); 181 + } 182 + 183 + // Read and parse the user-provided configuration. 184 + let config: AppConfig = Figment::new() 185 + .admerge(figment::providers::Toml::file(args.config)) 186 + .admerge(figment::providers::Env::prefixed("BLUEPDS_")) 187 + .extract() 188 + .context("failed to load configuration")?; 189 + 190 + if config.test { 191 + warn!("BluePDS starting up in TEST mode."); 192 + warn!("This means the application will not federate with the rest of the network."); 193 + warn!( 194 + "If you want to turn this off, either set `test` to false in the config or define `BLUEPDS_TEST = false`" 195 + ); 196 + } 197 + 198 + // Initialize metrics reporting. 199 + super::metrics::setup(config.metrics.as_ref()).context("failed to set up metrics exporter")?; 200 + 201 + // Create a reqwest client that will be used for all outbound requests. 202 + let simple_client = reqwest::Client::builder() 203 + .user_agent(APP_USER_AGENT) 204 + .build() 205 + .context("failed to build requester client")?; 206 + let client = reqwest_middleware::ClientBuilder::new(simple_client.clone()) 207 + .with(http_cache_reqwest::Cache(http_cache_reqwest::HttpCache { 208 + mode: CacheMode::Default, 209 + manager: MokaManager::default(), 210 + options: HttpCacheOptions::default(), 211 + })) 212 + .build(); 213 + 214 + tokio::fs::create_dir_all(&config.key.parent().context("should have parent")?) 215 + .await 216 + .context("failed to create key directory")?; 217 + 218 + // Check if crypto keys exist. If not, create new ones. 219 + let (skey, rkey) = if let Ok(f) = std::fs::File::open(&config.key) { 220 + let keys: KeyData = serde_ipld_dagcbor::from_reader(std::io::BufReader::new(f)) 221 + .context("failed to deserialize crypto keys")?; 222 + 223 + let skey = Secp256k1Keypair::import(&keys.skey).context("failed to import signing key")?; 224 + let rkey = Secp256k1Keypair::import(&keys.rkey).context("failed to import rotation key")?; 225 + 226 + (SigningKey(Arc::new(skey)), RotationKey(Arc::new(rkey))) 227 + } else { 228 + info!("signing keys not found, generating new ones"); 229 + 230 + let skey = Secp256k1Keypair::create(&mut rand::thread_rng()); 231 + let rkey = Secp256k1Keypair::create(&mut rand::thread_rng()); 232 + 233 + let keys = KeyData { 234 + skey: skey.export(), 235 + rkey: rkey.export(), 236 + }; 237 + 238 + let mut f = std::fs::File::create(&config.key).context("failed to create key file")?; 239 + serde_ipld_dagcbor::to_writer(&mut f, &keys).context("failed to serialize crypto keys")?; 240 + 241 + (SigningKey(Arc::new(skey)), RotationKey(Arc::new(rkey))) 242 + }; 243 + 244 + tokio::fs::create_dir_all(&config.repo.path).await?; 245 + tokio::fs::create_dir_all(&config.plc.path).await?; 246 + tokio::fs::create_dir_all(&config.blob.path).await?; 247 + 248 + // Create a database connection manager and pool for the main database. 249 + let pool = 250 + establish_pool(&config.db).context("failed to establish database connection pool")?; 251 + 252 + // Create a dictionary of database connection pools for each actor. 253 + let mut actor_pools = std::collections::HashMap::new(); 254 + // We'll determine actors by looking in the data/repo dir for .db files. 255 + let mut actor_dbs = tokio::fs::read_dir(&config.repo.path) 256 + .await 257 + .context("failed to read repo directory")?; 258 + while let Some(entry) = actor_dbs 259 + .next_entry() 260 + .await 261 + .context("failed to read repo dir")? 262 + { 263 + let path = entry.path(); 264 + if path.extension().and_then(|s| s.to_str()) == Some("db") { 265 + let actor_repo_pool = establish_pool(&format!("sqlite://{}", path.display())) 266 + .context("failed to create database connection pool")?; 267 + 268 + let did = Did::from_str(&format!( 269 + "did:plc:{}", 270 + path.file_stem() 271 + .and_then(|s| s.to_str()) 272 + .context("failed to get actor DID")? 273 + )) 274 + .expect("should be able to parse actor DID") 275 + .to_string(); 276 + let blob_path = config.blob.path.to_path_buf(); 277 + let actor_storage = ActorStorage { 278 + repo: actor_repo_pool, 279 + blob: blob_path.clone(), 280 + }; 281 + drop(actor_pools.insert(did, actor_storage)); 282 + } 283 + } 284 + // Apply pending migrations 285 + // let conn = pool.get().await?; 286 + // conn.run_pending_migrations(MIGRATIONS) 287 + // .expect("should be able to run migrations"); 288 + 289 + let hostname = config.host_name.clone(); 290 + let crawlers: Vec<String> = config 291 + .firehose 292 + .relays 293 + .iter() 294 + .map(|s| s.to_string()) 295 + .collect(); 296 + let sequencer = Arc::new(SharedSequencer { 297 + sequencer: RwLock::new(Sequencer::new( 298 + Crawlers::new(hostname, crawlers.clone()), 299 + None, 300 + )), 301 + }); 302 + let account_manager = SharedAccountManager { 303 + account_manager: RwLock::new(AccountManager::new(pool.clone())), 304 + }; 305 + 306 + let addr = config 307 + .listen_address 308 + .unwrap_or(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8000)); 309 + 310 + let app = Router::new() 311 + .route("/", get(super::index)) 312 + .merge(super::oauth::routes()) 313 + .nest( 314 + "/xrpc", 315 + super::apis::routes() 316 + .merge(super::actor_endpoints::routes()) 317 + .fallback(service_proxy), 318 + ) 319 + // .layer(RateLimitLayer::new(30, Duration::from_secs(30))) 320 + .layer(CorsLayer::permissive()) 321 + .layer(TraceLayer::new_for_http()) 322 + .with_state(AppState { 323 + config: config.clone(), 324 + db: pool.clone(), 325 + db_actors: actor_pools.clone(), 326 + client: client.clone(), 327 + simple_client, 328 + sequencer: sequencer.clone(), 329 + account_manager: Arc::new(account_manager), 330 + signing_key: skey, 331 + rotation_key: rkey, 332 + }); 333 + 334 + info!("listening on {addr}"); 335 + info!("connect to: http://127.0.0.1:{}", addr.port()); 336 + 337 + // Determine whether or not this was the first startup (i.e. no accounts exist and no invite codes were created). 338 + // If so, create an invite code and share it via the console. 339 + let conn = pool.get().await.context("failed to get db connection")?; 340 + 341 + #[derive(QueryableByName)] 342 + struct TotalCount { 343 + #[diesel(sql_type = diesel::sql_types::Integer)] 344 + total_count: i32, 345 + } 346 + 347 + let result = conn.interact(move |conn| { 348 + diesel::sql_query( 349 + "SELECT (SELECT COUNT(*) FROM account) + (SELECT COUNT(*) FROM invite_code) AS total_count", 350 + ) 351 + .get_result::<TotalCount>(conn) 352 + }) 353 + .await 354 + .expect("should be able to query database")?; 355 + 356 + let c = result.total_count; 357 + 358 + #[expect(clippy::print_stdout)] 359 + if c == 0 { 360 + let uuid = Uuid::new_v4().to_string(); 361 + 362 + use crate::models::pds as models; 363 + use crate::schema::pds::invite_code::dsl as InviteCode; 364 + let uuid_clone = uuid.clone(); 365 + drop( 366 + conn.interact(move |conn| { 367 + diesel::insert_into(InviteCode::invite_code) 368 + .values(models::InviteCode { 369 + code: uuid_clone, 370 + available_uses: 1, 371 + disabled: 0, 372 + for_account: "None".to_owned(), 373 + created_by: "None".to_owned(), 374 + created_at: "None".to_owned(), 375 + }) 376 + .execute(conn) 377 + .context("failed to create new invite code") 378 + }) 379 + .await 380 + .expect("should be able to create invite code"), 381 + ); 382 + 383 + // N.B: This is a sensitive message, so we're bypassing `tracing` here and 384 + // logging it directly to console. 385 + println!("====================================="); 386 + println!(" FIRST STARTUP "); 387 + println!("====================================="); 388 + println!("Use this code to create an account:"); 389 + println!("{uuid}"); 390 + println!("====================================="); 391 + } 392 + 393 + let listener = TcpListener::bind(&addr) 394 + .await 395 + .context("failed to bind address")?; 396 + 397 + // Serve the app, and request crawling from upstream relays. 398 + let serve = tokio::spawn(async move { 399 + axum::serve(listener, app.into_make_service()) 400 + .await 401 + .context("failed to serve app") 402 + }); 403 + 404 + // Now that the app is live, request a crawl from upstream relays. 405 + let mut background_sequencer = sequencer.sequencer.write().await.clone(); 406 + drop(tokio::spawn( 407 + async move { background_sequencer.start().await }, 408 + )); 409 + 410 + serve 411 + .await 412 + .map_err(Into::into) 413 + .and_then(|r| r) 414 + .context("failed to serve app") 415 + }
+6 -26
src/service_proxy.rs
··· 3 3 /// Reference: <https://atproto.com/specs/xrpc#service-proxying> 4 4 use anyhow::{Context as _, anyhow}; 5 5 use atrium_api::types::string::Did; 6 - use atrium_crypto::keypair::{Export as _, Secp256k1Keypair}; 7 6 use axum::{ 8 - Router, 9 7 body::Body, 10 - extract::{FromRef, Request, State}, 8 + extract::{Request, State}, 11 9 http::{self, HeaderMap, Response, StatusCode, Uri}, 12 - response::IntoResponse, 13 - routing::get, 14 10 }; 15 - use azure_core::credentials::TokenCredential; 16 - use clap::Parser; 17 - use clap_verbosity_flag::{InfoLevel, Verbosity, log::LevelFilter}; 18 - use deadpool_diesel::sqlite::Pool; 19 - use diesel::prelude::*; 20 - use diesel_migrations::{EmbeddedMigrations, embed_migrations}; 21 - use figment::{Figment, providers::Format as _}; 22 - use http_cache_reqwest::{CacheMode, HttpCacheOptions, MokaManager}; 23 11 use rand::Rng as _; 24 - use serde::{Deserialize, Serialize}; 25 - use std::{ 26 - net::{IpAddr, Ipv4Addr, SocketAddr}, 27 - path::PathBuf, 28 - str::FromStr as _, 29 - sync::Arc, 30 - }; 31 - use tokio::net::TcpListener; 32 - use tower_http::{cors::CorsLayer, trace::TraceLayer}; 33 - use tracing::{info, warn}; 34 - use uuid::Uuid; 12 + use std::str::FromStr as _; 35 13 36 - use super::{Client, Error, Result}; 37 - use crate::{AuthenticatedUser, SigningKey}; 14 + use super::{ 15 + auth::AuthenticatedUser, 16 + serve::{Client, Error, Result, SigningKey}, 17 + }; 38 18 39 19 pub(super) async fn service_proxy( 40 20 uri: Uri,
-459
src/tests.rs
··· 1 - //! Testing utilities for the PDS. 2 - #![expect(clippy::arbitrary_source_item_ordering)] 3 - use std::{ 4 - net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener}, 5 - path::PathBuf, 6 - time::{Duration, Instant}, 7 - }; 8 - 9 - use anyhow::Result; 10 - use atrium_api::{ 11 - com::atproto::server, 12 - types::string::{AtIdentifier, Did, Handle, Nsid, RecordKey}, 13 - }; 14 - use figment::{Figment, providers::Format as _}; 15 - use futures::future::join_all; 16 - use serde::{Deserialize, Serialize}; 17 - use tokio::sync::OnceCell; 18 - use uuid::Uuid; 19 - 20 - use crate::config::AppConfig; 21 - 22 - /// Global test state, created once for all tests. 23 - pub(crate) static TEST_STATE: OnceCell<TestState> = OnceCell::const_new(); 24 - 25 - /// A temporary test directory that will be cleaned up when the struct is dropped. 26 - struct TempDir { 27 - /// The path to the directory. 28 - path: PathBuf, 29 - } 30 - 31 - impl TempDir { 32 - /// Create a new temporary directory. 33 - fn new() -> Result<Self> { 34 - let path = std::env::temp_dir().join(format!("bluepds-test-{}", Uuid::new_v4())); 35 - std::fs::create_dir_all(&path)?; 36 - Ok(Self { path }) 37 - } 38 - 39 - /// Get the path to the directory. 40 - fn path(&self) -> &PathBuf { 41 - &self.path 42 - } 43 - } 44 - 45 - impl Drop for TempDir { 46 - fn drop(&mut self) { 47 - drop(std::fs::remove_dir_all(&self.path)); 48 - } 49 - } 50 - 51 - /// Test state for the application. 52 - pub(crate) struct TestState { 53 - /// The address the test server is listening on. 54 - address: SocketAddr, 55 - /// The HTTP client. 56 - client: reqwest::Client, 57 - /// The application configuration. 58 - config: AppConfig, 59 - /// The temporary directory for test data. 60 - #[expect(dead_code)] 61 - temp_dir: TempDir, 62 - } 63 - 64 - impl TestState { 65 - /// Get a base URL for the test server. 66 - pub(crate) fn base_url(&self) -> String { 67 - format!("http://{}", self.address) 68 - } 69 - 70 - /// Create a test account. 71 - pub(crate) async fn create_test_account(&self) -> Result<TestAccount> { 72 - // Create the account 73 - let handle = "test.handle"; 74 - let response = self 75 - .client 76 - .post(format!( 77 - "http://{}/xrpc/com.atproto.server.createAccount", 78 - self.address 79 - )) 80 - .json(&server::create_account::InputData { 81 - did: None, 82 - verification_code: None, 83 - verification_phone: None, 84 - email: Some(format!("{}@example.com", &handle)), 85 - handle: Handle::new(handle.to_owned()).expect("should be able to create handle"), 86 - password: Some("password123".to_owned()), 87 - invite_code: None, 88 - recovery_key: None, 89 - plc_op: None, 90 - }) 91 - .send() 92 - .await?; 93 - 94 - let account: server::create_account::Output = response.json().await?; 95 - 96 - Ok(TestAccount { 97 - handle: handle.to_owned(), 98 - did: account.did.to_string(), 99 - access_token: account.access_jwt.clone(), 100 - refresh_token: account.refresh_jwt.clone(), 101 - }) 102 - } 103 - 104 - /// Create a new test state. 105 - #[expect(clippy::unused_async)] 106 - async fn new() -> Result<Self> { 107 - // Configure the test app 108 - #[derive(Serialize, Deserialize)] 109 - struct TestConfigInput { 110 - db: Option<String>, 111 - host_name: Option<String>, 112 - key: Option<PathBuf>, 113 - listen_address: Option<SocketAddr>, 114 - test: Option<bool>, 115 - } 116 - // Create a temporary directory for test data 117 - let temp_dir = TempDir::new()?; 118 - 119 - // Find a free port 120 - let listener = TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))?; 121 - let address = listener.local_addr()?; 122 - drop(listener); 123 - 124 - let test_config = TestConfigInput { 125 - db: Some(format!("sqlite://{}/test.db", temp_dir.path().display())), 126 - host_name: Some(format!("localhost:{}", address.port())), 127 - key: Some(temp_dir.path().join("test.key")), 128 - listen_address: Some(address), 129 - test: Some(true), 130 - }; 131 - 132 - let config: AppConfig = Figment::new() 133 - .admerge(figment::providers::Toml::file("default.toml")) 134 - .admerge(figment::providers::Env::prefixed("BLUEPDS_")) 135 - .merge(figment::providers::Serialized::defaults(test_config)) 136 - .merge( 137 - figment::providers::Toml::string( 138 - r#" 139 - [firehose] 140 - relays = [] 141 - 142 - [repo] 143 - path = "repo" 144 - 145 - [plc] 146 - path = "plc" 147 - 148 - [blob] 149 - path = "blob" 150 - limit = 10485760 # 10 MB 151 - "#, 152 - ) 153 - .nested(), 154 - ) 155 - .extract()?; 156 - 157 - // Create directories 158 - std::fs::create_dir_all(temp_dir.path().join("repo"))?; 159 - std::fs::create_dir_all(temp_dir.path().join("plc"))?; 160 - std::fs::create_dir_all(temp_dir.path().join("blob"))?; 161 - 162 - // Create client 163 - let client = reqwest::Client::builder() 164 - .timeout(Duration::from_secs(30)) 165 - .build()?; 166 - 167 - Ok(Self { 168 - address, 169 - client, 170 - config, 171 - temp_dir, 172 - }) 173 - } 174 - 175 - /// Start the application in a background task. 176 - async fn start_app(&self) -> Result<()> { 177 - // // Get a reference to the config that can be moved into the task 178 - // let config = self.config.clone(); 179 - // let address = self.address; 180 - 181 - // // Start the application in a background task 182 - // let _handle = tokio::spawn(async move { 183 - // // Set up the application 184 - // use crate::*; 185 - 186 - // // Initialize metrics (noop in test mode) 187 - // drop(metrics::setup(None)); 188 - 189 - // // Create client 190 - // let simple_client = reqwest::Client::builder() 191 - // .user_agent(APP_USER_AGENT) 192 - // .build() 193 - // .context("failed to build requester client")?; 194 - // let client = reqwest_middleware::ClientBuilder::new(simple_client.clone()) 195 - // .with(http_cache_reqwest::Cache(http_cache_reqwest::HttpCache { 196 - // mode: CacheMode::Default, 197 - // manager: MokaManager::default(), 198 - // options: HttpCacheOptions::default(), 199 - // })) 200 - // .build(); 201 - 202 - // // Create a test keypair 203 - // std::fs::create_dir_all(config.key.parent().context("should have parent")?)?; 204 - // let (skey, rkey) = { 205 - // let skey = Secp256k1Keypair::create(&mut rand::thread_rng()); 206 - // let rkey = Secp256k1Keypair::create(&mut rand::thread_rng()); 207 - 208 - // let keys = KeyData { 209 - // skey: skey.export(), 210 - // rkey: rkey.export(), 211 - // }; 212 - 213 - // let mut f = 214 - // std::fs::File::create(&config.key).context("failed to create key file")?; 215 - // serde_ipld_dagcbor::to_writer(&mut f, &keys) 216 - // .context("failed to serialize crypto keys")?; 217 - 218 - // (SigningKey(Arc::new(skey)), RotationKey(Arc::new(rkey))) 219 - // }; 220 - 221 - // // Set up database 222 - // let opts = SqliteConnectOptions::from_str(&config.db) 223 - // .context("failed to parse database options")? 224 - // .create_if_missing(true); 225 - // let db = SqliteDbConn::connect_with(opts).await?; 226 - 227 - // sqlx::migrate!() 228 - // .run(&db) 229 - // .await 230 - // .context("failed to apply migrations")?; 231 - 232 - // // Create firehose 233 - // let (_fh, fhp) = firehose::spawn(client.clone(), config.clone()); 234 - 235 - // // Create the application state 236 - // let app_state = AppState { 237 - // cred: azure_identity::DefaultAzureCredential::new()?, 238 - // config: config.clone(), 239 - // db: db.clone(), 240 - // client: client.clone(), 241 - // simple_client, 242 - // firehose: fhp, 243 - // signing_key: skey, 244 - // rotation_key: rkey, 245 - // }; 246 - 247 - // // Create the router 248 - // let app = Router::new() 249 - // .route("/", get(index)) 250 - // .merge(oauth::routes()) 251 - // .nest( 252 - // "/xrpc", 253 - // endpoints::routes() 254 - // .merge(actor_endpoints::routes()) 255 - // .fallback(service_proxy), 256 - // ) 257 - // .layer(CorsLayer::permissive()) 258 - // .layer(TraceLayer::new_for_http()) 259 - // .with_state(app_state); 260 - 261 - // // Listen for connections 262 - // let listener = TcpListener::bind(&address) 263 - // .await 264 - // .context("failed to bind address")?; 265 - 266 - // axum::serve(listener, app.into_make_service()) 267 - // .await 268 - // .context("failed to serve app") 269 - // }); 270 - 271 - // // Give the server a moment to start 272 - // tokio::time::sleep(Duration::from_millis(500)).await; 273 - 274 - Ok(()) 275 - } 276 - } 277 - 278 - /// A test account that can be used for testing. 279 - pub(crate) struct TestAccount { 280 - /// The access token for the account. 281 - pub(crate) access_token: String, 282 - /// The account DID. 283 - pub(crate) did: String, 284 - /// The account handle. 285 - pub(crate) handle: String, 286 - /// The refresh token for the account. 287 - #[expect(dead_code)] 288 - pub(crate) refresh_token: String, 289 - } 290 - 291 - /// Initialize the test state. 292 - pub(crate) async fn init_test_state() -> Result<&'static TestState> { 293 - async fn init_test_state() -> std::result::Result<TestState, anyhow::Error> { 294 - let state = TestState::new().await?; 295 - state.start_app().await?; 296 - Ok(state) 297 - } 298 - TEST_STATE.get_or_try_init(init_test_state).await 299 - } 300 - 301 - /// Create a record benchmark that creates records and measures the time it takes. 302 - #[expect( 303 - clippy::arithmetic_side_effects, 304 - clippy::integer_division, 305 - clippy::integer_division_remainder_used, 306 - clippy::use_debug, 307 - clippy::print_stdout 308 - )] 309 - pub(crate) async fn create_record_benchmark(count: usize, concurrent: usize) -> Result<Duration> { 310 - // Initialize the test state 311 - let state = init_test_state().await?; 312 - 313 - // Create a test account 314 - let account = state.create_test_account().await?; 315 - 316 - // Create the client with authorization 317 - let client = reqwest::Client::builder() 318 - .timeout(Duration::from_secs(30)) 319 - .build()?; 320 - 321 - let start = Instant::now(); 322 - 323 - // Split the work into batches 324 - let mut handles = Vec::new(); 325 - for batch_idx in 0..concurrent { 326 - let batch_size = count / concurrent; 327 - let client = client.clone(); 328 - let base_url = state.base_url(); 329 - let account_did = account.did.clone(); 330 - let account_handle = account.handle.clone(); 331 - let access_token = account.access_token.clone(); 332 - 333 - let handle = tokio::spawn(async move { 334 - let mut results = Vec::new(); 335 - 336 - for i in 0..batch_size { 337 - let request_start = Instant::now(); 338 - let record_idx = batch_idx * batch_size + i; 339 - 340 - let result = client 341 - .post(format!("{base_url}/xrpc/com.atproto.repo.createRecord")) 342 - .header("Authorization", format!("Bearer {access_token}")) 343 - .json(&atrium_api::com::atproto::repo::create_record::InputData { 344 - repo: AtIdentifier::Did(Did::new(account_did.clone()).expect("valid DID")), 345 - collection: Nsid::new("app.bsky.feed.post".to_owned()).expect("valid NSID"), 346 - rkey: Some( 347 - RecordKey::new(format!("test-{record_idx}")).expect("valid record key"), 348 - ), 349 - validate: None, 350 - record: serde_json::from_str( 351 - &serde_json::json!({ 352 - "$type": "app.bsky.feed.post", 353 - "text": format!("Test post {record_idx} from {account_handle}"), 354 - "createdAt": chrono::Utc::now().to_rfc3339(), 355 - }) 356 - .to_string(), 357 - ) 358 - .expect("valid JSON record"), 359 - swap_commit: None, 360 - }) 361 - .send() 362 - .await; 363 - 364 - // Fetch the record we just created 365 - let get_response = client 366 - .get(format!( 367 - "{base_url}/xrpc/com.atproto.sync.getRecord?did={account_did}&collection=app.bsky.feed.post&rkey={record_idx}" 368 - )) 369 - .header("Authorization", format!("Bearer {access_token}")) 370 - .send() 371 - .await; 372 - if get_response.is_err() { 373 - println!("Failed to fetch record {record_idx}: {get_response:?}"); 374 - results.push(get_response); 375 - continue; 376 - } 377 - 378 - let request_duration = request_start.elapsed(); 379 - if record_idx % 10 == 0 { 380 - println!("Created record {record_idx} in {request_duration:?}"); 381 - } 382 - results.push(result); 383 - } 384 - 385 - results 386 - }); 387 - 388 - handles.push(handle); 389 - } 390 - 391 - // Wait for all batches to complete 392 - let results = join_all(handles).await; 393 - 394 - // Check for errors 395 - for batch_result in results { 396 - let batch_responses = batch_result?; 397 - for response_result in batch_responses { 398 - match response_result { 399 - Ok(response) => { 400 - if !response.status().is_success() { 401 - return Err(anyhow::anyhow!( 402 - "Failed to create record: {}", 403 - response.status() 404 - )); 405 - } 406 - } 407 - Err(err) => { 408 - return Err(anyhow::anyhow!("Failed to create record: {}", err)); 409 - } 410 - } 411 - } 412 - } 413 - 414 - let duration = start.elapsed(); 415 - Ok(duration) 416 - } 417 - 418 - #[cfg(test)] 419 - #[expect(clippy::module_inception, clippy::use_debug, clippy::print_stdout)] 420 - mod tests { 421 - use super::*; 422 - use anyhow::anyhow; 423 - 424 - #[tokio::test] 425 - async fn test_create_account() -> Result<()> { 426 - return Ok(()); 427 - #[expect(unreachable_code, reason = "Disabled")] 428 - let state = init_test_state().await?; 429 - let account = state.create_test_account().await?; 430 - 431 - println!("Created test account: {}", account.handle); 432 - if account.handle.is_empty() { 433 - return Err(anyhow::anyhow!("Account handle is empty")); 434 - } 435 - if account.did.is_empty() { 436 - return Err(anyhow::anyhow!("Account DID is empty")); 437 - } 438 - if account.access_token.is_empty() { 439 - return Err(anyhow::anyhow!("Account access token is empty")); 440 - } 441 - 442 - Ok(()) 443 - } 444 - 445 - #[tokio::test] 446 - async fn test_create_record_benchmark() -> Result<()> { 447 - return Ok(()); 448 - #[expect(unreachable_code, reason = "Disabled")] 449 - let duration = create_record_benchmark(100, 1).await?; 450 - 451 - println!("Created 100 records in {duration:?}"); 452 - 453 - if duration.as_secs() >= 10 { 454 - return Err(anyhow!("Benchmark took too long")); 455 - } 456 - 457 - Ok(()) 458 - } 459 - }