Alternative ATProto PDS implementation

add db module

+94
src/db/cast.rs
··· 1 + //! Type-safe casting utilities. 2 + 3 + use serde::{Serialize, de::DeserializeOwned}; 4 + use std::fmt; 5 + 6 + /// Represents an ISO 8601 date string (e.g., "2023-01-01T12:00:00Z"). 7 + #[derive(Debug, Clone, PartialEq, Eq)] 8 + pub struct DateISO(String); 9 + 10 + impl DateISO { 11 + /// Converts a `chrono::DateTime<Utc>` to a `DateISO`. 12 + pub fn from_date(date: chrono::DateTime<chrono::Utc>) -> Self { 13 + Self(date.to_rfc3339()) 14 + } 15 + 16 + /// Converts a `DateISO` back to a `chrono::DateTime<Utc>`. 17 + pub fn to_date(&self) -> Result<chrono::DateTime<chrono::Utc>, chrono::ParseError> { 18 + self.0.parse::<chrono::DateTime<chrono::Utc>>() 19 + } 20 + } 21 + 22 + impl fmt::Display for DateISO { 23 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 24 + write!(f, "{}", self.0) 25 + } 26 + } 27 + 28 + /// Represents a JSON-encoded string. 29 + #[derive(Debug, Clone, PartialEq, Eq)] 30 + pub struct JsonEncoded<T: Serialize>(String, std::marker::PhantomData<T>); 31 + 32 + impl<T: Serialize> JsonEncoded<T> { 33 + /// Encodes a value into a JSON string. 34 + pub fn to_json(value: &T) -> Result<Self, serde_json::Error> { 35 + let json = serde_json::to_string(value)?; 36 + Ok(Self(json, std::marker::PhantomData)) 37 + } 38 + 39 + /// Decodes a JSON string back into a value. 40 + pub fn from_json(json_str: &str) -> Result<T, serde_json::Error> 41 + where 42 + T: DeserializeOwned, 43 + { 44 + serde_json::from_str(json_str) 45 + } 46 + 47 + /// Returns the underlying JSON string. 48 + pub fn as_str(&self) -> &str { 49 + &self.0 50 + } 51 + } 52 + 53 + impl<T: Serialize> fmt::Display for JsonEncoded<T> { 54 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 55 + write!(f, "{}", self.0) 56 + } 57 + } 58 + 59 + #[cfg(test)] 60 + mod tests { 61 + use super::*; 62 + use chrono::Utc; 63 + use serde::{Deserialize, Serialize}; 64 + 65 + #[test] 66 + fn test_date_iso() { 67 + let now = Utc::now(); 68 + let date_iso = DateISO::from_date(now); 69 + let parsed_date = date_iso.to_date().unwrap(); 70 + assert_eq!(now.to_rfc3339(), parsed_date.to_rfc3339()); 71 + } 72 + 73 + #[derive(Serialize, Deserialize, PartialEq, Debug)] 74 + struct TestStruct { 75 + name: String, 76 + value: i32, 77 + } 78 + 79 + #[test] 80 + fn test_json_encoded() { 81 + let test_value = TestStruct { 82 + name: "example".to_string(), 83 + value: 42, 84 + }; 85 + 86 + // Encode to JSON 87 + let encoded = JsonEncoded::to_json(&test_value).unwrap(); 88 + assert_eq!(encoded.as_str(), r#"{"name":"example","value":42}"#); 89 + 90 + // Decode from JSON 91 + let decoded: TestStruct = JsonEncoded::from_json(encoded.as_str()).unwrap(); 92 + assert_eq!(decoded, test_value); 93 + } 94 + }
+111
src/db/db.rs
··· 1 + //! Database connection and transaction management. 2 + 3 + use sqlx::{ 4 + Sqlite, Transaction, 5 + sqlite::{SqliteConnectOptions, SqlitePool, SqliteQueryResult, SqliteTransactionManager}, 6 + }; 7 + use std::collections::VecDeque; 8 + use std::str::FromStr; 9 + use std::sync::{Arc, Mutex}; 10 + use tokio::sync::Mutex as AsyncMutex; 11 + 12 + use crate::db::util::retry_sqlite; 13 + 14 + /// Default pragmas for SQLite. 15 + const DEFAULT_PRAGMAS: &[(&str, &str)] = &[ 16 + // Add default pragmas here if needed, e.g., ("foreign_keys", "ON") 17 + ]; 18 + 19 + /// Database struct for managing SQLite connections and transactions. 20 + pub struct Database { 21 + /// SQLite connection pool. 22 + db: Arc<SqlitePool>, 23 + /// Flag indicating if the database is destroyed. 24 + destroyed: Arc<Mutex<bool>>, 25 + /// Queue of commit hooks. 26 + commit_hooks: Arc<AsyncMutex<VecDeque<Box<dyn FnOnce() + Send>>>>, 27 + } 28 + 29 + impl Database { 30 + /// Creates a new database instance with the given location and optional pragmas. 31 + pub async fn new(location: &str, pragmas: Option<&[(&str, &str)]>) -> sqlx::Result<Self> { 32 + let mut options = SqliteConnectOptions::from_str(location)?.create_if_missing(true); 33 + 34 + // Apply default and user-provided pragmas. 35 + for &(key, value) in DEFAULT_PRAGMAS.iter().chain(pragmas.unwrap_or(&[])) { 36 + options = options.pragma(key.to_string(), value.to_string()); 37 + } 38 + 39 + let pool = SqlitePool::connect_with(options).await?; 40 + Ok(Self { 41 + db: Arc::new(pool), 42 + destroyed: Arc::new(Mutex::new(false)), 43 + commit_hooks: Arc::new(AsyncMutex::new(VecDeque::new())), 44 + }) 45 + } 46 + 47 + /// Ensures the database is using Write-Ahead Logging (WAL) mode. 48 + pub async fn ensure_wal(&self) -> sqlx::Result<()> { 49 + let mut conn = self.db.acquire().await?; 50 + sqlx::query("PRAGMA journal_mode = WAL") 51 + .execute(&mut *conn) 52 + .await?; 53 + Ok(()) 54 + } 55 + 56 + /// Executes a transaction without retry logic. 57 + pub async fn transaction_no_retry<F, T>(&self, func: F) -> sqlx::Result<T> 58 + where 59 + F: FnOnce(&mut Transaction<'_, Sqlite>) -> sqlx::Result<T>, 60 + { 61 + let mut tx = self.db.begin().await?; 62 + let result = func(&mut tx)?; 63 + tx.commit().await?; 64 + self.run_commit_hooks().await; 65 + Ok(result) 66 + } 67 + 68 + /// Executes a transaction with retry logic. 69 + pub async fn transaction<F, T>(&self, func: F) -> sqlx::Result<T> 70 + where 71 + F: FnOnce(&mut Transaction<'_, Sqlite>) -> sqlx::Result<T> + Copy, 72 + { 73 + retry_sqlite(|| self.transaction_no_retry(func)).await 74 + } 75 + 76 + /// Executes a query with retry logic. 77 + pub async fn execute_with_retry<F, T>(&self, query: F) -> sqlx::Result<T> 78 + where 79 + F: Fn() -> std::pin::Pin<Box<dyn futures::Future<Output = sqlx::Result<T>> + Send>> + Copy, 80 + { 81 + retry_sqlite(|| query()).await 82 + } 83 + 84 + /// Adds a commit hook to be executed after a successful transaction. 85 + pub async fn on_commit<F>(&self, hook: F) 86 + where 87 + F: FnOnce() + Send + 'static, 88 + { 89 + let mut hooks = self.commit_hooks.lock().await; 90 + hooks.push_back(Box::new(hook)); 91 + } 92 + 93 + /// Closes the database connection pool. 94 + pub async fn close(&self) -> sqlx::Result<()> { 95 + let mut destroyed = self.destroyed.lock().unwrap(); 96 + if *destroyed { 97 + return Ok(()); 98 + } 99 + *destroyed = true; 100 + drop(self.db.clone()); // Drop the pool to close connections. 101 + Ok(()) 102 + } 103 + 104 + /// Runs all commit hooks in the queue. 105 + async fn run_commit_hooks(&self) { 106 + let mut hooks = self.commit_hooks.lock().await; 107 + while let Some(hook) = hooks.pop_front() { 108 + hook(); 109 + } 110 + } 111 + }
+64
src/db/migrator.rs
··· 1 + //! Database migration management. 2 + 3 + use sqlx::{SqlitePool, migrate::Migrator}; 4 + use std::path::Path; 5 + use thiserror::Error; 6 + 7 + /// Error type for migration-related issues. 8 + #[derive(Debug, Error)] 9 + pub enum MigrationError { 10 + #[error("Migration failed: {0}")] 11 + MigrationFailed(String), 12 + #[error("Unknown failure occurred while migrating")] 13 + UnknownFailure, 14 + } 15 + 16 + /// Migrator struct for managing database migrations. 17 + pub struct DatabaseMigrator { 18 + /// SQLx migrator instance. 19 + migrator: Migrator, 20 + /// SQLite connection pool. 21 + db: SqlitePool, 22 + } 23 + 24 + impl DatabaseMigrator { 25 + /// Creates a new `DatabaseMigrator` instance. 26 + /// 27 + /// # Arguments 28 + /// - `migrations_path`: Path to the directory containing migration files. 29 + /// - `db`: SQLite connection pool. 30 + pub async fn new(migrations_path: &Path, db: SqlitePool) -> Self { 31 + let migrator = Migrator::new(migrations_path) 32 + .await 33 + .expect("Failed to initialize migrator"); 34 + Self { migrator, db } 35 + } 36 + 37 + /// Migrates the database to a specific migration or throws an error. 38 + /// 39 + /// # Arguments 40 + /// - `migration`: The target migration name. 41 + /// 42 + /// # Unimplemented 43 + /// This currently runs all migrations instead of a specific one. 44 + pub async fn migrate_to_or_throw(&self, _migration: &str) -> Result<(), MigrationError> { 45 + // TODO: Implement migration to a specific version 46 + // For now, we will just run all migrations 47 + let result = self.migrator.run(&self.db).await; 48 + 49 + match result { 50 + Ok(_) => Ok(()), 51 + Err(err) => Err(MigrationError::MigrationFailed(err.to_string())), 52 + } 53 + } 54 + 55 + /// Migrates the database to the latest migration or throws an error. 56 + pub async fn migrate_to_latest_or_throw(&self) -> Result<(), MigrationError> { 57 + let result = self.migrator.run(&self.db).await; 58 + 59 + match result { 60 + Ok(_) => Ok(()), 61 + Err(err) => Err(MigrationError::MigrationFailed(err.to_string())), 62 + } 63 + } 64 + }
+5
src/db/mod.rs
··· 1 + mod cast; 2 + mod db; 3 + mod migrator; 4 + mod pagination; 5 + mod util;
+163
src/db/pagination.rs
··· 1 + use std::fmt::Debug; 2 + 3 + /// Represents a cursor with primary and secondary parts. 4 + #[derive(Debug, Clone)] 5 + pub struct Cursor { 6 + pub primary: String, 7 + pub secondary: String, 8 + } 9 + 10 + /// Represents a labeled result with primary and secondary parts. 11 + #[derive(Debug, Clone)] 12 + pub struct LabeledResult { 13 + pub primary: String, 14 + pub secondary: String, 15 + } 16 + 17 + /// Trait defining the interface for a keyset-paginated cursor. 18 + pub trait GenericKeyset<R, LR: Debug> { 19 + fn label_result(&self, result: R) -> LR; 20 + fn labeled_result_to_cursor(&self, labeled: LR) -> Cursor; 21 + fn cursor_to_labeled_result(&self, cursor: Cursor) -> LR; 22 + 23 + fn pack_from_result(&self, results: Vec<R>) -> Option<String> { 24 + todo!() 25 + // results 26 + // .last() 27 + // .map(|result| self.pack(Some(self.label_result(result.clone())))) 28 + } 29 + 30 + fn pack(&self, labeled: Option<LR>) -> Option<String> { 31 + labeled.map(|l| self.pack_cursor(self.labeled_result_to_cursor(l))) 32 + } 33 + 34 + fn unpack(&self, cursor_str: Option<String>) -> Option<LR> { 35 + cursor_str 36 + .and_then(|cursor| self.unpack_cursor(cursor)) 37 + .map(|c| self.cursor_to_labeled_result(c)) 38 + } 39 + 40 + fn pack_cursor(&self, cursor: Cursor) -> String { 41 + format!("{}::{}", cursor.primary, cursor.secondary) 42 + } 43 + 44 + fn unpack_cursor(&self, cursor_str: String) -> Option<Cursor> { 45 + let parts: Vec<&str> = cursor_str.split("::").collect(); 46 + if parts.len() == 2 { 47 + Some(Cursor { 48 + primary: parts[0].to_string(), 49 + secondary: parts[1].to_string(), 50 + }) 51 + } else { 52 + None 53 + } 54 + } 55 + } 56 + 57 + /// A concrete implementation of `GenericKeyset` for time and CID-based pagination. 58 + pub struct TimeCidKeyset; 59 + 60 + impl TimeCidKeyset { 61 + pub fn new() -> Self { 62 + Self 63 + } 64 + } 65 + 66 + impl GenericKeyset<CreatedAtCidResult, LabeledResult> for TimeCidKeyset { 67 + fn label_result(&self, result: CreatedAtCidResult) -> LabeledResult { 68 + LabeledResult { 69 + primary: result.created_at, 70 + secondary: result.cid, 71 + } 72 + } 73 + 74 + fn labeled_result_to_cursor(&self, labeled: LabeledResult) -> Cursor { 75 + Cursor { 76 + primary: labeled.primary, 77 + secondary: labeled.secondary, 78 + } 79 + } 80 + 81 + fn cursor_to_labeled_result(&self, cursor: Cursor) -> LabeledResult { 82 + LabeledResult { 83 + primary: cursor.primary, 84 + secondary: cursor.secondary, 85 + } 86 + } 87 + } 88 + 89 + /// Represents a database result with created_at and cid fields. 90 + #[derive(Debug, Clone)] 91 + pub struct CreatedAtCidResult { 92 + pub created_at: String, 93 + pub cid: String, 94 + } 95 + 96 + /// Pagination options for queries. 97 + pub struct PaginationOptions<'a> { 98 + pub limit: Option<usize>, 99 + pub cursor: Option<String>, 100 + pub direction: Option<&'a str>, 101 + pub try_index: Option<bool>, 102 + } 103 + 104 + /// Applies pagination to a query. 105 + pub fn paginate<K>(query: &mut String, opts: PaginationOptions, keyset: &K) -> String 106 + where 107 + K: GenericKeyset<CreatedAtCidResult, LabeledResult>, 108 + { 109 + let PaginationOptions { 110 + limit, 111 + cursor, 112 + direction, 113 + try_index, 114 + } = opts; 115 + 116 + let direction = direction.unwrap_or("desc"); 117 + let labeled = cursor.and_then(|c| keyset.unpack(Some(c))); 118 + let keyset_sql = labeled.map(|l| get_sql(&l, direction, try_index.unwrap_or(false))); 119 + 120 + if let Some(sql) = keyset_sql { 121 + query.push_str(&format!(" WHERE {}", sql)); 122 + } 123 + 124 + if let Some(l) = limit { 125 + query.push_str(&format!(" LIMIT {}", l)); 126 + } 127 + 128 + query.push_str(&format!( 129 + " ORDER BY primary {} secondary {}", 130 + direction, direction 131 + )); 132 + 133 + query.clone() 134 + } 135 + 136 + /// Generates SQL conditions for pagination. 137 + fn get_sql(labeled: &LabeledResult, direction: &str, try_index: bool) -> String { 138 + if try_index { 139 + if direction == "asc" { 140 + format!( 141 + "(primary, secondary) > ('{}', '{}')", 142 + labeled.primary, labeled.secondary 143 + ) 144 + } else { 145 + format!( 146 + "(primary, secondary) < ('{}', '{}')", 147 + labeled.primary, labeled.secondary 148 + ) 149 + } 150 + } else { 151 + if direction == "asc" { 152 + format!( 153 + "(primary > '{}' OR (primary = '{}' AND secondary > '{}'))", 154 + labeled.primary, labeled.primary, labeled.secondary 155 + ) 156 + } else { 157 + format!( 158 + "(primary < '{}' OR (primary = '{}' AND secondary < '{}'))", 159 + labeled.primary, labeled.primary, labeled.secondary 160 + ) 161 + } 162 + } 163 + }
+124
src/db/util.rs
··· 1 + //! This module contains utility functions and types for working with SQLite databases using SQLx. 2 + 3 + use sqlx::Error; 4 + use std::collections::HashSet; 5 + 6 + /// Returns a SQL clause to check if a record is not soft-deleted. 7 + pub fn not_soft_deleted_clause(alias: &str) -> String { 8 + format!(r#"{}."takedownRef" IS NULL"#, alias) 9 + } 10 + 11 + /// Checks if a record is soft-deleted. 12 + pub fn is_soft_deleted(takedown_ref: Option<&str>) -> bool { 13 + takedown_ref.is_some() 14 + } 15 + 16 + /// SQL clause to count all rows. 17 + pub const COUNT_ALL: &str = "COUNT(*)"; 18 + 19 + /// SQL clause to count distinct rows based on a reference. 20 + pub fn count_distinct(ref_col: &str) -> String { 21 + format!("COUNT(DISTINCT {})", ref_col) 22 + } 23 + 24 + /// Generates a SQL clause for the `excluded` column in an `ON CONFLICT` clause. 25 + pub fn excluded(col: &str) -> String { 26 + format!("excluded.{}", col) 27 + } 28 + 29 + /// Generates a SQL clause for a large `WHERE IN` clause using a hash lookup. 30 + /// # DEPRECATED 31 + /// Use SQLx parameterized queries instead. 32 + #[deprecated = "Use SQLx parameterized queries instead"] 33 + pub fn values_list(vals: &[&str]) -> String { 34 + let values = vals 35 + .iter() 36 + .map(|val| format!("('{}')", val)) 37 + .collect::<Vec<_>>() 38 + .join(", "); 39 + format!("(VALUES {})", values) 40 + } 41 + 42 + /// Retries an asynchronous SQLite operation with exponential backoff. 43 + pub async fn retry_sqlite<F, Fut, T>(operation: F) -> Result<T, sqlx::Error> 44 + where 45 + F: Fn() -> Fut, 46 + Fut: std::future::Future<Output = Result<T, sqlx::Error>>, 47 + { 48 + let max_retries = 60; 49 + let mut attempt = 0; 50 + 51 + while attempt < max_retries { 52 + match operation().await { 53 + Ok(result) => return Ok(result), 54 + Err(err) if is_retryable_sqlite_error(&err) => { 55 + if let Some(wait_ms) = get_wait_ms_sqlite(attempt, 5000) { 56 + tokio::time::sleep(std::time::Duration::from_millis(wait_ms)).await; 57 + attempt += 1; 58 + } else { 59 + return Err(err); 60 + } 61 + } 62 + Err(err) => return Err(err), 63 + } 64 + } 65 + 66 + Err(sqlx::Error::Protocol("Max retries exceeded".into())) 67 + } 68 + 69 + /// Checks if an error is retryable for SQLite. 70 + fn is_retryable_sqlite_error(err: &Error) -> bool { 71 + matches!( 72 + err, 73 + Error::Database(db_err) if { 74 + let code = db_err.code().unwrap_or_default().to_string(); 75 + RETRY_ERRORS.contains(code.as_str()) 76 + } 77 + ) 78 + } 79 + 80 + /// Calculates the wait time for retries based on SQLite's backoff strategy. 81 + fn get_wait_ms_sqlite(attempt: usize, timeout: u64) -> Option<u64> { 82 + const DELAYS: [u64; 12] = [1, 2, 5, 10, 15, 20, 25, 25, 25, 50, 50, 100]; 83 + const TOTALS: [u64; 12] = [0, 1, 3, 8, 18, 33, 53, 78, 103, 128, 178, 228]; 84 + 85 + if attempt >= DELAYS.len() { 86 + let delay = DELAYS.last().unwrap(); 87 + let prior = TOTALS.last().unwrap() + delay * (attempt as u64 - (DELAYS.len() as u64 - 1)); 88 + if prior + delay > timeout { 89 + return None; 90 + } 91 + Some(*delay) 92 + } else { 93 + let delay = DELAYS[attempt]; 94 + let prior = TOTALS[attempt]; 95 + if prior + delay > timeout { 96 + None 97 + } else { 98 + Some(delay) 99 + } 100 + } 101 + } 102 + 103 + /// Checks if an error is a unique constraint violation. 104 + pub fn is_err_unique_violation(err: &Error) -> bool { 105 + matches!( 106 + err, 107 + Error::Database(db_err) if { 108 + let code = db_err.code().unwrap_or_default(); 109 + code == "23505" || code == "SQLITE_CONSTRAINT_UNIQUE" 110 + } 111 + ) 112 + } 113 + 114 + lazy_static::lazy_static! { 115 + /// Set of retryable SQLite error codes. 116 + static ref RETRY_ERRORS: HashSet<&'static str> = { 117 + let mut set = HashSet::new(); 118 + set.insert("SQLITE_BUSY"); 119 + set.insert("SQLITE_BUSY_SNAPSHOT"); 120 + set.insert("SQLITE_BUSY_RECOVERY"); 121 + set.insert("SQLITE_BUSY_TIMEOUT"); 122 + set 123 + }; 124 + }