+94
src/db/cast.rs
+94
src/db/cast.rs
···
1
+
//! Type-safe casting utilities.
2
+
3
+
use serde::{Serialize, de::DeserializeOwned};
4
+
use std::fmt;
5
+
6
+
/// Represents an ISO 8601 date string (e.g., "2023-01-01T12:00:00Z").
7
+
#[derive(Debug, Clone, PartialEq, Eq)]
8
+
pub struct DateISO(String);
9
+
10
+
impl DateISO {
11
+
/// Converts a `chrono::DateTime<Utc>` to a `DateISO`.
12
+
pub fn from_date(date: chrono::DateTime<chrono::Utc>) -> Self {
13
+
Self(date.to_rfc3339())
14
+
}
15
+
16
+
/// Converts a `DateISO` back to a `chrono::DateTime<Utc>`.
17
+
pub fn to_date(&self) -> Result<chrono::DateTime<chrono::Utc>, chrono::ParseError> {
18
+
self.0.parse::<chrono::DateTime<chrono::Utc>>()
19
+
}
20
+
}
21
+
22
+
impl fmt::Display for DateISO {
23
+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24
+
write!(f, "{}", self.0)
25
+
}
26
+
}
27
+
28
+
/// Represents a JSON-encoded string.
29
+
#[derive(Debug, Clone, PartialEq, Eq)]
30
+
pub struct JsonEncoded<T: Serialize>(String, std::marker::PhantomData<T>);
31
+
32
+
impl<T: Serialize> JsonEncoded<T> {
33
+
/// Encodes a value into a JSON string.
34
+
pub fn to_json(value: &T) -> Result<Self, serde_json::Error> {
35
+
let json = serde_json::to_string(value)?;
36
+
Ok(Self(json, std::marker::PhantomData))
37
+
}
38
+
39
+
/// Decodes a JSON string back into a value.
40
+
pub fn from_json(json_str: &str) -> Result<T, serde_json::Error>
41
+
where
42
+
T: DeserializeOwned,
43
+
{
44
+
serde_json::from_str(json_str)
45
+
}
46
+
47
+
/// Returns the underlying JSON string.
48
+
pub fn as_str(&self) -> &str {
49
+
&self.0
50
+
}
51
+
}
52
+
53
+
impl<T: Serialize> fmt::Display for JsonEncoded<T> {
54
+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55
+
write!(f, "{}", self.0)
56
+
}
57
+
}
58
+
59
+
#[cfg(test)]
60
+
mod tests {
61
+
use super::*;
62
+
use chrono::Utc;
63
+
use serde::{Deserialize, Serialize};
64
+
65
+
#[test]
66
+
fn test_date_iso() {
67
+
let now = Utc::now();
68
+
let date_iso = DateISO::from_date(now);
69
+
let parsed_date = date_iso.to_date().unwrap();
70
+
assert_eq!(now.to_rfc3339(), parsed_date.to_rfc3339());
71
+
}
72
+
73
+
#[derive(Serialize, Deserialize, PartialEq, Debug)]
74
+
struct TestStruct {
75
+
name: String,
76
+
value: i32,
77
+
}
78
+
79
+
#[test]
80
+
fn test_json_encoded() {
81
+
let test_value = TestStruct {
82
+
name: "example".to_string(),
83
+
value: 42,
84
+
};
85
+
86
+
// Encode to JSON
87
+
let encoded = JsonEncoded::to_json(&test_value).unwrap();
88
+
assert_eq!(encoded.as_str(), r#"{"name":"example","value":42}"#);
89
+
90
+
// Decode from JSON
91
+
let decoded: TestStruct = JsonEncoded::from_json(encoded.as_str()).unwrap();
92
+
assert_eq!(decoded, test_value);
93
+
}
94
+
}
+111
src/db/db.rs
+111
src/db/db.rs
···
1
+
//! Database connection and transaction management.
2
+
3
+
use sqlx::{
4
+
Sqlite, Transaction,
5
+
sqlite::{SqliteConnectOptions, SqlitePool, SqliteQueryResult, SqliteTransactionManager},
6
+
};
7
+
use std::collections::VecDeque;
8
+
use std::str::FromStr;
9
+
use std::sync::{Arc, Mutex};
10
+
use tokio::sync::Mutex as AsyncMutex;
11
+
12
+
use crate::db::util::retry_sqlite;
13
+
14
+
/// Default pragmas for SQLite.
15
+
const DEFAULT_PRAGMAS: &[(&str, &str)] = &[
16
+
// Add default pragmas here if needed, e.g., ("foreign_keys", "ON")
17
+
];
18
+
19
+
/// Database struct for managing SQLite connections and transactions.
20
+
pub struct Database {
21
+
/// SQLite connection pool.
22
+
db: Arc<SqlitePool>,
23
+
/// Flag indicating if the database is destroyed.
24
+
destroyed: Arc<Mutex<bool>>,
25
+
/// Queue of commit hooks.
26
+
commit_hooks: Arc<AsyncMutex<VecDeque<Box<dyn FnOnce() + Send>>>>,
27
+
}
28
+
29
+
impl Database {
30
+
/// Creates a new database instance with the given location and optional pragmas.
31
+
pub async fn new(location: &str, pragmas: Option<&[(&str, &str)]>) -> sqlx::Result<Self> {
32
+
let mut options = SqliteConnectOptions::from_str(location)?.create_if_missing(true);
33
+
34
+
// Apply default and user-provided pragmas.
35
+
for &(key, value) in DEFAULT_PRAGMAS.iter().chain(pragmas.unwrap_or(&[])) {
36
+
options = options.pragma(key.to_string(), value.to_string());
37
+
}
38
+
39
+
let pool = SqlitePool::connect_with(options).await?;
40
+
Ok(Self {
41
+
db: Arc::new(pool),
42
+
destroyed: Arc::new(Mutex::new(false)),
43
+
commit_hooks: Arc::new(AsyncMutex::new(VecDeque::new())),
44
+
})
45
+
}
46
+
47
+
/// Ensures the database is using Write-Ahead Logging (WAL) mode.
48
+
pub async fn ensure_wal(&self) -> sqlx::Result<()> {
49
+
let mut conn = self.db.acquire().await?;
50
+
sqlx::query("PRAGMA journal_mode = WAL")
51
+
.execute(&mut *conn)
52
+
.await?;
53
+
Ok(())
54
+
}
55
+
56
+
/// Executes a transaction without retry logic.
57
+
pub async fn transaction_no_retry<F, T>(&self, func: F) -> sqlx::Result<T>
58
+
where
59
+
F: FnOnce(&mut Transaction<'_, Sqlite>) -> sqlx::Result<T>,
60
+
{
61
+
let mut tx = self.db.begin().await?;
62
+
let result = func(&mut tx)?;
63
+
tx.commit().await?;
64
+
self.run_commit_hooks().await;
65
+
Ok(result)
66
+
}
67
+
68
+
/// Executes a transaction with retry logic.
69
+
pub async fn transaction<F, T>(&self, func: F) -> sqlx::Result<T>
70
+
where
71
+
F: FnOnce(&mut Transaction<'_, Sqlite>) -> sqlx::Result<T> + Copy,
72
+
{
73
+
retry_sqlite(|| self.transaction_no_retry(func)).await
74
+
}
75
+
76
+
/// Executes a query with retry logic.
77
+
pub async fn execute_with_retry<F, T>(&self, query: F) -> sqlx::Result<T>
78
+
where
79
+
F: Fn() -> std::pin::Pin<Box<dyn futures::Future<Output = sqlx::Result<T>> + Send>> + Copy,
80
+
{
81
+
retry_sqlite(|| query()).await
82
+
}
83
+
84
+
/// Adds a commit hook to be executed after a successful transaction.
85
+
pub async fn on_commit<F>(&self, hook: F)
86
+
where
87
+
F: FnOnce() + Send + 'static,
88
+
{
89
+
let mut hooks = self.commit_hooks.lock().await;
90
+
hooks.push_back(Box::new(hook));
91
+
}
92
+
93
+
/// Closes the database connection pool.
94
+
pub async fn close(&self) -> sqlx::Result<()> {
95
+
let mut destroyed = self.destroyed.lock().unwrap();
96
+
if *destroyed {
97
+
return Ok(());
98
+
}
99
+
*destroyed = true;
100
+
drop(self.db.clone()); // Drop the pool to close connections.
101
+
Ok(())
102
+
}
103
+
104
+
/// Runs all commit hooks in the queue.
105
+
async fn run_commit_hooks(&self) {
106
+
let mut hooks = self.commit_hooks.lock().await;
107
+
while let Some(hook) = hooks.pop_front() {
108
+
hook();
109
+
}
110
+
}
111
+
}
+64
src/db/migrator.rs
+64
src/db/migrator.rs
···
1
+
//! Database migration management.
2
+
3
+
use sqlx::{SqlitePool, migrate::Migrator};
4
+
use std::path::Path;
5
+
use thiserror::Error;
6
+
7
+
/// Error type for migration-related issues.
8
+
#[derive(Debug, Error)]
9
+
pub enum MigrationError {
10
+
#[error("Migration failed: {0}")]
11
+
MigrationFailed(String),
12
+
#[error("Unknown failure occurred while migrating")]
13
+
UnknownFailure,
14
+
}
15
+
16
+
/// Migrator struct for managing database migrations.
17
+
pub struct DatabaseMigrator {
18
+
/// SQLx migrator instance.
19
+
migrator: Migrator,
20
+
/// SQLite connection pool.
21
+
db: SqlitePool,
22
+
}
23
+
24
+
impl DatabaseMigrator {
25
+
/// Creates a new `DatabaseMigrator` instance.
26
+
///
27
+
/// # Arguments
28
+
/// - `migrations_path`: Path to the directory containing migration files.
29
+
/// - `db`: SQLite connection pool.
30
+
pub async fn new(migrations_path: &Path, db: SqlitePool) -> Self {
31
+
let migrator = Migrator::new(migrations_path)
32
+
.await
33
+
.expect("Failed to initialize migrator");
34
+
Self { migrator, db }
35
+
}
36
+
37
+
/// Migrates the database to a specific migration or throws an error.
38
+
///
39
+
/// # Arguments
40
+
/// - `migration`: The target migration name.
41
+
///
42
+
/// # Unimplemented
43
+
/// This currently runs all migrations instead of a specific one.
44
+
pub async fn migrate_to_or_throw(&self, _migration: &str) -> Result<(), MigrationError> {
45
+
// TODO: Implement migration to a specific version
46
+
// For now, we will just run all migrations
47
+
let result = self.migrator.run(&self.db).await;
48
+
49
+
match result {
50
+
Ok(_) => Ok(()),
51
+
Err(err) => Err(MigrationError::MigrationFailed(err.to_string())),
52
+
}
53
+
}
54
+
55
+
/// Migrates the database to the latest migration or throws an error.
56
+
pub async fn migrate_to_latest_or_throw(&self) -> Result<(), MigrationError> {
57
+
let result = self.migrator.run(&self.db).await;
58
+
59
+
match result {
60
+
Ok(_) => Ok(()),
61
+
Err(err) => Err(MigrationError::MigrationFailed(err.to_string())),
62
+
}
63
+
}
64
+
}
+163
src/db/pagination.rs
+163
src/db/pagination.rs
···
1
+
use std::fmt::Debug;
2
+
3
+
/// Represents a cursor with primary and secondary parts.
4
+
#[derive(Debug, Clone)]
5
+
pub struct Cursor {
6
+
pub primary: String,
7
+
pub secondary: String,
8
+
}
9
+
10
+
/// Represents a labeled result with primary and secondary parts.
11
+
#[derive(Debug, Clone)]
12
+
pub struct LabeledResult {
13
+
pub primary: String,
14
+
pub secondary: String,
15
+
}
16
+
17
+
/// Trait defining the interface for a keyset-paginated cursor.
18
+
pub trait GenericKeyset<R, LR: Debug> {
19
+
fn label_result(&self, result: R) -> LR;
20
+
fn labeled_result_to_cursor(&self, labeled: LR) -> Cursor;
21
+
fn cursor_to_labeled_result(&self, cursor: Cursor) -> LR;
22
+
23
+
fn pack_from_result(&self, results: Vec<R>) -> Option<String> {
24
+
todo!()
25
+
// results
26
+
// .last()
27
+
// .map(|result| self.pack(Some(self.label_result(result.clone()))))
28
+
}
29
+
30
+
fn pack(&self, labeled: Option<LR>) -> Option<String> {
31
+
labeled.map(|l| self.pack_cursor(self.labeled_result_to_cursor(l)))
32
+
}
33
+
34
+
fn unpack(&self, cursor_str: Option<String>) -> Option<LR> {
35
+
cursor_str
36
+
.and_then(|cursor| self.unpack_cursor(cursor))
37
+
.map(|c| self.cursor_to_labeled_result(c))
38
+
}
39
+
40
+
fn pack_cursor(&self, cursor: Cursor) -> String {
41
+
format!("{}::{}", cursor.primary, cursor.secondary)
42
+
}
43
+
44
+
fn unpack_cursor(&self, cursor_str: String) -> Option<Cursor> {
45
+
let parts: Vec<&str> = cursor_str.split("::").collect();
46
+
if parts.len() == 2 {
47
+
Some(Cursor {
48
+
primary: parts[0].to_string(),
49
+
secondary: parts[1].to_string(),
50
+
})
51
+
} else {
52
+
None
53
+
}
54
+
}
55
+
}
56
+
57
+
/// A concrete implementation of `GenericKeyset` for time and CID-based pagination.
58
+
pub struct TimeCidKeyset;
59
+
60
+
impl TimeCidKeyset {
61
+
pub fn new() -> Self {
62
+
Self
63
+
}
64
+
}
65
+
66
+
impl GenericKeyset<CreatedAtCidResult, LabeledResult> for TimeCidKeyset {
67
+
fn label_result(&self, result: CreatedAtCidResult) -> LabeledResult {
68
+
LabeledResult {
69
+
primary: result.created_at,
70
+
secondary: result.cid,
71
+
}
72
+
}
73
+
74
+
fn labeled_result_to_cursor(&self, labeled: LabeledResult) -> Cursor {
75
+
Cursor {
76
+
primary: labeled.primary,
77
+
secondary: labeled.secondary,
78
+
}
79
+
}
80
+
81
+
fn cursor_to_labeled_result(&self, cursor: Cursor) -> LabeledResult {
82
+
LabeledResult {
83
+
primary: cursor.primary,
84
+
secondary: cursor.secondary,
85
+
}
86
+
}
87
+
}
88
+
89
+
/// Represents a database result with created_at and cid fields.
90
+
#[derive(Debug, Clone)]
91
+
pub struct CreatedAtCidResult {
92
+
pub created_at: String,
93
+
pub cid: String,
94
+
}
95
+
96
+
/// Pagination options for queries.
97
+
pub struct PaginationOptions<'a> {
98
+
pub limit: Option<usize>,
99
+
pub cursor: Option<String>,
100
+
pub direction: Option<&'a str>,
101
+
pub try_index: Option<bool>,
102
+
}
103
+
104
+
/// Applies pagination to a query.
105
+
pub fn paginate<K>(query: &mut String, opts: PaginationOptions, keyset: &K) -> String
106
+
where
107
+
K: GenericKeyset<CreatedAtCidResult, LabeledResult>,
108
+
{
109
+
let PaginationOptions {
110
+
limit,
111
+
cursor,
112
+
direction,
113
+
try_index,
114
+
} = opts;
115
+
116
+
let direction = direction.unwrap_or("desc");
117
+
let labeled = cursor.and_then(|c| keyset.unpack(Some(c)));
118
+
let keyset_sql = labeled.map(|l| get_sql(&l, direction, try_index.unwrap_or(false)));
119
+
120
+
if let Some(sql) = keyset_sql {
121
+
query.push_str(&format!(" WHERE {}", sql));
122
+
}
123
+
124
+
if let Some(l) = limit {
125
+
query.push_str(&format!(" LIMIT {}", l));
126
+
}
127
+
128
+
query.push_str(&format!(
129
+
" ORDER BY primary {} secondary {}",
130
+
direction, direction
131
+
));
132
+
133
+
query.clone()
134
+
}
135
+
136
+
/// Generates SQL conditions for pagination.
137
+
fn get_sql(labeled: &LabeledResult, direction: &str, try_index: bool) -> String {
138
+
if try_index {
139
+
if direction == "asc" {
140
+
format!(
141
+
"(primary, secondary) > ('{}', '{}')",
142
+
labeled.primary, labeled.secondary
143
+
)
144
+
} else {
145
+
format!(
146
+
"(primary, secondary) < ('{}', '{}')",
147
+
labeled.primary, labeled.secondary
148
+
)
149
+
}
150
+
} else {
151
+
if direction == "asc" {
152
+
format!(
153
+
"(primary > '{}' OR (primary = '{}' AND secondary > '{}'))",
154
+
labeled.primary, labeled.primary, labeled.secondary
155
+
)
156
+
} else {
157
+
format!(
158
+
"(primary < '{}' OR (primary = '{}' AND secondary < '{}'))",
159
+
labeled.primary, labeled.primary, labeled.secondary
160
+
)
161
+
}
162
+
}
163
+
}
+124
src/db/util.rs
+124
src/db/util.rs
···
1
+
//! This module contains utility functions and types for working with SQLite databases using SQLx.
2
+
3
+
use sqlx::Error;
4
+
use std::collections::HashSet;
5
+
6
+
/// Returns a SQL clause to check if a record is not soft-deleted.
7
+
pub fn not_soft_deleted_clause(alias: &str) -> String {
8
+
format!(r#"{}."takedownRef" IS NULL"#, alias)
9
+
}
10
+
11
+
/// Checks if a record is soft-deleted.
12
+
pub fn is_soft_deleted(takedown_ref: Option<&str>) -> bool {
13
+
takedown_ref.is_some()
14
+
}
15
+
16
+
/// SQL clause to count all rows.
17
+
pub const COUNT_ALL: &str = "COUNT(*)";
18
+
19
+
/// SQL clause to count distinct rows based on a reference.
20
+
pub fn count_distinct(ref_col: &str) -> String {
21
+
format!("COUNT(DISTINCT {})", ref_col)
22
+
}
23
+
24
+
/// Generates a SQL clause for the `excluded` column in an `ON CONFLICT` clause.
25
+
pub fn excluded(col: &str) -> String {
26
+
format!("excluded.{}", col)
27
+
}
28
+
29
+
/// Generates a SQL clause for a large `WHERE IN` clause using a hash lookup.
30
+
/// # DEPRECATED
31
+
/// Use SQLx parameterized queries instead.
32
+
#[deprecated = "Use SQLx parameterized queries instead"]
33
+
pub fn values_list(vals: &[&str]) -> String {
34
+
let values = vals
35
+
.iter()
36
+
.map(|val| format!("('{}')", val))
37
+
.collect::<Vec<_>>()
38
+
.join(", ");
39
+
format!("(VALUES {})", values)
40
+
}
41
+
42
+
/// Retries an asynchronous SQLite operation with exponential backoff.
43
+
pub async fn retry_sqlite<F, Fut, T>(operation: F) -> Result<T, sqlx::Error>
44
+
where
45
+
F: Fn() -> Fut,
46
+
Fut: std::future::Future<Output = Result<T, sqlx::Error>>,
47
+
{
48
+
let max_retries = 60;
49
+
let mut attempt = 0;
50
+
51
+
while attempt < max_retries {
52
+
match operation().await {
53
+
Ok(result) => return Ok(result),
54
+
Err(err) if is_retryable_sqlite_error(&err) => {
55
+
if let Some(wait_ms) = get_wait_ms_sqlite(attempt, 5000) {
56
+
tokio::time::sleep(std::time::Duration::from_millis(wait_ms)).await;
57
+
attempt += 1;
58
+
} else {
59
+
return Err(err);
60
+
}
61
+
}
62
+
Err(err) => return Err(err),
63
+
}
64
+
}
65
+
66
+
Err(sqlx::Error::Protocol("Max retries exceeded".into()))
67
+
}
68
+
69
+
/// Checks if an error is retryable for SQLite.
70
+
fn is_retryable_sqlite_error(err: &Error) -> bool {
71
+
matches!(
72
+
err,
73
+
Error::Database(db_err) if {
74
+
let code = db_err.code().unwrap_or_default().to_string();
75
+
RETRY_ERRORS.contains(code.as_str())
76
+
}
77
+
)
78
+
}
79
+
80
+
/// Calculates the wait time for retries based on SQLite's backoff strategy.
81
+
fn get_wait_ms_sqlite(attempt: usize, timeout: u64) -> Option<u64> {
82
+
const DELAYS: [u64; 12] = [1, 2, 5, 10, 15, 20, 25, 25, 25, 50, 50, 100];
83
+
const TOTALS: [u64; 12] = [0, 1, 3, 8, 18, 33, 53, 78, 103, 128, 178, 228];
84
+
85
+
if attempt >= DELAYS.len() {
86
+
let delay = DELAYS.last().unwrap();
87
+
let prior = TOTALS.last().unwrap() + delay * (attempt as u64 - (DELAYS.len() as u64 - 1));
88
+
if prior + delay > timeout {
89
+
return None;
90
+
}
91
+
Some(*delay)
92
+
} else {
93
+
let delay = DELAYS[attempt];
94
+
let prior = TOTALS[attempt];
95
+
if prior + delay > timeout {
96
+
None
97
+
} else {
98
+
Some(delay)
99
+
}
100
+
}
101
+
}
102
+
103
+
/// Checks if an error is a unique constraint violation.
104
+
pub fn is_err_unique_violation(err: &Error) -> bool {
105
+
matches!(
106
+
err,
107
+
Error::Database(db_err) if {
108
+
let code = db_err.code().unwrap_or_default();
109
+
code == "23505" || code == "SQLITE_CONSTRAINT_UNIQUE"
110
+
}
111
+
)
112
+
}
113
+
114
+
lazy_static::lazy_static! {
115
+
/// Set of retryable SQLite error codes.
116
+
static ref RETRY_ERRORS: HashSet<&'static str> = {
117
+
let mut set = HashSet::new();
118
+
set.insert("SQLITE_BUSY");
119
+
set.insert("SQLITE_BUSY_SNAPSHOT");
120
+
set.insert("SQLITE_BUSY_RECOVERY");
121
+
set.insert("SQLITE_BUSY_TIMEOUT");
122
+
set
123
+
};
124
+
}