learn and share notes on atproto (wip) 馃 malfestio.stormlightlabs.org/
readability solid axum atproto srs
5
fork

Configure Feed

Select the types of activity you want to include in your feed.

at main 518 lines 18 kB view raw
1//! Sync service for coordinating bi-directional PDS synchronization. 2//! 3//! Handles push/pull operations and conflict resolution. 4 5use crate::middleware::auth::UserContext; 6use crate::pds::client::{GetRecordResult, PdsClient, PdsError}; 7use crate::pds::records::{prepare_card_record, prepare_deck_record, prepare_note_record}; 8use crate::repository::card::CardRepository; 9use crate::repository::deck::DeckRepository; 10use crate::repository::note::NoteRepository; 11use crate::repository::oauth::OAuthRepository; 12use crate::repository::sync::{LogOperationParams, SyncRepoError, SyncRepository, SyncStatus}; 13use std::str::FromStr; 14use std::sync::Arc; 15 16/// Result of a sync operation. 17#[derive(Debug, Clone)] 18pub struct SyncResult { 19 pub entity_type: String, 20 pub entity_id: String, 21 pub pds_uri: Option<String>, 22 pub pds_cid: Option<String>, 23 pub new_version: i32, 24 pub status: SyncStatus, 25} 26 27/// Conflict information for UI display. 28#[derive(Debug, Clone)] 29pub struct ConflictInfo { 30 pub entity_type: String, 31 pub entity_id: String, 32 pub local_version: i32, 33 pub remote_version: Option<i32>, 34 pub local_updated_at: Option<String>, 35 pub remote_updated_at: Option<String>, 36} 37 38/// Summary of sync status for a user. 39#[derive(Debug, Clone)] 40pub struct SyncStatusSummary { 41 pub pending_count: usize, 42 pub conflict_count: usize, 43 pub pending_items: Vec<(String, String)>, 44 pub conflicts: Vec<ConflictInfo>, 45} 46 47/// Conflict resolution strategy. 48#[derive(Debug, Clone, Copy, PartialEq, Eq)] 49pub enum ConflictStrategy { 50 /// Use the most recently modified version (default) 51 LastWriteWins, 52 /// Keep local version, overwrite remote 53 KeepLocal, 54 /// Keep remote version, overwrite local 55 KeepRemote, 56 // TODO: MergeUI - Show UI for manual merge 57} 58 59impl ConflictStrategy { 60 pub fn as_str(&self) -> &'static str { 61 match self { 62 ConflictStrategy::LastWriteWins => "last_write_wins", 63 ConflictStrategy::KeepLocal => "keep_local", 64 ConflictStrategy::KeepRemote => "keep_remote", 65 } 66 } 67} 68 69impl FromStr for ConflictStrategy { 70 type Err = String; 71 72 fn from_str(s: &str) -> Result<Self, Self::Err> { 73 match s { 74 "last_write_wins" => Ok(ConflictStrategy::LastWriteWins), 75 "keep_local" => Ok(ConflictStrategy::KeepLocal), 76 "keep_remote" => Ok(ConflictStrategy::KeepRemote), 77 _ => Err(format!("Invalid conflict strategy: {}", s)), 78 } 79 } 80} 81 82/// Error type for sync operations. 83#[derive(Debug)] 84pub enum SyncError { 85 /// Entity not found 86 NotFound(String), 87 /// Authentication required 88 AuthRequired(String), 89 /// No OAuth tokens available 90 NoTokens(String), 91 /// PDS operation failed 92 PdsError(PdsError), 93 /// Repository error 94 RepoError(SyncRepoError), 95 /// Invalid argument 96 InvalidArgument(String), 97 /// Conflict detected 98 ConflictDetected(ConflictInfo), 99} 100 101impl std::fmt::Display for SyncError { 102 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 103 match self { 104 SyncError::NotFound(e) => write!(f, "Not found: {}", e), 105 SyncError::AuthRequired(e) => write!(f, "Authentication required: {}", e), 106 SyncError::NoTokens(e) => write!(f, "No OAuth tokens: {}", e), 107 SyncError::PdsError(e) => write!(f, "PDS error: {}", e), 108 SyncError::RepoError(e) => write!(f, "Repository error: {}", e), 109 SyncError::InvalidArgument(e) => write!(f, "Invalid argument: {}", e), 110 SyncError::ConflictDetected(c) => { 111 write!(f, "Conflict detected for {}:{}", c.entity_type, c.entity_id) 112 } 113 } 114 } 115} 116 117impl std::error::Error for SyncError {} 118 119impl From<SyncRepoError> for SyncError { 120 fn from(e: SyncRepoError) -> Self { 121 SyncError::RepoError(e) 122 } 123} 124 125impl From<PdsError> for SyncError { 126 fn from(e: PdsError) -> Self { 127 SyncError::PdsError(e) 128 } 129} 130 131/// Remote record data retrieved from PDS. 132#[derive(Debug, Clone)] 133pub struct RemoteRecord { 134 pub uri: String, 135 pub cid: String, 136 pub value: serde_json::Value, 137} 138 139/// Sync service for coordinating sync operations. 140pub struct SyncService { 141 sync_repo: Arc<dyn SyncRepository>, 142 deck_repo: Arc<dyn DeckRepository>, 143 card_repo: Arc<dyn CardRepository>, 144 note_repo: Arc<dyn NoteRepository>, 145 oauth_repo: Arc<dyn OAuthRepository>, 146} 147 148impl SyncService { 149 pub fn new( 150 sync_repo: Arc<dyn SyncRepository>, deck_repo: Arc<dyn DeckRepository>, card_repo: Arc<dyn CardRepository>, 151 note_repo: Arc<dyn NoteRepository>, oauth_repo: Arc<dyn OAuthRepository>, 152 ) -> Self { 153 Self { sync_repo, deck_repo, card_repo, note_repo, oauth_repo } 154 } 155 156 /// Push a local deck to the user's PDS. 157 pub async fn push_deck(&self, deck_id: &str, user_ctx: &UserContext) -> Result<SyncResult, SyncError> { 158 // Log the operation 159 let log_id = self 160 .sync_repo 161 .log_operation(LogOperationParams { 162 owner_did: &user_ctx.did, 163 entity_type: "deck", 164 entity_id: deck_id, 165 operation: "push", 166 status: "pending", 167 pds_cid: None, 168 error_message: None, 169 }) 170 .await?; 171 172 // Get PDS client 173 let pds_client = self.get_pds_client(user_ctx).await?; 174 175 // Get deck from repository 176 let deck = self 177 .deck_repo 178 .get(deck_id) 179 .await 180 .map_err(|e| SyncError::NotFound(format!("Deck not found: {:?}", e)))?; 181 182 // Get cards for the deck 183 let cards = self 184 .card_repo 185 .list_by_deck(deck_id) 186 .await 187 .map_err(|e| SyncError::RepoError(SyncRepoError::DatabaseError(format!("{:?}", e))))?; 188 189 // Push cards first, collect AT-URIs 190 let mut card_at_uris = Vec::with_capacity(cards.len()); 191 for card in &cards { 192 let prepared = prepare_card_record(card, ""); // deck_ref filled later 193 let at_uri = pds_client 194 .put_record(&user_ctx.did, &prepared.collection, &prepared.rkey, prepared.record) 195 .await?; 196 card_at_uris.push(at_uri.to_string()); 197 198 // Mark card as synced 199 self.sync_repo 200 .mark_synced("card", &card.id, "", &at_uri.to_string()) 201 .await?; 202 } 203 204 // Push deck with card refs 205 let prepared = prepare_deck_record(&deck, card_at_uris); 206 let at_uri = pds_client 207 .put_record(&user_ctx.did, &prepared.collection, &prepared.rkey, prepared.record) 208 .await?; 209 210 // Mark deck as synced 211 self.sync_repo 212 .mark_synced("deck", deck_id, "", &at_uri.to_string()) 213 .await?; 214 215 let metadata = self.sync_repo.get_sync_metadata("deck", deck_id).await?; 216 217 // Complete log entry 218 self.sync_repo 219 .complete_log_entry(&log_id, "success", metadata.pds_cid.as_deref(), None) 220 .await?; 221 222 Ok(SyncResult { 223 entity_type: "deck".to_string(), 224 entity_id: deck_id.to_string(), 225 pds_uri: Some(at_uri.to_string()), 226 pds_cid: metadata.pds_cid, 227 new_version: metadata.version, 228 status: SyncStatus::Synced, 229 }) 230 } 231 232 /// Push a local note to the user's PDS. 233 pub async fn push_note(&self, note_id: &str, user_ctx: &UserContext) -> Result<SyncResult, SyncError> { 234 // Log the operation 235 let log_id = self 236 .sync_repo 237 .log_operation(LogOperationParams { 238 owner_did: &user_ctx.did, 239 entity_type: "note", 240 entity_id: note_id, 241 operation: "push", 242 status: "pending", 243 pds_cid: None, 244 error_message: None, 245 }) 246 .await?; 247 248 // Get PDS client 249 let pds_client = self.get_pds_client(user_ctx).await?; 250 251 // Get note from repository 252 let note = self 253 .note_repo 254 .get(note_id, Some(&user_ctx.did)) 255 .await 256 .map_err(|e| SyncError::NotFound(format!("Note not found: {:?}", e)))?; 257 258 let prepared = prepare_note_record(&note); 259 let at_uri = pds_client 260 .put_record(&user_ctx.did, &prepared.collection, &prepared.rkey, prepared.record) 261 .await?; 262 263 self.sync_repo 264 .mark_synced("note", note_id, "", &at_uri.to_string()) 265 .await?; 266 267 let metadata = self.sync_repo.get_sync_metadata("note", note_id).await?; 268 269 // Complete log entry 270 self.sync_repo 271 .complete_log_entry(&log_id, "success", metadata.pds_cid.as_deref(), None) 272 .await?; 273 274 Ok(SyncResult { 275 entity_type: "note".to_string(), 276 entity_id: note_id.to_string(), 277 pds_uri: Some(at_uri.to_string()), 278 pds_cid: metadata.pds_cid, 279 new_version: metadata.version, 280 status: SyncStatus::Synced, 281 }) 282 } 283 284 /// Pull a record from the user's PDS. 285 pub async fn pull_record( 286 &self, entity_type: &str, at_uri: &str, user_ctx: &UserContext, 287 ) -> Result<RemoteRecord, SyncError> { 288 let parsed = malfestio_core::at_uri::AtUri::parse(at_uri) 289 .map_err(|e| SyncError::InvalidArgument(format!("Invalid AT-URI: {}", e)))?; 290 291 let log_id = self 292 .sync_repo 293 .log_operation(LogOperationParams { 294 owner_did: &user_ctx.did, 295 entity_type, 296 entity_id: at_uri, 297 operation: "pull", 298 status: "pending", 299 pds_cid: None, 300 error_message: None, 301 }) 302 .await?; 303 304 let pds_client = self.get_pds_client(user_ctx).await?; 305 let result: GetRecordResult = pds_client 306 .get_record(&parsed.authority, &parsed.collection, &parsed.rkey) 307 .await 308 .map_err(|e| { 309 tracing::error!("Failed to pull record from PDS: {:?}", e); 310 SyncError::PdsError(e) 311 })?; 312 313 self.sync_repo 314 .complete_log_entry(&log_id, "success", Some(&result.cid), None) 315 .await?; 316 317 // TODO: Offline queue - Store pulled record in IndexedDB for offline access 318 319 Ok(RemoteRecord { uri: result.uri, cid: result.cid, value: result.value }) 320 } 321 322 /// Check if there's a conflict between local and remote versions. 323 pub async fn check_conflict( 324 &self, entity_type: &str, entity_id: &str, remote_cid: &str, 325 ) -> Result<bool, SyncError> { 326 let metadata = self.sync_repo.get_sync_metadata(entity_type, entity_id).await?; 327 328 let has_conflict = 329 metadata.status == SyncStatus::PendingPush && metadata.pds_cid.as_deref() != Some(remote_cid); 330 331 if has_conflict { 332 self.sync_repo.mark_conflict(entity_type, entity_id).await?; 333 } 334 335 Ok(has_conflict) 336 } 337 338 /// Get sync status for a user. 339 pub async fn get_sync_status(&self, user_ctx: &UserContext) -> Result<SyncStatusSummary, SyncError> { 340 let pending = self.sync_repo.get_pending_items(&user_ctx.did).await?; 341 let conflicts = self.sync_repo.get_conflicts(&user_ctx.did).await?; 342 343 Ok(SyncStatusSummary { 344 pending_count: pending.len(), 345 conflict_count: conflicts.len(), 346 pending_items: pending.into_iter().map(|p| (p.entity_type, p.entity_id)).collect(), 347 conflicts: conflicts 348 .into_iter() 349 .map(|c| ConflictInfo { 350 entity_type: c.entity_type, 351 entity_id: c.entity_id, 352 local_version: c.version, 353 remote_version: None, 354 local_updated_at: None, 355 remote_updated_at: None, 356 }) 357 .collect(), 358 }) 359 } 360 361 /// Resolve a conflict using the specified strategy. 362 pub async fn resolve_conflict( 363 &self, entity_type: &str, id: &str, strategy: ConflictStrategy, user_ctx: &UserContext, 364 ) -> Result<SyncResult, SyncError> { 365 let metadata = self.sync_repo.get_sync_metadata(entity_type, id).await?; 366 367 if metadata.status != SyncStatus::Conflict { 368 return Err(SyncError::InvalidArgument(format!( 369 "Entity is not in conflict state: {}:{}", 370 entity_type, id 371 ))); 372 } 373 374 match strategy { 375 ConflictStrategy::LastWriteWins | ConflictStrategy::KeepLocal => match entity_type { 376 "deck" => self.push_deck(id, user_ctx).await, 377 "note" => self.push_note(id, user_ctx).await, 378 _ => Err(SyncError::InvalidArgument(format!( 379 "Unknown entity type: {}", 380 entity_type 381 ))), 382 }, 383 ConflictStrategy::KeepRemote => { 384 if let Some(pds_uri) = &metadata.pds_uri { 385 let remote = self.pull_record(entity_type, pds_uri, user_ctx).await?; 386 387 self.sync_repo 388 .mark_synced(entity_type, id, &remote.cid, &remote.uri) 389 .await?; 390 391 let new_metadata = self.sync_repo.get_sync_metadata(entity_type, id).await?; 392 393 Ok(SyncResult { 394 entity_type: entity_type.to_string(), 395 entity_id: id.to_string(), 396 pds_uri: Some(remote.uri), 397 pds_cid: Some(remote.cid), 398 new_version: new_metadata.version, 399 status: SyncStatus::Synced, 400 }) 401 } else { 402 Err(SyncError::InvalidArgument("No PDS URI for remote record".to_string())) 403 } 404 } 405 } 406 } 407 408 async fn get_pds_client(&self, user_ctx: &UserContext) -> Result<PdsClient, SyncError> { 409 if user_ctx.has_dpop 410 && let Ok(stored_token) = self.oauth_repo.get_tokens(&user_ctx.did).await 411 && let Some(dpop_keypair) = stored_token.dpop_keypair() 412 { 413 Ok(PdsClient::new_with_dpop( 414 stored_token.pds_url.clone(), 415 stored_token.access_token.clone(), 416 dpop_keypair, 417 )) 418 } else { 419 Ok(PdsClient::new_bearer( 420 user_ctx.pds_url.clone(), 421 user_ctx.access_token.clone(), 422 )) 423 } 424 } 425} 426 427#[cfg(test)] 428mod tests { 429 use super::*; 430 431 #[test] 432 fn test_conflict_strategy_from_str() { 433 assert_eq!( 434 ConflictStrategy::from_str("last_write_wins"), 435 Ok(ConflictStrategy::LastWriteWins) 436 ); 437 assert_eq!( 438 ConflictStrategy::from_str("keep_local"), 439 Ok(ConflictStrategy::KeepLocal) 440 ); 441 assert_eq!( 442 ConflictStrategy::from_str("keep_remote"), 443 Ok(ConflictStrategy::KeepRemote) 444 ); 445 assert!(ConflictStrategy::from_str("unknown").is_err()); 446 } 447 448 #[test] 449 fn test_conflict_strategy_as_str() { 450 assert_eq!(ConflictStrategy::LastWriteWins.as_str(), "last_write_wins"); 451 assert_eq!(ConflictStrategy::KeepLocal.as_str(), "keep_local"); 452 assert_eq!(ConflictStrategy::KeepRemote.as_str(), "keep_remote"); 453 } 454 455 #[test] 456 fn test_sync_error_display() { 457 let err = SyncError::NotFound("deck:123".to_string()); 458 assert!(err.to_string().contains("Not found")); 459 460 let err = SyncError::AuthRequired("missing token".to_string()); 461 assert!(err.to_string().contains("Authentication required")); 462 463 let err = SyncError::InvalidArgument("bad type".to_string()); 464 assert!(err.to_string().contains("Invalid argument")); 465 } 466 467 #[test] 468 fn test_sync_result_creation() { 469 let result = SyncResult { 470 entity_type: "deck".to_string(), 471 entity_id: "123".to_string(), 472 pds_uri: Some("at://did:plc:test/deck/tid".to_string()), 473 pds_cid: Some("bafycid".to_string()), 474 new_version: 2, 475 status: SyncStatus::Synced, 476 }; 477 478 assert_eq!(result.entity_type, "deck"); 479 assert_eq!(result.new_version, 2); 480 assert_eq!(result.status, SyncStatus::Synced); 481 } 482 483 #[test] 484 fn test_sync_status_summary() { 485 let summary = SyncStatusSummary { 486 pending_count: 3, 487 conflict_count: 1, 488 pending_items: vec![ 489 ("deck".to_string(), "1".to_string()), 490 ("note".to_string(), "2".to_string()), 491 ], 492 conflicts: vec![ConflictInfo { 493 entity_type: "deck".to_string(), 494 entity_id: "3".to_string(), 495 local_version: 5, 496 remote_version: Some(6), 497 local_updated_at: None, 498 remote_updated_at: None, 499 }], 500 }; 501 502 assert_eq!(summary.pending_count, 3); 503 assert_eq!(summary.conflict_count, 1); 504 assert_eq!(summary.pending_items.len(), 2); 505 } 506 507 #[test] 508 fn test_remote_record_creation() { 509 let record = RemoteRecord { 510 uri: "at://did:plc:test/deck/tid".to_string(), 511 cid: "bafycid123".to_string(), 512 value: serde_json::json!({"title": "Test"}), 513 }; 514 515 assert_eq!(record.uri, "at://did:plc:test/deck/tid"); 516 assert!(record.value.get("title").is_some()); 517 } 518}