PDS software with bells & whistles you didn’t even know you needed. will move this to its own account when ready.
at main 12 kB view raw
1use axum::http::HeaderMap; 2use cid::Cid; 3use ipld_core::ipld::Ipld; 4use rand::Rng; 5use serde_json::Value as JsonValue; 6use sqlx::PgPool; 7use std::collections::BTreeMap; 8use std::str::FromStr; 9use std::sync::OnceLock; 10use uuid::Uuid; 11 12use crate::types::{Did, Handle}; 13 14const BASE32_ALPHABET: &str = "abcdefghijklmnopqrstuvwxyz234567"; 15const DEFAULT_MAX_BLOB_SIZE: usize = 10 * 1024 * 1024 * 1024; 16 17static MAX_BLOB_SIZE: OnceLock<usize> = OnceLock::new(); 18 19pub fn get_max_blob_size() -> usize { 20 *MAX_BLOB_SIZE.get_or_init(|| { 21 std::env::var("MAX_BLOB_SIZE") 22 .ok() 23 .and_then(|s| s.parse().ok()) 24 .unwrap_or(DEFAULT_MAX_BLOB_SIZE) 25 }) 26} 27 28pub fn generate_token_code() -> String { 29 generate_token_code_parts(2, 5) 30} 31 32pub fn generate_token_code_parts(parts: usize, part_len: usize) -> String { 33 let mut rng = rand::thread_rng(); 34 let chars: Vec<char> = BASE32_ALPHABET.chars().collect(); 35 36 (0..parts) 37 .map(|_| { 38 (0..part_len) 39 .map(|_| chars[rng.gen_range(0..chars.len())]) 40 .collect::<String>() 41 }) 42 .collect::<Vec<_>>() 43 .join("-") 44} 45 46#[derive(Debug)] 47pub enum DbLookupError { 48 NotFound, 49 DatabaseError(sqlx::Error), 50} 51 52impl From<sqlx::Error> for DbLookupError { 53 fn from(e: sqlx::Error) -> Self { 54 DbLookupError::DatabaseError(e) 55 } 56} 57 58pub async fn get_user_id_by_did(db: &PgPool, did: &str) -> Result<Uuid, DbLookupError> { 59 sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 60 .fetch_optional(db) 61 .await? 62 .ok_or(DbLookupError::NotFound) 63} 64 65pub struct UserInfo { 66 pub id: Uuid, 67 pub did: Did, 68 pub handle: Handle, 69} 70 71pub async fn get_user_by_did(db: &PgPool, did: &str) -> Result<UserInfo, DbLookupError> { 72 sqlx::query_as!( 73 UserInfo, 74 "SELECT id, did, handle FROM users WHERE did = $1", 75 did 76 ) 77 .fetch_optional(db) 78 .await? 79 .ok_or(DbLookupError::NotFound) 80} 81 82pub async fn get_user_by_identifier( 83 db: &PgPool, 84 identifier: &str, 85) -> Result<UserInfo, DbLookupError> { 86 sqlx::query_as!( 87 UserInfo, 88 "SELECT id, did, handle FROM users WHERE did = $1 OR handle = $1", 89 identifier 90 ) 91 .fetch_optional(db) 92 .await? 93 .ok_or(DbLookupError::NotFound) 94} 95 96pub async fn is_account_migrated(db: &PgPool, did: &str) -> Result<bool, sqlx::Error> { 97 let row = sqlx::query!( 98 r#"SELECT (migrated_to_pds IS NOT NULL AND deactivated_at IS NOT NULL) as "migrated!: bool" FROM users WHERE did = $1"#, 99 did 100 ) 101 .fetch_optional(db) 102 .await?; 103 Ok(row.map(|r| r.migrated).unwrap_or(false)) 104} 105 106pub fn parse_repeated_query_param(query: Option<&str>, key: &str) -> Vec<String> { 107 query 108 .map(|q| { 109 let mut values = Vec::new(); 110 for pair in q.split('&') { 111 if let Some((k, v)) = pair.split_once('=') 112 && k == key 113 && let Ok(decoded) = urlencoding::decode(v) 114 { 115 let decoded = decoded.into_owned(); 116 if decoded.contains(',') { 117 for part in decoded.split(',') { 118 let trimmed = part.trim(); 119 if !trimmed.is_empty() { 120 values.push(trimmed.to_string()); 121 } 122 } 123 } else if !decoded.is_empty() { 124 values.push(decoded); 125 } 126 } 127 } 128 values 129 }) 130 .unwrap_or_default() 131} 132 133pub fn extract_client_ip(headers: &HeaderMap) -> String { 134 if let Some(forwarded) = headers.get("x-forwarded-for") 135 && let Ok(value) = forwarded.to_str() 136 && let Some(first_ip) = value.split(',').next() 137 { 138 return first_ip.trim().to_string(); 139 } 140 if let Some(real_ip) = headers.get("x-real-ip") 141 && let Ok(value) = real_ip.to_str() 142 { 143 return value.trim().to_string(); 144 } 145 "unknown".to_string() 146} 147 148pub fn pds_hostname() -> String { 149 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 150} 151 152pub fn pds_public_url() -> String { 153 format!("https://{}", pds_hostname()) 154} 155 156pub fn build_full_url(path: &str) -> String { 157 let normalized_path = if !path.starts_with("/xrpc/") 158 && (path.starts_with("/com.atproto.") 159 || path.starts_with("/app.bsky.") 160 || path.starts_with("/_")) 161 { 162 format!("/xrpc{}", path) 163 } else { 164 path.to_string() 165 }; 166 format!("{}{}", pds_public_url(), normalized_path) 167} 168 169pub fn json_to_ipld(value: &JsonValue) -> Ipld { 170 match value { 171 JsonValue::Null => Ipld::Null, 172 JsonValue::Bool(b) => Ipld::Bool(*b), 173 JsonValue::Number(n) => { 174 if let Some(i) = n.as_i64() { 175 Ipld::Integer(i as i128) 176 } else if let Some(f) = n.as_f64() { 177 Ipld::Float(f) 178 } else { 179 Ipld::Null 180 } 181 } 182 JsonValue::String(s) => Ipld::String(s.clone()), 183 JsonValue::Array(arr) => Ipld::List(arr.iter().map(json_to_ipld).collect()), 184 JsonValue::Object(obj) => { 185 if let Some(JsonValue::String(link)) = obj.get("$link") 186 && obj.len() == 1 187 && let Ok(cid) = Cid::from_str(link) 188 { 189 return Ipld::Link(cid); 190 } 191 let map: BTreeMap<String, Ipld> = obj 192 .iter() 193 .map(|(k, v)| (k.clone(), json_to_ipld(v))) 194 .collect(); 195 Ipld::Map(map) 196 } 197 } 198} 199 200#[cfg(test)] 201mod tests { 202 use super::*; 203 204 #[test] 205 fn test_parse_repeated_query_param_repeated() { 206 let query = "did=test&cids=a&cids=b&cids=c"; 207 let result = parse_repeated_query_param(Some(query), "cids"); 208 assert_eq!(result, vec!["a", "b", "c"]); 209 } 210 211 #[test] 212 fn test_parse_repeated_query_param_comma_separated() { 213 let query = "did=test&cids=a,b,c"; 214 let result = parse_repeated_query_param(Some(query), "cids"); 215 assert_eq!(result, vec!["a", "b", "c"]); 216 } 217 218 #[test] 219 fn test_parse_repeated_query_param_mixed() { 220 let query = "did=test&cids=a,b&cids=c"; 221 let result = parse_repeated_query_param(Some(query), "cids"); 222 assert_eq!(result, vec!["a", "b", "c"]); 223 } 224 225 #[test] 226 fn test_parse_repeated_query_param_single() { 227 let query = "did=test&cids=a"; 228 let result = parse_repeated_query_param(Some(query), "cids"); 229 assert_eq!(result, vec!["a"]); 230 } 231 232 #[test] 233 fn test_parse_repeated_query_param_empty() { 234 let query = "did=test"; 235 let result = parse_repeated_query_param(Some(query), "cids"); 236 assert!(result.is_empty()); 237 } 238 239 #[test] 240 fn test_parse_repeated_query_param_url_encoded() { 241 let query = "did=test&cids=bafyreib%2Btest"; 242 let result = parse_repeated_query_param(Some(query), "cids"); 243 assert_eq!(result, vec!["bafyreib+test"]); 244 } 245 246 #[test] 247 fn test_generate_token_code() { 248 let code = generate_token_code(); 249 assert_eq!(code.len(), 11); 250 assert!(code.contains('-')); 251 252 let parts: Vec<&str> = code.split('-').collect(); 253 assert_eq!(parts.len(), 2); 254 assert_eq!(parts[0].len(), 5); 255 assert_eq!(parts[1].len(), 5); 256 257 for c in code.chars() { 258 if c != '-' { 259 assert!(BASE32_ALPHABET.contains(c)); 260 } 261 } 262 } 263 264 #[test] 265 fn test_generate_token_code_parts() { 266 let code = generate_token_code_parts(3, 4); 267 let parts: Vec<&str> = code.split('-').collect(); 268 assert_eq!(parts.len(), 3); 269 270 for part in parts { 271 assert_eq!(part.len(), 4); 272 } 273 } 274 275 #[test] 276 fn test_json_to_ipld_cid_link() { 277 let json = serde_json::json!({ 278 "$link": "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku" 279 }); 280 let ipld = json_to_ipld(&json); 281 match ipld { 282 Ipld::Link(cid) => { 283 assert_eq!( 284 cid.to_string(), 285 "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku" 286 ); 287 } 288 _ => panic!("Expected Ipld::Link, got {:?}", ipld), 289 } 290 } 291 292 #[test] 293 fn test_json_to_ipld_blob_ref() { 294 let json = serde_json::json!({ 295 "$type": "blob", 296 "ref": { 297 "$link": "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku" 298 }, 299 "mimeType": "image/jpeg", 300 "size": 12345 301 }); 302 let ipld = json_to_ipld(&json); 303 match ipld { 304 Ipld::Map(map) => { 305 assert_eq!(map.get("$type"), Some(&Ipld::String("blob".to_string()))); 306 assert_eq!( 307 map.get("mimeType"), 308 Some(&Ipld::String("image/jpeg".to_string())) 309 ); 310 assert_eq!(map.get("size"), Some(&Ipld::Integer(12345))); 311 match map.get("ref") { 312 Some(Ipld::Link(cid)) => { 313 assert_eq!( 314 cid.to_string(), 315 "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku" 316 ); 317 } 318 _ => panic!("Expected Ipld::Link in ref field, got {:?}", map.get("ref")), 319 } 320 } 321 _ => panic!("Expected Ipld::Map, got {:?}", ipld), 322 } 323 } 324 325 #[test] 326 fn test_json_to_ipld_nested_blob_refs_serializes_correctly() { 327 let record = serde_json::json!({ 328 "$type": "app.bsky.feed.post", 329 "text": "Hello world", 330 "embed": { 331 "$type": "app.bsky.embed.images", 332 "images": [ 333 { 334 "alt": "Test image", 335 "image": { 336 "$type": "blob", 337 "ref": { 338 "$link": "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku" 339 }, 340 "mimeType": "image/jpeg", 341 "size": 12345 342 } 343 } 344 ] 345 } 346 }); 347 let ipld = json_to_ipld(&record); 348 let cbor_bytes = serde_ipld_dagcbor::to_vec(&ipld).expect("CBOR serialization failed"); 349 assert!(!cbor_bytes.is_empty()); 350 let parsed: Ipld = 351 serde_ipld_dagcbor::from_slice(&cbor_bytes).expect("CBOR deserialization failed"); 352 if let Ipld::Map(map) = &parsed 353 && let Some(Ipld::Map(embed)) = map.get("embed") 354 && let Some(Ipld::List(images)) = embed.get("images") 355 && let Some(Ipld::Map(img)) = images.first() 356 && let Some(Ipld::Map(blob)) = img.get("image") 357 && let Some(Ipld::Link(cid)) = blob.get("ref") 358 { 359 assert_eq!( 360 cid.to_string(), 361 "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku" 362 ); 363 return; 364 } 365 panic!("Failed to find CID link in parsed CBOR"); 366 } 367 368 #[test] 369 fn test_build_full_url_adds_xrpc_prefix_for_atproto_paths() { 370 unsafe { std::env::set_var("PDS_HOSTNAME", "example.com") }; 371 assert_eq!( 372 build_full_url("/com.atproto.server.getSession"), 373 "https://example.com/xrpc/com.atproto.server.getSession" 374 ); 375 assert_eq!( 376 build_full_url("/app.bsky.feed.getTimeline"), 377 "https://example.com/xrpc/app.bsky.feed.getTimeline" 378 ); 379 assert_eq!( 380 build_full_url("/_health"), 381 "https://example.com/xrpc/_health" 382 ); 383 assert_eq!( 384 build_full_url("/xrpc/com.atproto.server.getSession"), 385 "https://example.com/xrpc/com.atproto.server.getSession" 386 ); 387 assert_eq!( 388 build_full_url("/oauth/token"), 389 "https://example.com/oauth/token" 390 ); 391 } 392}