Alternative ATProto PDS implementation
1//! PDS implementation. 2mod actor_store; 3mod auth; 4mod config; 5mod db; 6mod did; 7mod endpoints; 8mod error; 9mod firehose; 10mod metrics; 11mod mmap; 12mod oauth; 13mod plc; 14mod storage; 15#[cfg(test)] 16mod tests; 17 18/// HACK: store private user preferences in the PDS. 19/// 20/// We shouldn't have to know about any bsky endpoints to store private user data. 21/// This will _very likely_ be changed in the future. 22mod actor_endpoints { 23 use atrium_api::app::bsky::actor; 24 use axum::{Json, routing::post}; 25 use constcat::concat; 26 27 use super::*; 28 29 async fn put_preferences( 30 user: AuthenticatedUser, 31 State(db): State<Db>, 32 Json(input): Json<actor::put_preferences::Input>, 33 ) -> Result<()> { 34 let did = user.did(); 35 let prefs = sqlx::types::Json(input.preferences.clone()); 36 _ = sqlx::query!( 37 r#"UPDATE accounts SET private_prefs = ? WHERE did = ?"#, 38 prefs, 39 did 40 ) 41 .execute(&db) 42 .await 43 .context("failed to update user preferences")?; 44 45 Ok(()) 46 } 47 48 async fn get_preferences( 49 user: AuthenticatedUser, 50 State(db): State<Db>, 51 ) -> Result<Json<actor::get_preferences::Output>> { 52 let did = user.did(); 53 let json: Option<sqlx::types::Json<actor::defs::Preferences>> = 54 sqlx::query_scalar("SELECT private_prefs FROM accounts WHERE did = ?") 55 .bind(did) 56 .fetch_one(&db) 57 .await 58 .context("failed to fetch preferences")?; 59 60 if let Some(prefs) = json { 61 Ok(Json( 62 actor::get_preferences::OutputData { 63 preferences: prefs.0, 64 } 65 .into(), 66 )) 67 } else { 68 Ok(Json( 69 actor::get_preferences::OutputData { 70 preferences: Vec::new(), 71 } 72 .into(), 73 )) 74 } 75 } 76 77 /// Register all actor endpoints. 78 pub(crate) fn routes() -> Router<AppState> { 79 // AP /xrpc/app.bsky.actor.putPreferences 80 // AG /xrpc/app.bsky.actor.getPreferences 81 Router::new() 82 .route( 83 concat!("/", actor::put_preferences::NSID), 84 post(put_preferences), 85 ) 86 .route( 87 concat!("/", actor::get_preferences::NSID), 88 get(get_preferences), 89 ) 90 } 91} 92 93use anyhow::{Context as _, anyhow}; 94use atrium_api::types::string::Did; 95use atrium_crypto::keypair::{Export as _, Secp256k1Keypair}; 96use auth::AuthenticatedUser; 97use axum::{ 98 Router, 99 body::Body, 100 extract::{FromRef, Request, State}, 101 http::{self, HeaderMap, Response, StatusCode, Uri}, 102 response::IntoResponse, 103 routing::get, 104}; 105use azure_core::credentials::TokenCredential; 106use clap::Parser; 107use clap_verbosity_flag::{InfoLevel, Verbosity, log::LevelFilter}; 108use config::AppConfig; 109#[expect(clippy::pub_use, clippy::useless_attribute)] 110pub use error::Error; 111use figment::{Figment, providers::Format as _}; 112use firehose::FirehoseProducer; 113use http_cache_reqwest::{CacheMode, HttpCacheOptions, MokaManager}; 114use rand::Rng as _; 115use serde::{Deserialize, Serialize}; 116use sqlx::{SqlitePool, sqlite::SqliteConnectOptions}; 117use std::{ 118 net::{IpAddr, Ipv4Addr, SocketAddr}, 119 path::PathBuf, 120 str::FromStr as _, 121 sync::Arc, 122}; 123use tokio::net::TcpListener; 124use tower_http::{cors::CorsLayer, trace::TraceLayer}; 125use tracing::{info, warn}; 126use uuid::Uuid; 127 128/// The application user agent. Concatenates the package name and version. e.g. `bluepds/0.0.0`. 129pub const APP_USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),); 130 131/// The application-wide result type. 132pub type Result<T> = std::result::Result<T, Error>; 133/// The reqwest client type with middleware. 134pub type Client = reqwest_middleware::ClientWithMiddleware; 135/// The database connection pool. 136pub type Db = SqlitePool; 137/// The Azure credential type. 138pub type Cred = Arc<dyn TokenCredential>; 139 140#[expect( 141 clippy::arbitrary_source_item_ordering, 142 reason = "serialized data might be structured" 143)] 144#[derive(Serialize, Deserialize, Debug, Clone)] 145/// The key data structure. 146struct KeyData { 147 /// Primary signing key for all repo operations. 148 skey: Vec<u8>, 149 /// Primary signing (rotation) key for all PLC operations. 150 rkey: Vec<u8>, 151} 152 153// FIXME: We should use P256Keypair instead. SecP256K1 is primarily used for cryptocurrencies, 154// and the implementations of this algorithm are much more limited as compared to P256. 155// 156// Reference: https://soatok.blog/2022/05/19/guidance-for-choosing-an-elliptic-curve-signature-algorithm-in-2022/ 157#[derive(Clone)] 158/// The signing key for PLC/DID operations. 159pub struct SigningKey(Arc<Secp256k1Keypair>); 160#[derive(Clone)] 161/// The rotation key for PLC operations. 162pub struct RotationKey(Arc<Secp256k1Keypair>); 163 164impl std::ops::Deref for SigningKey { 165 type Target = Secp256k1Keypair; 166 167 fn deref(&self) -> &Self::Target { 168 &self.0 169 } 170} 171 172impl SigningKey { 173 /// Import from a private key. 174 pub fn import(key: &[u8]) -> Result<Self> { 175 let key = Secp256k1Keypair::import(key).context("failed to import signing key")?; 176 Ok(Self(Arc::new(key))) 177 } 178} 179 180impl std::ops::Deref for RotationKey { 181 type Target = Secp256k1Keypair; 182 183 fn deref(&self) -> &Self::Target { 184 &self.0 185 } 186} 187 188#[derive(Parser, Debug, Clone)] 189/// Command line arguments. 190struct Args { 191 /// Path to the configuration file 192 #[arg(short, long, default_value = "default.toml")] 193 config: PathBuf, 194 /// The verbosity level. 195 #[command(flatten)] 196 verbosity: Verbosity<InfoLevel>, 197} 198 199#[expect(clippy::arbitrary_source_item_ordering, reason = "arbitrary")] 200#[derive(Clone, FromRef)] 201struct AppState { 202 /// The application configuration. 203 config: AppConfig, 204 /// The Azure credential. 205 cred: Cred, 206 /// The database connection pool. 207 db: Db, 208 209 /// The HTTP client with middleware. 210 client: Client, 211 /// The simple HTTP client. 212 simple_client: reqwest::Client, 213 /// The firehose producer. 214 firehose: FirehoseProducer, 215 216 /// The signing key. 217 signing_key: SigningKey, 218 /// The rotation key. 219 rotation_key: RotationKey, 220} 221 222/// The index (/) route. 223async fn index() -> impl IntoResponse { 224 r" 225 __ __ 226 /\ \__ /\ \__ 227 __ \ \ ,_\ _____ _ __ ___\ \ ,_\ ___ 228 /'__'\ \ \ \/ /\ '__'\/\''__\/ __'\ \ \/ / __'\ 229 /\ \L\.\_\ \ \_\ \ \L\ \ \ \//\ \L\ \ \ \_/\ \L\ \ 230 \ \__/.\_\\ \__\\ \ ,__/\ \_\\ \____/\ \__\ \____/ 231 \/__/\/_/ \/__/ \ \ \/ \/_/ \/___/ \/__/\/___/ 232 \ \_\ 233 \/_/ 234 235 236This is an AT Protocol Personal Data Server (aka, an atproto PDS) 237 238Most API routes are under /xrpc/ 239 240 Code: https://github.com/DrChat/bluepds 241 Protocol: https://atproto.com 242 " 243} 244 245/// Service proxy. 246/// 247/// Reference: <https://atproto.com/specs/xrpc#service-proxying> 248async fn service_proxy( 249 uri: Uri, 250 user: AuthenticatedUser, 251 State(skey): State<SigningKey>, 252 State(client): State<reqwest::Client>, 253 headers: HeaderMap, 254 request: Request<Body>, 255) -> Result<Response<Body>> { 256 let url_path = uri.path_and_query().context("invalid service proxy url")?; 257 let lxm = url_path 258 .path() 259 .strip_prefix("/") 260 .with_context(|| format!("invalid service proxy url prefix: {}", url_path.path()))?; 261 262 let user_did = user.did(); 263 let (did, id) = match headers.get("atproto-proxy") { 264 Some(val) => { 265 let val = 266 std::str::from_utf8(val.as_bytes()).context("proxy header not valid utf-8")?; 267 268 let (did, id) = val.split_once('#').context("invalid proxy header")?; 269 270 let did = 271 Did::from_str(did).map_err(|e| anyhow!("atproto proxy not a valid DID: {e}"))?; 272 273 (did, format!("#{id}")) 274 } 275 // HACK: Assume the bluesky appview by default. 276 None => ( 277 Did::new("did:web:api.bsky.app".to_owned()) 278 .expect("service proxy should be a valid DID"), 279 "#bsky_appview".to_owned(), 280 ), 281 }; 282 283 let did_doc = did::resolve(&Client::new(client.clone(), []), did.clone()) 284 .await 285 .with_context(|| format!("failed to resolve did document {}", did.as_str()))?; 286 287 let Some(service) = did_doc.service.iter().find(|s| s.id == id) else { 288 return Err(Error::with_status( 289 StatusCode::BAD_REQUEST, 290 anyhow!("could not find resolve service #{id}"), 291 )); 292 }; 293 294 let target_url: url::Url = service 295 .service_endpoint 296 .join(&format!("/xrpc{url_path}")) 297 .context("failed to construct target url")?; 298 299 let exp = (chrono::Utc::now().checked_add_signed(chrono::Duration::minutes(1))) 300 .context("should be valid expiration datetime")? 301 .timestamp(); 302 let jti = rand::thread_rng() 303 .sample_iter(rand::distributions::Alphanumeric) 304 .take(10) 305 .map(char::from) 306 .collect::<String>(); 307 308 // Mint a bearer token by signing a JSON web token. 309 // https://github.com/DavidBuchanan314/millipds/blob/5c7529a739d394e223c0347764f1cf4e8fd69f94/src/millipds/appview_proxy.py#L47-L59 310 let token = auth::sign( 311 &skey, 312 "JWT", 313 &serde_json::json!({ 314 "iss": user_did.as_str(), 315 "aud": did.as_str(), 316 "lxm": lxm, 317 "exp": exp, 318 "jti": jti, 319 }), 320 ) 321 .context("failed to sign jwt")?; 322 323 let mut h = HeaderMap::new(); 324 if let Some(hdr) = request.headers().get("atproto-accept-labelers") { 325 drop(h.insert("atproto-accept-labelers", hdr.clone())); 326 } 327 if let Some(hdr) = request.headers().get(http::header::CONTENT_TYPE) { 328 drop(h.insert(http::header::CONTENT_TYPE, hdr.clone())); 329 } 330 331 let r = client 332 .request(request.method().clone(), target_url) 333 .headers(h) 334 .header(http::header::AUTHORIZATION, format!("Bearer {token}")) 335 .body(reqwest::Body::wrap_stream( 336 request.into_body().into_data_stream(), 337 )) 338 .send() 339 .await 340 .context("failed to send request")?; 341 342 let mut resp = Response::builder().status(r.status()); 343 if let Some(hdrs) = resp.headers_mut() { 344 *hdrs = r.headers().clone(); 345 } 346 347 let resp = resp 348 .body(Body::from_stream(r.bytes_stream())) 349 .context("failed to construct response")?; 350 351 Ok(resp) 352} 353 354/// The main application entry point. 355#[expect( 356 clippy::cognitive_complexity, 357 clippy::too_many_lines, 358 reason = "main function has high complexity" 359)] 360async fn run() -> anyhow::Result<()> { 361 let args = Args::parse(); 362 363 // Set up trace logging to console and account for the user-provided verbosity flag. 364 if args.verbosity.log_level_filter() != LevelFilter::Off { 365 let lvl = match args.verbosity.log_level_filter() { 366 LevelFilter::Error => tracing::Level::ERROR, 367 LevelFilter::Warn => tracing::Level::WARN, 368 LevelFilter::Info | LevelFilter::Off => tracing::Level::INFO, 369 LevelFilter::Debug => tracing::Level::DEBUG, 370 LevelFilter::Trace => tracing::Level::TRACE, 371 }; 372 tracing_subscriber::fmt().with_max_level(lvl).init(); 373 } 374 375 if !args.config.exists() { 376 // Throw up a warning if the config file does not exist. 377 // 378 // This is not fatal because users can specify all configuration settings via 379 // the environment, but the most likely scenario here is that a user accidentally 380 // omitted the config file for some reason (e.g. forgot to mount it into Docker). 381 warn!( 382 "configuration file {} does not exist", 383 args.config.display() 384 ); 385 } 386 387 // Read and parse the user-provided configuration. 388 let config: AppConfig = Figment::new() 389 .admerge(figment::providers::Toml::file(args.config)) 390 .admerge(figment::providers::Env::prefixed("BLUEPDS_")) 391 .extract() 392 .context("failed to load configuration")?; 393 394 if config.test { 395 warn!("BluePDS starting up in TEST mode."); 396 warn!("This means the application will not federate with the rest of the network."); 397 warn!( 398 "If you want to turn this off, either set `test` to false in the config or define `BLUEPDS_TEST = false`" 399 ); 400 } 401 402 // Initialize metrics reporting. 403 metrics::setup(config.metrics.as_ref()).context("failed to set up metrics exporter")?; 404 405 // Create a reqwest client that will be used for all outbound requests. 406 let simple_client = reqwest::Client::builder() 407 .user_agent(APP_USER_AGENT) 408 .build() 409 .context("failed to build requester client")?; 410 let client = reqwest_middleware::ClientBuilder::new(simple_client.clone()) 411 .with(http_cache_reqwest::Cache(http_cache_reqwest::HttpCache { 412 mode: CacheMode::Default, 413 manager: MokaManager::default(), 414 options: HttpCacheOptions::default(), 415 })) 416 .build(); 417 418 tokio::fs::create_dir_all(&config.key.parent().context("should have parent")?) 419 .await 420 .context("failed to create key directory")?; 421 422 // Check if crypto keys exist. If not, create new ones. 423 let (skey, rkey) = if let Ok(f) = std::fs::File::open(&config.key) { 424 let keys: KeyData = serde_ipld_dagcbor::from_reader(std::io::BufReader::new(f)) 425 .context("failed to deserialize crypto keys")?; 426 427 let skey = Secp256k1Keypair::import(&keys.skey).context("failed to import signing key")?; 428 let rkey = Secp256k1Keypair::import(&keys.rkey).context("failed to import rotation key")?; 429 430 (SigningKey(Arc::new(skey)), RotationKey(Arc::new(rkey))) 431 } else { 432 info!("signing keys not found, generating new ones"); 433 434 let skey = Secp256k1Keypair::create(&mut rand::thread_rng()); 435 let rkey = Secp256k1Keypair::create(&mut rand::thread_rng()); 436 437 let keys = KeyData { 438 skey: skey.export(), 439 rkey: rkey.export(), 440 }; 441 442 let mut f = std::fs::File::create(&config.key).context("failed to create key file")?; 443 serde_ipld_dagcbor::to_writer(&mut f, &keys).context("failed to serialize crypto keys")?; 444 445 (SigningKey(Arc::new(skey)), RotationKey(Arc::new(rkey))) 446 }; 447 448 tokio::fs::create_dir_all(&config.repo.path).await?; 449 tokio::fs::create_dir_all(&config.plc.path).await?; 450 tokio::fs::create_dir_all(&config.blob.path).await?; 451 452 let cred = azure_identity::DefaultAzureCredential::new() 453 .context("failed to create Azure credential")?; 454 let opts = SqliteConnectOptions::from_str(&config.db) 455 .context("failed to parse database options")? 456 .create_if_missing(true); 457 let db = SqlitePool::connect_with(opts).await?; 458 459 sqlx::migrate!() 460 .run(&db) 461 .await 462 .context("failed to apply migrations")?; 463 464 let (_fh, fhp) = firehose::spawn(client.clone(), config.clone()); 465 466 let addr = config 467 .listen_address 468 .unwrap_or(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8000)); 469 470 let app = Router::new() 471 .route("/", get(index)) 472 .merge(oauth::routes()) 473 .nest( 474 "/xrpc", 475 endpoints::routes() 476 .merge(actor_endpoints::routes()) 477 .fallback(service_proxy), 478 ) 479 // .layer(RateLimitLayer::new(30, Duration::from_secs(30))) 480 .layer(CorsLayer::permissive()) 481 .layer(TraceLayer::new_for_http()) 482 .with_state(AppState { 483 cred, 484 config: config.clone(), 485 db: db.clone(), 486 client: client.clone(), 487 simple_client, 488 firehose: fhp, 489 signing_key: skey, 490 rotation_key: rkey, 491 }); 492 493 info!("listening on {addr}"); 494 info!("connect to: http://127.0.0.1:{}", addr.port()); 495 496 // Determine whether or not this was the first startup (i.e. no accounts exist and no invite codes were created). 497 // If so, create an invite code and share it via the console. 498 let c = sqlx::query_scalar!( 499 r#" 500 SELECT 501 (SELECT COUNT(*) FROM accounts) + (SELECT COUNT(*) FROM invites) 502 AS total_count 503 "# 504 ) 505 .fetch_one(&db) 506 .await 507 .context("failed to query database")?; 508 509 #[expect(clippy::print_stdout)] 510 if c == 0 { 511 let uuid = Uuid::new_v4().to_string(); 512 513 _ = sqlx::query!( 514 r#" 515 INSERT INTO invites (id, did, count, created_at) 516 VALUES (?, NULL, 1, datetime('now')) 517 "#, 518 uuid, 519 ) 520 .execute(&db) 521 .await 522 .context("failed to create new invite code")?; 523 524 // N.B: This is a sensitive message, so we're bypassing `tracing` here and 525 // logging it directly to console. 526 println!("====================================="); 527 println!(" FIRST STARTUP "); 528 println!("====================================="); 529 println!("Use this code to create an account:"); 530 println!("{uuid}"); 531 println!("====================================="); 532 } 533 534 let listener = TcpListener::bind(&addr) 535 .await 536 .context("failed to bind address")?; 537 538 // Serve the app, and request crawling from upstream relays. 539 let serve = tokio::spawn(async move { 540 axum::serve(listener, app.into_make_service()) 541 .await 542 .context("failed to serve app") 543 }); 544 545 // Now that the app is live, request a crawl from upstream relays. 546 firehose::reconnect_relays(&client, &config).await; 547 548 serve 549 .await 550 .map_err(Into::into) 551 .and_then(|r| r) 552 .context("failed to serve app") 553} 554 555#[tokio::main(flavor = "multi_thread")] 556async fn main() -> anyhow::Result<()> { 557 // Dispatch out to a separate function without a derive macro to help rust-analyzer along. 558 run().await 559}