forked from
slices.network/slices
Highly ambitious ATProtocol AppView service and sdks
1//! Cursor-based pagination utilities.
2//!
3//! Cursors encode the position in a result set as base64(field1|field2|...|cid)
4//! to enable stable pagination even when new records are inserted.
5//!
6//! The cursor format:
7//! - All sort field values are included in the cursor
8//! - Values are separated by pipe (|) characters
9//! - CID is always the last element as the ultimate tiebreaker
10
11use super::types::SortField;
12use crate::models::Record;
13use base64::{Engine as _, engine::general_purpose};
14
15/// Generates a cursor from a record based on the sort configuration.
16///
17/// Extracts all sort field values from the record and encodes them along with the CID.
18/// Format: `base64(field1_value|field2_value|...|cid)`
19///
20/// # Arguments
21/// * `record` - The record to generate a cursor for
22/// * `sort_by` - Optional array defining sort fields
23///
24/// # Returns
25/// Base64-encoded cursor string
26pub fn generate_cursor_from_record(record: &Record, sort_by: Option<&Vec<SortField>>) -> String {
27 let mut cursor_parts = Vec::new();
28
29 // Extract values for all sort fields
30 if let Some(sort_fields) = sort_by {
31 for sort_field in sort_fields {
32 let field_value = extract_field_value(record, &sort_field.field);
33 cursor_parts.push(field_value);
34 }
35 }
36
37 // Always add CID as the final tiebreaker
38 cursor_parts.push(record.cid.clone());
39
40 // Join with pipe and encode
41 let cursor_content = cursor_parts.join("|");
42 general_purpose::URL_SAFE_NO_PAD.encode(cursor_content)
43}
44
45/// Extracts a field value from a record.
46///
47/// Handles both table columns and JSON fields with nested paths.
48fn extract_field_value(record: &Record, field: &str) -> String {
49 match field {
50 "indexed_at" => record.indexed_at.to_rfc3339(),
51 "uri" => record.uri.clone(),
52 "cid" => record.cid.clone(),
53 "did" => record.did.clone(),
54 "collection" => record.collection.clone(),
55 _ => {
56 // Handle nested JSON paths
57 let field_path: Vec<&str> = field.split('.').collect();
58 let mut value = &record.json;
59
60 for key in &field_path {
61 value = match value.get(key) {
62 Some(v) => v,
63 None => return "NULL".to_string(),
64 };
65 }
66
67 match value {
68 serde_json::Value::String(s) => s.clone(),
69 serde_json::Value::Number(n) => n.to_string(),
70 serde_json::Value::Bool(b) => b.to_string(),
71 serde_json::Value::Null => "NULL".to_string(),
72 _ => "NULL".to_string(),
73 }
74 }
75 }
76}
77
78/// Decoded cursor components for pagination.
79#[derive(Debug, Clone)]
80pub struct DecodedCursor {
81 /// Field values in the order they appear in sortBy
82 pub field_values: Vec<String>,
83 /// CID (always the last element)
84 pub cid: String,
85}
86
87/// Decodes a base64-encoded cursor back into its components.
88///
89/// The cursor format is: `base64(field1|field2|...|cid)`
90///
91/// # Arguments
92/// * `cursor` - Base64-encoded cursor string
93/// * `sort_by` - Optional array of sort fields to validate cursor format
94///
95/// # Returns
96/// Result containing DecodedCursor or error if decoding fails
97pub fn decode_cursor(cursor: &str, sort_by: Option<&Vec<SortField>>) -> Result<DecodedCursor, String> {
98 let decoded_bytes = general_purpose::URL_SAFE_NO_PAD.decode(cursor)
99 .map_err(|e| format!("Failed to decode base64: {}", e))?;
100 let decoded_str = String::from_utf8(decoded_bytes)
101 .map_err(|e| format!("Invalid UTF-8 in cursor: {}", e))?;
102
103 let parts: Vec<&str> = decoded_str.split('|').collect();
104
105 // Validate cursor format matches sortBy fields
106 let expected_parts = if let Some(fields) = sort_by {
107 fields.len() + 1 // sort fields + CID
108 } else {
109 1 // just CID if no sortBy
110 };
111
112 if parts.len() != expected_parts {
113 return Err(format!(
114 "Invalid cursor format: expected {} parts, got {}",
115 expected_parts,
116 parts.len()
117 ));
118 }
119
120 let cid = parts[parts.len() - 1].to_string();
121 let field_values: Vec<String> = parts[..parts.len() - 1]
122 .iter()
123 .map(|s| s.to_string())
124 .collect();
125
126 Ok(DecodedCursor {
127 field_values,
128 cid,
129 })
130}
131
132/// Builds cursor-based WHERE conditions for proper multi-field pagination.
133///
134/// Creates progressive equality checks for stable multi-field sorting.
135/// For each field, we OR together:
136/// 1. field1 > cursor_value1
137/// 2. field1 = cursor_value1 AND field2 > cursor_value2
138/// 3. field1 = cursor_value1 AND field2 = cursor_value2 AND field3 > cursor_value3
139/// ... and so on
140///
141/// Finally: all fields equal AND cid > cursor_cid
142///
143/// # Arguments
144/// * `decoded_cursor` - The decoded cursor components
145/// * `sort_by` - Optional array of sort fields
146/// * `param_count` - Mutable counter for parameter numbering
147/// * `field_types` - Optional array indicating if each field is a datetime
148///
149/// # Returns
150/// Tuple of (where_condition_sql, bind_values)
151pub fn build_cursor_where_condition(
152 decoded_cursor: &DecodedCursor,
153 sort_by: Option<&Vec<SortField>>,
154 param_count: &mut usize,
155 field_types: Option<&[bool]>,
156) -> (String, Vec<String>) {
157 let mut bind_values = Vec::new();
158 let mut clauses = Vec::new();
159
160 let sort_fields = match sort_by {
161 Some(fields) if !fields.is_empty() => fields,
162 _ => {
163 // No sort fields, shouldn't happen but handle gracefully
164 return ("1=1".to_string(), vec![]);
165 }
166 };
167
168 // Build progressive equality checks for each level
169 for i in 0..sort_fields.len() {
170 let mut clause_parts = Vec::new();
171
172 // Add equality checks for all previous fields
173 for (j, sort_field) in sort_fields.iter().enumerate().take(i) {
174 let field = &sort_field.field;
175 let cursor_value = &decoded_cursor.field_values[j];
176 let is_datetime = field_types.and_then(|types| types.get(j).copied()).unwrap_or(false);
177
178 let field_ref = build_field_reference(field, is_datetime);
179 let param_cast = if is_datetime { "::timestamp" } else { "" };
180
181 clause_parts.push(format!("{} = ${}{}", field_ref, param_count, param_cast));
182 *param_count += 1;
183 bind_values.push(cursor_value.clone());
184 }
185
186 // Add comparison for current field
187 let field = &sort_fields[i].field;
188 let cursor_value = &decoded_cursor.field_values[i];
189 let direction = &sort_fields[i].direction;
190 let is_datetime = field_types.and_then(|types| types.get(i).copied()).unwrap_or(false);
191
192 let comparison_op = if direction.to_lowercase() == "desc" { "<" } else { ">" };
193 let field_ref = build_field_reference(field, is_datetime);
194 let param_cast = if is_datetime { "::timestamp" } else { "" };
195
196 clause_parts.push(format!("{} {} ${}{}", field_ref, comparison_op, param_count, param_cast));
197 *param_count += 1;
198 bind_values.push(cursor_value.clone());
199
200 // Combine with AND
201 clauses.push(format!("({})", clause_parts.join(" AND ")));
202 }
203
204 // Add final clause: all fields equal AND cid comparison
205 let mut final_clause_parts = Vec::new();
206 for (j, field) in sort_fields.iter().enumerate() {
207 let cursor_value = &decoded_cursor.field_values[j];
208 let is_datetime = field_types.and_then(|types| types.get(j).copied()).unwrap_or(false);
209
210 let field_ref = build_field_reference(&field.field, is_datetime);
211 let param_cast = if is_datetime { "::timestamp" } else { "" };
212
213 final_clause_parts.push(format!("{} = ${}{}", field_ref, param_count, param_cast));
214 *param_count += 1;
215 bind_values.push(cursor_value.clone());
216 }
217
218 // CID comparison uses the direction of the last sort field
219 let last_direction = &sort_fields[sort_fields.len() - 1].direction;
220 let cid_comparison_op = if last_direction.to_lowercase() == "desc" { "<" } else { ">" };
221
222 final_clause_parts.push(format!("cid {} ${}", cid_comparison_op, param_count));
223 *param_count += 1;
224 bind_values.push(decoded_cursor.cid.clone());
225
226 clauses.push(format!("({})", final_clause_parts.join(" AND ")));
227
228 // Combine all clauses with OR
229 let where_condition = format!("({})", clauses.join(" OR "));
230
231 (where_condition, bind_values)
232}
233
234/// Builds a field reference for SQL queries.
235///
236/// Handles table columns, JSON fields, and nested paths with optional timestamp casting.
237pub(super) fn build_field_reference(field: &str, is_datetime: bool) -> String {
238 // Table columns don't need JSON extraction
239 if matches!(field, "uri" | "cid" | "did" | "collection" | "indexed_at") {
240 return field.to_string();
241 }
242
243 // Build JSON path for nested or simple fields
244 let json_path = if field.contains('.') {
245 let parts: Vec<&str> = field.split('.').collect();
246 let mut path = String::from("json");
247 for (i, part) in parts.iter().enumerate() {
248 if i == parts.len() - 1 {
249 path.push_str(&format!("->>'{}'", part));
250 } else {
251 path.push_str(&format!("->'{}'", part));
252 }
253 }
254 path
255 } else {
256 format!("json->>'{}'", field)
257 };
258
259 // Add timestamp cast if needed
260 if is_datetime {
261 format!("({})::timestamp", json_path)
262 } else {
263 json_path
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270 use chrono::Utc;
271
272 fn create_test_record() -> Record {
273 Record {
274 uri: "at://did:plc:test/app.bsky.feed.post/123".to_string(),
275 cid: "bafytest123".to_string(),
276 did: "did:plc:test".to_string(),
277 collection: "app.bsky.feed.post".to_string(),
278 json: serde_json::json!({
279 "text": "Hello world",
280 "createdAt": "2025-01-15T12:00:00Z",
281 "nested": {
282 "field": "value"
283 }
284 }),
285 indexed_at: Utc::now(),
286 slice_uri: Some("at://did:plc:slice/network.slices.slice/abc".to_string()),
287 }
288 }
289
290 #[test]
291 fn test_generate_cursor_no_sort() {
292 let record = create_test_record();
293 let cursor = generate_cursor_from_record(&record, None);
294
295 let decoded = general_purpose::URL_SAFE_NO_PAD.decode(&cursor).unwrap();
296 let decoded_str = String::from_utf8(decoded).unwrap();
297
298 assert_eq!(decoded_str, "bafytest123");
299 }
300
301 #[test]
302 fn test_generate_cursor_with_sort() {
303 let record = create_test_record();
304 let sort_by = vec![
305 SortField {
306 field: "text".to_string(),
307 direction: "desc".to_string(),
308 },
309 ];
310
311 let cursor = generate_cursor_from_record(&record, Some(&sort_by));
312 let decoded = general_purpose::URL_SAFE_NO_PAD.decode(&cursor).unwrap();
313 let decoded_str = String::from_utf8(decoded).unwrap();
314
315 assert_eq!(decoded_str, "Hello world|bafytest123");
316 }
317
318 #[test]
319 fn test_decode_cursor_single_field() {
320 let sort_by = vec![
321 SortField {
322 field: "createdAt".to_string(),
323 direction: "desc".to_string(),
324 },
325 ];
326
327 let cursor_content = "2025-01-15T12:00:00Z|bafytest123";
328 let cursor = general_purpose::URL_SAFE_NO_PAD.encode(cursor_content);
329
330 let decoded = decode_cursor(&cursor, Some(&sort_by)).unwrap();
331
332 assert_eq!(decoded.field_values, vec!["2025-01-15T12:00:00Z"]);
333 assert_eq!(decoded.cid, "bafytest123");
334 }
335
336 #[test]
337 fn test_decode_cursor_multiple_fields() {
338 let sort_by = vec![
339 SortField {
340 field: "text".to_string(),
341 direction: "desc".to_string(),
342 },
343 SortField {
344 field: "createdAt".to_string(),
345 direction: "desc".to_string(),
346 },
347 ];
348
349 let cursor_content = "Hello world|2025-01-15T12:00:00Z|bafytest123";
350 let cursor = general_purpose::URL_SAFE_NO_PAD.encode(cursor_content);
351
352 let decoded = decode_cursor(&cursor, Some(&sort_by)).unwrap();
353
354 assert_eq!(decoded.field_values, vec!["Hello world", "2025-01-15T12:00:00Z"]);
355 assert_eq!(decoded.cid, "bafytest123");
356 }
357
358 #[test]
359 fn test_decode_cursor_invalid_format() {
360 let sort_by = vec![
361 SortField {
362 field: "text".to_string(),
363 direction: "desc".to_string(),
364 },
365 ];
366
367 let cursor_content = "bafytest123";
368 let cursor = general_purpose::URL_SAFE_NO_PAD.encode(cursor_content);
369
370 let result = decode_cursor(&cursor, Some(&sort_by));
371 assert!(result.is_err());
372 }
373
374 #[test]
375 fn test_build_field_reference_table_column() {
376 assert_eq!(build_field_reference("uri", false), "uri");
377 assert_eq!(build_field_reference("cid", false), "cid");
378 assert_eq!(build_field_reference("indexed_at", false), "indexed_at");
379 }
380
381 #[test]
382 fn test_build_field_reference_json_field() {
383 assert_eq!(build_field_reference("text", false), "json->>'text'");
384 assert_eq!(
385 build_field_reference("text", true),
386 "(json->>'text')::timestamp"
387 );
388 }
389
390 #[test]
391 fn test_build_field_reference_nested_json() {
392 assert_eq!(
393 build_field_reference("nested.field", false),
394 "json->'nested'->>'field'"
395 );
396 assert_eq!(
397 build_field_reference("nested.field", true),
398 "(json->'nested'->>'field')::timestamp"
399 );
400 }
401
402 #[test]
403 fn test_extract_field_value_table_columns() {
404 let record = create_test_record();
405
406 assert_eq!(extract_field_value(&record, "uri"), record.uri);
407 assert_eq!(extract_field_value(&record, "cid"), record.cid);
408 assert_eq!(extract_field_value(&record, "did"), record.did);
409 assert_eq!(extract_field_value(&record, "collection"), record.collection);
410 }
411
412 #[test]
413 fn test_extract_field_value_json() {
414 let record = create_test_record();
415
416 assert_eq!(extract_field_value(&record, "text"), "Hello world");
417 assert_eq!(extract_field_value(&record, "createdAt"), "2025-01-15T12:00:00Z");
418 }
419
420 #[test]
421 fn test_extract_field_value_nested_json() {
422 let record = create_test_record();
423
424 assert_eq!(extract_field_value(&record, "nested.field"), "value");
425 }
426
427 #[test]
428 fn test_extract_field_value_missing() {
429 let record = create_test_record();
430
431 assert_eq!(extract_field_value(&record, "nonexistent"), "NULL");
432 assert_eq!(extract_field_value(&record, "nested.nonexistent"), "NULL");
433 }
434}