Highly ambitious ATProtocol AppView service and sdks
1//! Cursor-based pagination utilities.
2//!
3//! Cursors encode the position in a result set as base64(field1|field2|...|cid)
4//! to enable stable pagination even when new records are inserted.
5//!
6//! The cursor format:
7//! - All sort field values are included in the cursor
8//! - Values are separated by pipe (|) characters
9//! - CID is always the last element as the ultimate tiebreaker
10
11use super::types::SortField;
12use crate::models::Record;
13use base64::{Engine as _, engine::general_purpose};
14
15/// Generates a cursor from a record based on the sort configuration.
16///
17/// Extracts all sort field values from the record and encodes them along with the CID.
18/// Format: `base64(field1_value|field2_value|...|cid)`
19///
20/// # Arguments
21/// * `record` - The record to generate a cursor for
22/// * `sort_by` - Optional array defining sort fields
23///
24/// # Returns
25/// Base64-encoded cursor string
26pub fn generate_cursor_from_record(record: &Record, sort_by: Option<&Vec<SortField>>) -> String {
27 let mut cursor_parts = Vec::new();
28
29 // Extract values for all sort fields
30 if let Some(sort_fields) = sort_by {
31 for sort_field in sort_fields {
32 let field_value = extract_field_value(record, &sort_field.field);
33 cursor_parts.push(field_value);
34 }
35 }
36
37 // Always add CID as the final tiebreaker
38 cursor_parts.push(record.cid.clone());
39
40 // Join with pipe and encode
41 let cursor_content = cursor_parts.join("|");
42 general_purpose::URL_SAFE_NO_PAD.encode(cursor_content)
43}
44
45/// Extracts a field value from a record.
46///
47/// Handles both table columns and JSON fields with nested paths.
48fn extract_field_value(record: &Record, field: &str) -> String {
49 match field {
50 "indexed_at" => record.indexed_at.to_rfc3339(),
51 "uri" => record.uri.clone(),
52 "cid" => record.cid.clone(),
53 "did" => record.did.clone(),
54 "collection" => record.collection.clone(),
55 _ => {
56 // Handle nested JSON paths
57 let field_path: Vec<&str> = field.split('.').collect();
58 let mut value = &record.json;
59
60 for key in &field_path {
61 value = match value.get(key) {
62 Some(v) => v,
63 None => return "NULL".to_string(),
64 };
65 }
66
67 match value {
68 serde_json::Value::String(s) => s.clone(),
69 serde_json::Value::Number(n) => n.to_string(),
70 serde_json::Value::Bool(b) => b.to_string(),
71 serde_json::Value::Null => "NULL".to_string(),
72 _ => "NULL".to_string(),
73 }
74 }
75 }
76}
77
78/// Decoded cursor components for pagination.
79#[derive(Debug, Clone)]
80pub struct DecodedCursor {
81 /// Field values in the order they appear in sortBy
82 pub field_values: Vec<String>,
83 /// CID (always the last element)
84 pub cid: String,
85}
86
87/// Decodes a base64-encoded cursor back into its components.
88///
89/// The cursor format is: `base64(field1|field2|...|cid)`
90///
91/// # Arguments
92/// * `cursor` - Base64-encoded cursor string
93/// * `sort_by` - Optional array of sort fields to validate cursor format
94///
95/// # Returns
96/// Result containing DecodedCursor or error if decoding fails
97pub fn decode_cursor(
98 cursor: &str,
99 sort_by: Option<&Vec<SortField>>,
100) -> Result<DecodedCursor, String> {
101 let decoded_bytes = general_purpose::URL_SAFE_NO_PAD
102 .decode(cursor)
103 .map_err(|e| format!("Failed to decode base64: {}", e))?;
104 let decoded_str =
105 String::from_utf8(decoded_bytes).map_err(|e| format!("Invalid UTF-8 in cursor: {}", e))?;
106
107 let parts: Vec<&str> = decoded_str.split('|').collect();
108
109 // Validate cursor format matches sortBy fields
110 let expected_parts = if let Some(fields) = sort_by {
111 fields.len() + 1 // sort fields + CID
112 } else {
113 1 // just CID if no sortBy
114 };
115
116 if parts.len() != expected_parts {
117 return Err(format!(
118 "Invalid cursor format: expected {} parts, got {}",
119 expected_parts,
120 parts.len()
121 ));
122 }
123
124 let cid = parts[parts.len() - 1].to_string();
125 let field_values: Vec<String> = parts[..parts.len() - 1]
126 .iter()
127 .map(|s| s.to_string())
128 .collect();
129
130 Ok(DecodedCursor { field_values, cid })
131}
132
133/// Builds cursor-based WHERE conditions for proper multi-field pagination.
134///
135/// Creates progressive equality checks for stable multi-field sorting.
136/// For each field, we OR together:
137/// 1. field1 > cursor_value1
138/// 2. field1 = cursor_value1 AND field2 > cursor_value2
139/// 3. field1 = cursor_value1 AND field2 = cursor_value2 AND field3 > cursor_value3
140/// ... and so on
141///
142/// Finally: all fields equal AND cid > cursor_cid
143///
144/// # Arguments
145/// * `decoded_cursor` - The decoded cursor components
146/// * `sort_by` - Optional array of sort fields
147/// * `param_count` - Mutable counter for parameter numbering
148/// * `field_types` - Optional array indicating if each field is a datetime
149///
150/// # Returns
151/// Tuple of (where_condition_sql, bind_values)
152pub fn build_cursor_where_condition(
153 decoded_cursor: &DecodedCursor,
154 sort_by: Option<&Vec<SortField>>,
155 param_count: &mut usize,
156 field_types: Option<&[bool]>,
157) -> (String, Vec<String>) {
158 let mut bind_values = Vec::new();
159 let mut clauses = Vec::new();
160
161 let sort_fields = match sort_by {
162 Some(fields) if !fields.is_empty() => fields,
163 _ => {
164 // No sort fields, shouldn't happen but handle gracefully
165 return ("1=1".to_string(), vec![]);
166 }
167 };
168
169 // Build progressive equality checks for each level
170 for i in 0..sort_fields.len() {
171 let mut clause_parts = Vec::new();
172
173 // Add equality checks for all previous fields
174 for (j, sort_field) in sort_fields.iter().enumerate().take(i) {
175 let field = &sort_field.field;
176 let cursor_value = &decoded_cursor.field_values[j];
177 let is_datetime = field_types
178 .and_then(|types| types.get(j).copied())
179 .unwrap_or(false);
180
181 let field_ref = build_field_reference(field, is_datetime);
182 let param_cast = if is_datetime { "::timestamp" } else { "" };
183
184 clause_parts.push(format!("{} = ${}{}", field_ref, param_count, param_cast));
185 *param_count += 1;
186 bind_values.push(cursor_value.clone());
187 }
188
189 // Add comparison for current field
190 let field = &sort_fields[i].field;
191 let cursor_value = &decoded_cursor.field_values[i];
192 let direction = &sort_fields[i].direction;
193 let is_datetime = field_types
194 .and_then(|types| types.get(i).copied())
195 .unwrap_or(false);
196
197 let comparison_op = if direction.to_lowercase() == "desc" {
198 "<"
199 } else {
200 ">"
201 };
202 let field_ref = build_field_reference(field, is_datetime);
203 let param_cast = if is_datetime { "::timestamp" } else { "" };
204
205 clause_parts.push(format!(
206 "{} {} ${}{}",
207 field_ref, comparison_op, param_count, param_cast
208 ));
209 *param_count += 1;
210 bind_values.push(cursor_value.clone());
211
212 // Combine with AND
213 clauses.push(format!("({})", clause_parts.join(" AND ")));
214 }
215
216 // Add final clause: all fields equal AND cid comparison
217 let mut final_clause_parts = Vec::new();
218 for (j, field) in sort_fields.iter().enumerate() {
219 let cursor_value = &decoded_cursor.field_values[j];
220 let is_datetime = field_types
221 .and_then(|types| types.get(j).copied())
222 .unwrap_or(false);
223
224 let field_ref = build_field_reference(&field.field, is_datetime);
225 let param_cast = if is_datetime { "::timestamp" } else { "" };
226
227 final_clause_parts.push(format!("{} = ${}{}", field_ref, param_count, param_cast));
228 *param_count += 1;
229 bind_values.push(cursor_value.clone());
230 }
231
232 // CID comparison uses the direction of the last sort field
233 let last_direction = &sort_fields[sort_fields.len() - 1].direction;
234 let cid_comparison_op = if last_direction.to_lowercase() == "desc" {
235 "<"
236 } else {
237 ">"
238 };
239
240 final_clause_parts.push(format!("cid {} ${}", cid_comparison_op, param_count));
241 *param_count += 1;
242 bind_values.push(decoded_cursor.cid.clone());
243
244 clauses.push(format!("({})", final_clause_parts.join(" AND ")));
245
246 // Combine all clauses with OR
247 let where_condition = format!("({})", clauses.join(" OR "));
248
249 (where_condition, bind_values)
250}
251
252/// Builds a field reference for SQL queries.
253///
254/// Handles table columns, JSON fields, and nested paths with optional timestamp casting.
255pub(super) fn build_field_reference(field: &str, is_datetime: bool) -> String {
256 // Table columns don't need JSON extraction
257 if matches!(field, "uri" | "cid" | "did" | "collection" | "indexed_at") {
258 return field.to_string();
259 }
260
261 // Build JSON path for nested or simple fields
262 let json_path = if field.contains('.') {
263 let parts: Vec<&str> = field.split('.').collect();
264 let mut path = String::from("json");
265 for (i, part) in parts.iter().enumerate() {
266 if i == parts.len() - 1 {
267 path.push_str(&format!("->>'{}'", part));
268 } else {
269 path.push_str(&format!("->'{}'", part));
270 }
271 }
272 path
273 } else {
274 format!("json->>'{}'", field)
275 };
276
277 // Add timestamp cast if needed
278 if is_datetime {
279 format!("({})::timestamp", json_path)
280 } else {
281 json_path
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288 use chrono::Utc;
289
290 fn create_test_record() -> Record {
291 Record {
292 uri: "at://did:plc:test/app.bsky.feed.post/123".to_string(),
293 cid: "bafytest123".to_string(),
294 did: "did:plc:test".to_string(),
295 collection: "app.bsky.feed.post".to_string(),
296 json: serde_json::json!({
297 "text": "Hello world",
298 "createdAt": "2025-01-15T12:00:00Z",
299 "nested": {
300 "field": "value"
301 }
302 }),
303 indexed_at: Utc::now(),
304 slice_uri: Some("at://did:plc:slice/network.slices.slice/abc".to_string()),
305 }
306 }
307
308 #[test]
309 fn test_generate_cursor_no_sort() {
310 let record = create_test_record();
311 let cursor = generate_cursor_from_record(&record, None);
312
313 let decoded = general_purpose::URL_SAFE_NO_PAD.decode(&cursor).unwrap();
314 let decoded_str = String::from_utf8(decoded).unwrap();
315
316 assert_eq!(decoded_str, "bafytest123");
317 }
318
319 #[test]
320 fn test_generate_cursor_with_sort() {
321 let record = create_test_record();
322 let sort_by = vec![SortField {
323 field: "text".to_string(),
324 direction: "desc".to_string(),
325 }];
326
327 let cursor = generate_cursor_from_record(&record, Some(&sort_by));
328 let decoded = general_purpose::URL_SAFE_NO_PAD.decode(&cursor).unwrap();
329 let decoded_str = String::from_utf8(decoded).unwrap();
330
331 assert_eq!(decoded_str, "Hello world|bafytest123");
332 }
333
334 #[test]
335 fn test_decode_cursor_single_field() {
336 let sort_by = vec![SortField {
337 field: "createdAt".to_string(),
338 direction: "desc".to_string(),
339 }];
340
341 let cursor_content = "2025-01-15T12:00:00Z|bafytest123";
342 let cursor = general_purpose::URL_SAFE_NO_PAD.encode(cursor_content);
343
344 let decoded = decode_cursor(&cursor, Some(&sort_by)).unwrap();
345
346 assert_eq!(decoded.field_values, vec!["2025-01-15T12:00:00Z"]);
347 assert_eq!(decoded.cid, "bafytest123");
348 }
349
350 #[test]
351 fn test_decode_cursor_multiple_fields() {
352 let sort_by = vec![
353 SortField {
354 field: "text".to_string(),
355 direction: "desc".to_string(),
356 },
357 SortField {
358 field: "createdAt".to_string(),
359 direction: "desc".to_string(),
360 },
361 ];
362
363 let cursor_content = "Hello world|2025-01-15T12:00:00Z|bafytest123";
364 let cursor = general_purpose::URL_SAFE_NO_PAD.encode(cursor_content);
365
366 let decoded = decode_cursor(&cursor, Some(&sort_by)).unwrap();
367
368 assert_eq!(
369 decoded.field_values,
370 vec!["Hello world", "2025-01-15T12:00:00Z"]
371 );
372 assert_eq!(decoded.cid, "bafytest123");
373 }
374
375 #[test]
376 fn test_decode_cursor_invalid_format() {
377 let sort_by = vec![SortField {
378 field: "text".to_string(),
379 direction: "desc".to_string(),
380 }];
381
382 let cursor_content = "bafytest123";
383 let cursor = general_purpose::URL_SAFE_NO_PAD.encode(cursor_content);
384
385 let result = decode_cursor(&cursor, Some(&sort_by));
386 assert!(result.is_err());
387 }
388
389 #[test]
390 fn test_build_field_reference_table_column() {
391 assert_eq!(build_field_reference("uri", false), "uri");
392 assert_eq!(build_field_reference("cid", false), "cid");
393 assert_eq!(build_field_reference("indexed_at", false), "indexed_at");
394 }
395
396 #[test]
397 fn test_build_field_reference_json_field() {
398 assert_eq!(build_field_reference("text", false), "json->>'text'");
399 assert_eq!(
400 build_field_reference("text", true),
401 "(json->>'text')::timestamp"
402 );
403 }
404
405 #[test]
406 fn test_build_field_reference_nested_json() {
407 assert_eq!(
408 build_field_reference("nested.field", false),
409 "json->'nested'->>'field'"
410 );
411 assert_eq!(
412 build_field_reference("nested.field", true),
413 "(json->'nested'->>'field')::timestamp"
414 );
415 }
416
417 #[test]
418 fn test_extract_field_value_table_columns() {
419 let record = create_test_record();
420
421 assert_eq!(extract_field_value(&record, "uri"), record.uri);
422 assert_eq!(extract_field_value(&record, "cid"), record.cid);
423 assert_eq!(extract_field_value(&record, "did"), record.did);
424 assert_eq!(
425 extract_field_value(&record, "collection"),
426 record.collection
427 );
428 }
429
430 #[test]
431 fn test_extract_field_value_json() {
432 let record = create_test_record();
433
434 assert_eq!(extract_field_value(&record, "text"), "Hello world");
435 assert_eq!(
436 extract_field_value(&record, "createdAt"),
437 "2025-01-15T12:00:00Z"
438 );
439 }
440
441 #[test]
442 fn test_extract_field_value_nested_json() {
443 let record = create_test_record();
444
445 assert_eq!(extract_field_value(&record, "nested.field"), "value");
446 }
447
448 #[test]
449 fn test_extract_field_value_missing() {
450 let record = create_test_record();
451
452 assert_eq!(extract_field_value(&record, "nonexistent"), "NULL");
453 assert_eq!(extract_field_value(&record, "nested.nonexistent"), "NULL");
454 }
455}