Highly ambitious ATProtocol AppView service and sdks
at main 15 kB view raw
1//! Cursor-based pagination utilities. 2//! 3//! Cursors encode the position in a result set as base64(field1|field2|...|cid) 4//! to enable stable pagination even when new records are inserted. 5//! 6//! The cursor format: 7//! - All sort field values are included in the cursor 8//! - Values are separated by pipe (|) characters 9//! - CID is always the last element as the ultimate tiebreaker 10 11use super::types::SortField; 12use crate::models::Record; 13use base64::{Engine as _, engine::general_purpose}; 14 15/// Generates a cursor from a record based on the sort configuration. 16/// 17/// Extracts all sort field values from the record and encodes them along with the CID. 18/// Format: `base64(field1_value|field2_value|...|cid)` 19/// 20/// # Arguments 21/// * `record` - The record to generate a cursor for 22/// * `sort_by` - Optional array defining sort fields 23/// 24/// # Returns 25/// Base64-encoded cursor string 26pub fn generate_cursor_from_record(record: &Record, sort_by: Option<&Vec<SortField>>) -> String { 27 let mut cursor_parts = Vec::new(); 28 29 // Extract values for all sort fields 30 if let Some(sort_fields) = sort_by { 31 for sort_field in sort_fields { 32 let field_value = extract_field_value(record, &sort_field.field); 33 cursor_parts.push(field_value); 34 } 35 } 36 37 // Always add CID as the final tiebreaker 38 cursor_parts.push(record.cid.clone()); 39 40 // Join with pipe and encode 41 let cursor_content = cursor_parts.join("|"); 42 general_purpose::URL_SAFE_NO_PAD.encode(cursor_content) 43} 44 45/// Extracts a field value from a record. 46/// 47/// Handles both table columns and JSON fields with nested paths. 48fn extract_field_value(record: &Record, field: &str) -> String { 49 match field { 50 "indexed_at" => record.indexed_at.to_rfc3339(), 51 "uri" => record.uri.clone(), 52 "cid" => record.cid.clone(), 53 "did" => record.did.clone(), 54 "collection" => record.collection.clone(), 55 _ => { 56 // Handle nested JSON paths 57 let field_path: Vec<&str> = field.split('.').collect(); 58 let mut value = &record.json; 59 60 for key in &field_path { 61 value = match value.get(key) { 62 Some(v) => v, 63 None => return "NULL".to_string(), 64 }; 65 } 66 67 match value { 68 serde_json::Value::String(s) => s.clone(), 69 serde_json::Value::Number(n) => n.to_string(), 70 serde_json::Value::Bool(b) => b.to_string(), 71 serde_json::Value::Null => "NULL".to_string(), 72 _ => "NULL".to_string(), 73 } 74 } 75 } 76} 77 78/// Decoded cursor components for pagination. 79#[derive(Debug, Clone)] 80pub struct DecodedCursor { 81 /// Field values in the order they appear in sortBy 82 pub field_values: Vec<String>, 83 /// CID (always the last element) 84 pub cid: String, 85} 86 87/// Decodes a base64-encoded cursor back into its components. 88/// 89/// The cursor format is: `base64(field1|field2|...|cid)` 90/// 91/// # Arguments 92/// * `cursor` - Base64-encoded cursor string 93/// * `sort_by` - Optional array of sort fields to validate cursor format 94/// 95/// # Returns 96/// Result containing DecodedCursor or error if decoding fails 97pub fn decode_cursor(cursor: &str, sort_by: Option<&Vec<SortField>>) -> Result<DecodedCursor, String> { 98 let decoded_bytes = general_purpose::URL_SAFE_NO_PAD.decode(cursor) 99 .map_err(|e| format!("Failed to decode base64: {}", e))?; 100 let decoded_str = String::from_utf8(decoded_bytes) 101 .map_err(|e| format!("Invalid UTF-8 in cursor: {}", e))?; 102 103 let parts: Vec<&str> = decoded_str.split('|').collect(); 104 105 // Validate cursor format matches sortBy fields 106 let expected_parts = if let Some(fields) = sort_by { 107 fields.len() + 1 // sort fields + CID 108 } else { 109 1 // just CID if no sortBy 110 }; 111 112 if parts.len() != expected_parts { 113 return Err(format!( 114 "Invalid cursor format: expected {} parts, got {}", 115 expected_parts, 116 parts.len() 117 )); 118 } 119 120 let cid = parts[parts.len() - 1].to_string(); 121 let field_values: Vec<String> = parts[..parts.len() - 1] 122 .iter() 123 .map(|s| s.to_string()) 124 .collect(); 125 126 Ok(DecodedCursor { 127 field_values, 128 cid, 129 }) 130} 131 132/// Builds cursor-based WHERE conditions for proper multi-field pagination. 133/// 134/// Creates progressive equality checks for stable multi-field sorting. 135/// For each field, we OR together: 136/// 1. field1 > cursor_value1 137/// 2. field1 = cursor_value1 AND field2 > cursor_value2 138/// 3. field1 = cursor_value1 AND field2 = cursor_value2 AND field3 > cursor_value3 139/// ... and so on 140/// 141/// Finally: all fields equal AND cid > cursor_cid 142/// 143/// # Arguments 144/// * `decoded_cursor` - The decoded cursor components 145/// * `sort_by` - Optional array of sort fields 146/// * `param_count` - Mutable counter for parameter numbering 147/// * `field_types` - Optional array indicating if each field is a datetime 148/// 149/// # Returns 150/// Tuple of (where_condition_sql, bind_values) 151pub fn build_cursor_where_condition( 152 decoded_cursor: &DecodedCursor, 153 sort_by: Option<&Vec<SortField>>, 154 param_count: &mut usize, 155 field_types: Option<&[bool]>, 156) -> (String, Vec<String>) { 157 let mut bind_values = Vec::new(); 158 let mut clauses = Vec::new(); 159 160 let sort_fields = match sort_by { 161 Some(fields) if !fields.is_empty() => fields, 162 _ => { 163 // No sort fields, shouldn't happen but handle gracefully 164 return ("1=1".to_string(), vec![]); 165 } 166 }; 167 168 // Build progressive equality checks for each level 169 for i in 0..sort_fields.len() { 170 let mut clause_parts = Vec::new(); 171 172 // Add equality checks for all previous fields 173 for (j, sort_field) in sort_fields.iter().enumerate().take(i) { 174 let field = &sort_field.field; 175 let cursor_value = &decoded_cursor.field_values[j]; 176 let is_datetime = field_types.and_then(|types| types.get(j).copied()).unwrap_or(false); 177 178 let field_ref = build_field_reference(field, is_datetime); 179 let param_cast = if is_datetime { "::timestamp" } else { "" }; 180 181 clause_parts.push(format!("{} = ${}{}", field_ref, param_count, param_cast)); 182 *param_count += 1; 183 bind_values.push(cursor_value.clone()); 184 } 185 186 // Add comparison for current field 187 let field = &sort_fields[i].field; 188 let cursor_value = &decoded_cursor.field_values[i]; 189 let direction = &sort_fields[i].direction; 190 let is_datetime = field_types.and_then(|types| types.get(i).copied()).unwrap_or(false); 191 192 let comparison_op = if direction.to_lowercase() == "desc" { "<" } else { ">" }; 193 let field_ref = build_field_reference(field, is_datetime); 194 let param_cast = if is_datetime { "::timestamp" } else { "" }; 195 196 clause_parts.push(format!("{} {} ${}{}", field_ref, comparison_op, param_count, param_cast)); 197 *param_count += 1; 198 bind_values.push(cursor_value.clone()); 199 200 // Combine with AND 201 clauses.push(format!("({})", clause_parts.join(" AND "))); 202 } 203 204 // Add final clause: all fields equal AND cid comparison 205 let mut final_clause_parts = Vec::new(); 206 for (j, field) in sort_fields.iter().enumerate() { 207 let cursor_value = &decoded_cursor.field_values[j]; 208 let is_datetime = field_types.and_then(|types| types.get(j).copied()).unwrap_or(false); 209 210 let field_ref = build_field_reference(&field.field, is_datetime); 211 let param_cast = if is_datetime { "::timestamp" } else { "" }; 212 213 final_clause_parts.push(format!("{} = ${}{}", field_ref, param_count, param_cast)); 214 *param_count += 1; 215 bind_values.push(cursor_value.clone()); 216 } 217 218 // CID comparison uses the direction of the last sort field 219 let last_direction = &sort_fields[sort_fields.len() - 1].direction; 220 let cid_comparison_op = if last_direction.to_lowercase() == "desc" { "<" } else { ">" }; 221 222 final_clause_parts.push(format!("cid {} ${}", cid_comparison_op, param_count)); 223 *param_count += 1; 224 bind_values.push(decoded_cursor.cid.clone()); 225 226 clauses.push(format!("({})", final_clause_parts.join(" AND "))); 227 228 // Combine all clauses with OR 229 let where_condition = format!("({})", clauses.join(" OR ")); 230 231 (where_condition, bind_values) 232} 233 234/// Builds a field reference for SQL queries. 235/// 236/// Handles table columns, JSON fields, and nested paths with optional timestamp casting. 237pub(super) fn build_field_reference(field: &str, is_datetime: bool) -> String { 238 // Table columns don't need JSON extraction 239 if matches!(field, "uri" | "cid" | "did" | "collection" | "indexed_at") { 240 return field.to_string(); 241 } 242 243 // Build JSON path for nested or simple fields 244 let json_path = if field.contains('.') { 245 let parts: Vec<&str> = field.split('.').collect(); 246 let mut path = String::from("json"); 247 for (i, part) in parts.iter().enumerate() { 248 if i == parts.len() - 1 { 249 path.push_str(&format!("->>'{}'", part)); 250 } else { 251 path.push_str(&format!("->'{}'", part)); 252 } 253 } 254 path 255 } else { 256 format!("json->>'{}'", field) 257 }; 258 259 // Add timestamp cast if needed 260 if is_datetime { 261 format!("({})::timestamp", json_path) 262 } else { 263 json_path 264 } 265} 266 267#[cfg(test)] 268mod tests { 269 use super::*; 270 use chrono::Utc; 271 272 fn create_test_record() -> Record { 273 Record { 274 uri: "at://did:plc:test/app.bsky.feed.post/123".to_string(), 275 cid: "bafytest123".to_string(), 276 did: "did:plc:test".to_string(), 277 collection: "app.bsky.feed.post".to_string(), 278 json: serde_json::json!({ 279 "text": "Hello world", 280 "createdAt": "2025-01-15T12:00:00Z", 281 "nested": { 282 "field": "value" 283 } 284 }), 285 indexed_at: Utc::now(), 286 slice_uri: Some("at://did:plc:slice/network.slices.slice/abc".to_string()), 287 } 288 } 289 290 #[test] 291 fn test_generate_cursor_no_sort() { 292 let record = create_test_record(); 293 let cursor = generate_cursor_from_record(&record, None); 294 295 let decoded = general_purpose::URL_SAFE_NO_PAD.decode(&cursor).unwrap(); 296 let decoded_str = String::from_utf8(decoded).unwrap(); 297 298 assert_eq!(decoded_str, "bafytest123"); 299 } 300 301 #[test] 302 fn test_generate_cursor_with_sort() { 303 let record = create_test_record(); 304 let sort_by = vec![ 305 SortField { 306 field: "text".to_string(), 307 direction: "desc".to_string(), 308 }, 309 ]; 310 311 let cursor = generate_cursor_from_record(&record, Some(&sort_by)); 312 let decoded = general_purpose::URL_SAFE_NO_PAD.decode(&cursor).unwrap(); 313 let decoded_str = String::from_utf8(decoded).unwrap(); 314 315 assert_eq!(decoded_str, "Hello world|bafytest123"); 316 } 317 318 #[test] 319 fn test_decode_cursor_single_field() { 320 let sort_by = vec![ 321 SortField { 322 field: "createdAt".to_string(), 323 direction: "desc".to_string(), 324 }, 325 ]; 326 327 let cursor_content = "2025-01-15T12:00:00Z|bafytest123"; 328 let cursor = general_purpose::URL_SAFE_NO_PAD.encode(cursor_content); 329 330 let decoded = decode_cursor(&cursor, Some(&sort_by)).unwrap(); 331 332 assert_eq!(decoded.field_values, vec!["2025-01-15T12:00:00Z"]); 333 assert_eq!(decoded.cid, "bafytest123"); 334 } 335 336 #[test] 337 fn test_decode_cursor_multiple_fields() { 338 let sort_by = vec![ 339 SortField { 340 field: "text".to_string(), 341 direction: "desc".to_string(), 342 }, 343 SortField { 344 field: "createdAt".to_string(), 345 direction: "desc".to_string(), 346 }, 347 ]; 348 349 let cursor_content = "Hello world|2025-01-15T12:00:00Z|bafytest123"; 350 let cursor = general_purpose::URL_SAFE_NO_PAD.encode(cursor_content); 351 352 let decoded = decode_cursor(&cursor, Some(&sort_by)).unwrap(); 353 354 assert_eq!(decoded.field_values, vec!["Hello world", "2025-01-15T12:00:00Z"]); 355 assert_eq!(decoded.cid, "bafytest123"); 356 } 357 358 #[test] 359 fn test_decode_cursor_invalid_format() { 360 let sort_by = vec![ 361 SortField { 362 field: "text".to_string(), 363 direction: "desc".to_string(), 364 }, 365 ]; 366 367 let cursor_content = "bafytest123"; 368 let cursor = general_purpose::URL_SAFE_NO_PAD.encode(cursor_content); 369 370 let result = decode_cursor(&cursor, Some(&sort_by)); 371 assert!(result.is_err()); 372 } 373 374 #[test] 375 fn test_build_field_reference_table_column() { 376 assert_eq!(build_field_reference("uri", false), "uri"); 377 assert_eq!(build_field_reference("cid", false), "cid"); 378 assert_eq!(build_field_reference("indexed_at", false), "indexed_at"); 379 } 380 381 #[test] 382 fn test_build_field_reference_json_field() { 383 assert_eq!(build_field_reference("text", false), "json->>'text'"); 384 assert_eq!( 385 build_field_reference("text", true), 386 "(json->>'text')::timestamp" 387 ); 388 } 389 390 #[test] 391 fn test_build_field_reference_nested_json() { 392 assert_eq!( 393 build_field_reference("nested.field", false), 394 "json->'nested'->>'field'" 395 ); 396 assert_eq!( 397 build_field_reference("nested.field", true), 398 "(json->'nested'->>'field')::timestamp" 399 ); 400 } 401 402 #[test] 403 fn test_extract_field_value_table_columns() { 404 let record = create_test_record(); 405 406 assert_eq!(extract_field_value(&record, "uri"), record.uri); 407 assert_eq!(extract_field_value(&record, "cid"), record.cid); 408 assert_eq!(extract_field_value(&record, "did"), record.did); 409 assert_eq!(extract_field_value(&record, "collection"), record.collection); 410 } 411 412 #[test] 413 fn test_extract_field_value_json() { 414 let record = create_test_record(); 415 416 assert_eq!(extract_field_value(&record, "text"), "Hello world"); 417 assert_eq!(extract_field_value(&record, "createdAt"), "2025-01-15T12:00:00Z"); 418 } 419 420 #[test] 421 fn test_extract_field_value_nested_json() { 422 let record = create_test_record(); 423 424 assert_eq!(extract_field_value(&record, "nested.field"), "value"); 425 } 426 427 #[test] 428 fn test_extract_field_value_missing() { 429 let record = create_test_record(); 430 431 assert_eq!(extract_field_value(&record, "nonexistent"), "NULL"); 432 assert_eq!(extract_field_value(&record, "nested.nonexistent"), "NULL"); 433 } 434}