Alternative ATProto PDS implementation

prototype deadpool_diesel connection pooling

+43
Cargo.lock
··· 1298 1298 "clap", 1299 1299 "clap-verbosity-flag", 1300 1300 "constcat", 1301 + "deadpool-diesel", 1301 1302 "diesel", 1302 1303 "diesel_migrations", 1303 1304 "dotenvy", ··· 1862 1863 dependencies = [ 1863 1864 "data-encoding", 1864 1865 "syn 2.0.101", 1866 + ] 1867 + 1868 + [[package]] 1869 + name = "deadpool" 1870 + version = "0.12.2" 1871 + source = "registry+https://github.com/rust-lang/crates.io-index" 1872 + checksum = "5ed5957ff93768adf7a65ab167a17835c3d2c3c50d084fe305174c112f468e2f" 1873 + dependencies = [ 1874 + "deadpool-runtime", 1875 + "num_cpus", 1876 + "serde", 1877 + "tokio", 1878 + ] 1879 + 1880 + [[package]] 1881 + name = "deadpool-diesel" 1882 + version = "0.6.1" 1883 + source = "registry+https://github.com/rust-lang/crates.io-index" 1884 + checksum = "590573e9e29c5190a5ff782136f871e6e652e35d598a349888e028693601adf1" 1885 + dependencies = [ 1886 + "deadpool", 1887 + "deadpool-sync", 1888 + "diesel", 1889 + ] 1890 + 1891 + [[package]] 1892 + name = "deadpool-runtime" 1893 + version = "0.1.4" 1894 + source = "registry+https://github.com/rust-lang/crates.io-index" 1895 + checksum = "092966b41edc516079bdf31ec78a2e0588d1d0c08f78b91d8307215928642b2b" 1896 + dependencies = [ 1897 + "tokio", 1898 + ] 1899 + 1900 + [[package]] 1901 + name = "deadpool-sync" 1902 + version = "0.1.4" 1903 + source = "registry+https://github.com/rust-lang/crates.io-index" 1904 + checksum = "524bc3df0d57e98ecd022e21ba31166c2625e7d3e5bcc4510efaeeab4abcab04" 1905 + dependencies = [ 1906 + "deadpool-runtime", 1907 + "tracing", 1865 1908 ] 1866 1909 1867 1910 [[package]]
+1
Cargo.toml
··· 252 252 lazy_static = "1.5.0" 253 253 secp256k1 = "0.28.2" 254 254 dotenvy = "0.15.7" 255 + deadpool-diesel = { version = "0.6.1", features = ["serde", "sqlite", "tracing"] } 255 256 [dependencies.rocket_sync_db_pools] 256 257 version = "=0.1.0" 257 258 features = ["diesel_sqlite_pool"]
+9 -4
src/account_manager/mod.rs
··· 30 30 31 31 #[derive(Clone, Debug)] 32 32 pub struct AccountManager { 33 - pub db: Arc<DbConn>, 33 + pub db: deadpool_diesel::Connection<SqliteConnection>, 34 34 } 35 35 36 - pub type AccountManagerCreator = Box<dyn Fn(Arc<DbConn>) -> AccountManager + Send + Sync>; 36 + pub type AccountManagerCreator = 37 + Box<dyn Fn(deadpool_diesel::Connection<SqliteConnection>) -> AccountManager + Send + Sync>; 37 38 38 39 impl AccountManager { 39 - pub fn new(db: Arc<DbConn>) -> Self { 40 + pub fn new(db: deadpool_diesel::Connection<SqliteConnection>) -> Self { 40 41 Self { db } 41 42 } 42 43 43 44 pub fn creator() -> AccountManagerCreator { 44 - Box::new(move |db: Arc<DbConn>| -> AccountManager { AccountManager::new(db) }) 45 + Box::new( 46 + move |db: deadpool_diesel::Connection<SqliteConnection>| -> AccountManager { 47 + AccountManager::new(db) 48 + }, 49 + ) 45 50 } 46 51 47 52 pub async fn get_account(
+50 -32
src/actor_store/blob.rs
··· 27 27 use rsky_repo::error::BlobError; 28 28 use rsky_repo::types::{PreparedBlobRef, PreparedWrite}; 29 29 use std::str::FromStr as _; 30 - use std::sync::Arc; 31 30 32 31 use super::sql_blob::{BlobStoreSql, ByteStream}; 33 - use crate::db::DbConn; 34 32 35 33 pub struct GetBlobOutput { 36 34 pub size: i32, ··· 45 43 /// DID of the actor 46 44 pub did: String, 47 45 /// Database connection 48 - pub db: Arc<DbConn>, 46 + pub db: deadpool_diesel::Connection<SqliteConnection>, 49 47 } 50 48 51 49 impl BlobReader { 52 50 /// Create a new blob reader 53 - pub fn new(blobstore: BlobStoreSql, db: Arc<DbConn>) -> Self { 51 + pub fn new(blobstore: BlobStoreSql, db: deadpool_diesel::Connection<SqliteConnection>) -> Self { 54 52 BlobReader { 55 53 did: blobstore.did.clone(), 56 54 blobstore, ··· 65 63 let did = self.did.clone(); 66 64 let found = self 67 65 .db 68 - .run(move |conn| { 66 + .interact(move |conn| { 69 67 BlobSchema::blob 70 68 .filter(BlobSchema::did.eq(did)) 71 69 .filter(BlobSchema::cid.eq(cid.to_string())) ··· 74 72 .first(conn) 75 73 .optional() 76 74 }) 77 - .await?; 75 + .await 76 + .expect("Failed to get blob metadata")?; 78 77 79 78 match found { 80 79 None => bail!("Blob not found"), ··· 107 106 let did = self.did.clone(); 108 107 let res = self 109 108 .db 110 - .run(move |conn| { 109 + .interact(move |conn| { 111 110 let results = RecordBlobSchema::record_blob 112 111 .filter(RecordBlobSchema::blobCid.eq(cid.to_string())) 113 112 .filter(RecordBlobSchema::did.eq(did)) ··· 115 114 .get_results(conn)?; 116 115 Ok::<_, result::Error>(results.into_iter().map(|row| row.record_uri)) 117 116 }) 118 - .await? 117 + .await 118 + .expect("Failed to get records for blob")? 119 119 .collect::<Vec<String>>(); 120 120 121 121 Ok(res) ··· 163 163 use rsky_pds::schema::pds::blob::dsl as BlobSchema; 164 164 165 165 let did = self.did.clone(); 166 - self.db.run(move |conn| { 166 + self.db.interact(move |conn| { 167 167 let BlobMetadata { 168 168 temp_key, 169 169 size, ··· 207 207 .execute(conn)?; 208 208 209 209 Ok(BlobRef::new(cid, mime_type, size, None)) 210 - }).await 210 + }).await.expect("Failed to track untethered blob") 211 211 } 212 212 213 213 /// Process blobs associated with writes ··· 265 265 let uris_clone = uris.clone(); 266 266 let deleted_repo_blobs: Vec<models::RecordBlob> = self 267 267 .db 268 - .run(move |conn| { 268 + .interact(move |conn| { 269 269 RecordBlobSchema::record_blob 270 270 .filter(RecordBlobSchema::recordUri.eq_any(&uris_clone)) 271 271 .filter(RecordBlobSchema::did.eq(&did)) 272 272 .load::<models::RecordBlob>(conn) 273 273 }) 274 - .await?; 274 + .await 275 + .expect("Failed to get deleted repo blobs")?; 275 276 276 277 if deleted_repo_blobs.is_empty() { 277 278 return Ok(()); ··· 280 281 // Now perform the delete 281 282 let uris_clone = uris.clone(); 282 283 self.db 283 - .run(move |conn| { 284 + .interact(move |conn| { 284 285 delete(RecordBlobSchema::record_blob) 285 286 .filter(RecordBlobSchema::recordUri.eq_any(uris_clone)) 286 287 .execute(conn) 287 288 }) 288 - .await?; 289 + .await 290 + .expect("Failed to delete repo blobs")?; 289 291 290 292 // Extract blob cids from the deleted records 291 293 let deleted_repo_blob_cids: Vec<String> = deleted_repo_blobs ··· 298 300 let did_clone = self.did.clone(); 299 301 let duplicated_cids: Vec<String> = self 300 302 .db 301 - .run(move |conn| { 303 + .interact(move |conn| { 302 304 RecordBlobSchema::record_blob 303 305 .filter(RecordBlobSchema::blobCid.eq_any(cids_clone)) 304 306 .filter(RecordBlobSchema::did.eq(did_clone)) 305 307 .select(RecordBlobSchema::blobCid) 306 308 .load::<String>(conn) 307 309 }) 308 - .await?; 310 + .await 311 + .expect("Failed to get duplicated cids")?; 309 312 310 313 // Extract new blob cids from writes (creates and updates) 311 314 let new_blob_cids: Vec<String> = writes ··· 333 336 let cids = cids_to_delete.clone(); 334 337 let did_clone = self.did.clone(); 335 338 self.db 336 - .run(move |conn| { 339 + .interact(move |conn| { 337 340 delete(BlobSchema::blob) 338 341 .filter(BlobSchema::cid.eq_any(cids)) 339 342 .filter(BlobSchema::did.eq(did_clone)) 340 343 .execute(conn) 341 344 }) 342 - .await?; 345 + .await 346 + .expect("Failed to delete blobs")?; 343 347 344 348 // Delete from blob storage 345 349 // Ideally we'd use a background queue here, but for now: ··· 364 368 365 369 let found = self 366 370 .db 367 - .run(move |conn| { 371 + .interact(move |conn| { 368 372 BlobSchema::blob 369 373 .filter( 370 374 BlobSchema::cid ··· 375 379 .first(conn) 376 380 .optional() 377 381 }) 378 - .await?; 382 + .await 383 + .expect("Failed to verify blob")?; 379 384 380 385 if let Some(found) = found { 381 386 verify_blob(&blob, &found).await?; ··· 385 390 .await?; 386 391 } 387 392 self.db 388 - .run(move |conn| { 393 + .interact(move |conn| { 389 394 update(BlobSchema::blob) 390 395 .filter(BlobSchema::tempKey.eq(found.temp_key)) 391 396 .set(BlobSchema::tempKey.eq::<Option<String>>(None)) 392 397 .execute(conn) 393 398 }) 394 - .await?; 399 + .await 400 + .expect("Failed to update blob")?; 395 401 Ok(()) 396 402 } else { 397 403 bail!("Could not find blob: {:?}", blob.cid.to_string()) ··· 406 412 let did = self.did.clone(); 407 413 408 414 self.db 409 - .run(move |conn| { 415 + .interact(move |conn| { 410 416 insert_into(RecordBlobSchema::record_blob) 411 417 .values(( 412 418 RecordBlobSchema::blobCid.eq(cid), ··· 416 422 .on_conflict_do_nothing() 417 423 .execute(conn) 418 424 }) 419 - .await?; 425 + .await 426 + .expect("Failed to associate blob")?; 420 427 421 428 Ok(()) 422 429 } ··· 427 434 428 435 let did = self.did.clone(); 429 436 self.db 430 - .run(move |conn| { 437 + .interact(move |conn| { 431 438 let res = BlobSchema::blob 432 439 .filter(BlobSchema::did.eq(&did)) 433 440 .count() ··· 435 442 Ok(res) 436 443 }) 437 444 .await 445 + .expect("Failed to count blobs") 438 446 } 439 447 440 448 /// Count blobs associated with records ··· 443 451 444 452 let did = self.did.clone(); 445 453 self.db 446 - .run(move |conn| { 454 + .interact(move |conn| { 447 455 let res: i64 = RecordBlobSchema::record_blob 448 456 .filter(RecordBlobSchema::did.eq(&did)) 449 457 .select(count_distinct(RecordBlobSchema::blobCid)) ··· 451 459 Ok(res) 452 460 }) 453 461 .await 462 + .expect("Failed to count record blobs") 454 463 } 455 464 456 465 /// List blobs that are referenced but missing ··· 463 472 464 473 let did = self.did.clone(); 465 474 self.db 466 - .run(move |conn| { 475 + .interact(move |conn| { 467 476 let ListMissingBlobsOpts { cursor, limit } = opts; 468 477 469 478 if limit > 1000 { ··· 513 522 Ok(result) 514 523 }) 515 524 .await 525 + .expect("Failed to list missing blobs") 516 526 } 517 527 518 528 /// List all blobs with optional filtering ··· 541 551 if let Some(cursor) = cursor { 542 552 builder = builder.filter(RecordBlobSchema::blobCid.gt(cursor)); 543 553 } 544 - self.db.run(move |conn| builder.load(conn)).await? 554 + self.db 555 + .interact(move |conn| builder.load(conn)) 556 + .await 557 + .expect("Failed to list blobs")? 545 558 } else { 546 559 let mut builder = RecordBlobSchema::record_blob 547 560 .select(RecordBlobSchema::blobCid) ··· 553 566 if let Some(cursor) = cursor { 554 567 builder = builder.filter(RecordBlobSchema::blobCid.gt(cursor)); 555 568 } 556 - self.db.run(move |conn| builder.load(conn)).await? 569 + self.db 570 + .interact(move |conn| builder.load(conn)) 571 + .await 572 + .expect("Failed to list blobs")? 557 573 }; 558 574 559 575 Ok(res) ··· 564 580 use rsky_pds::schema::pds::blob::dsl as BlobSchema; 565 581 566 582 self.db 567 - .run(move |conn| { 583 + .interact(move |conn| { 568 584 let res = BlobSchema::blob 569 585 .filter(BlobSchema::cid.eq(cid.to_string())) 570 586 .select(models::Blob::as_select()) ··· 586 602 } 587 603 }) 588 604 .await 605 + .expect("Failed to get blob takedown status") 589 606 } 590 607 591 608 /// Update the takedown status of a blob ··· 604 621 let did_clone = self.did.clone(); 605 622 606 623 self.db 607 - .run(move |conn| { 624 + .interact(move |conn| { 608 625 update(BlobSchema::blob) 609 626 .filter(BlobSchema::cid.eq(blob_cid)) 610 627 .filter(BlobSchema::did.eq(did_clone)) ··· 612 629 .execute(conn)?; 613 630 Ok::<_, result::Error>(blob) 614 631 }) 615 - .await?; 632 + .await 633 + .expect("Failed to update blob takedown status")?; 616 634 617 635 let res = match takedown.applied { 618 636 true => self.blobstore.quarantine(blob).await,
+7 -6
src/actor_store/mod.rs
··· 32 32 use std::{env, fmt}; 33 33 use tokio::sync::RwLock; 34 34 35 - use crate::db::DbConn; 36 - 37 35 use blob::BlobReader; 38 36 use preference::PreferenceReader; 39 37 use record::RecordReader; ··· 74 72 // Combination of RepoReader/Transactor, BlobReader/Transactor, SqlRepoReader/Transactor 75 73 impl ActorStore { 76 74 /// Concrete reader of an individual repo (hence BlobStoreSql which takes `did` param) 77 - pub fn new(did: String, blobstore: BlobStoreSql, db: DbConn) -> Self { 78 - let db = Arc::new(db); 75 + pub fn new( 76 + did: String, 77 + blobstore: BlobStoreSql, 78 + db: deadpool_diesel::Connection<SqliteConnection>, 79 + ) -> Self { 79 80 ActorStore { 80 81 storage: Arc::new(RwLock::new(SqlRepoReader::new( 81 82 did.clone(), ··· 437 438 pub async fn destroy(&mut self) -> Result<()> { 438 439 let did: String = self.did.clone(); 439 440 let storage_guard = self.storage.read().await; 440 - let db: Arc<DbConn> = storage_guard.db.clone(); 441 + let db: deadpool_diesel::Connection<SqliteConnection> = storage_guard.db.clone(); 441 442 use rsky_pds::schema::pds::blob::dsl as BlobSchema; 442 443 443 444 let blob_rows: Vec<String> = db ··· 471 472 } 472 473 let did: String = self.did.clone(); 473 474 let storage_guard = self.storage.read().await; 474 - let db: Arc<DbConn> = storage_guard.db.clone(); 475 + let db: deadpool_diesel::Connection<SqliteConnection> = storage_guard.db.clone(); 475 476 use rsky_pds::schema::pds::record::dsl as RecordSchema; 476 477 477 478 let cid_strs: Vec<String> = cids.into_iter().map(|c| c.to_string()).collect();
+6 -8
src/actor_store/preference.rs
··· 4 4 //! 5 5 //! Modified for SQLite backend 6 6 7 - use std::sync::Arc; 8 - 9 7 use anyhow::{Result, bail}; 10 8 use diesel::*; 11 9 use rsky_lexicon::app::bsky::actor::RefPreferences; ··· 13 11 use rsky_pds::actor_store::preference::util::pref_in_scope; 14 12 use rsky_pds::auth_verifier::AuthScope; 15 13 use rsky_pds::models::AccountPref; 16 - 17 - use crate::db::DbConn; 18 14 19 15 pub struct PreferenceReader { 20 16 pub did: String, 21 - pub db: Arc<DbConn>, 17 + pub db: deadpool_diesel::Connection<SqliteConnection>, 22 18 } 23 19 24 20 impl PreferenceReader { 25 - pub fn new(did: String, db: Arc<DbConn>) -> Self { 21 + pub fn new(did: String, db: deadpool_diesel::Connection<SqliteConnection>) -> Self { 26 22 PreferenceReader { did, db } 27 23 } 28 24 ··· 35 31 36 32 let did = self.did.clone(); 37 33 self.db 38 - .run(move |conn| { 34 + .interact(move |conn| { 39 35 let prefs_res = AccountPrefSchema::account_pref 40 36 .filter(AccountPrefSchema::did.eq(&did)) 41 37 .select(AccountPref::as_select()) ··· 62 58 Ok(account_prefs) 63 59 }) 64 60 .await 61 + .expect("Failed to get preferences") 65 62 } 66 63 67 64 #[tracing::instrument(skip_all)] ··· 73 70 ) -> Result<()> { 74 71 let did = self.did.clone(); 75 72 self.db 76 - .run(move |conn| { 73 + .interact(move |conn| { 77 74 match values 78 75 .iter() 79 76 .all(|value| pref_match_namespace(&namespace, &value.get_type())) ··· 142 139 } 143 140 }) 144 141 .await 142 + .expect("Failed to put preferences") 145 143 } 146 144 }
+37 -24
src/actor_store/record.rs
··· 17 17 use rsky_syntax::aturi::AtUri; 18 18 use std::env; 19 19 use std::str::FromStr; 20 - use std::sync::Arc; 21 - 22 - use crate::db::DbConn; 23 20 24 21 /// Combined handler for record operations with both read and write capabilities. 25 22 pub(crate) struct RecordReader { 26 23 /// Database connection. 27 - pub db: Arc<DbConn>, 24 + pub db: deadpool_diesel::Connection<SqliteConnection>, 28 25 /// DID of the actor. 29 26 pub did: String, 30 27 } 31 28 32 29 impl RecordReader { 33 30 /// Create a new record handler. 34 - pub(crate) fn new(did: String, db: Arc<DbConn>) -> Self { 31 + pub(crate) fn new(did: String, db: deadpool_diesel::Connection<SqliteConnection>) -> Self { 35 32 Self { did, db } 36 33 } 37 34 ··· 41 38 42 39 let other_did = self.did.clone(); 43 40 self.db 44 - .run(move |conn| { 41 + .interact(move |conn| { 45 42 let res: i64 = record.filter(did.eq(&other_did)).count().get_result(conn)?; 46 43 Ok(res) 47 44 }) 48 45 .await 46 + .expect("Failed to count records") 49 47 } 50 48 51 49 /// List all collections in the repository. ··· 54 52 55 53 let other_did = self.did.clone(); 56 54 self.db 57 - .run(move |conn| { 55 + .interact(move |conn| { 58 56 let collections = record 59 57 .filter(did.eq(&other_did)) 60 58 .select(collection) ··· 65 63 Ok(collections) 66 64 }) 67 65 .await 66 + .expect("Failed to list collections") 68 67 } 69 68 70 69 /// List records for a specific collection. ··· 116 115 builder = builder.filter(RecordSchema::rkey.lt(rkey_end)); 117 116 } 118 117 } 119 - let res: Vec<(Record, RepoBlock)> = self.db.run(move |conn| builder.load(conn)).await?; 118 + let res: Vec<(Record, RepoBlock)> = self 119 + .db 120 + .interact(move |conn| builder.load(conn)) 121 + .await 122 + .expect("Failed to load records")?; 120 123 res.into_iter() 121 124 .map(|row| { 122 125 Ok(RecordsForCollection { ··· 156 159 } 157 160 let record: Option<(Record, RepoBlock)> = self 158 161 .db 159 - .run(move |conn| builder.first(conn).optional()) 160 - .await?; 162 + .interact(move |conn| builder.first(conn).optional()) 163 + .await 164 + .expect("Failed to load record")?; 161 165 if let Some(record) = record { 162 166 Ok(Some(GetRecord { 163 167 uri: record.0.uri, ··· 197 201 } 198 202 let record_uri = self 199 203 .db 200 - .run(move |conn| builder.first::<String>(conn).optional()) 201 - .await?; 204 + .interact(move |conn| builder.first::<String>(conn).optional()) 205 + .await 206 + .expect("Failed to check record")?; 202 207 Ok(!!record_uri.is_some()) 203 208 } 204 209 ··· 211 216 212 217 let res = self 213 218 .db 214 - .run(move |conn| { 219 + .interact(move |conn| { 215 220 RecordSchema::record 216 221 .select(RecordSchema::takedownRef) 217 222 .filter(RecordSchema::uri.eq(uri)) 218 223 .first::<Option<String>>(conn) 219 224 .optional() 220 225 }) 221 - .await?; 226 + .await 227 + .expect("Failed to get takedown status")?; 222 228 if let Some(res) = res { 223 229 if let Some(takedown_ref) = res { 224 230 Ok(Some(StatusAttr { ··· 242 248 243 249 let res = self 244 250 .db 245 - .run(move |conn| { 251 + .interact(move |conn| { 246 252 RecordSchema::record 247 253 .select(RecordSchema::cid) 248 254 .filter(RecordSchema::uri.eq(uri)) 249 255 .first::<String>(conn) 250 256 .optional() 251 257 }) 252 - .await?; 258 + .await 259 + .expect("Failed to get current CID")?; 253 260 if let Some(res) = res { 254 261 Ok(Some(Cid::from_str(&res)?)) 255 262 } else { ··· 269 276 270 277 let res = self 271 278 .db 272 - .run(move |conn| { 279 + .interact(move |conn| { 273 280 RecordSchema::record 274 281 .inner_join( 275 282 BacklinkSchema::backlink.on(BacklinkSchema::uri.eq(RecordSchema::uri)), ··· 280 287 .filter(RecordSchema::collection.eq(collection)) 281 288 .load::<Record>(conn) 282 289 }) 283 - .await?; 290 + .await 291 + .expect("Failed to get backlinks")?; 284 292 Ok(res) 285 293 } 286 294 ··· 365 373 // Track current version of record 366 374 let (record, uri) = self 367 375 .db 368 - .run(move |conn| { 376 + .interact(move |conn| { 369 377 insert_into(RecordSchema::record) 370 378 .values(row) 371 379 .on_conflict(RecordSchema::uri) ··· 378 386 .execute(conn)?; 379 387 Ok::<_, Error>((record, uri)) 380 388 }) 381 - .await?; 389 + .await 390 + .expect("Failed to index record")?; 382 391 383 392 if let Some(record) = record { 384 393 // Maintain backlinks ··· 402 411 use rsky_pds::schema::pds::record::dsl as RecordSchema; 403 412 let uri = uri.to_string(); 404 413 self.db 405 - .run(move |conn| { 414 + .interact(move |conn| { 406 415 delete(RecordSchema::record) 407 416 .filter(RecordSchema::uri.eq(&uri)) 408 417 .execute(conn)?; ··· 415 424 Ok(()) 416 425 }) 417 426 .await 427 + .expect("Failed to delete record") 418 428 } 419 429 420 430 /// Remove backlinks for a URI. ··· 422 432 use rsky_pds::schema::pds::backlink::dsl as BacklinkSchema; 423 433 let uri = uri.to_string(); 424 434 self.db 425 - .run(move |conn| { 435 + .interact(move |conn| { 426 436 delete(BacklinkSchema::backlink) 427 437 .filter(BacklinkSchema::uri.eq(uri)) 428 438 .execute(conn)?; 429 439 Ok(()) 430 440 }) 431 441 .await 442 + .expect("Failed to remove backlinks") 432 443 } 433 444 434 445 /// Add backlinks to the database. ··· 438 449 } else { 439 450 use rsky_pds::schema::pds::backlink::dsl as BacklinkSchema; 440 451 self.db 441 - .run(move |conn| { 452 + .interact(move |conn| { 442 453 insert_or_ignore_into(BacklinkSchema::backlink) 443 454 .values(&backlinks) 444 455 .execute(conn)?; 445 456 Ok(()) 446 457 }) 447 458 .await 459 + .expect("Failed to add backlinks") 448 460 } 449 461 } 450 462 ··· 466 478 let uri_string = uri.to_string(); 467 479 468 480 self.db 469 - .run(move |conn| { 481 + .interact(move |conn| { 470 482 update(RecordSchema::record) 471 483 .filter(RecordSchema::uri.eq(uri_string)) 472 484 .set(RecordSchema::takedownRef.eq(takedown_ref)) ··· 474 486 Ok(()) 475 487 }) 476 488 .await 489 + .expect("Failed to update takedown status") 477 490 } 478 491 }
+30 -25
src/actor_store/sql_blob.rs
··· 7 7 use anyhow::{Context, Result}; 8 8 use cidv10::Cid; 9 9 use diesel::prelude::*; 10 - use std::sync::Arc; 11 - 12 - use crate::db::DbConn; 13 10 14 11 /// ByteStream implementation for blob data 15 12 pub struct ByteStream { ··· 27 24 } 28 25 29 26 /// SQL-based implementation of blob storage 30 - #[derive(Clone)] 31 27 pub struct BlobStoreSql { 32 28 /// Database connection for metadata 33 - pub db: Arc<DbConn>, 29 + pub db: deadpool_diesel::Connection<SqliteConnection>, 34 30 /// DID of the actor 35 31 pub did: String, 36 32 } ··· 61 57 62 58 impl BlobStoreSql { 63 59 /// Create a new SQL-based blob store for the given DID 64 - pub fn new(did: String, db: Arc<DbConn>) -> Self { 60 + pub fn new(did: String, db: deadpool_diesel::Connection<SqliteConnection>) -> Self { 65 61 BlobStoreSql { db, did } 66 62 } 67 63 68 - /// Create a factory function for blob stores 69 - pub fn creator(db: Arc<DbConn>) -> Box<dyn Fn(String) -> BlobStoreSql> { 70 - let db_clone = db.clone(); 71 - Box::new(move |did: String| BlobStoreSql::new(did, db_clone.clone())) 72 - } 64 + // /// Create a factory function for blob stores 65 + // pub fn creator( 66 + // db: deadpool_diesel::Connection<SqliteConnection>, 67 + // ) -> Box<dyn Fn(String) -> BlobStoreSql> { 68 + // let db_clone = db.clone(); 69 + // Box::new(move |did: String| BlobStoreSql::new(did, db_clone.clone())) 70 + // } 73 71 74 72 /// Store a blob temporarily - now just stores permanently with a key returned for API compatibility 75 73 pub async fn put_temp(&self, bytes: Vec<u8>) -> Result<String> { ··· 109 107 110 108 // Store directly in the database 111 109 self.db 112 - .run(move |conn| { 110 + .interact(move |conn| { 113 111 let data_clone = bytes.clone(); 114 112 let entry = BlobEntry { 115 113 cid: cid_str.clone(), ··· 128 126 .execute(conn) 129 127 .context("Failed to insert blob data") 130 128 }) 131 - .await?; 129 + .await 130 + .expect("Failed to store blob data")?; 132 131 133 132 Ok(()) 134 133 } ··· 146 145 147 146 // Update the quarantine flag in the database 148 147 self.db 149 - .run(move |conn| { 148 + .interact(move |conn| { 150 149 diesel::update(blobs::table) 151 150 .filter(blobs::cid.eq(&cid_str)) 152 151 .filter(blobs::did.eq(&did_clone)) ··· 154 153 .execute(conn) 155 154 .context("Failed to quarantine blob") 156 155 }) 157 - .await?; 156 + .await 157 + .expect("Failed to update quarantine status")?; 158 158 159 159 Ok(()) 160 160 } ··· 166 166 167 167 // Update the quarantine flag in the database 168 168 self.db 169 - .run(move |conn| { 169 + .interact(move |conn| { 170 170 diesel::update(blobs::table) 171 171 .filter(blobs::cid.eq(&cid_str)) 172 172 .filter(blobs::did.eq(&did_clone)) ··· 174 174 .execute(conn) 175 175 .context("Failed to unquarantine blob") 176 176 }) 177 - .await?; 177 + .await 178 + .expect("Failed to update unquarantine status")?; 178 179 179 180 Ok(()) 180 181 } ··· 189 190 // Get the blob data from the database 190 191 let blob_data = self 191 192 .db 192 - .run(move |conn| { 193 + .interact(move |conn| { 193 194 blobs 194 195 .filter(self::blobs::cid.eq(&cid_str)) 195 196 .filter(did.eq(&did_clone)) ··· 199 200 .optional() 200 201 .context("Failed to query blob data") 201 202 }) 202 - .await?; 203 + .await 204 + .expect("Failed to get blob data")?; 203 205 204 206 if let Some(bytes) = blob_data { 205 207 Ok(ByteStream::new(bytes)) ··· 227 229 228 230 // Delete from database 229 231 self.db 230 - .run(move |conn| { 232 + .interact(move |conn| { 231 233 diesel::delete(blobs) 232 234 .filter(self::blobs::cid.eq(&blob_cid)) 233 235 .filter(did.eq(&did_clone)) 234 236 .execute(conn) 235 237 .context("Failed to delete blob") 236 238 }) 237 - .await?; 239 + .await 240 + .expect("Failed to delete blob")?; 238 241 239 242 Ok(()) 240 243 } ··· 248 251 249 252 // Delete all blobs in one operation 250 253 self.db 251 - .run(move |conn| { 254 + .interact(move |conn| { 252 255 diesel::delete(blobs) 253 256 .filter(self::blobs::cid.eq_any(cid_strings)) 254 257 .filter(did.eq(&did_clone)) 255 258 .execute(conn) 256 259 .context("Failed to delete multiple blobs") 257 260 }) 258 - .await?; 261 + .await 262 + .expect("Failed to delete multiple blobs")?; 259 263 260 264 Ok(()) 261 265 } ··· 269 273 270 274 let exists = self 271 275 .db 272 - .run(move |conn| { 276 + .interact(move |conn| { 273 277 diesel::select(diesel::dsl::exists( 274 278 blobs 275 279 .filter(self::blobs::cid.eq(&cid_str)) ··· 278 282 .get_result::<bool>(conn) 279 283 .context("Failed to check if blob exists") 280 284 }) 281 - .await?; 285 + .await 286 + .expect("Failed to check blob existence")?; 282 287 283 288 Ok(exists) 284 289 }
+105 -93
src/actor_store/sql_repo.rs
··· 25 25 use std::sync::Arc; 26 26 use tokio::sync::RwLock; 27 27 28 - use crate::db::DbConn; 29 - 30 - #[derive(Clone, Debug)] 31 28 pub struct SqlRepoReader { 32 29 pub cache: Arc<RwLock<BlockMap>>, 33 - pub db: Arc<DbConn>, 30 + pub db: deadpool_diesel::Connection<SqliteConnection>, 34 31 pub root: Option<Cid>, 35 32 pub rev: Option<String>, 36 33 pub now: String, 37 34 pub did: String, 38 35 } 39 36 37 + impl std::fmt::Debug for SqlRepoReader { 38 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 39 + f.debug_struct("SqlRepoReader") 40 + .field("did", &self.did) 41 + .field("root", &self.root) 42 + .field("rev", &self.rev) 43 + .finish() 44 + } 45 + } 46 + 40 47 impl ReadableBlockstore for SqlRepoReader { 41 48 fn get_bytes<'life>( 42 49 &'life self, 43 50 cid: &'life Cid, 44 51 ) -> Pin<Box<dyn Future<Output = Result<Option<Vec<u8>>>> + Send + Sync + 'life>> { 45 52 let did: String = self.did.clone(); 46 - let db: Arc<DbConn> = self.db.clone(); 47 53 let cid = cid.clone(); 48 54 49 55 Box::pin(async move { ··· 56 62 return Ok(Some(cached_result.clone())); 57 63 } 58 64 59 - let found: Option<Vec<u8>> = db 60 - .run(move |conn| { 65 + let found: Option<Vec<u8>> = self 66 + .db 67 + .interact(move |conn| { 61 68 RepoBlockSchema::repo_block 62 69 .filter(RepoBlockSchema::cid.eq(cid.to_string())) 63 70 .filter(RepoBlockSchema::did.eq(did)) ··· 65 72 .first(conn) 66 73 .optional() 67 74 }) 68 - .await?; 75 + .await 76 + .expect("Failed to get block")?; 69 77 match found { 70 78 None => Ok(None), 71 79 Some(result) => { ··· 94 102 cids: Vec<Cid>, 95 103 ) -> Pin<Box<dyn Future<Output = Result<BlocksAndMissing>> + Send + Sync + 'life>> { 96 104 let did: String = self.did.clone(); 97 - let db: Arc<DbConn> = self.db.clone(); 98 105 99 106 Box::pin(async move { 100 107 use rsky_pds::schema::pds::repo_block::dsl as RepoBlockSchema; ··· 115 122 116 123 let _: Vec<_> = stream::iter(missing_strings.chunks(500)) 117 124 .then(|batch| { 118 - let this_db = db.clone(); 119 125 let this_did = did.clone(); 120 126 let blocks = Arc::clone(&blocks); 121 127 let missing = Arc::clone(&missing_set); ··· 123 129 124 130 async move { 125 131 // Database query 126 - let rows: Vec<(String, Vec<u8>)> = this_db 127 - .run(move |conn| { 132 + let rows: Vec<(String, Vec<u8>)> = self 133 + .db 134 + .interact(move |conn| { 128 135 RepoBlockSchema::repo_block 129 136 .filter(RepoBlockSchema::cid.eq_any(batch)) 130 137 .filter(RepoBlockSchema::did.eq(this_did)) 131 138 .select((RepoBlockSchema::cid, RepoBlockSchema::content)) 132 139 .load(conn) 133 140 }) 134 - .await?; 141 + .await 142 + .expect("Failed to get blocks")?; 135 143 136 144 // Process rows with locked access 137 145 let mut blocks = blocks.lock().await; ··· 191 199 rev: String, 192 200 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + Sync + 'life>> { 193 201 let did: String = self.did.clone(); 194 - let db: Arc<DbConn> = self.db.clone(); 195 202 let bytes_cloned = bytes.clone(); 196 203 Box::pin(async move { 197 204 use rsky_pds::schema::pds::repo_block::dsl as RepoBlockSchema; 198 205 199 - db.run(move |conn| { 200 - insert_into(RepoBlockSchema::repo_block) 201 - .values(( 202 - RepoBlockSchema::did.eq(did), 203 - RepoBlockSchema::cid.eq(cid.to_string()), 204 - RepoBlockSchema::repoRev.eq(rev), 205 - RepoBlockSchema::size.eq(bytes.len() as i32), 206 - RepoBlockSchema::content.eq(bytes), 207 - )) 208 - .execute(conn) 209 - }) 210 - .await?; 206 + self.db 207 + .interact(move |conn| { 208 + insert_into(RepoBlockSchema::repo_block) 209 + .values(( 210 + RepoBlockSchema::did.eq(did), 211 + RepoBlockSchema::cid.eq(cid.to_string()), 212 + RepoBlockSchema::repoRev.eq(rev), 213 + RepoBlockSchema::size.eq(bytes.len() as i32), 214 + RepoBlockSchema::content.eq(bytes), 215 + )) 216 + .execute(conn) 217 + }) 218 + .await 219 + .expect("Failed to put block")?; 211 220 { 212 221 let mut cache_guard = self.cache.write().await; 213 222 cache_guard.set(cid, bytes_cloned); ··· 222 231 rev: String, 223 232 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + Sync + 'life>> { 224 233 let did: String = self.did.clone(); 225 - let db: Arc<DbConn> = self.db.clone(); 226 234 227 235 Box::pin(async move { 228 236 use rsky_pds::schema::pds::repo_block::dsl as RepoBlockSchema; ··· 242 250 let chunks: Vec<Vec<RepoBlock>> = 243 251 blocks.chunks(50).map(|chunk| chunk.to_vec()).collect(); 244 252 245 - let _: Vec<_> = stream::iter(chunks) 246 - .then(|batch| { 247 - let db = db.clone(); 248 - async move { 249 - db.run(move |conn| { 250 - insert_or_ignore_into(RepoBlockSchema::repo_block) 251 - .values(batch) 252 - .execute(conn) 253 - .map(|_| ()) 254 - }) 255 - .await 256 - .map_err(anyhow::Error::from) 257 - } 258 - }) 259 - .collect::<Vec<_>>() 260 - .await 261 - .into_iter() 262 - .collect::<Result<Vec<()>>>()?; 253 + for batch in chunks { 254 + self.db 255 + .interact(move |conn| { 256 + insert_or_ignore_into(RepoBlockSchema::repo_block) 257 + .values(&batch) 258 + .execute(conn) 259 + }) 260 + .await 261 + .expect("Failed to insert blocks")?; 262 + } 263 263 264 264 Ok(()) 265 265 }) ··· 271 271 is_create: Option<bool>, 272 272 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + Sync + 'life>> { 273 273 let did: String = self.did.clone(); 274 - let db: Arc<DbConn> = self.db.clone(); 275 274 let now: String = self.now.clone(); 276 275 277 276 Box::pin(async move { ··· 279 278 280 279 let is_create = is_create.unwrap_or(false); 281 280 if is_create { 282 - db.run(move |conn| { 283 - insert_into(RepoRootSchema::repo_root) 284 - .values(( 285 - RepoRootSchema::did.eq(did), 286 - RepoRootSchema::cid.eq(cid.to_string()), 287 - RepoRootSchema::rev.eq(rev), 288 - RepoRootSchema::indexedAt.eq(now), 289 - )) 290 - .execute(conn) 291 - }) 292 - .await?; 281 + self.db 282 + .interact(move |conn| { 283 + insert_into(RepoRootSchema::repo_root) 284 + .values(( 285 + RepoRootSchema::did.eq(did), 286 + RepoRootSchema::cid.eq(cid.to_string()), 287 + RepoRootSchema::rev.eq(rev), 288 + RepoRootSchema::indexedAt.eq(now), 289 + )) 290 + .execute(conn) 291 + }) 292 + .await 293 + .expect("Failed to create root")?; 293 294 } else { 294 - db.run(move |conn| { 295 - update(RepoRootSchema::repo_root) 296 - .filter(RepoRootSchema::did.eq(did)) 297 - .set(( 298 - RepoRootSchema::cid.eq(cid.to_string()), 299 - RepoRootSchema::rev.eq(rev), 300 - RepoRootSchema::indexedAt.eq(now), 301 - )) 302 - .execute(conn) 303 - }) 304 - .await?; 295 + self.db 296 + .interact(move |conn| { 297 + update(RepoRootSchema::repo_root) 298 + .filter(RepoRootSchema::did.eq(did)) 299 + .set(( 300 + RepoRootSchema::cid.eq(cid.to_string()), 301 + RepoRootSchema::rev.eq(rev), 302 + RepoRootSchema::indexedAt.eq(now), 303 + )) 304 + .execute(conn) 305 + }) 306 + .await 307 + .expect("Failed to update root")?; 305 308 } 306 309 Ok(()) 307 310 }) ··· 324 327 325 328 // Basically handles getting ipld blocks from db 326 329 impl SqlRepoReader { 327 - pub fn new(did: String, now: Option<String>, db: Arc<DbConn>) -> Self { 330 + pub fn new( 331 + did: String, 332 + now: Option<String>, 333 + db: deadpool_diesel::Connection<SqliteConnection>, 334 + ) -> Self { 328 335 let now = now.unwrap_or_else(rsky_common::now); 329 336 SqlRepoReader { 330 337 cache: Arc::new(RwLock::new(BlockMap::new())), ··· 371 378 cursor: &Option<CidAndRev>, 372 379 ) -> Result<Vec<RepoBlock>> { 373 380 let did: String = self.did.clone(); 374 - let db: Arc<DbConn> = self.db.clone(); 375 381 let since = since.clone(); 376 382 let cursor = cursor.clone(); 377 383 use rsky_pds::schema::pds::repo_block::dsl as RepoBlockSchema; 378 384 379 - Ok(db 380 - .run(move |conn| { 385 + Ok(self 386 + .db 387 + .interact(move |conn| { 381 388 let mut builder = RepoBlockSchema::repo_block 382 389 .select(RepoBlock::as_select()) 383 390 .order((RepoBlockSchema::repoRev.desc(), RepoBlockSchema::cid.desc())) ··· 404 411 } 405 412 builder.load(conn) 406 413 }) 407 - .await?) 414 + .await 415 + .expect("Failed to get block range")?) 408 416 } 409 417 410 418 pub async fn count_blocks(&self) -> Result<i64> { 411 419 let did: String = self.did.clone(); 412 - let db: Arc<DbConn> = self.db.clone(); 413 420 use rsky_pds::schema::pds::repo_block::dsl as RepoBlockSchema; 414 421 415 - let res = db 416 - .run(move |conn| { 422 + let res = self 423 + .db 424 + .interact(move |conn| { 417 425 RepoBlockSchema::repo_block 418 426 .filter(RepoBlockSchema::did.eq(did)) 419 427 .count() 420 428 .get_result(conn) 421 429 }) 422 - .await?; 430 + .await 431 + .expect("Failed to count blocks")?; 423 432 Ok(res) 424 433 } 425 434 ··· 429 438 /// Proactively cache all blocks from a particular commit (to prevent multiple roundtrips) 430 439 pub async fn cache_rev(&mut self, rev: String) -> Result<()> { 431 440 let did: String = self.did.clone(); 432 - let db: Arc<DbConn> = self.db.clone(); 433 441 use rsky_pds::schema::pds::repo_block::dsl as RepoBlockSchema; 434 442 435 - let result: Vec<(String, Vec<u8>)> = db 436 - .run(move |conn| { 443 + let result: Vec<(String, Vec<u8>)> = self 444 + .db 445 + .interact(move |conn| { 437 446 RepoBlockSchema::repo_block 438 447 .filter(RepoBlockSchema::did.eq(did)) 439 448 .filter(RepoBlockSchema::repoRev.eq(rev)) ··· 441 450 .limit(15) 442 451 .get_results::<(String, Vec<u8>)>(conn) 443 452 }) 444 - .await?; 453 + .await 454 + .expect("Failed to cache rev")?; 445 455 for row in result { 446 456 let mut cache_guard = self.cache.write().await; 447 457 cache_guard.set(Cid::from_str(&row.0)?, row.1) ··· 454 464 return Ok(()); 455 465 } 456 466 let did: String = self.did.clone(); 457 - let db: Arc<DbConn> = self.db.clone(); 458 467 use rsky_pds::schema::pds::repo_block::dsl as RepoBlockSchema; 459 468 460 469 let cid_strings: Vec<String> = cids.into_iter().map(|c| c.to_string()).collect(); 461 - db.run(move |conn| { 462 - delete(RepoBlockSchema::repo_block) 463 - .filter(RepoBlockSchema::did.eq(did)) 464 - .filter(RepoBlockSchema::cid.eq_any(cid_strings)) 465 - .execute(conn) 466 - }) 467 - .await?; 470 + self.db 471 + .interact(move |conn| { 472 + delete(RepoBlockSchema::repo_block) 473 + .filter(RepoBlockSchema::did.eq(did)) 474 + .filter(RepoBlockSchema::cid.eq_any(cid_strings)) 475 + .execute(conn) 476 + }) 477 + .await 478 + .expect("Failed to delete many")?; 468 479 Ok(()) 469 480 } 470 481 471 482 pub async fn get_root_detailed(&self) -> Result<CidAndRev> { 472 483 let did: String = self.did.clone(); 473 - let db: Arc<DbConn> = self.db.clone(); 474 484 use rsky_pds::schema::pds::repo_root::dsl as RepoRootSchema; 475 485 476 - let res = db 477 - .run(move |conn| { 486 + let res = self 487 + .db 488 + .interact(move |conn| { 478 489 RepoRootSchema::repo_root 479 490 .filter(RepoRootSchema::did.eq(did)) 480 491 .select(models::RepoRoot::as_select()) 481 492 .first(conn) 482 493 }) 483 - .await?; 494 + .await 495 + .expect("Failed to get root")?; 484 496 485 497 Ok(CidAndRev { 486 498 cid: Cid::from_str(&res.cid)?,
+56 -24
src/auth.rs
··· 5 5 }; 6 6 use axum::{extract::FromRequestParts, http::StatusCode}; 7 7 use base64::Engine as _; 8 + use diesel::prelude::*; 8 9 use sha2::{Digest as _, Sha256}; 9 10 10 - use crate::{AppState, Error, error::ErrorMessage}; 11 + use crate::{AppState, Error, db::DbConn, error::ErrorMessage}; 11 12 12 13 /// Request extractor for authenticated users. 13 14 /// If specified in an API endpoint, this guarantees the API can only be called ··· 129 130 130 131 // Extract subject (DID) 131 132 if let Some(did) = claims.get("sub").and_then(serde_json::Value::as_str) { 132 - let _status = sqlx::query_scalar!(r#"SELECT status FROM accounts WHERE did = ?"#, did) 133 - .fetch_one(&state.db) 133 + // Convert SQLx query to Diesel query 134 + use crate::schema::accounts::dsl as AccountSchema; 135 + 136 + let _status = state 137 + .db 138 + .run(move |conn| { 139 + AccountSchema::accounts 140 + .filter(AccountSchema::did.eq(did.to_string())) 141 + .select(AccountSchema::status) 142 + .first::<String>(conn) 143 + }) 134 144 .await 135 145 .with_context(|| format!("failed to query account {did}")) 136 146 .context("should fetch account status")?; ··· 326 336 327 337 let timestamp = chrono::Utc::now().timestamp(); 328 338 339 + // Convert SQLx JTI check to Diesel 340 + use crate::schema::oauth_used_jtis::dsl as JtiSchema; 341 + 329 342 // Check if JTI has been used before 330 - let jti_used = 331 - sqlx::query_scalar!(r#"SELECT COUNT(*) FROM oauth_used_jtis WHERE jti = ?"#, jti) 332 - .fetch_one(&state.db) 333 - .await 334 - .context("failed to check JTI")?; 343 + let jti_string = jti.to_string(); 344 + let jti_used = state 345 + .db 346 + .run(move |conn| { 347 + JtiSchema::oauth_used_jtis 348 + .filter(JtiSchema::jti.eq(jti_string)) 349 + .count() 350 + .get_result::<i64>(conn) 351 + }) 352 + .await 353 + .context("failed to check JTI")?; 335 354 336 355 if jti_used > 0 { 337 356 return Err(Error::with_status( ··· 347 366 .and_then(serde_json::Value::as_i64) 348 367 .unwrap_or_else(|| timestamp.checked_add(60).unwrap_or(timestamp)); 349 368 350 - _ = sqlx::query!( 351 - r#" 352 - INSERT INTO oauth_used_jtis (jti, issuer, created_at, expires_at) 353 - VALUES (?, ?, ?, ?) 354 - "#, 355 - jti, 356 - calculated_thumbprint, // Use thumbprint as issuer identifier 357 - timestamp, 358 - exp 359 - ) 360 - .execute(&state.db) 361 - .await 362 - .context("failed to store JTI")?; 369 + // Convert SQLx INSERT to Diesel 370 + let jti_str = jti.to_string(); 371 + let thumbprint_str = calculated_thumbprint.to_string(); 372 + state 373 + .db 374 + .run(move |conn| { 375 + diesel::insert_into(JtiSchema::oauth_used_jtis) 376 + .values(( 377 + JtiSchema::jti.eq(jti_str), 378 + JtiSchema::issuer.eq(thumbprint_str), 379 + JtiSchema::created_at.eq(timestamp), 380 + JtiSchema::expires_at.eq(exp), 381 + )) 382 + .execute(conn) 383 + }) 384 + .await 385 + .context("failed to store JTI")?; 363 386 364 387 // Extract subject (DID) from access token 365 - if let Some(did) = claims.get("sub").and_then(|v| v.as_str()) { 366 - let _status = sqlx::query_scalar!(r#"SELECT status FROM accounts WHERE did = ?"#, did) 367 - .fetch_one(&state.db) 388 + if let Some(did) = claims.get("sub").and_then(|v| v.as_str) { 389 + // Convert SQLx query to Diesel 390 + use crate::schema::accounts::dsl as AccountSchema; 391 + 392 + let _status = state 393 + .db 394 + .run(move |conn| { 395 + AccountSchema::accounts 396 + .filter(AccountSchema::did.eq(did.to_string())) 397 + .select(AccountSchema::status) 398 + .first::<String>(conn) 399 + }) 368 400 .await 369 401 .with_context(|| format!("failed to query account {did}")) 370 402 .context("should fetch account status")?;
+12 -23
src/db.rs
··· 1 1 use anyhow::Result; 2 - use diesel::prelude::*; 3 - use dotenvy::dotenv; 4 - use rocket_sync_db_pools::database; 5 - use std::env; 6 - use std::fmt::{Debug, Formatter}; 7 - 8 - #[database("sqlite_db")] 9 - pub struct DbConn(SqliteConnection); 10 - 11 - impl Debug for DbConn { 12 - fn fmt(&self, _f: &mut Formatter<'_>) -> std::fmt::Result { 13 - todo!() 14 - } 15 - } 2 + use deadpool_diesel::sqlite::{Manager, Pool, Runtime}; 16 3 17 4 #[tracing::instrument(skip_all)] 18 - pub fn establish_connection_for_sequencer() -> Result<SqliteConnection> { 19 - dotenv().ok(); 20 - tracing::debug!("Establishing database connection for Sequencer"); 21 - let database_url = env::var("BLUEPDS_DB").unwrap_or("sqlite://data/sqlite.db".into()); 22 - let db = SqliteConnection::establish(&database_url).map_err(|error| { 23 - let context = format!("Error connecting to {database_url:?}"); 24 - anyhow::Error::new(error).context(context) 25 - })?; 26 - Ok(db) 5 + /// Establish a connection to the database 6 + /// Takes a database URL as an argument (like "sqlite://data/sqlite.db") 7 + pub(crate) fn establish_pool(database_url: &str) -> Result<Pool> { 8 + tracing::debug!("Establishing database connection"); 9 + let manager = Manager::new(database_url, Runtime::Tokio1); 10 + let pool = Pool::builder(manager) 11 + .max_size(8) 12 + .build() 13 + .expect("should be able to create connection pool"); 14 + tracing::debug!("Database connection established"); 15 + Ok(pool) 27 16 }
+11 -4
src/endpoints/repo/apply_writes.rs
··· 1 1 //! Apply a batch transaction of repository creates, updates, and deletes. Requires auth, implemented by PDS. 2 2 use crate::{ 3 - AppState, Db, Error, Result, SigningKey, 4 - actor_store::ActorStore, 5 - actor_store::sql_blob::BlobStoreSql, 3 + ActorPools, AppState, Db, Error, Result, SigningKey, 4 + actor_store::{ActorStore, sql_blob::BlobStoreSql}, 6 5 auth::AuthenticatedUser, 7 6 config::AppConfig, 8 7 error::ErrorMessage, ··· 66 65 State(skey): State<SigningKey>, 67 66 State(config): State<AppConfig>, 68 67 State(db): State<Db>, 68 + State(db_actors): State<std::collections::HashMap<String, ActorPools>>, 69 69 State(fhp): State<FirehoseProducer>, 70 70 Json(input): Json<ApplyWritesInput>, 71 71 ) -> Result<Json<repo::apply_writes::Output>> { ··· 156 156 None => None, 157 157 }; 158 158 159 - let mut actor_store = ActorStore::new(did.clone(), BlobStoreSql::new(did.clone(), db), db); 159 + let actor_db = db_actors 160 + .get(did) 161 + .ok_or_else(|| anyhow!("Actor DB not found"))?; 162 + let mut actor_store = ActorStore::new( 163 + did.clone(), 164 + BlobStoreSql::new(did.clone(), actor_db.blob), 165 + actor_db.repo, 166 + ); 160 167 161 168 let commit = actor_store 162 169 .process_writes(writes.clone(), swap_commit_cid)
+88 -26
src/main.rs
··· 37 37 use clap::Parser; 38 38 use clap_verbosity_flag::{InfoLevel, Verbosity, log::LevelFilter}; 39 39 use config::AppConfig; 40 + use db::establish_pool; 41 + use deadpool_diesel::sqlite::Pool; 40 42 use diesel::prelude::*; 41 - use diesel::r2d2::{self, ConnectionManager}; 42 - use diesel_migrations::{EmbeddedMigrations, MigrationHarness, embed_migrations}; 43 + use diesel_migrations::{EmbeddedMigrations, embed_migrations}; 43 44 #[expect(clippy::pub_use, clippy::useless_attribute)] 44 45 pub use error::Error; 45 46 use figment::{Figment, providers::Format as _}; ··· 68 69 pub type Result<T> = std::result::Result<T, Error>; 69 70 /// The reqwest client type with middleware. 70 71 pub type Client = reqwest_middleware::ClientWithMiddleware; 71 - /// The database connection pool. 72 - pub type Db = r2d2::Pool<ConnectionManager<SqliteConnection>>; 73 72 /// The Azure credential type. 74 73 pub type Cred = Arc<dyn TokenCredential>; 75 74 ··· 132 131 verbosity: Verbosity<InfoLevel>, 133 132 } 134 133 134 + struct ActorPools { 135 + repo: Pool, 136 + blob: Pool, 137 + } 138 + impl Clone for ActorPools { 139 + fn clone(&self) -> Self { 140 + Self { 141 + repo: self.repo.clone(), 142 + blob: self.blob.clone(), 143 + } 144 + } 145 + } 146 + 135 147 #[expect(clippy::arbitrary_source_item_ordering, reason = "arbitrary")] 136 148 #[derive(Clone, FromRef)] 137 149 struct AppState { ··· 139 151 config: AppConfig, 140 152 /// The Azure credential. 141 153 cred: Cred, 142 - /// The database connection pool. 143 - db: Db, 154 + /// The main database connection pool. Used for common PDS data, like invite codes. 155 + db: Pool, 156 + /// Actor-specific database connection pools. Hashed by DID. 157 + db_actors: std::collections::HashMap<String, ActorPools>, 144 158 145 159 /// The HTTP client with middleware. 146 160 client: Client, ··· 291 305 #[expect( 292 306 clippy::cognitive_complexity, 293 307 clippy::too_many_lines, 308 + unused_qualifications, 294 309 reason = "main function has high complexity" 295 310 )] 296 311 async fn run() -> anyhow::Result<()> { ··· 388 403 let cred = azure_identity::DefaultAzureCredential::new() 389 404 .context("failed to create Azure credential")?; 390 405 391 - // Create a database connection manager and pool 392 - let manager = ConnectionManager::<SqliteConnection>::new(&config.db); 393 - let db = r2d2::Pool::builder() 394 - .build(manager) 395 - .context("failed to create database connection pool")?; 406 + // Create a database connection manager and pool for the main database. 407 + let pool = 408 + establish_pool(&config.db).context("failed to establish database connection pool")?; 409 + // Create a dictionary of database connection pools for each actor. 410 + let mut actor_pools = std::collections::HashMap::new(); 411 + // let mut actor_blob_pools = std::collections::HashMap::new(); 412 + // We'll determine actors by looking in the data/repo dir for .db files. 413 + let mut actor_dbs = tokio::fs::read_dir(&config.repo.path) 414 + .await 415 + .context("failed to read repo directory")?; 416 + while let Some(entry) = actor_dbs 417 + .next_entry() 418 + .await 419 + .context("failed to read repo dir")? 420 + { 421 + let path = entry.path(); 422 + if path.extension().and_then(|s| s.to_str()) == Some("db") { 423 + let did = path 424 + .file_stem() 425 + .and_then(|s| s.to_str()) 426 + .context("failed to get actor DID")?; 427 + let did = Did::from_str(did).expect("should be able to parse actor DID"); 396 428 429 + // Create a new database connection manager and pool for the actor. 430 + // The path for the SQLite connection needs to look like "sqlite://data/repo/<actor>.db" 431 + let path_repo = format!("sqlite://{}", path.display()); 432 + let actor_repo_pool = 433 + establish_pool(&path_repo).context("failed to create database connection pool")?; 434 + // Create a new database connection manager and pool for the actor blobs. 435 + // The path for the SQLite connection needs to look like "sqlite://data/blob/<actor>.db" 436 + let path_blob = path_repo.replace("repo", "blob"); 437 + let actor_blob_pool = 438 + establish_pool(&path_blob).context("failed to create database connection pool")?; 439 + actor_pools.insert( 440 + did.to_string(), 441 + ActorPools { 442 + repo: actor_repo_pool, 443 + blob: actor_blob_pool, 444 + }, 445 + ); 446 + } 447 + } 397 448 // Apply pending migrations 398 - let conn = &mut db 399 - .get() 400 - .context("failed to get database connection for migrations")?; 401 - conn.run_pending_migrations(MIGRATIONS) 402 - .expect("should be able to run migrations"); 449 + // let conn = pool.get().await?; 450 + // conn.run_pending_migrations(MIGRATIONS) 451 + // .expect("should be able to run migrations"); 403 452 404 453 let (_fh, fhp) = firehose::spawn(client.clone(), config.clone()); 405 454 ··· 422 471 .with_state(AppState { 423 472 cred, 424 473 config: config.clone(), 425 - db: db.clone(), 474 + db: pool.clone(), 475 + db_actors: actor_pools.clone(), 426 476 client: client.clone(), 427 477 simple_client, 428 478 firehose: fhp, ··· 435 485 436 486 // Determine whether or not this was the first startup (i.e. no accounts exist and no invite codes were created). 437 487 // If so, create an invite code and share it via the console. 438 - let conn = &mut db.get().context("failed to get database connection")?; 488 + let conn = pool.get().await.context("failed to get db connection")?; 439 489 440 490 #[derive(QueryableByName)] 441 491 struct TotalCount { ··· 443 493 total_count: i32, 444 494 } 445 495 446 - let result = diesel::sql_query( 447 - "SELECT (SELECT COUNT(*) FROM accounts) + (SELECT COUNT(*) FROM invites) AS total_count", 448 - ) 449 - .get_result::<TotalCount>(conn) 450 - .context("failed to query database")?; 496 + // let result = diesel::sql_query( 497 + // "SELECT (SELECT COUNT(*) FROM accounts) + (SELECT COUNT(*) FROM invites) AS total_count", 498 + // ) 499 + // .get_result::<TotalCount>(conn) 500 + // .context("failed to query database")?; 501 + let result = conn.interact(move |conn| { 502 + diesel::sql_query( 503 + "SELECT (SELECT COUNT(*) FROM accounts) + (SELECT COUNT(*) FROM invites) AS total_count", 504 + ) 505 + .get_result::<TotalCount>(conn) 506 + }) 507 + .await 508 + .expect("should be able to query database")?; 451 509 452 510 let c = result.total_count; 453 511 ··· 455 513 if c == 0 { 456 514 let uuid = Uuid::new_v4().to_string(); 457 515 458 - diesel::sql_query( 516 + let uuid_clone = uuid.clone(); 517 + conn.interact(move |conn| { 518 + diesel::sql_query( 459 519 "INSERT INTO invites (id, did, count, created_at) VALUES (?, NULL, 1, datetime('now'))", 460 520 ) 461 - .bind::<diesel::sql_types::Text, _>(uuid.clone()) 521 + .bind::<diesel::sql_types::Text, _>(uuid_clone) 462 522 .execute(conn) 463 - .context("failed to create new invite code")?; 523 + .context("failed to create new invite code") 524 + .expect("should be able to create invite code") 525 + }); 464 526 465 527 // N.B: This is a sensitive message, so we're bypassing `tracing` here and 466 528 // logging it directly to console.
+1 -1
src/tests.rs
··· 222 222 let opts = SqliteConnectOptions::from_str(&config.db) 223 223 .context("failed to parse database options")? 224 224 .create_if_missing(true); 225 - let db = SqlitePool::connect_with(opts).await?; 225 + let db = SqliteDbConn::connect_with(opts).await?; 226 226 227 227 sqlx::migrate!() 228 228 .run(&db)