Highly ambitious ATProtocol AppView service and sdks
at main 15 kB view raw
1//! Cursor-based pagination utilities. 2//! 3//! Cursors encode the position in a result set as base64(field1|field2|...|cid) 4//! to enable stable pagination even when new records are inserted. 5//! 6//! The cursor format: 7//! - All sort field values are included in the cursor 8//! - Values are separated by pipe (|) characters 9//! - CID is always the last element as the ultimate tiebreaker 10 11use super::types::SortField; 12use crate::models::Record; 13use base64::{Engine as _, engine::general_purpose}; 14 15/// Generates a cursor from a record based on the sort configuration. 16/// 17/// Extracts all sort field values from the record and encodes them along with the CID. 18/// Format: `base64(field1_value|field2_value|...|cid)` 19/// 20/// # Arguments 21/// * `record` - The record to generate a cursor for 22/// * `sort_by` - Optional array defining sort fields 23/// 24/// # Returns 25/// Base64-encoded cursor string 26pub fn generate_cursor_from_record(record: &Record, sort_by: Option<&Vec<SortField>>) -> String { 27 let mut cursor_parts = Vec::new(); 28 29 // Extract values for all sort fields 30 if let Some(sort_fields) = sort_by { 31 for sort_field in sort_fields { 32 let field_value = extract_field_value(record, &sort_field.field); 33 cursor_parts.push(field_value); 34 } 35 } 36 37 // Always add CID as the final tiebreaker 38 cursor_parts.push(record.cid.clone()); 39 40 // Join with pipe and encode 41 let cursor_content = cursor_parts.join("|"); 42 general_purpose::URL_SAFE_NO_PAD.encode(cursor_content) 43} 44 45/// Extracts a field value from a record. 46/// 47/// Handles both table columns and JSON fields with nested paths. 48fn extract_field_value(record: &Record, field: &str) -> String { 49 match field { 50 "indexed_at" => record.indexed_at.to_rfc3339(), 51 "uri" => record.uri.clone(), 52 "cid" => record.cid.clone(), 53 "did" => record.did.clone(), 54 "collection" => record.collection.clone(), 55 _ => { 56 // Handle nested JSON paths 57 let field_path: Vec<&str> = field.split('.').collect(); 58 let mut value = &record.json; 59 60 for key in &field_path { 61 value = match value.get(key) { 62 Some(v) => v, 63 None => return "NULL".to_string(), 64 }; 65 } 66 67 match value { 68 serde_json::Value::String(s) => s.clone(), 69 serde_json::Value::Number(n) => n.to_string(), 70 serde_json::Value::Bool(b) => b.to_string(), 71 serde_json::Value::Null => "NULL".to_string(), 72 _ => "NULL".to_string(), 73 } 74 } 75 } 76} 77 78/// Decoded cursor components for pagination. 79#[derive(Debug, Clone)] 80pub struct DecodedCursor { 81 /// Field values in the order they appear in sortBy 82 pub field_values: Vec<String>, 83 /// CID (always the last element) 84 pub cid: String, 85} 86 87/// Decodes a base64-encoded cursor back into its components. 88/// 89/// The cursor format is: `base64(field1|field2|...|cid)` 90/// 91/// # Arguments 92/// * `cursor` - Base64-encoded cursor string 93/// * `sort_by` - Optional array of sort fields to validate cursor format 94/// 95/// # Returns 96/// Result containing DecodedCursor or error if decoding fails 97pub fn decode_cursor( 98 cursor: &str, 99 sort_by: Option<&Vec<SortField>>, 100) -> Result<DecodedCursor, String> { 101 let decoded_bytes = general_purpose::URL_SAFE_NO_PAD 102 .decode(cursor) 103 .map_err(|e| format!("Failed to decode base64: {}", e))?; 104 let decoded_str = 105 String::from_utf8(decoded_bytes).map_err(|e| format!("Invalid UTF-8 in cursor: {}", e))?; 106 107 let parts: Vec<&str> = decoded_str.split('|').collect(); 108 109 // Validate cursor format matches sortBy fields 110 let expected_parts = if let Some(fields) = sort_by { 111 fields.len() + 1 // sort fields + CID 112 } else { 113 1 // just CID if no sortBy 114 }; 115 116 if parts.len() != expected_parts { 117 return Err(format!( 118 "Invalid cursor format: expected {} parts, got {}", 119 expected_parts, 120 parts.len() 121 )); 122 } 123 124 let cid = parts[parts.len() - 1].to_string(); 125 let field_values: Vec<String> = parts[..parts.len() - 1] 126 .iter() 127 .map(|s| s.to_string()) 128 .collect(); 129 130 Ok(DecodedCursor { field_values, cid }) 131} 132 133/// Builds cursor-based WHERE conditions for proper multi-field pagination. 134/// 135/// Creates progressive equality checks for stable multi-field sorting. 136/// For each field, we OR together: 137/// 1. field1 > cursor_value1 138/// 2. field1 = cursor_value1 AND field2 > cursor_value2 139/// 3. field1 = cursor_value1 AND field2 = cursor_value2 AND field3 > cursor_value3 140/// ... and so on 141/// 142/// Finally: all fields equal AND cid > cursor_cid 143/// 144/// # Arguments 145/// * `decoded_cursor` - The decoded cursor components 146/// * `sort_by` - Optional array of sort fields 147/// * `param_count` - Mutable counter for parameter numbering 148/// * `field_types` - Optional array indicating if each field is a datetime 149/// 150/// # Returns 151/// Tuple of (where_condition_sql, bind_values) 152pub fn build_cursor_where_condition( 153 decoded_cursor: &DecodedCursor, 154 sort_by: Option<&Vec<SortField>>, 155 param_count: &mut usize, 156 field_types: Option<&[bool]>, 157) -> (String, Vec<String>) { 158 let mut bind_values = Vec::new(); 159 let mut clauses = Vec::new(); 160 161 let sort_fields = match sort_by { 162 Some(fields) if !fields.is_empty() => fields, 163 _ => { 164 // No sort fields, shouldn't happen but handle gracefully 165 return ("1=1".to_string(), vec![]); 166 } 167 }; 168 169 // Build progressive equality checks for each level 170 for i in 0..sort_fields.len() { 171 let mut clause_parts = Vec::new(); 172 173 // Add equality checks for all previous fields 174 for (j, sort_field) in sort_fields.iter().enumerate().take(i) { 175 let field = &sort_field.field; 176 let cursor_value = &decoded_cursor.field_values[j]; 177 let is_datetime = field_types 178 .and_then(|types| types.get(j).copied()) 179 .unwrap_or(false); 180 181 let field_ref = build_field_reference(field, is_datetime); 182 let param_cast = if is_datetime { "::timestamp" } else { "" }; 183 184 clause_parts.push(format!("{} = ${}{}", field_ref, param_count, param_cast)); 185 *param_count += 1; 186 bind_values.push(cursor_value.clone()); 187 } 188 189 // Add comparison for current field 190 let field = &sort_fields[i].field; 191 let cursor_value = &decoded_cursor.field_values[i]; 192 let direction = &sort_fields[i].direction; 193 let is_datetime = field_types 194 .and_then(|types| types.get(i).copied()) 195 .unwrap_or(false); 196 197 let comparison_op = if direction.to_lowercase() == "desc" { 198 "<" 199 } else { 200 ">" 201 }; 202 let field_ref = build_field_reference(field, is_datetime); 203 let param_cast = if is_datetime { "::timestamp" } else { "" }; 204 205 clause_parts.push(format!( 206 "{} {} ${}{}", 207 field_ref, comparison_op, param_count, param_cast 208 )); 209 *param_count += 1; 210 bind_values.push(cursor_value.clone()); 211 212 // Combine with AND 213 clauses.push(format!("({})", clause_parts.join(" AND "))); 214 } 215 216 // Add final clause: all fields equal AND cid comparison 217 let mut final_clause_parts = Vec::new(); 218 for (j, field) in sort_fields.iter().enumerate() { 219 let cursor_value = &decoded_cursor.field_values[j]; 220 let is_datetime = field_types 221 .and_then(|types| types.get(j).copied()) 222 .unwrap_or(false); 223 224 let field_ref = build_field_reference(&field.field, is_datetime); 225 let param_cast = if is_datetime { "::timestamp" } else { "" }; 226 227 final_clause_parts.push(format!("{} = ${}{}", field_ref, param_count, param_cast)); 228 *param_count += 1; 229 bind_values.push(cursor_value.clone()); 230 } 231 232 // CID comparison uses the direction of the last sort field 233 let last_direction = &sort_fields[sort_fields.len() - 1].direction; 234 let cid_comparison_op = if last_direction.to_lowercase() == "desc" { 235 "<" 236 } else { 237 ">" 238 }; 239 240 final_clause_parts.push(format!("cid {} ${}", cid_comparison_op, param_count)); 241 *param_count += 1; 242 bind_values.push(decoded_cursor.cid.clone()); 243 244 clauses.push(format!("({})", final_clause_parts.join(" AND "))); 245 246 // Combine all clauses with OR 247 let where_condition = format!("({})", clauses.join(" OR ")); 248 249 (where_condition, bind_values) 250} 251 252/// Builds a field reference for SQL queries. 253/// 254/// Handles table columns, JSON fields, and nested paths with optional timestamp casting. 255pub(super) fn build_field_reference(field: &str, is_datetime: bool) -> String { 256 // Table columns don't need JSON extraction 257 if matches!(field, "uri" | "cid" | "did" | "collection" | "indexed_at") { 258 return field.to_string(); 259 } 260 261 // Build JSON path for nested or simple fields 262 let json_path = if field.contains('.') { 263 let parts: Vec<&str> = field.split('.').collect(); 264 let mut path = String::from("json"); 265 for (i, part) in parts.iter().enumerate() { 266 if i == parts.len() - 1 { 267 path.push_str(&format!("->>'{}'", part)); 268 } else { 269 path.push_str(&format!("->'{}'", part)); 270 } 271 } 272 path 273 } else { 274 format!("json->>'{}'", field) 275 }; 276 277 // Add timestamp cast if needed 278 if is_datetime { 279 format!("({})::timestamp", json_path) 280 } else { 281 json_path 282 } 283} 284 285#[cfg(test)] 286mod tests { 287 use super::*; 288 use chrono::Utc; 289 290 fn create_test_record() -> Record { 291 Record { 292 uri: "at://did:plc:test/app.bsky.feed.post/123".to_string(), 293 cid: "bafytest123".to_string(), 294 did: "did:plc:test".to_string(), 295 collection: "app.bsky.feed.post".to_string(), 296 json: serde_json::json!({ 297 "text": "Hello world", 298 "createdAt": "2025-01-15T12:00:00Z", 299 "nested": { 300 "field": "value" 301 } 302 }), 303 indexed_at: Utc::now(), 304 slice_uri: Some("at://did:plc:slice/network.slices.slice/abc".to_string()), 305 } 306 } 307 308 #[test] 309 fn test_generate_cursor_no_sort() { 310 let record = create_test_record(); 311 let cursor = generate_cursor_from_record(&record, None); 312 313 let decoded = general_purpose::URL_SAFE_NO_PAD.decode(&cursor).unwrap(); 314 let decoded_str = String::from_utf8(decoded).unwrap(); 315 316 assert_eq!(decoded_str, "bafytest123"); 317 } 318 319 #[test] 320 fn test_generate_cursor_with_sort() { 321 let record = create_test_record(); 322 let sort_by = vec![SortField { 323 field: "text".to_string(), 324 direction: "desc".to_string(), 325 }]; 326 327 let cursor = generate_cursor_from_record(&record, Some(&sort_by)); 328 let decoded = general_purpose::URL_SAFE_NO_PAD.decode(&cursor).unwrap(); 329 let decoded_str = String::from_utf8(decoded).unwrap(); 330 331 assert_eq!(decoded_str, "Hello world|bafytest123"); 332 } 333 334 #[test] 335 fn test_decode_cursor_single_field() { 336 let sort_by = vec![SortField { 337 field: "createdAt".to_string(), 338 direction: "desc".to_string(), 339 }]; 340 341 let cursor_content = "2025-01-15T12:00:00Z|bafytest123"; 342 let cursor = general_purpose::URL_SAFE_NO_PAD.encode(cursor_content); 343 344 let decoded = decode_cursor(&cursor, Some(&sort_by)).unwrap(); 345 346 assert_eq!(decoded.field_values, vec!["2025-01-15T12:00:00Z"]); 347 assert_eq!(decoded.cid, "bafytest123"); 348 } 349 350 #[test] 351 fn test_decode_cursor_multiple_fields() { 352 let sort_by = vec![ 353 SortField { 354 field: "text".to_string(), 355 direction: "desc".to_string(), 356 }, 357 SortField { 358 field: "createdAt".to_string(), 359 direction: "desc".to_string(), 360 }, 361 ]; 362 363 let cursor_content = "Hello world|2025-01-15T12:00:00Z|bafytest123"; 364 let cursor = general_purpose::URL_SAFE_NO_PAD.encode(cursor_content); 365 366 let decoded = decode_cursor(&cursor, Some(&sort_by)).unwrap(); 367 368 assert_eq!( 369 decoded.field_values, 370 vec!["Hello world", "2025-01-15T12:00:00Z"] 371 ); 372 assert_eq!(decoded.cid, "bafytest123"); 373 } 374 375 #[test] 376 fn test_decode_cursor_invalid_format() { 377 let sort_by = vec![SortField { 378 field: "text".to_string(), 379 direction: "desc".to_string(), 380 }]; 381 382 let cursor_content = "bafytest123"; 383 let cursor = general_purpose::URL_SAFE_NO_PAD.encode(cursor_content); 384 385 let result = decode_cursor(&cursor, Some(&sort_by)); 386 assert!(result.is_err()); 387 } 388 389 #[test] 390 fn test_build_field_reference_table_column() { 391 assert_eq!(build_field_reference("uri", false), "uri"); 392 assert_eq!(build_field_reference("cid", false), "cid"); 393 assert_eq!(build_field_reference("indexed_at", false), "indexed_at"); 394 } 395 396 #[test] 397 fn test_build_field_reference_json_field() { 398 assert_eq!(build_field_reference("text", false), "json->>'text'"); 399 assert_eq!( 400 build_field_reference("text", true), 401 "(json->>'text')::timestamp" 402 ); 403 } 404 405 #[test] 406 fn test_build_field_reference_nested_json() { 407 assert_eq!( 408 build_field_reference("nested.field", false), 409 "json->'nested'->>'field'" 410 ); 411 assert_eq!( 412 build_field_reference("nested.field", true), 413 "(json->'nested'->>'field')::timestamp" 414 ); 415 } 416 417 #[test] 418 fn test_extract_field_value_table_columns() { 419 let record = create_test_record(); 420 421 assert_eq!(extract_field_value(&record, "uri"), record.uri); 422 assert_eq!(extract_field_value(&record, "cid"), record.cid); 423 assert_eq!(extract_field_value(&record, "did"), record.did); 424 assert_eq!( 425 extract_field_value(&record, "collection"), 426 record.collection 427 ); 428 } 429 430 #[test] 431 fn test_extract_field_value_json() { 432 let record = create_test_record(); 433 434 assert_eq!(extract_field_value(&record, "text"), "Hello world"); 435 assert_eq!( 436 extract_field_value(&record, "createdAt"), 437 "2025-01-15T12:00:00Z" 438 ); 439 } 440 441 #[test] 442 fn test_extract_field_value_nested_json() { 443 let record = create_test_record(); 444 445 assert_eq!(extract_field_value(&record, "nested.field"), "value"); 446 } 447 448 #[test] 449 fn test_extract_field_value_missing() { 450 let record = create_test_record(); 451 452 assert_eq!(extract_field_value(&record, "nonexistent"), "NULL"); 453 assert_eq!(extract_field_value(&record, "nested.nonexistent"), "NULL"); 454 } 455}