forked from
lewis.moe/bspds-sandbox
PDS software with bells & whistles you didn’t even know you needed. will move this to its own account when ready.
1use axum::http::HeaderMap;
2use cid::Cid;
3use ipld_core::ipld::Ipld;
4use rand::Rng;
5use serde_json::Value as JsonValue;
6use sqlx::PgPool;
7use std::collections::BTreeMap;
8use std::str::FromStr;
9use std::sync::OnceLock;
10use uuid::Uuid;
11
12use crate::types::{Did, Handle};
13
14const BASE32_ALPHABET: &str = "abcdefghijklmnopqrstuvwxyz234567";
15const DEFAULT_MAX_BLOB_SIZE: usize = 10 * 1024 * 1024 * 1024;
16
17static MAX_BLOB_SIZE: OnceLock<usize> = OnceLock::new();
18
19pub fn get_max_blob_size() -> usize {
20 *MAX_BLOB_SIZE.get_or_init(|| {
21 std::env::var("MAX_BLOB_SIZE")
22 .ok()
23 .and_then(|s| s.parse().ok())
24 .unwrap_or(DEFAULT_MAX_BLOB_SIZE)
25 })
26}
27
28pub fn generate_token_code() -> String {
29 generate_token_code_parts(2, 5)
30}
31
32pub fn generate_token_code_parts(parts: usize, part_len: usize) -> String {
33 let mut rng = rand::thread_rng();
34 let chars: Vec<char> = BASE32_ALPHABET.chars().collect();
35
36 (0..parts)
37 .map(|_| {
38 (0..part_len)
39 .map(|_| chars[rng.gen_range(0..chars.len())])
40 .collect::<String>()
41 })
42 .collect::<Vec<_>>()
43 .join("-")
44}
45
46#[derive(Debug)]
47pub enum DbLookupError {
48 NotFound,
49 DatabaseError(sqlx::Error),
50}
51
52impl From<sqlx::Error> for DbLookupError {
53 fn from(e: sqlx::Error) -> Self {
54 DbLookupError::DatabaseError(e)
55 }
56}
57
58pub async fn get_user_id_by_did(db: &PgPool, did: &str) -> Result<Uuid, DbLookupError> {
59 sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
60 .fetch_optional(db)
61 .await?
62 .ok_or(DbLookupError::NotFound)
63}
64
65pub struct UserInfo {
66 pub id: Uuid,
67 pub did: Did,
68 pub handle: Handle,
69}
70
71pub async fn get_user_by_did(db: &PgPool, did: &str) -> Result<UserInfo, DbLookupError> {
72 sqlx::query_as!(
73 UserInfo,
74 "SELECT id, did, handle FROM users WHERE did = $1",
75 did
76 )
77 .fetch_optional(db)
78 .await?
79 .ok_or(DbLookupError::NotFound)
80}
81
82pub async fn get_user_by_identifier(
83 db: &PgPool,
84 identifier: &str,
85) -> Result<UserInfo, DbLookupError> {
86 sqlx::query_as!(
87 UserInfo,
88 "SELECT id, did, handle FROM users WHERE did = $1 OR handle = $1",
89 identifier
90 )
91 .fetch_optional(db)
92 .await?
93 .ok_or(DbLookupError::NotFound)
94}
95
96pub async fn is_account_migrated(db: &PgPool, did: &str) -> Result<bool, sqlx::Error> {
97 let row = sqlx::query!(
98 r#"SELECT (migrated_to_pds IS NOT NULL AND deactivated_at IS NOT NULL) as "migrated!: bool" FROM users WHERE did = $1"#,
99 did
100 )
101 .fetch_optional(db)
102 .await?;
103 Ok(row.map(|r| r.migrated).unwrap_or(false))
104}
105
106pub fn parse_repeated_query_param(query: Option<&str>, key: &str) -> Vec<String> {
107 query
108 .map(|q| {
109 let mut values = Vec::new();
110 for pair in q.split('&') {
111 if let Some((k, v)) = pair.split_once('=')
112 && k == key
113 && let Ok(decoded) = urlencoding::decode(v)
114 {
115 let decoded = decoded.into_owned();
116 if decoded.contains(',') {
117 for part in decoded.split(',') {
118 let trimmed = part.trim();
119 if !trimmed.is_empty() {
120 values.push(trimmed.to_string());
121 }
122 }
123 } else if !decoded.is_empty() {
124 values.push(decoded);
125 }
126 }
127 }
128 values
129 })
130 .unwrap_or_default()
131}
132
133pub fn extract_client_ip(headers: &HeaderMap) -> String {
134 if let Some(forwarded) = headers.get("x-forwarded-for")
135 && let Ok(value) = forwarded.to_str()
136 && let Some(first_ip) = value.split(',').next()
137 {
138 return first_ip.trim().to_string();
139 }
140 if let Some(real_ip) = headers.get("x-real-ip")
141 && let Ok(value) = real_ip.to_str()
142 {
143 return value.trim().to_string();
144 }
145 "unknown".to_string()
146}
147
148pub fn pds_hostname() -> String {
149 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string())
150}
151
152pub fn pds_public_url() -> String {
153 format!("https://{}", pds_hostname())
154}
155
156pub fn build_full_url(path: &str) -> String {
157 let normalized_path = if !path.starts_with("/xrpc/")
158 && (path.starts_with("/com.atproto.")
159 || path.starts_with("/app.bsky.")
160 || path.starts_with("/_"))
161 {
162 format!("/xrpc{}", path)
163 } else {
164 path.to_string()
165 };
166 format!("{}{}", pds_public_url(), normalized_path)
167}
168
169pub fn json_to_ipld(value: &JsonValue) -> Ipld {
170 match value {
171 JsonValue::Null => Ipld::Null,
172 JsonValue::Bool(b) => Ipld::Bool(*b),
173 JsonValue::Number(n) => {
174 if let Some(i) = n.as_i64() {
175 Ipld::Integer(i as i128)
176 } else if let Some(f) = n.as_f64() {
177 Ipld::Float(f)
178 } else {
179 Ipld::Null
180 }
181 }
182 JsonValue::String(s) => Ipld::String(s.clone()),
183 JsonValue::Array(arr) => Ipld::List(arr.iter().map(json_to_ipld).collect()),
184 JsonValue::Object(obj) => {
185 if let Some(JsonValue::String(link)) = obj.get("$link")
186 && obj.len() == 1
187 && let Ok(cid) = Cid::from_str(link)
188 {
189 return Ipld::Link(cid);
190 }
191 let map: BTreeMap<String, Ipld> = obj
192 .iter()
193 .map(|(k, v)| (k.clone(), json_to_ipld(v)))
194 .collect();
195 Ipld::Map(map)
196 }
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 #[test]
205 fn test_parse_repeated_query_param_repeated() {
206 let query = "did=test&cids=a&cids=b&cids=c";
207 let result = parse_repeated_query_param(Some(query), "cids");
208 assert_eq!(result, vec!["a", "b", "c"]);
209 }
210
211 #[test]
212 fn test_parse_repeated_query_param_comma_separated() {
213 let query = "did=test&cids=a,b,c";
214 let result = parse_repeated_query_param(Some(query), "cids");
215 assert_eq!(result, vec!["a", "b", "c"]);
216 }
217
218 #[test]
219 fn test_parse_repeated_query_param_mixed() {
220 let query = "did=test&cids=a,b&cids=c";
221 let result = parse_repeated_query_param(Some(query), "cids");
222 assert_eq!(result, vec!["a", "b", "c"]);
223 }
224
225 #[test]
226 fn test_parse_repeated_query_param_single() {
227 let query = "did=test&cids=a";
228 let result = parse_repeated_query_param(Some(query), "cids");
229 assert_eq!(result, vec!["a"]);
230 }
231
232 #[test]
233 fn test_parse_repeated_query_param_empty() {
234 let query = "did=test";
235 let result = parse_repeated_query_param(Some(query), "cids");
236 assert!(result.is_empty());
237 }
238
239 #[test]
240 fn test_parse_repeated_query_param_url_encoded() {
241 let query = "did=test&cids=bafyreib%2Btest";
242 let result = parse_repeated_query_param(Some(query), "cids");
243 assert_eq!(result, vec!["bafyreib+test"]);
244 }
245
246 #[test]
247 fn test_generate_token_code() {
248 let code = generate_token_code();
249 assert_eq!(code.len(), 11);
250 assert!(code.contains('-'));
251
252 let parts: Vec<&str> = code.split('-').collect();
253 assert_eq!(parts.len(), 2);
254 assert_eq!(parts[0].len(), 5);
255 assert_eq!(parts[1].len(), 5);
256
257 for c in code.chars() {
258 if c != '-' {
259 assert!(BASE32_ALPHABET.contains(c));
260 }
261 }
262 }
263
264 #[test]
265 fn test_generate_token_code_parts() {
266 let code = generate_token_code_parts(3, 4);
267 let parts: Vec<&str> = code.split('-').collect();
268 assert_eq!(parts.len(), 3);
269
270 for part in parts {
271 assert_eq!(part.len(), 4);
272 }
273 }
274
275 #[test]
276 fn test_json_to_ipld_cid_link() {
277 let json = serde_json::json!({
278 "$link": "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"
279 });
280 let ipld = json_to_ipld(&json);
281 match ipld {
282 Ipld::Link(cid) => {
283 assert_eq!(
284 cid.to_string(),
285 "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"
286 );
287 }
288 _ => panic!("Expected Ipld::Link, got {:?}", ipld),
289 }
290 }
291
292 #[test]
293 fn test_json_to_ipld_blob_ref() {
294 let json = serde_json::json!({
295 "$type": "blob",
296 "ref": {
297 "$link": "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"
298 },
299 "mimeType": "image/jpeg",
300 "size": 12345
301 });
302 let ipld = json_to_ipld(&json);
303 match ipld {
304 Ipld::Map(map) => {
305 assert_eq!(map.get("$type"), Some(&Ipld::String("blob".to_string())));
306 assert_eq!(
307 map.get("mimeType"),
308 Some(&Ipld::String("image/jpeg".to_string()))
309 );
310 assert_eq!(map.get("size"), Some(&Ipld::Integer(12345)));
311 match map.get("ref") {
312 Some(Ipld::Link(cid)) => {
313 assert_eq!(
314 cid.to_string(),
315 "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"
316 );
317 }
318 _ => panic!("Expected Ipld::Link in ref field, got {:?}", map.get("ref")),
319 }
320 }
321 _ => panic!("Expected Ipld::Map, got {:?}", ipld),
322 }
323 }
324
325 #[test]
326 fn test_json_to_ipld_nested_blob_refs_serializes_correctly() {
327 let record = serde_json::json!({
328 "$type": "app.bsky.feed.post",
329 "text": "Hello world",
330 "embed": {
331 "$type": "app.bsky.embed.images",
332 "images": [
333 {
334 "alt": "Test image",
335 "image": {
336 "$type": "blob",
337 "ref": {
338 "$link": "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"
339 },
340 "mimeType": "image/jpeg",
341 "size": 12345
342 }
343 }
344 ]
345 }
346 });
347 let ipld = json_to_ipld(&record);
348 let cbor_bytes = serde_ipld_dagcbor::to_vec(&ipld).expect("CBOR serialization failed");
349 assert!(!cbor_bytes.is_empty());
350 let parsed: Ipld =
351 serde_ipld_dagcbor::from_slice(&cbor_bytes).expect("CBOR deserialization failed");
352 if let Ipld::Map(map) = &parsed
353 && let Some(Ipld::Map(embed)) = map.get("embed")
354 && let Some(Ipld::List(images)) = embed.get("images")
355 && let Some(Ipld::Map(img)) = images.first()
356 && let Some(Ipld::Map(blob)) = img.get("image")
357 && let Some(Ipld::Link(cid)) = blob.get("ref")
358 {
359 assert_eq!(
360 cid.to_string(),
361 "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"
362 );
363 return;
364 }
365 panic!("Failed to find CID link in parsed CBOR");
366 }
367
368 #[test]
369 fn test_build_full_url_adds_xrpc_prefix_for_atproto_paths() {
370 unsafe { std::env::set_var("PDS_HOSTNAME", "example.com") };
371 assert_eq!(
372 build_full_url("/com.atproto.server.getSession"),
373 "https://example.com/xrpc/com.atproto.server.getSession"
374 );
375 assert_eq!(
376 build_full_url("/app.bsky.feed.getTimeline"),
377 "https://example.com/xrpc/app.bsky.feed.getTimeline"
378 );
379 assert_eq!(
380 build_full_url("/_health"),
381 "https://example.com/xrpc/_health"
382 );
383 assert_eq!(
384 build_full_url("/xrpc/com.atproto.server.getSession"),
385 "https://example.com/xrpc/com.atproto.server.getSession"
386 );
387 assert_eq!(
388 build_full_url("/oauth/token"),
389 "https://example.com/oauth/token"
390 );
391 }
392}