Alternative ATProto PDS implementation

prototype account_manager; reorganize

Changed files
+738 -868
src
account_manager
actor_store
endpoints
storage
+491
src/account_manager/mod.rs
···
··· 1 + use anyhow::Result; 2 + use chrono::DateTime; 3 + use chrono::offset::Utc as UtcOffset; 4 + use cidv10::Cid; 5 + use futures::try_join; 6 + use rsky_common::RFC3339_VARIANT; 7 + use rsky_common::time::{HOUR, from_micros_to_str, from_str_to_micros}; 8 + use rsky_lexicon::com::atproto::admin::StatusAttr; 9 + use rsky_lexicon::com::atproto::server::{AccountCodes, CreateAppPasswordOutput}; 10 + use rsky_pds::account_manager::CreateAccountOpts; 11 + use rsky_pds::account_manager::helpers::account::{ 12 + AccountStatus, ActorAccount, AvailabilityFlags, GetAccountAdminStatusOutput, 13 + }; 14 + use rsky_pds::account_manager::helpers::auth::{ 15 + AuthHelperError, CreateTokensOpts, RefreshGracePeriodOpts, 16 + }; 17 + use rsky_pds::account_manager::helpers::invite::CodeDetail; 18 + use rsky_pds::account_manager::helpers::password::UpdateUserPasswordOpts; 19 + use rsky_pds::account_manager::helpers::repo; 20 + use rsky_pds::account_manager::helpers::{account, auth, email_token, invite, password}; 21 + use rsky_pds::auth_verifier::AuthScope; 22 + use rsky_pds::models::models::EmailTokenPurpose; 23 + use secp256k1::{Keypair, Secp256k1, SecretKey}; 24 + use std::collections::BTreeMap; 25 + use std::env; 26 + use std::sync::Arc; 27 + use std::time::SystemTime; 28 + 29 + use crate::db::DbConn; 30 + 31 + #[derive(Clone, Debug)] 32 + pub struct AccountManager { 33 + pub db: Arc<DbConn>, 34 + } 35 + 36 + pub type AccountManagerCreator = Box<dyn Fn(Arc<DbConn>) -> AccountManager + Send + Sync>; 37 + 38 + impl AccountManager { 39 + pub fn new(db: Arc<DbConn>) -> Self { 40 + Self { db } 41 + } 42 + 43 + pub fn creator() -> AccountManagerCreator { 44 + Box::new(move |db: Arc<DbConn>| -> AccountManager { AccountManager::new(db) }) 45 + } 46 + 47 + pub async fn get_account( 48 + &self, 49 + handle_or_did: &str, 50 + flags: Option<AvailabilityFlags>, 51 + ) -> Result<Option<ActorAccount>> { 52 + let db = self.db.clone(); 53 + account::get_account(handle_or_did, flags, db.as_ref()).await 54 + } 55 + 56 + pub async fn get_account_by_email( 57 + &self, 58 + email: &str, 59 + flags: Option<AvailabilityFlags>, 60 + ) -> Result<Option<ActorAccount>> { 61 + let db = self.db.clone(); 62 + account::get_account_by_email(email, flags, db.as_ref()).await 63 + } 64 + 65 + pub async fn is_account_activated(&self, did: &str) -> Result<bool> { 66 + let account = self 67 + .get_account( 68 + did, 69 + Some(AvailabilityFlags { 70 + include_taken_down: None, 71 + include_deactivated: Some(true), 72 + }), 73 + ) 74 + .await?; 75 + if let Some(account) = account { 76 + Ok(account.deactivated_at.is_none()) 77 + } else { 78 + Ok(false) 79 + } 80 + } 81 + 82 + pub async fn get_did_for_actor( 83 + &self, 84 + handle_or_did: &str, 85 + flags: Option<AvailabilityFlags>, 86 + ) -> Result<Option<String>> { 87 + match self.get_account(handle_or_did, flags).await { 88 + Ok(Some(got)) => Ok(Some(got.did)), 89 + _ => Ok(None), 90 + } 91 + } 92 + 93 + pub async fn create_account(&self, opts: CreateAccountOpts) -> Result<(String, String)> { 94 + let db = self.db.clone(); 95 + let CreateAccountOpts { 96 + did, 97 + handle, 98 + email, 99 + password, 100 + repo_cid, 101 + repo_rev, 102 + invite_code, 103 + deactivated, 104 + } = opts; 105 + let password_encrypted: Option<String> = match password { 106 + Some(password) => Some(password::gen_salt_and_hash(password)?), 107 + None => None, 108 + }; 109 + // Should be a global var so this only happens once 110 + let secp = Secp256k1::new(); 111 + let private_key = env::var("PDS_JWT_KEY_K256_PRIVATE_KEY_HEX")?; 112 + let secret_key = 113 + SecretKey::from_slice(&Result::unwrap(hex::decode(private_key.as_bytes())))?; 114 + let jwt_key = Keypair::from_secret_key(&secp, &secret_key); 115 + let (access_jwt, refresh_jwt) = auth::create_tokens(CreateTokensOpts { 116 + did: did.clone(), 117 + jwt_key, 118 + service_did: env::var("PDS_SERVICE_DID").unwrap(), 119 + scope: Some(AuthScope::Access), 120 + jti: None, 121 + expires_in: None, 122 + })?; 123 + let refresh_payload = auth::decode_refresh_token(refresh_jwt.clone(), jwt_key)?; 124 + let now = rsky_common::now(); 125 + 126 + if let Some(invite_code) = invite_code.clone() { 127 + invite::ensure_invite_is_available(invite_code, db.as_ref()).await?; 128 + } 129 + account::register_actor(did.clone(), handle, deactivated, db.as_ref()).await?; 130 + if let (Some(email), Some(password_encrypted)) = (email, password_encrypted) { 131 + account::register_account(did.clone(), email, password_encrypted, db.as_ref()).await?; 132 + } 133 + invite::record_invite_use(did.clone(), invite_code, now, db.as_ref()).await?; 134 + auth::store_refresh_token(refresh_payload, None, db.as_ref()).await?; 135 + repo::update_root(did, repo_cid, repo_rev, db.as_ref()).await?; 136 + Ok((access_jwt, refresh_jwt)) 137 + } 138 + 139 + pub async fn get_account_admin_status( 140 + &self, 141 + did: &str, 142 + ) -> Result<Option<GetAccountAdminStatusOutput>> { 143 + let db = self.db.clone(); 144 + account::get_account_admin_status(did, db.as_ref()).await 145 + } 146 + 147 + pub async fn update_repo_root(&self, did: String, cid: Cid, rev: String) -> Result<()> { 148 + let db = self.db.clone(); 149 + repo::update_root(did, cid, rev, db.as_ref()).await 150 + } 151 + 152 + pub async fn delete_account(&self, did: &str) -> Result<()> { 153 + let db = self.db.clone(); 154 + account::delete_account(did, db.as_ref()).await 155 + } 156 + 157 + pub async fn takedown_account(&self, did: &str, takedown: StatusAttr) -> Result<()> { 158 + (_, _) = try_join!( 159 + account::update_account_takedown_status(did, takedown, self.db.as_ref()), 160 + auth::revoke_refresh_tokens_by_did(did, self.db.as_ref()) 161 + )?; 162 + Ok(()) 163 + } 164 + 165 + // @NOTE should always be paired with a sequenceHandle(). 166 + pub async fn update_handle(&self, did: &str, handle: &str) -> Result<()> { 167 + let db = self.db.clone(); 168 + account::update_handle(did, handle, db.as_ref()).await 169 + } 170 + 171 + pub async fn deactivate_account(&self, did: &str, delete_after: Option<String>) -> Result<()> { 172 + account::deactivate_account(did, delete_after, self.db.as_ref()).await 173 + } 174 + 175 + pub async fn activate_account(&self, did: &str) -> Result<()> { 176 + let db = self.db.clone(); 177 + account::activate_account(did, db.as_ref()).await 178 + } 179 + 180 + pub async fn get_account_status(&self, handle_or_did: &str) -> Result<AccountStatus> { 181 + let got = account::get_account( 182 + handle_or_did, 183 + Some(AvailabilityFlags { 184 + include_deactivated: Some(true), 185 + include_taken_down: Some(true), 186 + }), 187 + self.db.as_ref(), 188 + ) 189 + .await?; 190 + let res = account::format_account_status(got); 191 + match res.active { 192 + true => Ok(AccountStatus::Active), 193 + false => Ok(res.status.expect("Account status not properly formatted.")), 194 + } 195 + } 196 + 197 + // Auth 198 + // ---------- 199 + pub async fn create_session( 200 + &self, 201 + did: String, 202 + app_password_name: Option<String>, 203 + ) -> Result<(String, String)> { 204 + let db = self.db.clone(); 205 + let secp = Secp256k1::new(); 206 + let private_key = env::var("PDS_JWT_KEY_K256_PRIVATE_KEY_HEX")?; 207 + let secret_key = SecretKey::from_slice(&hex::decode(private_key.as_bytes())?)?; 208 + let jwt_key = Keypair::from_secret_key(&secp, &secret_key); 209 + let scope = if app_password_name.is_none() { 210 + AuthScope::Access 211 + } else { 212 + AuthScope::AppPass 213 + }; 214 + let (access_jwt, refresh_jwt) = auth::create_tokens(CreateTokensOpts { 215 + did, 216 + jwt_key, 217 + service_did: env::var("PDS_SERVICE_DID").unwrap(), 218 + scope: Some(scope), 219 + jti: None, 220 + expires_in: None, 221 + })?; 222 + let refresh_payload = auth::decode_refresh_token(refresh_jwt.clone(), jwt_key)?; 223 + auth::store_refresh_token(refresh_payload, app_password_name, db.as_ref()).await?; 224 + Ok((access_jwt, refresh_jwt)) 225 + } 226 + 227 + pub async fn rotate_refresh_token(&self, id: &String) -> Result<Option<(String, String)>> { 228 + let token = auth::get_refresh_token(id, self.db.as_ref()).await?; 229 + if let Some(token) = token { 230 + let system_time = SystemTime::now(); 231 + let dt: DateTime<UtcOffset> = system_time.into(); 232 + let now = format!("{}", dt.format(RFC3339_VARIANT)); 233 + 234 + // take the chance to tidy all of a user's expired tokens 235 + // does not need to be transactional since this is just best-effort 236 + auth::delete_expired_refresh_tokens(&token.did, now, self.db.as_ref()).await?; 237 + 238 + // Shorten the refresh token lifespan down from its 239 + // original expiration time to its revocation grace period. 240 + let prev_expires_at = from_str_to_micros(&token.expires_at); 241 + 242 + const REFRESH_GRACE_MS: i32 = 2 * HOUR; 243 + let grace_expires_at = dt.timestamp_micros() + REFRESH_GRACE_MS as i64; 244 + 245 + let expires_at = if grace_expires_at < prev_expires_at { 246 + grace_expires_at 247 + } else { 248 + prev_expires_at 249 + }; 250 + 251 + if expires_at <= dt.timestamp_micros() { 252 + return Ok(None); 253 + } 254 + 255 + // Determine the next refresh token id: upon refresh token 256 + // reuse you always receive a refresh token with the same id. 257 + let next_id = token.next_id.unwrap_or_else(auth::get_refresh_token_id); 258 + 259 + let secp = Secp256k1::new(); 260 + let private_key = env::var("PDS_JWT_KEY_K256_PRIVATE_KEY_HEX").unwrap(); 261 + let secret_key = 262 + SecretKey::from_slice(&hex::decode(private_key.as_bytes()).unwrap()).unwrap(); 263 + let jwt_key = Keypair::from_secret_key(&secp, &secret_key); 264 + 265 + let (access_jwt, refresh_jwt) = auth::create_tokens(CreateTokensOpts { 266 + did: token.did, 267 + jwt_key, 268 + service_did: env::var("PDS_SERVICE_DID").unwrap(), 269 + scope: Some(if token.app_password_name.is_none() { 270 + AuthScope::Access 271 + } else { 272 + AuthScope::AppPass 273 + }), 274 + jti: Some(next_id.clone()), 275 + expires_in: None, 276 + })?; 277 + let refresh_payload = auth::decode_refresh_token(refresh_jwt.clone(), jwt_key)?; 278 + match try_join!( 279 + auth::add_refresh_grace_period( 280 + RefreshGracePeriodOpts { 281 + id: id.clone(), 282 + expires_at: from_micros_to_str(expires_at), 283 + next_id 284 + }, 285 + self.db.as_ref() 286 + ), 287 + auth::store_refresh_token( 288 + refresh_payload, 289 + token.app_password_name, 290 + self.db.as_ref() 291 + ) 292 + ) { 293 + Ok(_) => Ok(Some((access_jwt, refresh_jwt))), 294 + Err(e) => match e.downcast_ref() { 295 + Some(AuthHelperError::ConcurrentRefresh) => { 296 + Box::pin(self.rotate_refresh_token(id)).await 297 + } 298 + _ => Err(e), 299 + }, 300 + } 301 + } else { 302 + Ok(None) 303 + } 304 + } 305 + 306 + pub async fn revoke_refresh_token(&self, id: String) -> Result<bool> { 307 + auth::revoke_refresh_token(id, self.db.as_ref()).await 308 + } 309 + 310 + // Invites 311 + // ---------- 312 + 313 + pub async fn create_invite_codes( 314 + &self, 315 + to_create: Vec<AccountCodes>, 316 + use_count: i32, 317 + ) -> Result<()> { 318 + let db = self.db.clone(); 319 + invite::create_invite_codes(to_create, use_count, db.as_ref()).await 320 + } 321 + 322 + pub async fn create_account_invite_codes( 323 + &self, 324 + for_account: &str, 325 + codes: Vec<String>, 326 + expected_total: usize, 327 + disabled: bool, 328 + ) -> Result<Vec<CodeDetail>> { 329 + invite::create_account_invite_codes( 330 + for_account, 331 + codes, 332 + expected_total, 333 + disabled, 334 + self.db.as_ref(), 335 + ) 336 + .await 337 + } 338 + 339 + pub async fn get_account_invite_codes(&self, did: &str) -> Result<Vec<CodeDetail>> { 340 + let db = self.db.clone(); 341 + invite::get_account_invite_codes(did, db.as_ref()).await 342 + } 343 + 344 + pub async fn get_invited_by_for_accounts( 345 + &self, 346 + dids: Vec<String>, 347 + ) -> Result<BTreeMap<String, CodeDetail>> { 348 + let db = self.db.clone(); 349 + invite::get_invited_by_for_accounts(dids, db.as_ref()).await 350 + } 351 + 352 + pub async fn set_account_invites_disabled(&self, did: &str, disabled: bool) -> Result<()> { 353 + invite::set_account_invites_disabled(did, disabled, self.db.as_ref()).await 354 + } 355 + 356 + pub async fn disable_invite_codes(&self, opts: DisableInviteCodesOpts) -> Result<()> { 357 + invite::disable_invite_codes(opts, self.db.as_ref()).await 358 + } 359 + 360 + // Passwords 361 + // ---------- 362 + 363 + pub async fn create_app_password( 364 + &self, 365 + did: String, 366 + name: String, 367 + ) -> Result<CreateAppPasswordOutput> { 368 + password::create_app_password(did, name, self.db.as_ref()).await 369 + } 370 + 371 + pub async fn list_app_passwords(&self, did: &str) -> Result<Vec<(String, String)>> { 372 + password::list_app_passwords(did, self.db.as_ref()).await 373 + } 374 + 375 + pub async fn verify_account_password(&self, did: &str, password_str: &String) -> Result<bool> { 376 + let db = self.db.clone(); 377 + password::verify_account_password(did, password_str, db.as_ref()).await 378 + } 379 + 380 + pub async fn verify_app_password( 381 + &self, 382 + did: &str, 383 + password_str: &str, 384 + ) -> Result<Option<String>> { 385 + let db = self.db.clone(); 386 + password::verify_app_password(did, password_str, db.as_ref()).await 387 + } 388 + 389 + pub async fn reset_password(&self, opts: ResetPasswordOpts) -> Result<()> { 390 + let db = self.db.clone(); 391 + let did = email_token::assert_valid_token_and_find_did( 392 + EmailTokenPurpose::ResetPassword, 393 + &opts.token, 394 + None, 395 + db.as_ref(), 396 + ) 397 + .await?; 398 + self.update_account_password(UpdateAccountPasswordOpts { 399 + did, 400 + password: opts.password, 401 + }) 402 + .await 403 + } 404 + 405 + pub async fn update_account_password(&self, opts: UpdateAccountPasswordOpts) -> Result<()> { 406 + let db = self.db.clone(); 407 + let UpdateAccountPasswordOpts { did, .. } = opts; 408 + let password_encrypted = password::gen_salt_and_hash(opts.password)?; 409 + try_join!( 410 + password::update_user_password( 411 + UpdateUserPasswordOpts { 412 + did: did.clone(), 413 + password_encrypted 414 + }, 415 + self.db.as_ref() 416 + ), 417 + email_token::delete_email_token(&did, EmailTokenPurpose::ResetPassword, db.as_ref()), 418 + auth::revoke_refresh_tokens_by_did(&did, self.db.as_ref()) 419 + )?; 420 + Ok(()) 421 + } 422 + 423 + pub async fn revoke_app_password(&self, did: String, name: String) -> Result<()> { 424 + try_join!( 425 + password::delete_app_password(&did, &name, self.db.as_ref()), 426 + auth::revoke_app_password_refresh_token(&did, &name, self.db.as_ref()) 427 + )?; 428 + Ok(()) 429 + } 430 + 431 + // Email Tokens 432 + // ---------- 433 + pub async fn confirm_email<'em>(&self, opts: ConfirmEmailOpts<'em>) -> Result<()> { 434 + let db = self.db.clone(); 435 + let ConfirmEmailOpts { did, token } = opts; 436 + email_token::assert_valid_token( 437 + did, 438 + EmailTokenPurpose::ConfirmEmail, 439 + token, 440 + None, 441 + db.as_ref(), 442 + ) 443 + .await?; 444 + let now = rsky_common::now(); 445 + try_join!( 446 + email_token::delete_email_token(did, EmailTokenPurpose::ConfirmEmail, db.as_ref()), 447 + account::set_email_confirmed_at(did, now, self.db.as_ref()) 448 + )?; 449 + Ok(()) 450 + } 451 + 452 + pub async fn update_email(&self, opts: UpdateEmailOpts) -> Result<()> { 453 + let db = self.db.clone(); 454 + let UpdateEmailOpts { did, email } = opts; 455 + try_join!( 456 + account::update_email(&did, &email, db.as_ref()), 457 + email_token::delete_all_email_tokens(&did, db.as_ref()) 458 + )?; 459 + Ok(()) 460 + } 461 + 462 + pub async fn assert_valid_email_token( 463 + &self, 464 + did: &str, 465 + purpose: EmailTokenPurpose, 466 + token: &str, 467 + ) -> Result<()> { 468 + let db = self.db.clone(); 469 + email_token::assert_valid_token(did, purpose, token, None, db.as_ref()).await 470 + } 471 + 472 + pub async fn assert_valid_email_token_and_cleanup( 473 + &self, 474 + did: &str, 475 + purpose: EmailTokenPurpose, 476 + token: &str, 477 + ) -> Result<()> { 478 + let db = self.db.clone(); 479 + email_token::assert_valid_token(did, purpose, token, None, db.as_ref()).await?; 480 + email_token::delete_email_token(did, purpose, db.as_ref()).await 481 + } 482 + 483 + pub async fn create_email_token( 484 + &self, 485 + did: &str, 486 + purpose: EmailTokenPurpose, 487 + ) -> Result<String> { 488 + let db = self.db.clone(); 489 + email_token::create_email_token(did, purpose, db.as_ref()).await 490 + } 491 + }
+76
src/actor_endpoints.rs
···
··· 1 + use atrium_api::app::bsky::actor; 2 + use axum::{Json, routing::post}; 3 + use constcat::concat; 4 + use diesel::prelude::*; 5 + 6 + use super::*; 7 + 8 + async fn put_preferences( 9 + user: AuthenticatedUser, 10 + State(db): State<Db>, 11 + Json(input): Json<actor::put_preferences::Input>, 12 + ) -> Result<()> { 13 + let did = user.did(); 14 + let json_string = 15 + serde_json::to_string(&input.preferences).context("failed to serialize preferences")?; 16 + 17 + // Use the db connection pool to execute the update 18 + let conn = &mut db.get().context("failed to get database connection")?; 19 + diesel::sql_query("UPDATE accounts SET private_prefs = ? WHERE did = ?") 20 + .bind::<diesel::sql_types::Text, _>(json_string) 21 + .bind::<diesel::sql_types::Text, _>(did) 22 + .execute(conn) 23 + .context("failed to update user preferences")?; 24 + 25 + Ok(()) 26 + } 27 + 28 + async fn get_preferences( 29 + user: AuthenticatedUser, 30 + State(db): State<Db>, 31 + ) -> Result<Json<actor::get_preferences::Output>> { 32 + let did = user.did(); 33 + let conn = &mut db.get().context("failed to get database connection")?; 34 + 35 + #[derive(QueryableByName)] 36 + struct Prefs { 37 + #[diesel(sql_type = diesel::sql_types::Text)] 38 + private_prefs: Option<String>, 39 + } 40 + 41 + let result = diesel::sql_query("SELECT private_prefs FROM accounts WHERE did = ?") 42 + .bind::<diesel::sql_types::Text, _>(did) 43 + .get_result::<Prefs>(conn) 44 + .context("failed to fetch preferences")?; 45 + 46 + if let Some(prefs_json) = result.private_prefs { 47 + let prefs: actor::defs::Preferences = 48 + serde_json::from_str(&prefs_json).context("failed to deserialize preferences")?; 49 + 50 + Ok(Json( 51 + actor::get_preferences::OutputData { preferences: prefs }.into(), 52 + )) 53 + } else { 54 + Ok(Json( 55 + actor::get_preferences::OutputData { 56 + preferences: Vec::new(), 57 + } 58 + .into(), 59 + )) 60 + } 61 + } 62 + 63 + /// Register all actor endpoints. 64 + pub(crate) fn routes() -> Router<AppState> { 65 + // AP /xrpc/app.bsky.actor.putPreferences 66 + // AG /xrpc/app.bsky.actor.getPreferences 67 + Router::new() 68 + .route( 69 + concat!("/", actor::put_preferences::NSID), 70 + post(put_preferences), 71 + ) 72 + .route( 73 + concat!("/", actor::get_preferences::NSID), 74 + get(get_preferences), 75 + ) 76 + }
+1 -1
src/actor_store/mod.rs
··· 9 mod blob; 10 mod preference; 11 mod record; 12 - mod sql_blob; 13 mod sql_repo; 14 15 use anyhow::Result;
··· 9 mod blob; 10 mod preference; 11 mod record; 12 + pub(crate) mod sql_blob; 13 mod sql_repo; 14 15 use anyhow::Result;
src/db/mod.rs src/db.rs
+132 -434
src/endpoints/repo/apply_writes.rs
··· 1 //! Apply a batch transaction of repository creates, updates, and deletes. Requires auth, implemented by PDS. 2 - use std::{collections::HashSet, str::FromStr}; 3 - 4 use anyhow::{Context as _, anyhow}; 5 use atrium_api::com::atproto::repo::apply_writes::{self, InputWritesItem, OutputResultsItem}; 6 use atrium_api::{ ··· 10 string::{AtIdentifier, Nsid, Tid}, 11 }, 12 }; 13 - use atrium_repo::{Cid, blockstore::CarStore}; 14 use axum::{ 15 Json, Router, 16 body::Body, ··· 18 http::{self, StatusCode}, 19 routing::{get, post}, 20 }; 21 use constcat::concat; 22 use futures::TryStreamExt as _; 23 use metrics::counter; 24 use rsky_syntax::aturi::AtUri; 25 use serde::Deserialize; 26 use tokio::io::AsyncWriteExt as _; 27 28 - use crate::repo::block_map::cid_for_cbor; 29 - use crate::repo::types::PreparedCreateOrUpdate; 30 - use crate::{ 31 - AppState, Db, Error, Result, SigningKey, 32 - actor_store::{ActorStoreTransactor, ActorStoreWriter}, 33 - auth::AuthenticatedUser, 34 - config::AppConfig, 35 - error::ErrorMessage, 36 - firehose::{self, FirehoseProducer, RepoOp}, 37 - metrics::{REPO_COMMITS, REPO_OP_CREATE, REPO_OP_DELETE, REPO_OP_UPDATE}, 38 - repo::types::{PreparedWrite, WriteOpAction}, 39 - storage, 40 - }; 41 - 42 use super::resolve_did; 43 44 /// Apply a batch transaction of repository creates, updates, and deletes. Requires auth, implemented by PDS. ··· 57 State(config): State<AppConfig>, 58 State(db): State<Db>, 59 State(fhp): State<FirehoseProducer>, 60 - Json(input): Json<repo::apply_writes::Input>, 61 ) -> Result<Json<repo::apply_writes::Output>> { 62 - todo!(); 63 - // // TODO: `input.validate` 64 - 65 - // // Resolve DID from identifier 66 - // let (target_did, _) = resolve_did(&db, &input.repo) 67 - // .await 68 - // .context("failed to resolve did")?; 69 - 70 - // // Ensure that we are updating the correct repository 71 - // if target_did.as_str() != user.did() { 72 - // return Err(Error::with_status( 73 - // StatusCode::BAD_REQUEST, 74 - // anyhow!("repo did not match the authenticated user"), 75 - // )); 76 - // } 77 - 78 - // // Validate writes count 79 - // if input.writes.len() > 200 { 80 - // return Err(Error::with_status( 81 - // StatusCode::BAD_REQUEST, 82 - // anyhow!("Too many writes. Max: 200"), 83 - // )); 84 - // } 85 - 86 - // // Convert input writes to prepared format 87 - // let mut prepared_writes = Vec::with_capacity(input.writes.len()); 88 - // for write in input.writes.iter() { 89 - // match write { 90 - // InputWritesItem::Create(create) => { 91 - // let uri = AtUri::make( 92 - // user.did(), 93 - // &create.collection.as_str(), 94 - // create 95 - // .rkey 96 - // .as_deref() 97 - // .unwrap_or(&Tid::now(LimitedU32::MIN).to_string()), 98 - // ); 99 - 100 - // let cid = match cid_for_cbor(&create.value) { 101 - // Ok(cid) => cid, 102 - // Err(e) => { 103 - // return Err(Error::with_status( 104 - // StatusCode::BAD_REQUEST, 105 - // anyhow!("Failed to encode record: {}", e), 106 - // )); 107 - // } 108 - // }; 109 - 110 - // let blobs = scan_blobs(&create.value) 111 - // .unwrap_or_default() 112 - // .into_iter() 113 - // .map(|cid| { 114 - // // TODO: Create BlobRef from cid with proper metadata 115 - // BlobRef { 116 - // cid, 117 - // mime_type: "application/octet-stream".to_string(), // Default 118 - // size: 0, // Unknown at this point 119 - // } 120 - // }) 121 - // .collect(); 122 - 123 - // prepared_writes.push(PreparedCreateOrUpdate { 124 - // action: WriteOpAction::Create, 125 - // uri: uri?.to_string(), 126 - // cid, 127 - // record: create.value.clone(), 128 - // blobs, 129 - // swap_cid: None, 130 - // }); 131 - // } 132 - // InputWritesItem::Update(update) => { 133 - // let uri = AtUri::make( 134 - // user.did(), 135 - // Some(update.collection.to_string()), 136 - // Some(update.rkey.to_string()), 137 - // ); 138 - 139 - // let cid = match cid_for_cbor(&update.value) { 140 - // Ok(cid) => cid, 141 - // Err(e) => { 142 - // return Err(Error::with_status( 143 - // StatusCode::BAD_REQUEST, 144 - // anyhow!("Failed to encode record: {}", e), 145 - // )); 146 - // } 147 - // }; 148 - 149 - // let blobs = scan_blobs(&update.value) 150 - // .unwrap_or_default() 151 - // .into_iter() 152 - // .map(|cid| { 153 - // // TODO: Create BlobRef from cid with proper metadata 154 - // BlobRef { 155 - // cid, 156 - // mime_type: "application/octet-stream".to_string(), 157 - // size: 0, 158 - // } 159 - // }) 160 - // .collect(); 161 - 162 - // prepared_writes.push(PreparedCreateOrUpdate { 163 - // action: WriteOpAction::Update, 164 - // uri: uri?.to_string(), 165 - // cid, 166 - // record: update.value.clone(), 167 - // blobs, 168 - // swap_cid: None, 169 - // }); 170 - // } 171 - // InputWritesItem::Delete(delete) => { 172 - // let uri = AtUri::make(user.did(), &delete.collection.as_str(), &delete.rkey); 173 - 174 - // prepared_writes.push(PreparedCreateOrUpdate { 175 - // action: WriteOpAction::Delete, 176 - // uri: uri?.to_string(), 177 - // cid: Cid::default(), // Not needed for delete 178 - // record: serde_json::Value::Null, 179 - // blobs: vec![], 180 - // swap_cid: None, 181 - // }); 182 - // } 183 - // } 184 - // } 185 - 186 - // // Get swap commit CID if provided 187 - // let swap_commit_cid = input.swap_commit.as_ref().map(|cid| *cid.as_ref()); 188 - 189 - // let did_str = user.did(); 190 - // let mut repo = storage::open_repo_db(&config.repo, &db, did_str) 191 - // .await 192 - // .context("failed to open user repo")?; 193 - // let orig_cid = repo.root(); 194 - // let orig_rev = repo.commit().rev(); 195 - 196 - // let mut blobs = vec![]; 197 - // let mut res = vec![]; 198 - // let mut ops = vec![]; 199 - 200 - // for write in &prepared_writes { 201 - // let (builder, key) = match write.action { 202 - // WriteOpAction::Create => { 203 - // let key = format!("{}/{}", write.uri.collection, write.uri.rkey); 204 - // let uri = format!("at://{}/{}", user.did(), key); 205 - 206 - // let (builder, cid) = repo 207 - // .add_raw(&key, &write.record) 208 - // .await 209 - // .context("failed to add record")?; 210 - 211 - // // Extract and track blobs 212 - // if let Ok(new_blobs) = scan_blobs(&write.record) { 213 - // blobs.extend( 214 - // new_blobs 215 - // .into_iter() 216 - // .map(|blob_cid| (key.clone(), blob_cid)), 217 - // ); 218 - // } 219 - 220 - // ops.push(RepoOp::Create { 221 - // cid, 222 - // path: key.clone(), 223 - // }); 224 - 225 - // res.push(OutputResultsItem::CreateResult(Box::new( 226 - // apply_writes::CreateResultData { 227 - // cid: atrium_api::types::string::Cid::new(cid), 228 - // uri, 229 - // validation_status: None, 230 - // } 231 - // .into(), 232 - // ))); 233 - 234 - // (builder, key) 235 - // } 236 - // WriteOpAction::Update => { 237 - // let key = format!("{}/{}", write.uri.collection, write.uri.rkey); 238 - // let uri = format!("at://{}/{}", user.did(), key); 239 - 240 - // let prev = repo 241 - // .tree() 242 - // .get(&key) 243 - // .await 244 - // .context("failed to search MST")?; 245 - 246 - // if prev.is_none() { 247 - // // No existing record, treat as create 248 - // let (create_builder, cid) = repo 249 - // .add_raw(&key, &write.record) 250 - // .await 251 - // .context("failed to add record")?; 252 - 253 - // if let Ok(new_blobs) = scan_blobs(&write.record) { 254 - // blobs.extend( 255 - // new_blobs 256 - // .into_iter() 257 - // .map(|blob_cid| (key.clone(), blob_cid)), 258 - // ); 259 - // } 260 - 261 - // ops.push(RepoOp::Create { 262 - // cid, 263 - // path: key.clone(), 264 - // }); 265 - 266 - // res.push(OutputResultsItem::CreateResult(Box::new( 267 - // apply_writes::CreateResultData { 268 - // cid: atrium_api::types::string::Cid::new(cid), 269 - // uri, 270 - // validation_status: None, 271 - // } 272 - // .into(), 273 - // ))); 274 - 275 - // (create_builder, key) 276 - // } else { 277 - // // Update existing record 278 - // let prev = prev.context("should be able to find previous record")?; 279 - // let (update_builder, cid) = repo 280 - // .update_raw(&key, &write.record) 281 - // .await 282 - // .context("failed to add record")?; 283 - 284 - // if let Ok(new_blobs) = scan_blobs(&write.record) { 285 - // blobs.extend( 286 - // new_blobs 287 - // .into_iter() 288 - // .map(|blob_cid| (key.clone(), blob_cid)), 289 - // ); 290 - // } 291 - 292 - // ops.push(RepoOp::Update { 293 - // cid, 294 - // path: key.clone(), 295 - // prev, 296 - // }); 297 - 298 - // res.push(OutputResultsItem::UpdateResult(Box::new( 299 - // apply_writes::UpdateResultData { 300 - // cid: atrium_api::types::string::Cid::new(cid), 301 - // uri, 302 - // validation_status: None, 303 - // } 304 - // .into(), 305 - // ))); 306 307 - // (update_builder, key) 308 - // } 309 - // } 310 - // WriteOpAction::Delete => { 311 - // let key = format!("{}/{}", write.uri.collection, write.uri.rkey); 312 313 - // let prev = repo 314 - // .tree() 315 - // .get(&key) 316 - // .await 317 - // .context("failed to search MST")? 318 - // .context("previous record does not exist")?; 319 - 320 - // ops.push(RepoOp::Delete { 321 - // path: key.clone(), 322 - // prev, 323 - // }); 324 - 325 - // res.push(OutputResultsItem::DeleteResult(Box::new( 326 - // apply_writes::DeleteResultData {}.into(), 327 - // ))); 328 329 - // let builder = repo 330 - // .delete_raw(&key) 331 - // .await 332 - // .context("failed to add record")?; 333 334 - // (builder, key) 335 - // } 336 - // }; 337 338 - // let sig = skey 339 - // .sign(&builder.bytes()) 340 - // .context("failed to sign commit")?; 341 342 - // _ = builder 343 - // .finalize(sig) 344 - // .await 345 - // .context("failed to write signed commit")?; 346 - // } 347 - 348 - // // Construct a firehose record 349 - // let mut mem = Vec::new(); 350 - // let mut store = CarStore::create_with_roots(std::io::Cursor::new(&mut mem), [repo.root()]) 351 - // .await 352 - // .context("failed to create temp store")?; 353 - 354 - // // Extract the records out of the user's repository 355 - // for write in &prepared_writes { 356 - // let key = format!("{}/{}", write.uri.collection, write.uri.rkey); 357 - // repo.extract_raw_into(&key, &mut store) 358 - // .await 359 - // .context("failed to extract key")?; 360 - // } 361 - 362 - // let mut tx = db.begin().await.context("failed to begin transaction")?; 363 - 364 - // if !swap_commit( 365 - // &mut *tx, 366 - // repo.root(), 367 - // repo.commit().rev(), 368 - // input.swap_commit.as_ref().map(|cid| *cid.as_ref()), 369 - // &user.did(), 370 - // ) 371 - // .await 372 - // .context("failed to swap commit")? 373 - // { 374 - // // This should always succeed. 375 - // let old = input 376 - // .swap_commit 377 - // .clone() 378 - // .context("swap_commit should always be Some")?; 379 - 380 - // // The swap failed. Return the old commit and do not update the repository. 381 - // return Ok(Json( 382 - // apply_writes::OutputData { 383 - // results: None, 384 - // commit: Some( 385 - // CommitMetaData { 386 - // cid: old, 387 - // rev: orig_rev, 388 - // } 389 - // .into(), 390 - // ), 391 - // } 392 - // .into(), 393 - // )); 394 - // } 395 - 396 - // // For updates and removals, unlink the old/deleted record from the blob_ref table 397 - // for op in &ops { 398 - // match op { 399 - // &RepoOp::Update { ref path, .. } | &RepoOp::Delete { ref path, .. } => { 400 - // // FIXME: This may cause issues if a user deletes more than one record referencing the same blob. 401 - // _ = &sqlx::query!( 402 - // r#"UPDATE blob_ref SET record = NULL WHERE did = ? AND record = ?"#, 403 - // did_str, 404 - // path 405 - // ) 406 - // .execute(&mut *tx) 407 - // .await 408 - // .context("failed to remove blob_ref")?; 409 - // } 410 - // &RepoOp::Create { .. } => {} 411 - // } 412 - // } 413 - 414 - // // Process blobs 415 - // for (key, cid) in &blobs { 416 - // let cid_str = cid.to_string(); 417 - 418 - // // Handle the case where a new record references an existing blob 419 - // if sqlx::query!( 420 - // r#"UPDATE blob_ref SET record = ? WHERE cid = ? AND did = ? AND record IS NULL"#, 421 - // key, 422 - // cid_str, 423 - // did_str, 424 - // ) 425 - // .execute(&mut *tx) 426 - // .await 427 - // .context("failed to update blob_ref")? 428 - // .rows_affected() 429 - // == 0 430 - // { 431 - // _ = sqlx::query!( 432 - // r#"INSERT INTO blob_ref (record, cid, did) VALUES (?, ?, ?)"#, 433 - // key, 434 - // cid_str, 435 - // did_str, 436 - // ) 437 - // .execute(&mut *tx) 438 - // .await 439 - // .context("failed to update blob_ref")?; 440 - // } 441 - // } 442 - 443 - // tx.commit() 444 - // .await 445 - // .context("failed to commit blob ref to database")?; 446 - 447 - // // Update counters 448 - // counter!(REPO_COMMITS).increment(1); 449 - // for op in &ops { 450 - // match *op { 451 - // RepoOp::Create { .. } => counter!(REPO_OP_CREATE).increment(1), 452 - // RepoOp::Update { .. } => counter!(REPO_OP_UPDATE).increment(1), 453 - // RepoOp::Delete { .. } => counter!(REPO_OP_DELETE).increment(1), 454 - // } 455 - // } 456 - 457 - // // We've committed the transaction to the database, and the commit is now stored in the user's 458 - // // canonical repository. 459 - // // We can now broadcast this on the firehose. 460 - // fhp.commit(firehose::Commit { 461 - // car: mem, 462 - // ops, 463 - // cid: repo.root(), 464 - // rev: repo.commit().rev().to_string(), 465 - // did: atrium_api::types::string::Did::new(user.did()).expect("should be valid DID"), 466 - // pcid: Some(orig_cid), 467 - // blobs: blobs.into_iter().map(|(_, cid)| cid).collect::<Vec<_>>(), 468 - // }) 469 - // .await; 470 - 471 - // Ok(Json( 472 - // apply_writes::OutputData { 473 - // results: Some(res), 474 - // commit: Some( 475 - // CommitMetaData { 476 - // cid: atrium_api::types::string::Cid::new(repo.root()), 477 - // rev: repo.commit().rev(), 478 - // } 479 - // .into(), 480 - // ), 481 - // } 482 - // .into(), 483 - // )) 484 }
··· 1 //! Apply a batch transaction of repository creates, updates, and deletes. Requires auth, implemented by PDS. 2 + use crate::{ 3 + AppState, Db, Error, Result, SigningKey, 4 + actor_store::ActorStore, 5 + actor_store::sql_blob::BlobStoreSql, 6 + auth::AuthenticatedUser, 7 + config::AppConfig, 8 + error::ErrorMessage, 9 + firehose::{self, FirehoseProducer, RepoOp}, 10 + metrics::{REPO_COMMITS, REPO_OP_CREATE, REPO_OP_DELETE, REPO_OP_UPDATE}, 11 + storage, 12 + }; 13 + use anyhow::bail; 14 use anyhow::{Context as _, anyhow}; 15 use atrium_api::com::atproto::repo::apply_writes::{self, InputWritesItem, OutputResultsItem}; 16 use atrium_api::{ ··· 20 string::{AtIdentifier, Nsid, Tid}, 21 }, 22 }; 23 + use atrium_repo::blockstore::CarStore; 24 use axum::{ 25 Json, Router, 26 body::Body, ··· 28 http::{self, StatusCode}, 29 routing::{get, post}, 30 }; 31 + use cidv10::Cid; 32 use constcat::concat; 33 use futures::TryStreamExt as _; 34 + use futures::stream::{self, StreamExt}; 35 use metrics::counter; 36 + use rsky_lexicon::com::atproto::repo::{ApplyWritesInput, ApplyWritesInputRefWrite}; 37 + use rsky_pds::SharedSequencer; 38 + use rsky_pds::account_manager::AccountManager; 39 + use rsky_pds::account_manager::helpers::account::AvailabilityFlags; 40 + use rsky_pds::apis::ApiError; 41 + use rsky_pds::auth_verifier::AccessStandardIncludeChecks; 42 + use rsky_pds::repo::prepare::{ 43 + PrepareCreateOpts, PrepareDeleteOpts, PrepareUpdateOpts, prepare_create, prepare_delete, 44 + prepare_update, 45 + }; 46 + use rsky_repo::types::PreparedWrite; 47 use rsky_syntax::aturi::AtUri; 48 use serde::Deserialize; 49 + use std::{collections::HashSet, str::FromStr}; 50 use tokio::io::AsyncWriteExt as _; 51 52 use super::resolve_did; 53 54 /// Apply a batch transaction of repository creates, updates, and deletes. Requires auth, implemented by PDS. ··· 67 State(config): State<AppConfig>, 68 State(db): State<Db>, 69 State(fhp): State<FirehoseProducer>, 70 + Json(input): Json<ApplyWritesInput>, 71 ) -> Result<Json<repo::apply_writes::Output>> { 72 + let tx: ApplyWritesInput = input; 73 + let ApplyWritesInput { 74 + repo, 75 + validate, 76 + swap_commit, 77 + .. 78 + } = tx; 79 + let account = account_manager 80 + .get_account( 81 + &repo, 82 + Some(AvailabilityFlags { 83 + include_deactivated: Some(true), 84 + include_taken_down: None, 85 + }), 86 + ) 87 + .await?; 88 89 + if let Some(account) = account { 90 + if account.deactivated_at.is_some() { 91 + return Err(Error::with_message( 92 + StatusCode::FORBIDDEN, 93 + anyhow!("Account is deactivated"), 94 + ErrorMessage::new("AccountDeactivated", "Account is deactivated"), 95 + )); 96 + } 97 + let did = account.did; 98 + if did != user.did() { 99 + return Err(Error::with_message( 100 + StatusCode::FORBIDDEN, 101 + anyhow!("AuthRequiredError"), 102 + ErrorMessage::new("AuthRequiredError", "Auth required"), 103 + )); 104 + } 105 + let did: &String = &did; 106 + if tx.writes.len() > 200 { 107 + return Err(Error::with_message( 108 + StatusCode::BAD_REQUEST, 109 + anyhow!("Too many writes. Max: 200"), 110 + ErrorMessage::new("TooManyWrites", "Too many writes. Max: 200"), 111 + )); 112 + } 113 114 + let writes: Vec<PreparedWrite> = stream::iter(tx.writes) 115 + .then(|write| async move { 116 + Ok::<PreparedWrite, anyhow::Error>(match write { 117 + ApplyWritesInputRefWrite::Create(write) => PreparedWrite::Create( 118 + prepare_create(PrepareCreateOpts { 119 + did: did.clone(), 120 + collection: write.collection, 121 + rkey: write.rkey, 122 + swap_cid: None, 123 + record: serde_json::from_value(write.value)?, 124 + validate, 125 + }) 126 + .await?, 127 + ), 128 + ApplyWritesInputRefWrite::Update(write) => PreparedWrite::Update( 129 + prepare_update(PrepareUpdateOpts { 130 + did: did.clone(), 131 + collection: write.collection, 132 + rkey: write.rkey, 133 + swap_cid: None, 134 + record: serde_json::from_value(write.value)?, 135 + validate, 136 + }) 137 + .await?, 138 + ), 139 + ApplyWritesInputRefWrite::Delete(write) => { 140 + PreparedWrite::Delete(prepare_delete(PrepareDeleteOpts { 141 + did: did.clone(), 142 + collection: write.collection, 143 + rkey: write.rkey, 144 + swap_cid: None, 145 + })?) 146 + } 147 + }) 148 + }) 149 + .collect::<Vec<_>>() 150 + .await 151 + .into_iter() 152 + .collect::<Result<Vec<PreparedWrite>, _>>()?; 153 154 + let swap_commit_cid = match swap_commit { 155 + Some(swap_commit) => Some(Cid::from_str(&swap_commit)?), 156 + None => None, 157 + }; 158 159 + let mut actor_store = ActorStore::new(did.clone(), BlobStoreSql::new(did.clone(), db), db); 160 161 + let commit = actor_store 162 + .process_writes(writes.clone(), swap_commit_cid) 163 + .await?; 164 165 + let mut lock = sequencer.sequencer.write().await; 166 + lock.sequence_commit(did.clone(), commit.clone()).await?; 167 + account_manager 168 + .update_repo_root( 169 + did.to_string(), 170 + commit.commit_data.cid, 171 + commit.commit_data.rev, 172 + ) 173 + .await?; 174 + Ok(()) 175 + } else { 176 + Err(Error::with_message( 177 + StatusCode::NOT_FOUND, 178 + anyhow!("Could not find repo: `{repo}`"), 179 + ErrorMessage::new("RepoNotFound", "Could not find repo"), 180 + )) 181 + } 182 }
+38 -97
src/main.rs
··· 1 //! PDS implementation. 2 mod actor_store; 3 mod auth; 4 mod config; ··· 11 mod mmap; 12 mod oauth; 13 mod plc; 14 - mod storage; 15 #[cfg(test)] 16 mod tests; 17 ··· 19 /// 20 /// We shouldn't have to know about any bsky endpoints to store private user data. 21 /// This will _very likely_ be changed in the future. 22 - mod actor_endpoints { 23 - use atrium_api::app::bsky::actor; 24 - use axum::{Json, routing::post}; 25 - use constcat::concat; 26 - 27 - use super::*; 28 - 29 - async fn put_preferences( 30 - user: AuthenticatedUser, 31 - State(db): State<Db>, 32 - Json(input): Json<actor::put_preferences::Input>, 33 - ) -> Result<()> { 34 - let did = user.did(); 35 - let prefs = sqlx::types::Json(input.preferences.clone()); 36 - _ = sqlx::query!( 37 - r#"UPDATE accounts SET private_prefs = ? WHERE did = ?"#, 38 - prefs, 39 - did 40 - ) 41 - .execute(&db) 42 - .await 43 - .context("failed to update user preferences")?; 44 - 45 - Ok(()) 46 - } 47 - 48 - async fn get_preferences( 49 - user: AuthenticatedUser, 50 - State(db): State<Db>, 51 - ) -> Result<Json<actor::get_preferences::Output>> { 52 - let did = user.did(); 53 - let json: Option<sqlx::types::Json<actor::defs::Preferences>> = 54 - sqlx::query_scalar("SELECT private_prefs FROM accounts WHERE did = ?") 55 - .bind(did) 56 - .fetch_one(&db) 57 - .await 58 - .context("failed to fetch preferences")?; 59 - 60 - if let Some(prefs) = json { 61 - Ok(Json( 62 - actor::get_preferences::OutputData { 63 - preferences: prefs.0, 64 - } 65 - .into(), 66 - )) 67 - } else { 68 - Ok(Json( 69 - actor::get_preferences::OutputData { 70 - preferences: Vec::new(), 71 - } 72 - .into(), 73 - )) 74 - } 75 - } 76 - 77 - /// Register all actor endpoints. 78 - pub(crate) fn routes() -> Router<AppState> { 79 - // AP /xrpc/app.bsky.actor.putPreferences 80 - // AG /xrpc/app.bsky.actor.getPreferences 81 - Router::new() 82 - .route( 83 - concat!("/", actor::put_preferences::NSID), 84 - post(put_preferences), 85 - ) 86 - .route( 87 - concat!("/", actor::get_preferences::NSID), 88 - get(get_preferences), 89 - ) 90 - } 91 - } 92 93 use anyhow::{Context as _, anyhow}; 94 use atrium_api::types::string::Did; ··· 106 use clap::Parser; 107 use clap_verbosity_flag::{InfoLevel, Verbosity, log::LevelFilter}; 108 use config::AppConfig; 109 #[expect(clippy::pub_use, clippy::useless_attribute)] 110 pub use error::Error; 111 use figment::{Figment, providers::Format as _}; ··· 113 use http_cache_reqwest::{CacheMode, HttpCacheOptions, MokaManager}; 114 use rand::Rng as _; 115 use serde::{Deserialize, Serialize}; 116 - use sqlx::{SqlitePool, sqlite::SqliteConnectOptions}; 117 use std::{ 118 net::{IpAddr, Ipv4Addr, SocketAddr}, 119 path::PathBuf, ··· 128 /// The application user agent. Concatenates the package name and version. e.g. `bluepds/0.0.0`. 129 pub const APP_USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),); 130 131 /// The application-wide result type. 132 pub type Result<T> = std::result::Result<T, Error>; 133 /// The reqwest client type with middleware. 134 pub type Client = reqwest_middleware::ClientWithMiddleware; 135 /// The database connection pool. 136 - pub type Db = SqlitePool; 137 /// The Azure credential type. 138 pub type Cred = Arc<dyn TokenCredential>; 139 ··· 451 452 let cred = azure_identity::DefaultAzureCredential::new() 453 .context("failed to create Azure credential")?; 454 - let opts = SqliteConnectOptions::from_str(&config.db) 455 - .context("failed to parse database options")? 456 - .create_if_missing(true); 457 - let db = SqlitePool::connect_with(opts).await?; 458 459 - sqlx::migrate!() 460 - .run(&db) 461 - .await 462 - .context("failed to apply migrations")?; 463 464 let (_fh, fhp) = firehose::spawn(client.clone(), config.clone()); 465 ··· 495 496 // Determine whether or not this was the first startup (i.e. no accounts exist and no invite codes were created). 497 // If so, create an invite code and share it via the console. 498 - let c = sqlx::query_scalar!( 499 - r#" 500 - SELECT 501 - (SELECT COUNT(*) FROM accounts) + (SELECT COUNT(*) FROM invites) 502 - AS total_count 503 - "# 504 ) 505 - .fetch_one(&db) 506 - .await 507 .context("failed to query database")?; 508 509 #[expect(clippy::print_stdout)] 510 if c == 0 { 511 let uuid = Uuid::new_v4().to_string(); 512 513 - _ = sqlx::query!( 514 - r#" 515 - INSERT INTO invites (id, did, count, created_at) 516 - VALUES (?, NULL, 1, datetime('now')) 517 - "#, 518 - uuid, 519 ) 520 - .execute(&db) 521 - .await 522 .context("failed to create new invite code")?; 523 524 // N.B: This is a sensitive message, so we're bypassing `tracing` here and
··· 1 //! PDS implementation. 2 + mod account_manager; 3 mod actor_store; 4 mod auth; 5 mod config; ··· 12 mod mmap; 13 mod oauth; 14 mod plc; 15 #[cfg(test)] 16 mod tests; 17 ··· 19 /// 20 /// We shouldn't have to know about any bsky endpoints to store private user data. 21 /// This will _very likely_ be changed in the future. 22 + mod actor_endpoints; 23 24 use anyhow::{Context as _, anyhow}; 25 use atrium_api::types::string::Did; ··· 37 use clap::Parser; 38 use clap_verbosity_flag::{InfoLevel, Verbosity, log::LevelFilter}; 39 use config::AppConfig; 40 + use diesel::prelude::*; 41 + use diesel::r2d2::{self, ConnectionManager}; 42 + use diesel_migrations::{EmbeddedMigrations, MigrationHarness, embed_migrations}; 43 #[expect(clippy::pub_use, clippy::useless_attribute)] 44 pub use error::Error; 45 use figment::{Figment, providers::Format as _}; ··· 47 use http_cache_reqwest::{CacheMode, HttpCacheOptions, MokaManager}; 48 use rand::Rng as _; 49 use serde::{Deserialize, Serialize}; 50 use std::{ 51 net::{IpAddr, Ipv4Addr, SocketAddr}, 52 path::PathBuf, ··· 61 /// The application user agent. Concatenates the package name and version. e.g. `bluepds/0.0.0`. 62 pub const APP_USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),); 63 64 + /// Embedded migrations 65 + pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("./migrations"); 66 + 67 /// The application-wide result type. 68 pub type Result<T> = std::result::Result<T, Error>; 69 /// The reqwest client type with middleware. 70 pub type Client = reqwest_middleware::ClientWithMiddleware; 71 /// The database connection pool. 72 + pub type Db = r2d2::Pool<ConnectionManager<SqliteConnection>>; 73 /// The Azure credential type. 74 pub type Cred = Arc<dyn TokenCredential>; 75 ··· 387 388 let cred = azure_identity::DefaultAzureCredential::new() 389 .context("failed to create Azure credential")?; 390 391 + // Create a database connection manager and pool 392 + let manager = ConnectionManager::<SqliteConnection>::new(&config.db); 393 + let db = r2d2::Pool::builder() 394 + .build(manager) 395 + .context("failed to create database connection pool")?; 396 + 397 + // Apply pending migrations 398 + let conn = &mut db 399 + .get() 400 + .context("failed to get database connection for migrations")?; 401 + conn.run_pending_migrations(MIGRATIONS) 402 + .expect("should be able to run migrations"); 403 404 let (_fh, fhp) = firehose::spawn(client.clone(), config.clone()); 405 ··· 435 436 // Determine whether or not this was the first startup (i.e. no accounts exist and no invite codes were created). 437 // If so, create an invite code and share it via the console. 438 + let conn = &mut db.get().context("failed to get database connection")?; 439 + 440 + #[derive(QueryableByName)] 441 + struct TotalCount { 442 + #[diesel(sql_type = diesel::sql_types::Integer)] 443 + total_count: i32, 444 + } 445 + 446 + let result = diesel::sql_query( 447 + "SELECT (SELECT COUNT(*) FROM accounts) + (SELECT COUNT(*) FROM invites) AS total_count", 448 ) 449 + .get_result::<TotalCount>(conn) 450 .context("failed to query database")?; 451 452 + let c = result.total_count; 453 + 454 #[expect(clippy::print_stdout)] 455 if c == 0 { 456 let uuid = Uuid::new_v4().to_string(); 457 458 + diesel::sql_query( 459 + "INSERT INTO invites (id, did, count, created_at) VALUES (?, NULL, 1, datetime('now'))", 460 ) 461 + .bind::<diesel::sql_types::Text, _>(uuid.clone()) 462 + .execute(conn) 463 .context("failed to create new invite code")?; 464 465 // N.B: This is a sensitive message, so we're bypassing `tracing` here and
-28
src/storage/car.rs
··· 1 - //! CAR file-based repository storage 2 - 3 - use anyhow::{Context as _, Result}; 4 - use atrium_repo::blockstore::{AsyncBlockStoreRead, AsyncBlockStoreWrite, CarStore}; 5 - 6 - use crate::{config::RepoConfig, mmap::MappedFile}; 7 - 8 - /// Open a CAR block store for a given DID. 9 - pub(crate) async fn open_car_store( 10 - config: &RepoConfig, 11 - did: impl AsRef<str>, 12 - ) -> Result<impl AsyncBlockStoreRead + AsyncBlockStoreWrite> { 13 - let id = did 14 - .as_ref() 15 - .strip_prefix("did:plc:") 16 - .context("did in unknown format")?; 17 - 18 - let p = config.path.join(id).with_extension("car"); 19 - 20 - let f = std::fs::File::options() 21 - .read(true) 22 - .write(true) 23 - .open(p) 24 - .context("failed to open repository file")?; 25 - let f = MappedFile::new(f).context("failed to map repo")?; 26 - 27 - CarStore::open(f).await.context("failed to open car store") 28 - }
···
-159
src/storage/mod.rs
··· 1 - //! `ATProto` user repository datastore functionality. 2 - 3 - pub(crate) mod car; 4 - mod sqlite; 5 - 6 - use anyhow::{Context as _, Result}; 7 - use atrium_repo::{ 8 - Cid, Repository, 9 - blockstore::{AsyncBlockStoreRead, AsyncBlockStoreWrite}, 10 - }; 11 - use std::str::FromStr as _; 12 - 13 - use crate::{Db, config::RepoConfig}; 14 - 15 - // Re-export public items 16 - pub(crate) use car::open_car_store; 17 - pub(crate) use sqlite::{SQLiteStore, open_sqlite_store}; 18 - 19 - /// Open a repository for a given DID. 20 - pub(crate) async fn open_repo_db( 21 - config: &RepoConfig, 22 - db: &Db, 23 - did: impl Into<String>, 24 - ) -> Result<Repository<impl AsyncBlockStoreRead + AsyncBlockStoreWrite>> { 25 - let did = did.into(); 26 - let cid = sqlx::query_scalar!( 27 - r#" 28 - SELECT root FROM accounts 29 - WHERE did = ? 30 - "#, 31 - did 32 - ) 33 - .fetch_one(db) 34 - .await 35 - .context("failed to query database")?; 36 - 37 - open_repo( 38 - config, 39 - did, 40 - Cid::from_str(&cid).context("should be valid CID")?, 41 - ) 42 - .await 43 - } 44 - 45 - /// Open a repository for a given DID and CID. 46 - pub(crate) async fn open_repo( 47 - config: &RepoConfig, 48 - did: impl Into<String>, 49 - cid: Cid, 50 - ) -> Result<Repository<impl AsyncBlockStoreRead + AsyncBlockStoreWrite>> { 51 - let store = open_car_store(config, did.into()).await?; 52 - Repository::open(store, cid) 53 - .await 54 - .context("failed to open repo") 55 - } 56 - /// Open a repository for a given DID and CID. 57 - /// SQLite backend. 58 - pub(crate) async fn open_repo_sqlite( 59 - config: &RepoConfig, 60 - did: impl Into<String>, 61 - cid: Cid, 62 - ) -> Result<Repository<impl AsyncBlockStoreRead + AsyncBlockStoreWrite>> { 63 - let store = open_sqlite_store(config, did.into()).await?; 64 - return Repository::open(store, cid) 65 - .await 66 - .context("failed to open repo"); 67 - } 68 - 69 - /// Open a block store for a given DID. 70 - pub(crate) async fn open_store( 71 - config: &RepoConfig, 72 - did: impl Into<String>, 73 - ) -> Result<impl AsyncBlockStoreRead + AsyncBlockStoreWrite> { 74 - let did = did.into(); 75 - 76 - // if config.use_sqlite { 77 - return open_sqlite_store(config, did.clone()).await; 78 - // } 79 - // Default to CAR store 80 - // open_car_store(config, &did).await 81 - } 82 - 83 - /// Create a storage backend for a DID 84 - pub(crate) async fn create_storage_for_did( 85 - config: &RepoConfig, 86 - did_hash: &str, 87 - ) -> Result<impl AsyncBlockStoreRead + AsyncBlockStoreWrite> { 88 - // Use standard file structure but change extension based on type 89 - // if config.use_sqlite { 90 - // For SQLite, create a new database file 91 - let db_path = config.path.join(format!("{}.db", did_hash)); 92 - 93 - // Ensure parent directory exists 94 - if let Some(parent) = db_path.parent() { 95 - tokio::fs::create_dir_all(parent) 96 - .await 97 - .context("failed to create directory")?; 98 - } 99 - 100 - // Create SQLite store 101 - let pool = sqlx::sqlite::SqlitePoolOptions::new() 102 - .max_connections(5) 103 - .connect_with( 104 - sqlx::sqlite::SqliteConnectOptions::new() 105 - .filename(&db_path) 106 - .create_if_missing(true), 107 - ) 108 - .await 109 - .context("failed to connect to SQLite database")?; 110 - 111 - // Initialize tables 112 - _ = sqlx::query( 113 - " 114 - CREATE TABLE IF NOT EXISTS blocks ( 115 - cid TEXT PRIMARY KEY NOT NULL, 116 - data BLOB NOT NULL, 117 - multicodec INTEGER NOT NULL, 118 - multihash INTEGER NOT NULL 119 - ); 120 - CREATE TABLE IF NOT EXISTS tree_nodes ( 121 - repo_did TEXT NOT NULL, 122 - key TEXT NOT NULL, 123 - value_cid TEXT NOT NULL, 124 - PRIMARY KEY (repo_did, key), 125 - FOREIGN KEY (value_cid) REFERENCES blocks(cid) 126 - ); 127 - CREATE INDEX IF NOT EXISTS idx_blocks_cid ON blocks(cid); 128 - CREATE INDEX IF NOT EXISTS idx_tree_nodes_repo ON tree_nodes(repo_did); 129 - PRAGMA journal_mode=WAL; 130 - ", 131 - ) 132 - .execute(&pool) 133 - .await 134 - .context("failed to create tables")?; 135 - 136 - Ok(SQLiteStore { 137 - pool, 138 - did: format!("did:plc:{}", did_hash), 139 - }) 140 - // } else { 141 - // // For CAR files, create a new file 142 - // let file_path = config.path.join(format!("{}.car", did_hash)); 143 - 144 - // // Ensure parent directory exists 145 - // if let Some(parent) = file_path.parent() { 146 - // tokio::fs::create_dir_all(parent) 147 - // .await 148 - // .context("failed to create directory")?; 149 - // } 150 - 151 - // let file = tokio::fs::File::create_new(file_path) 152 - // .await 153 - // .context("failed to create repo file")?; 154 - 155 - // CarStore::create(file) 156 - // .await 157 - // .context("failed to create carstore") 158 - // } 159 - }
···
-149
src/storage/sqlite.rs
··· 1 - //! SQLite-based repository storage implementation. 2 - 3 - use anyhow::{Context as _, Result}; 4 - use atrium_repo::{ 5 - Cid, Multihash, 6 - blockstore::{AsyncBlockStoreRead, AsyncBlockStoreWrite, Error as BlockstoreError}, 7 - }; 8 - use sha2::Digest; 9 - use sqlx::SqlitePool; 10 - 11 - use crate::config::RepoConfig; 12 - 13 - /// SQLite-based implementation of block storage. 14 - pub(crate) struct SQLiteStore { 15 - pub did: String, 16 - pub pool: SqlitePool, 17 - } 18 - 19 - impl AsyncBlockStoreRead for SQLiteStore { 20 - async fn read_block(&mut self, cid: Cid) -> Result<Vec<u8>, BlockstoreError> { 21 - let mut contents = Vec::new(); 22 - self.read_block_into(cid, &mut contents).await?; 23 - Ok(contents) 24 - } 25 - async fn read_block_into( 26 - &mut self, 27 - cid: Cid, 28 - contents: &mut Vec<u8>, 29 - ) -> Result<(), BlockstoreError> { 30 - let cid_str = cid.to_string(); 31 - let record = sqlx::query!(r#"SELECT data FROM blocks WHERE cid = ?"#, cid_str) 32 - .fetch_optional(&self.pool) 33 - .await 34 - .map_err(|e| BlockstoreError::Other(Box::new(e)))? 35 - .ok_or(BlockstoreError::CidNotFound)?; 36 - 37 - contents.clear(); 38 - contents.extend_from_slice(&record.data); 39 - Ok(()) 40 - } 41 - } 42 - 43 - impl AsyncBlockStoreWrite for SQLiteStore { 44 - async fn write_block( 45 - &mut self, 46 - codec: u64, 47 - hash: u64, 48 - contents: &[u8], 49 - ) -> Result<Cid, BlockstoreError> { 50 - let digest = match hash { 51 - atrium_repo::blockstore::SHA2_256 => sha2::Sha256::digest(&contents), 52 - _ => return Err(BlockstoreError::UnsupportedHash(hash)), 53 - }; 54 - 55 - let multihash = Multihash::wrap(hash, digest.as_slice()) 56 - .map_err(|_| BlockstoreError::UnsupportedHash(hash))?; 57 - 58 - let cid = Cid::new_v1(codec, multihash); 59 - let cid_str = cid.to_string(); 60 - 61 - // Use a transaction for atomicity 62 - let mut tx = self 63 - .pool 64 - .begin() 65 - .await 66 - .map_err(|e| BlockstoreError::Other(Box::new(e)))?; 67 - 68 - // Check if block already exists 69 - let exists = sqlx::query_scalar!(r#"SELECT COUNT(*) FROM blocks WHERE cid = ?"#, cid_str) 70 - .fetch_one(&mut *tx) 71 - .await 72 - .map_err(|e| BlockstoreError::Other(Box::new(e)))?; 73 - 74 - // Only insert if block doesn't exist 75 - let codec = codec as i64; 76 - let hash = hash as i64; 77 - if exists == 0 { 78 - _ = sqlx::query!( 79 - r#"INSERT INTO blocks (cid, data, multicodec, multihash) VALUES (?, ?, ?, ?)"#, 80 - cid_str, 81 - contents, 82 - codec, 83 - hash 84 - ) 85 - .execute(&mut *tx) 86 - .await 87 - .map_err(|e| BlockstoreError::Other(Box::new(e)))?; 88 - } 89 - 90 - tx.commit() 91 - .await 92 - .map_err(|e| BlockstoreError::Other(Box::new(e)))?; 93 - 94 - Ok(cid) 95 - } 96 - } 97 - 98 - /// Open a SQLite store for the given DID. 99 - pub(crate) async fn open_sqlite_store( 100 - config: &RepoConfig, 101 - did: impl Into<String>, 102 - ) -> Result<impl AsyncBlockStoreRead + AsyncBlockStoreWrite> { 103 - tracing::info!("Opening SQLite store for DID"); 104 - let did_str = did.into(); 105 - 106 - // Extract the PLC ID from the DID 107 - let id = did_str 108 - .strip_prefix("did:plc:") 109 - .context("DID in unknown format")?; 110 - 111 - // Create database connection pool 112 - let db_path = config.path.join(format!("{id}.db")); 113 - 114 - let pool = sqlx::sqlite::SqlitePoolOptions::new() 115 - .max_connections(5) 116 - .connect_with( 117 - sqlx::sqlite::SqliteConnectOptions::new() 118 - .filename(&db_path) 119 - .create_if_missing(true), 120 - ) 121 - .await 122 - .context("failed to connect to SQLite database")?; 123 - 124 - // Ensure tables exist 125 - _ = sqlx::query( 126 - " 127 - CREATE TABLE IF NOT EXISTS blocks ( 128 - cid TEXT PRIMARY KEY NOT NULL, 129 - data BLOB NOT NULL, 130 - multicodec INTEGER NOT NULL, 131 - multihash INTEGER NOT NULL 132 - ); 133 - CREATE TABLE IF NOT EXISTS tree_nodes ( 134 - repo_did TEXT NOT NULL, 135 - key TEXT NOT NULL, 136 - value_cid TEXT NOT NULL, 137 - PRIMARY KEY (repo_did, key), 138 - FOREIGN KEY (value_cid) REFERENCES blocks(cid) 139 - ); 140 - CREATE INDEX IF NOT EXISTS idx_blocks_cid ON blocks(cid); 141 - CREATE INDEX IF NOT EXISTS idx_tree_nodes_repo ON tree_nodes(repo_did); 142 - ", 143 - ) 144 - .execute(&pool) 145 - .await 146 - .context("failed to create tables")?; 147 - 148 - Ok(SQLiteStore { pool, did: did_str }) 149 - }
···