auth dns over atproto
at main 183 lines 5.6 kB view raw
1//! SQLite database helpers for onis using sqlx. 2//! 3//! Two database types: 4//! - Per-DID databases: one per user, stores their DNS records 5//! - Reverse index: shared database mapping domains → DIDs + verification status 6//! 7//! Migrations live in: 8//! migrations/user/ — per-DID database migrations 9//! migrations/index/ — reverse index migrations 10 11use std::path::Path; 12 13use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions}; 14use thiserror::Error; 15 16use crate::config::DatabaseConfig; 17 18#[derive(Debug, Error)] 19pub enum DbError { 20 #[error("sqlx error: {0}")] 21 Sqlx(#[from] sqlx::Error), 22 #[error("migrate error: {0}")] 23 Migrate(#[from] sqlx::migrate::MigrateError), 24 #[error("io error: {0}")] 25 Io(#[from] std::io::Error), 26} 27 28/// Opens (or creates) a per-DID SQLite database. 29/// 30/// Runs migrations from `migrations/user/` on first open. 31pub async fn open_user_db(path: &Path, db_config: &DatabaseConfig) -> Result<SqlitePool, DbError> { 32 if let Some(parent) = path.parent() { 33 tokio::fs::create_dir_all(parent).await?; 34 } 35 36 let opts = SqliteConnectOptions::new() 37 .filename(path) 38 .create_if_missing(true) 39 .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal) 40 .busy_timeout(std::time::Duration::from_secs(db_config.busy_timeout)); 41 42 let pool = SqlitePoolOptions::new() 43 .max_connections(db_config.user_max_connections) 44 .connect_with(opts) 45 .await?; 46 47 let migrator = sqlx::migrate!("../migrations/user"); 48 tracing::info!("migrations to run: {}", migrator.migrations.len()); 49 migrator.run(&pool).await?; 50 51 Ok(pool) 52} 53 54/// Opens (or creates) the shared reverse index database. 55/// 56/// Runs migrations from `migrations/index/` on first open. 57pub async fn open_index_db(path: &Path, db_config: &DatabaseConfig) -> Result<SqlitePool, DbError> { 58 if let Some(parent) = path.parent() { 59 tokio::fs::create_dir_all(parent).await?; 60 } 61 62 let opts = SqliteConnectOptions::new() 63 .filename(path) 64 .create_if_missing(true) 65 .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal) 66 .busy_timeout(std::time::Duration::from_secs(db_config.busy_timeout)); 67 68 let pool = SqlitePoolOptions::new() 69 .max_connections(db_config.index_max_connections) 70 .connect_with(opts) 71 .await?; 72 73 sqlx::migrate!("../migrations/index").run(&pool).await?; 74 75 Ok(pool) 76} 77 78/// Converts a DID to a filesystem-safe path for its database. 79/// 80/// e.g. `did:plc:adtzorbhmmjbzxsl2y4vqlqs` → `{base}/ad/tz/did_plc_adtzorbhmmjbzxsl2y4vqlqs.db` 81/// e.g. `did:web:vim.sh` → `{base}/vi/m./did_web_vim.sh.db` 82pub fn did_to_db_path(base: &Path, did: &str) -> std::path::PathBuf { 83 let safe = did.replace(':', "_"); 84 85 // use first 4 chars after "did_(plc|web)_" for sharding 86 // XXX: this will break if another did method is added 87 // thats method is longer than 3 characters. 88 let shard = if safe.len() > 8 { 89 &safe[8..] 90 } else { 91 &safe 92 }; 93 94 let (a, b) = if shard.len() >= 4 { 95 (&shard[..2], &shard[2..4]) 96 } else { 97 ("xx", "xx") 98 }; 99 base.join(a).join(b).join(format!("{safe}.db")) 100} 101 102#[cfg(test)] 103mod tests { 104 use super::*; 105 use std::path::PathBuf; 106 107 #[test] 108 fn did_plc_standard() { 109 let base = PathBuf::from("/data/dbs"); 110 let path = did_to_db_path(&base, "did:plc:adtzorbhmmjbzxsl2y4vqlqs"); 111 assert_eq!( 112 path, 113 PathBuf::from("/data/dbs/ad/tz/did_plc_adtzorbhmmjbzxsl2y4vqlqs.db") 114 ); 115 } 116 117 #[test] 118 fn did_web_domain() { 119 let base = PathBuf::from("/data/dbs"); 120 let path = did_to_db_path(&base, "did:web:example.com"); 121 assert_eq!( 122 path, 123 PathBuf::from("/data/dbs/ex/am/did_web_example.com.db") 124 ); 125 } 126 127 #[test] 128 fn did_web_short_domain() { 129 let base = PathBuf::from("/data/dbs"); 130 let path = did_to_db_path(&base, "did:web:vim.sh"); 131 assert_eq!( 132 path, 133 PathBuf::from("/data/dbs/vi/m./did_web_vim.sh.db") 134 ); 135 } 136 137 #[test] 138 fn did_web_subdomain() { 139 let base = PathBuf::from("/data/dbs"); 140 let path = did_to_db_path(&base, "did:web:sub.example.com"); 141 assert_eq!( 142 path, 143 PathBuf::from("/data/dbs/su/b./did_web_sub.example.com.db") 144 ); 145 } 146 147 #[test] 148 fn did_web_very_short_falls_back() { 149 let base = PathBuf::from("/data/dbs"); 150 let path = did_to_db_path(&base, "did:web:a.b"); 151 assert_eq!( 152 path, 153 PathBuf::from("/data/dbs/xx/xx/did_web_a.b.db") 154 ); 155 } 156 157 #[test] 158 fn did_plc_short_falls_back() { 159 let base = PathBuf::from("/data/dbs"); 160 let path = did_to_db_path(&base, "did:plc:abc"); 161 assert_eq!( 162 path, 163 PathBuf::from("/data/dbs/xx/xx/did_plc_abc.db") 164 ); 165 } 166 167 #[test] 168 fn different_dids_produce_different_paths() { 169 let base = PathBuf::from("/data/dbs"); 170 let a = did_to_db_path(&base, "did:plc:aaaa1111bbbb2222"); 171 let b = did_to_db_path(&base, "did:plc:cccc3333dddd4444"); 172 assert_ne!(a, b); 173 } 174 175 #[test] 176 fn same_identifier_different_method_produces_different_paths() { 177 let base = PathBuf::from("/data/dbs"); 178 let plc = did_to_db_path(&base, "did:plc:example.com"); 179 let web = did_to_db_path(&base, "did:web:example.com"); 180 assert_ne!(plc, web); 181 assert_eq!(plc.parent(), web.parent()); 182 } 183}