BlueSky & more on desktop lazurite.stormlightlabs.org/
tauri rust typescript bluesky appview atproto solid
at main 2337 lines 79 kB view raw
1use super::auth::LazuriteOAuthSession; 2use super::error::{AppError, Result}; 3use super::state::AppState; 4use fastembed::{EmbeddingModel, TextEmbedding, TextInitOptions}; 5use hf_hub::api::{sync::ApiBuilder, Progress}; 6use hf_hub::Cache; 7use jacquard::api::app_bsky::actor::search_actors::SearchActors; 8use jacquard::api::app_bsky::bookmark::get_bookmarks::GetBookmarks; 9use jacquard::api::app_bsky::feed::get_actor_likes::GetActorLikes; 10use jacquard::api::app_bsky::feed::search_posts::SearchPosts; 11use jacquard::api::app_bsky::graph::search_starter_packs::SearchStarterPacks; 12use jacquard::types::datetime::Datetime; 13use jacquard::types::did::Did; 14use jacquard::types::ident::AtIdentifier; 15use jacquard::xrpc::XrpcClient; 16use rusqlite::{params, Connection, OptionalExtension}; 17use serde::{Deserialize, Serialize}; 18use std::collections::HashMap; 19use std::fs; 20use std::path::{Path, PathBuf}; 21use std::str::FromStr; 22use std::sync::{Arc, LazyLock, Mutex}; 23use std::time::{Duration, Instant}; 24use tauri::{AppHandle, Manager}; 25use tauri_plugin_log::log; 26 27const DEFAULT_RRF_K: f64 = 60.0; 28const EMBEDDING_MODEL_NAME: &str = "nomic-embed-text-v1.5"; 29const EMBEDDING_MODEL_REPO: &str = "nomic-ai/nomic-embed-text-v1.5"; 30const EMBEDDING_MODEL_FILE: &str = "onnx/model.onnx"; 31const EMBEDDING_TOKENIZER_FILES: &[&str] = &[ 32 "config.json", 33 "special_tokens_map.json", 34 "tokenizer.json", 35 "tokenizer_config.json", 36]; 37const EMBEDDING_DIMENSIONS: i64 = 768; 38const SEARCH_SYNC_CHECK_INTERVAL: Duration = Duration::from_secs(5); 39const SEARCH_SYNC_INTERVAL: Duration = Duration::from_secs(15 * 60); 40const EMBEDDINGS_ENABLED_KEY: &str = "embeddings_enabled"; 41const EMBEDDINGS_PREFLIGHT_SEEN_KEY: &str = "embeddings_preflight_seen"; 42static EMBEDDINGS_DOWNLOAD_STATE: LazyLock<Mutex<EmbeddingsDownloadState>> = 43 LazyLock::new(|| Mutex::new(EmbeddingsDownloadState::default())); 44 45#[derive(Debug, Serialize)] 46#[serde(rename_all = "camelCase")] 47pub struct SyncStatus { 48 pub did: String, 49 pub source: String, 50 pub post_count: i64, 51 pub cursor: Option<String>, 52 pub last_synced_at: Option<String>, 53} 54 55#[derive(Clone, Debug, Serialize)] 56#[serde(rename_all = "camelCase")] 57pub struct PostResult { 58 pub uri: String, 59 pub cid: String, 60 pub author_did: String, 61 pub author_handle: Option<String>, 62 pub text: Option<String>, 63 pub created_at: Option<String>, 64 pub source: String, 65 pub score: f64, 66 pub keyword_match: bool, 67 pub semantic_match: bool, 68} 69 70#[derive(Debug, Serialize)] 71#[serde(rename_all = "camelCase")] 72pub struct SavedPostsPage { 73 pub posts: Vec<PostResult>, 74 pub total: i64, 75 pub next_offset: Option<u32>, 76} 77 78#[derive(Clone, Copy, Debug, PartialEq, Eq)] 79enum SearchMode { 80 Keyword, 81 Semantic, 82 Hybrid, 83} 84 85#[derive(Clone, Debug)] 86struct SearchRow { 87 storage_key: String, 88 post: PostResult, 89} 90 91#[derive(Clone, Debug, Default)] 92struct EmbeddingsDownloadState { 93 active: bool, 94 current_file: Option<String>, 95 downloaded_files: usize, 96 total_files: usize, 97 current_bytes: usize, 98 current_total_bytes: usize, 99 started_at: Option<Instant>, 100 last_error: Option<String>, 101} 102 103struct ModelDownloadProgress { 104 file_index: usize, 105 total_files: usize, 106} 107 108#[derive(Clone, Debug, Deserialize)] 109#[serde(rename_all = "camelCase")] 110pub struct NetworkSearchQueryParams { 111 author: Option<String>, 112 cursor: Option<String>, 113 limit: Option<u32>, 114 mentions: Option<String>, 115 query: String, 116 since: Option<String>, 117 sort: Option<String>, 118 tags: Option<Vec<String>>, 119 until: Option<String>, 120} 121 122impl ModelDownloadProgress { 123 fn new(file_index: usize, total_files: usize) -> Self { 124 Self { file_index, total_files } 125 } 126} 127 128impl Progress for ModelDownloadProgress { 129 fn init(&mut self, size: usize, filename: &str) { 130 if let Ok(mut state) = EMBEDDINGS_DOWNLOAD_STATE.lock() { 131 state.active = true; 132 state.current_file = Some(filename.to_owned()); 133 state.downloaded_files = self.file_index; 134 state.total_files = self.total_files; 135 state.current_bytes = 0; 136 state.current_total_bytes = size; 137 state.started_at = Some(Instant::now()); 138 state.last_error = None; 139 } 140 } 141 142 fn update(&mut self, size: usize) { 143 if let Ok(mut state) = EMBEDDINGS_DOWNLOAD_STATE.lock() { 144 state.current_bytes = state.current_bytes.saturating_add(size); 145 } 146 } 147 148 fn finish(&mut self) { 149 if let Ok(mut state) = EMBEDDINGS_DOWNLOAD_STATE.lock() { 150 state.downloaded_files = self.file_index + 1; 151 state.current_bytes = state.current_total_bytes; 152 } 153 } 154} 155 156fn validate_query(query: &str) -> Result<()> { 157 if query.trim().is_empty() { 158 return Err(AppError::validation("search query must not be empty")); 159 } 160 Ok(()) 161} 162 163fn validate_limit(limit: u32) -> Result<usize> { 164 match limit { 165 0 => Err(AppError::validation("search limit must be greater than zero")), 166 _ => Ok(limit as usize), 167 } 168} 169 170fn normalize_identifier_filter(value: Option<&str>, label: &str) -> Result<Option<AtIdentifier<'static>>> { 171 let Some(value) = value.map(str::trim).filter(|value| !value.is_empty()) else { 172 return Ok(None); 173 }; 174 175 let normalized = value.trim_start_matches('@'); 176 AtIdentifier::new_owned(normalized).map(Some).map_err(|error| { 177 log::error!("invalid {label} filter: {error}"); 178 AppError::validation(format!("{label} must be a valid handle or DID.")) 179 }) 180} 181 182fn normalize_optional_filter(value: Option<&str>) -> Option<String> { 183 value 184 .map(str::trim) 185 .filter(|value| !value.is_empty()) 186 .map(str::to_owned) 187} 188 189fn normalize_datetime_filter(value: Option<&str>, label: &str) -> Result<Option<String>> { 190 let Some(value) = value.map(str::trim).filter(|value| !value.is_empty()) else { 191 return Ok(None); 192 }; 193 194 Datetime::from_str(value).map_err(|error| { 195 log::error!("invalid {label}: {error}"); 196 AppError::validation(format!("{label} must be a valid ISO 8601 datetime.")) 197 })?; 198 199 Ok(Some(value.to_owned())) 200} 201 202fn normalize_search_sort(value: Option<&str>) -> Result<Option<String>> { 203 let Some(value) = value.map(str::trim).filter(|value| !value.is_empty()) else { 204 return Ok(None); 205 }; 206 207 match value { 208 "top" | "latest" => Ok(Some(value.to_owned())), 209 _ => Err(AppError::validation("Search sort must be 'top' or 'latest'.")), 210 } 211} 212 213fn normalize_tag_filter(value: &str) -> Result<String> { 214 let normalized = value.trim().trim_start_matches('#').trim(); 215 if normalized.is_empty() { 216 return Err(AppError::validation("Tag filters must not be empty.")); 217 } 218 219 Ok(normalized.to_owned()) 220} 221 222fn normalize_tag_filters(tags: Option<Vec<String>>) -> Result<Option<Vec<String>>> { 223 let Some(tags) = tags else { 224 return Ok(None); 225 }; 226 227 let mut normalized = Vec::new(); 228 for tag in tags { 229 let tag = normalize_tag_filter(&tag)?; 230 if !normalized.iter().any(|existing| existing == &tag) { 231 normalized.push(tag); 232 } 233 } 234 235 if normalized.is_empty() { 236 return Ok(None); 237 } 238 239 Ok(Some(normalized)) 240} 241 242fn build_search_posts_request(params: &NetworkSearchQueryParams) -> Result<SearchPosts<'static>> { 243 validate_query(&params.query)?; 244 245 if let Some(limit) = params.limit { 246 let _ = validate_limit(limit)?; 247 } 248 249 let query = params.query.trim().to_owned(); 250 let sort = normalize_search_sort(params.sort.as_deref())?; 251 let since = normalize_datetime_filter(params.since.as_deref(), "Since filter")?; 252 let until = normalize_datetime_filter(params.until.as_deref(), "Until filter")?; 253 let author = normalize_identifier_filter(params.author.as_deref(), "Author filter")?; 254 let mentions = normalize_identifier_filter(params.mentions.as_deref(), "Mentions filter")?; 255 let tags = 256 normalize_tag_filters(params.tags.clone())?.map(|items| items.into_iter().map(Into::into).collect::<Vec<_>>()); 257 let cursor = normalize_optional_filter(params.cursor.as_deref()).map(Into::into); 258 259 if let (Some(since), Some(until)) = (since.as_deref(), until.as_deref()) { 260 let since = Datetime::from_str(since).map_err(|error| { 261 log::error!("invalid Since filter during range validation: {error}"); 262 AppError::validation("Since filter must be a valid ISO 8601 datetime.") 263 })?; 264 let until = Datetime::from_str(until).map_err(|error| { 265 log::error!("invalid Until filter during range validation: {error}"); 266 AppError::validation("Until filter must be a valid ISO 8601 datetime.") 267 })?; 268 269 if since >= until { 270 return Err(AppError::validation("Since filter must be earlier than until.")); 271 } 272 } 273 274 let since = since.map(Into::into); 275 let sort = sort.map(Into::into); 276 let until = until.map(Into::into); 277 278 Ok(SearchPosts::new() 279 .author(author) 280 .cursor(cursor) 281 .limit(params.limit.map(i64::from)) 282 .mentions(mentions) 283 .q(query) 284 .since(since) 285 .sort(sort) 286 .tag(tags) 287 .until(until) 288 .build()) 289} 290 291fn validate_search_mode(mode: &str) -> Result<SearchMode> { 292 match mode { 293 "keyword" => Ok(SearchMode::Keyword), 294 "semantic" => Ok(SearchMode::Semantic), 295 "hybrid" => Ok(SearchMode::Hybrid), 296 _ => Err(AppError::validation( 297 "search mode must be 'keyword', 'semantic', or 'hybrid'", 298 )), 299 } 300} 301 302fn validate_source(source: &str) -> Result<()> { 303 match source { 304 "like" | "bookmark" => Ok(()), 305 _ => Err(AppError::validation("source must be 'like' or 'bookmark'")), 306 } 307} 308 309fn storage_key(owner_did: &str, source: &str, uri: &str) -> String { 310 format!("{owner_did}|{source}|{uri}") 311} 312 313fn active_session_did(state: &AppState) -> Result<Option<String>> { 314 Ok(state 315 .active_session 316 .read() 317 .map_err(|error| { 318 log::error!("active_session poisoned: {error}"); 319 AppError::StatePoisoned("active_session") 320 })? 321 .as_ref() 322 .map(|session| session.did.clone())) 323} 324 325async fn get_session(state: &AppState) -> Result<Arc<LazuriteOAuthSession>> { 326 let did = state 327 .active_session 328 .read() 329 .map_err(|error| { 330 log::error!("active_session poisoned: {error}"); 331 AppError::StatePoisoned("active_session") 332 })? 333 .as_ref() 334 .ok_or_else(|| { 335 log::error!("no active account"); 336 AppError::Validation("no active account".into()) 337 })? 338 .did 339 .clone(); 340 341 state 342 .sessions 343 .read() 344 .map_err(|error| { 345 log::error!("sessions poisoned: {error}"); 346 AppError::StatePoisoned("sessions") 347 })? 348 .get(&did) 349 .cloned() 350 .ok_or_else(|| { 351 log::error!("session not found for active account"); 352 AppError::Validation("session not found for active account".into()) 353 }) 354} 355 356fn db_load_sync_cursor(conn: &Connection, did: &str, source: &str) -> Result<Option<String>> { 357 conn.query_row( 358 "SELECT cursor FROM sync_state WHERE did = ?1 AND source = ?2", 359 params![did, source], 360 |row| row.get::<_, Option<String>>(0), 361 ) 362 .optional() 363 .map(|opt| opt.flatten()) 364 .map_err(AppError::from) 365} 366 367fn db_save_sync_state(conn: &Connection, did: &str, source: &str, cursor: Option<&str>) -> Result<()> { 368 conn.execute( 369 "INSERT INTO sync_state(did, source, cursor, last_synced_at) 370 VALUES(?1, ?2, ?3, CURRENT_TIMESTAMP) 371 ON CONFLICT(did, source) DO UPDATE SET 372 cursor = excluded.cursor, 373 last_synced_at = excluded.last_synced_at", 374 params![did, source, cursor], 375 )?; 376 Ok(()) 377} 378 379fn db_post_exists(conn: &Connection, storage_key: &str) -> Result<bool> { 380 conn.query_row( 381 "SELECT 1 FROM posts WHERE storage_key = ?1", 382 params![storage_key], 383 |_| Ok(()), 384 ) 385 .optional() 386 .map(|row| row.is_some()) 387 .map_err(AppError::from) 388} 389 390/// Upsert a single `FeedViewPost` JSON item into the `posts` table. 391/// On conflict (same uri) updates mutable fields but keeps indexed_at. 392fn db_upsert_post_value(conn: &Connection, owner_did: &str, post: &serde_json::Value, source: &str) -> Result<bool> { 393 let uri = post 394 .get("uri") 395 .and_then(|v| v.as_str()) 396 .ok_or_else(|| AppError::validation("feed item missing post.uri"))?; 397 let cid = post 398 .get("cid") 399 .and_then(|v| v.as_str()) 400 .ok_or_else(|| AppError::validation("feed item missing post.cid"))?; 401 let author = post 402 .get("author") 403 .ok_or_else(|| AppError::validation("feed item missing post.author"))?; 404 let author_did = author 405 .get("did") 406 .and_then(|v| v.as_str()) 407 .ok_or_else(|| AppError::validation("feed item missing post.author.did"))?; 408 let author_handle = author.get("handle").and_then(|v| v.as_str()); 409 410 let record = post.get("record"); 411 let text = record.and_then(|r| r.get("text")).and_then(|v| v.as_str()); 412 let created_at = record.and_then(|r| r.get("createdAt")).and_then(|v| v.as_str()); 413 let json_record = record.map(|r| r.to_string()); 414 let storage_key = storage_key(owner_did, source, uri); 415 let inserted = !db_post_exists(conn, &storage_key)?; 416 417 conn.execute( 418 "INSERT INTO posts(storage_key, owner_did, uri, cid, author_did, author_handle, text, created_at, json_record, source) 419 VALUES(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10) 420 ON CONFLICT(storage_key) DO UPDATE SET 421 cid = excluded.cid, 422 author_handle = excluded.author_handle, 423 text = excluded.text, 424 created_at = excluded.created_at, 425 json_record = excluded.json_record", 426 params![ 427 storage_key, 428 owner_did, 429 uri, 430 cid, 431 author_did, 432 author_handle, 433 text, 434 created_at, 435 json_record, 436 source 437 ], 438 )?; 439 Ok(inserted) 440} 441 442fn db_upsert_post(conn: &Connection, owner_did: &str, feed_item: &serde_json::Value, source: &str) -> Result<bool> { 443 let post = feed_item.get("post").unwrap_or(feed_item); 444 let kind = post.get("$type").and_then(|value| value.as_str()); 445 match kind { 446 Some("app.bsky.feed.defs#blockedPost" | "app.bsky.feed.defs#notFoundPost") => Ok(true), 447 _ => db_upsert_post_value(conn, owner_did, post, source), 448 } 449} 450 451fn db_upsert_bookmark(conn: &Connection, owner_did: &str, bookmark: &serde_json::Value) -> Result<bool> { 452 let item = bookmark 453 .get("item") 454 .ok_or_else(|| AppError::validation("bookmark item missing item payload"))?; 455 let kind = item.get("$type").and_then(|value| value.as_str()); 456 match kind { 457 Some("app.bsky.feed.defs#blockedPost" | "app.bsky.feed.defs#notFoundPost") => Ok(true), 458 _ => db_upsert_post_value(conn, owner_did, item, "bookmark"), 459 } 460} 461 462fn db_post_count(conn: &Connection, owner_did: &str, source: &str) -> Result<i64> { 463 conn.query_row( 464 "SELECT COUNT(*) FROM posts WHERE owner_did = ?1 AND source = ?2", 465 params![owner_did, source], 466 |row| row.get(0), 467 ) 468 .map_err(AppError::from) 469} 470 471fn map_saved_post_row(row: &rusqlite::Row<'_>) -> rusqlite::Result<PostResult> { 472 Ok(PostResult { 473 uri: row.get(0)?, 474 cid: row.get(1)?, 475 author_did: row.get(2)?, 476 author_handle: row.get(3)?, 477 text: row.get(4)?, 478 created_at: row.get(5)?, 479 source: row.get(6)?, 480 score: 0.0, 481 keyword_match: false, 482 semantic_match: false, 483 }) 484} 485 486fn db_list_saved_posts( 487 conn: &Connection, owner_did: &str, source: &str, limit: usize, offset: usize, query: Option<&str>, 488) -> Result<SavedPostsPage> { 489 let trimmed_query = query.map(str::trim).filter(|query| !query.is_empty()); 490 let total = match trimmed_query { 491 Some(query) => { 492 let match_query = build_fts_match_query(query); 493 conn.query_row( 494 "SELECT COUNT(*) 495 FROM posts_fts 496 JOIN posts p ON p.rowid = posts_fts.rowid 497 WHERE p.owner_did = ?1 498 AND p.source = ?2 499 AND posts_fts MATCH ?3", 500 params![owner_did, source, match_query], 501 |row| row.get(0), 502 )? 503 } 504 None => db_post_count(conn, owner_did, source)?, 505 }; 506 507 let posts = match trimmed_query { 508 Some(query) => { 509 let match_query = build_fts_match_query(query); 510 let mut stmt = conn.prepare( 511 "SELECT p.uri, p.cid, p.author_did, p.author_handle, p.text, p.created_at, p.source 512 FROM posts_fts 513 JOIN posts p ON p.rowid = posts_fts.rowid 514 WHERE p.owner_did = ?1 515 AND p.source = ?2 516 AND posts_fts MATCH ?3 517 ORDER BY p.created_at DESC, p.uri DESC 518 LIMIT ?4 OFFSET ?5", 519 )?; 520 521 let q = stmt.query_map( 522 params![owner_did, source, match_query, limit as i64, offset as i64], 523 map_saved_post_row, 524 )?; 525 526 q.collect::<rusqlite::Result<Vec<_>>>().map_err(AppError::from)? 527 } 528 None => { 529 let mut stmt = conn.prepare( 530 "SELECT uri, cid, author_did, author_handle, text, created_at, source 531 FROM posts 532 WHERE owner_did = ?1 AND source = ?2 533 ORDER BY created_at DESC, uri DESC 534 LIMIT ?3 OFFSET ?4", 535 )?; 536 537 let q = stmt.query_map( 538 params![owner_did, source, limit as i64, offset as i64], 539 map_saved_post_row, 540 )?; 541 542 q.collect::<rusqlite::Result<Vec<_>>>().map_err(AppError::from)? 543 } 544 }; 545 546 let consumed = offset.saturating_add(posts.len()); 547 let next_offset = (consumed < total as usize).then_some(consumed as u32); 548 549 Ok(SavedPostsPage { posts, total, next_offset }) 550} 551 552fn db_sync_status(conn: &Connection, did: &str, source: &str) -> Result<SyncStatus> { 553 let post_count = db_post_count(conn, did, source)?; 554 let (cursor, last_synced_at) = conn 555 .query_row( 556 "SELECT cursor, last_synced_at FROM sync_state WHERE did = ?1 AND source = ?2", 557 params![did, source], 558 |row| Ok((row.get::<_, Option<String>>(0)?, row.get::<_, Option<String>>(1)?)), 559 ) 560 .optional()? 561 .unwrap_or((None, None)); 562 563 Ok(SyncStatus { did: did.to_owned(), source: source.to_owned(), post_count, cursor, last_synced_at }) 564} 565 566pub async fn search_posts_network(params: NetworkSearchQueryParams, state: &AppState) -> Result<serde_json::Value> { 567 let session = get_session(state).await?; 568 let request = build_search_posts_request(&params)?; 569 570 let output = session 571 .send(request) 572 .await 573 .map_err(|error| { 574 log::error!("searchPosts error: {error}"); 575 AppError::validation("searchPosts error") 576 })? 577 .into_output() 578 .map_err(|error| { 579 log::error!("searchPosts output error: {error}"); 580 AppError::validation("searchPosts output error") 581 })?; 582 583 serde_json::to_value(&output).map_err(AppError::from) 584} 585 586pub async fn search_actors( 587 query: String, limit: Option<u32>, cursor: Option<String>, state: &AppState, 588) -> Result<serde_json::Value> { 589 validate_query(&query)?; 590 let session = get_session(state).await?; 591 592 let output = session 593 .send( 594 SearchActors::new() 595 .q(Some(query.as_str().into())) 596 .limit(limit.map(|l| l as i64)) 597 .cursor(cursor.as_deref().map(|c| c.into())) 598 .build(), 599 ) 600 .await 601 .map_err(|error| { 602 log::error!("searchActors error: {error}"); 603 AppError::validation("searchActors error") 604 })? 605 .into_output() 606 .map_err(|error| { 607 log::error!("searchActors output error: {error}"); 608 AppError::validation("searchActors output error") 609 })?; 610 611 serde_json::to_value(&output).map_err(AppError::from) 612} 613 614pub async fn search_starter_packs( 615 query: String, limit: Option<u32>, cursor: Option<String>, state: &AppState, 616) -> Result<serde_json::Value> { 617 validate_query(&query)?; 618 let session = get_session(state).await?; 619 620 let output = session 621 .send( 622 SearchStarterPacks::new() 623 .limit(limit.map(|l| l as i64)) 624 .cursor(cursor.as_deref().map(|c| c.into())) 625 .q(query.as_str()) 626 .build(), 627 ) 628 .await 629 .map_err(|error| { 630 log::error!("searchStarterPacks error: {error}"); 631 AppError::validation("searchStarterPacks error") 632 })? 633 .into_output() 634 .map_err(|error| { 635 log::error!("searchStarterPacks output error: {error}"); 636 AppError::validation("searchStarterPacks output error") 637 })?; 638 639 serde_json::to_value(&output).map_err(AppError::from) 640} 641 642/// Sync the authenticated user's likes (or bookmarks) into the local DB. 643/// 644/// Resumes from the last stored cursor if a previous sync was interrupted. 645/// During a fresh sync pass, we stop once we hit already-indexed items so we do not re-fetch the full history. 646pub async fn sync_posts(did: String, source: String, state: &AppState) -> Result<SyncStatus> { 647 validate_source(&source)?; 648 let session = get_session(state).await?; 649 650 let mut cursor: Option<String> = { 651 let conn = state.auth_store.lock_connection()?; 652 db_load_sync_cursor(&conn, &did, &source)? 653 }; 654 let resuming = cursor.is_some(); 655 656 log::info!("starting {source} sync for {did}, resume cursor: {cursor:?}"); 657 658 loop { 659 let (items, next_cursor) = match source.as_str() { 660 "like" => { 661 let output = session 662 .send( 663 GetActorLikes::new() 664 .limit(Some(100i64)) 665 .cursor(cursor.as_deref().map(|value| value.into())) 666 .actor(AtIdentifier::Did(Did::new(&did)?)) 667 .build(), 668 ) 669 .await 670 .map_err(|error| { 671 log::error!("getActorLikes error: {error}"); 672 AppError::validation("getActorLikes error") 673 })? 674 .into_output() 675 .map_err(|error| { 676 log::error!("getActorLikes output error: {error}"); 677 AppError::validation("getActorLikes output error") 678 })?; 679 let output_json = serde_json::to_value(&output)?; 680 let feed = output_json 681 .get("feed") 682 .and_then(|value| value.as_array()) 683 .cloned() 684 .unwrap_or_default(); 685 let next = output_json 686 .get("cursor") 687 .and_then(|value| value.as_str()) 688 .map(str::to_owned); 689 (feed, next) 690 } 691 "bookmark" => { 692 let output = session 693 .send( 694 GetBookmarks::new() 695 .limit(Some(100i64)) 696 .cursor(cursor.as_deref().map(|value| value.into())) 697 .build(), 698 ) 699 .await 700 .map_err(|error| { 701 log::error!("getBookmarks error: {error}"); 702 AppError::validation("getBookmarks error") 703 })? 704 .into_output() 705 .map_err(|error| { 706 log::error!("getBookmarks output error: {error}"); 707 AppError::validation("getBookmarks output error") 708 })?; 709 let output_json = serde_json::to_value(&output)?; 710 let bookmarks = output_json 711 .get("bookmarks") 712 .and_then(|value| value.as_array()) 713 .cloned() 714 .unwrap_or_default(); 715 let next = output_json 716 .get("cursor") 717 .and_then(|value| value.as_str()) 718 .map(str::to_owned); 719 (bookmarks, next) 720 } 721 _ => unreachable!(), 722 }; 723 724 { 725 let conn = state.auth_store.lock_connection()?; 726 if items.is_empty() { 727 db_save_sync_state(&conn, &did, &source, None)?; 728 log::info!("{source} sync for {did}: empty page, stopping"); 729 break; 730 } 731 732 let mut inserted_count = 0usize; 733 let mut existing_count = 0usize; 734 735 for item in &items { 736 let inserted = match source.as_str() { 737 "like" => db_upsert_post(&conn, &did, item, &source)?, 738 "bookmark" => db_upsert_bookmark(&conn, &did, item)?, 739 _ => unreachable!(), 740 }; 741 742 if inserted { 743 inserted_count += 1; 744 } else { 745 existing_count += 1; 746 } 747 } 748 749 let stop_after_page = !resuming && existing_count > 0; 750 let cursor_to_store = if stop_after_page { None } else { next_cursor.as_deref() }; 751 db_save_sync_state(&conn, &did, &source, cursor_to_store)?; 752 753 log::debug!( 754 "{source} sync for {did}: processed {} item(s), inserted {}, existing {}, next cursor: {next_cursor:?}", 755 items.len(), 756 inserted_count, 757 existing_count 758 ); 759 760 if stop_after_page { 761 log::info!("{source} sync for {did}: reached previously indexed items, stopping"); 762 break; 763 } 764 } 765 766 match next_cursor { 767 None => { 768 log::info!("{source} sync for {did}: reached end of feed"); 769 break; 770 } 771 Some(next) => cursor = Some(next), 772 } 773 } 774 775 let conn = state.auth_store.lock_connection()?; 776 db_sync_status(&conn, &did, &source) 777} 778 779/// Returns sync status for all sources for the given DID. 780pub fn get_sync_status(did: &str, state: &AppState) -> Result<Vec<SyncStatus>> { 781 let conn = state.auth_store.lock_connection()?; 782 ["like", "bookmark"] 783 .into_iter() 784 .map(|source| db_sync_status(&conn, did, source)) 785 .collect() 786} 787 788pub fn list_saved_posts( 789 source: &str, limit: u32, offset: u32, query: Option<&str>, state: &AppState, 790) -> Result<SavedPostsPage> { 791 validate_source(source)?; 792 let limit = validate_limit(limit)?; 793 let owner_did = active_session_did(state)?.ok_or_else(|| AppError::validation("no active account"))?; 794 let conn = state.auth_store.lock_connection()?; 795 db_list_saved_posts(&conn, &owner_did, source, limit, offset as usize, query) 796} 797 798const EMBED_BATCH_SIZE: usize = 32; 799 800fn build_embedding_model(models_dir: PathBuf) -> Result<TextEmbedding> { 801 ensure_model_downloaded(&models_dir)?; 802 TextEmbedding::try_new( 803 TextInitOptions::new(EmbeddingModel::NomicEmbedTextV15) 804 .with_cache_dir(models_dir) 805 .with_show_download_progress(false), 806 ) 807 .map_err(|error| AppError::validation(format!("failed to init embedding model: {error}"))) 808} 809 810fn resolve_models_dir(app: &AppHandle) -> Result<PathBuf> { 811 let mut dir = app 812 .path() 813 .app_data_dir() 814 .map_err(|error| AppError::PathResolve(error.to_string()))?; 815 dir.push("models"); 816 std::fs::create_dir_all(&dir)?; 817 Ok(dir) 818} 819 820fn models_dir_path(app: &AppHandle) -> Result<PathBuf> { 821 let mut dir = app 822 .path() 823 .app_data_dir() 824 .map_err(|error| AppError::PathResolve(error.to_string()))?; 825 dir.push("models"); 826 Ok(dir) 827} 828 829fn required_embedding_files() -> Vec<&'static str> { 830 let mut files = vec![EMBEDDING_MODEL_FILE]; 831 files.extend(EMBEDDING_TOKENIZER_FILES); 832 files 833} 834 835fn cached_embedding_files(models_dir: &Path) -> usize { 836 let cache = Cache::new(models_dir.to_path_buf()); 837 let repo = cache.model(EMBEDDING_MODEL_REPO.to_owned()); 838 required_embedding_files() 839 .into_iter() 840 .filter(|filename| repo.get(filename).is_some()) 841 .count() 842} 843 844fn embeddings_downloaded(models_dir: &Path) -> bool { 845 cached_embedding_files(models_dir) == required_embedding_files().len() 846} 847 848fn directory_size(path: &Path) -> Result<u64> { 849 if !path.exists() { 850 return Ok(0); 851 } 852 853 if path.is_file() { 854 return Ok(path.metadata()?.len()); 855 } 856 857 let mut total = 0_u64; 858 for entry in fs::read_dir(path)? { 859 let entry = entry?; 860 total = total.saturating_add(directory_size(&entry.path())?); 861 } 862 863 Ok(total) 864} 865 866fn set_download_idle_state(downloaded_files: usize, total_files: usize) { 867 if let Ok(mut state) = EMBEDDINGS_DOWNLOAD_STATE.lock() { 868 state.active = false; 869 state.current_file = None; 870 state.downloaded_files = downloaded_files; 871 state.total_files = total_files; 872 state.current_bytes = 0; 873 state.current_total_bytes = 0; 874 state.started_at = None; 875 state.last_error = None; 876 } 877} 878 879fn set_download_error(message: String) { 880 if let Ok(mut state) = EMBEDDINGS_DOWNLOAD_STATE.lock() { 881 state.active = false; 882 state.current_file = None; 883 state.current_bytes = 0; 884 state.current_total_bytes = 0; 885 state.started_at = None; 886 state.last_error = Some(message); 887 } 888} 889 890fn clear_embeddings_model_cache_dir(models_dir: &Path) -> Result<()> { 891 if models_dir.exists() { 892 fs::remove_dir_all(models_dir)?; 893 } 894 set_download_idle_state(0, required_embedding_files().len()); 895 Ok(()) 896} 897 898fn ensure_model_downloaded(models_dir: &Path) -> Result<()> { 899 let required_files = required_embedding_files(); 900 let total_files = required_files.len(); 901 let already_cached = cached_embedding_files(models_dir); 902 if already_cached == total_files { 903 set_download_idle_state(total_files, total_files); 904 return Ok(()); 905 } 906 907 set_download_idle_state(already_cached, total_files); 908 909 let api = ApiBuilder::new() 910 .with_cache_dir(models_dir.to_path_buf()) 911 .with_progress(false) 912 .build() 913 .map_err(|error| AppError::validation(format!("failed to initialize embeddings downloader: {error}")))?; 914 let repo = api.model(EMBEDDING_MODEL_REPO.to_owned()); 915 let cache = Cache::new(models_dir.to_path_buf()); 916 let cache_repo = cache.model(EMBEDDING_MODEL_REPO.to_owned()); 917 918 for (index, filename) in required_files.iter().enumerate() { 919 if cache_repo.get(filename).is_some() { 920 set_download_idle_state(index + 1, total_files); 921 continue; 922 } 923 924 let download = repo.download_with_progress(filename, ModelDownloadProgress::new(index, total_files)); 925 if let Err(error) = download { 926 let message = format!("failed to download embeddings file {filename}: {error}"); 927 set_download_error(message.clone()); 928 return Err(AppError::validation(message)); 929 } 930 } 931 932 set_download_idle_state(total_files, total_files); 933 Ok(()) 934} 935 936fn db_get_embeddings_enabled(conn: &Connection) -> Result<bool> { 937 let val: Option<String> = conn 938 .query_row( 939 "SELECT value FROM app_settings WHERE key = ?1", 940 params![EMBEDDINGS_ENABLED_KEY], 941 |row| row.get(0), 942 ) 943 .optional()?; 944 Ok(val.map(|v| v != "0").unwrap_or(false)) 945} 946 947fn db_set_embeddings_enabled(conn: &Connection, enabled: bool) -> Result<()> { 948 conn.execute( 949 "INSERT INTO app_settings(key, value) VALUES(?1, ?2) 950 ON CONFLICT(key) DO UPDATE SET value = excluded.value", 951 params![EMBEDDINGS_ENABLED_KEY, if enabled { "1" } else { "0" }], 952 )?; 953 Ok(()) 954} 955 956fn db_get_embeddings_preflight_seen(conn: &Connection) -> Result<bool> { 957 let val: Option<String> = conn 958 .query_row( 959 "SELECT value FROM app_settings WHERE key = ?1", 960 params![EMBEDDINGS_PREFLIGHT_SEEN_KEY], 961 |row| row.get(0), 962 ) 963 .optional()?; 964 Ok(val.map(|value| value != "0").unwrap_or(false)) 965} 966 967fn db_set_embeddings_preflight_seen(conn: &Connection, seen: bool) -> Result<()> { 968 conn.execute( 969 "INSERT INTO app_settings(key, value) VALUES(?1, ?2) 970 ON CONFLICT(key) DO UPDATE SET value = excluded.value", 971 params![EMBEDDINGS_PREFLIGHT_SEEN_KEY, if seen { "1" } else { "0" }], 972 )?; 973 Ok(()) 974} 975 976fn db_keyword_search(conn: &Connection, owner_did: &str, query: &str, limit: usize) -> Result<Vec<SearchRow>> { 977 let match_query = build_fts_match_query(query); 978 let mut stmt = conn.prepare( 979 "SELECT p.storage_key, 980 p.uri, 981 p.cid, 982 p.author_did, 983 p.author_handle, 984 p.text, 985 p.created_at, 986 p.source, 987 bm25(posts_fts) AS rank 988 FROM posts_fts 989 JOIN posts p ON p.rowid = posts_fts.rowid 990 WHERE p.owner_did = ?1 991 AND posts_fts MATCH ?2 992 ORDER BY rank ASC, p.created_at DESC, p.uri ASC 993 LIMIT ?3", 994 )?; 995 996 let rows = stmt.query_map( 997 params![owner_did, match_query, limit as i64], 998 search_row_from_keyword_row, 999 )?; 1000 rows.collect::<rusqlite::Result<Vec<_>>>().map_err(AppError::from) 1001} 1002 1003fn db_semantic_search( 1004 conn: &Connection, owner_did: &str, query_embedding: &[f32], limit: usize, 1005) -> Result<Vec<SearchRow>> { 1006 let bytes: Vec<u8> = query_embedding.iter().flat_map(|f| f.to_le_bytes()).collect(); 1007 let mut stmt = conn.prepare( 1008 "SELECT p.storage_key, 1009 p.uri, 1010 p.cid, 1011 p.author_did, 1012 p.author_handle, 1013 p.text, 1014 p.created_at, 1015 p.source, 1016 v.distance 1017 FROM posts_vec v 1018 JOIN posts p ON p.storage_key = v.storage_key 1019 WHERE p.owner_did = ?1 1020 AND v.embedding MATCH ?2 1021 AND v.k = ?3 1022 ORDER BY v.distance ASC, p.created_at DESC, p.uri ASC 1023 ", 1024 )?; 1025 1026 let rows = stmt.query_map( 1027 params![owner_did, bytes.as_slice(), limit as i64], 1028 search_row_from_semantic_row, 1029 )?; 1030 rows.collect::<rusqlite::Result<Vec<_>>>().map_err(AppError::from) 1031} 1032 1033fn search_row_from_keyword_row(row: &rusqlite::Row<'_>) -> rusqlite::Result<SearchRow> { 1034 let raw_rank = row.get::<_, f64>(8)?; 1035 Ok(SearchRow { 1036 storage_key: row.get(0)?, 1037 post: PostResult { 1038 uri: row.get(1)?, 1039 cid: row.get(2)?, 1040 author_did: row.get(3)?, 1041 author_handle: row.get(4)?, 1042 text: row.get(5)?, 1043 created_at: row.get(6)?, 1044 source: row.get(7)?, 1045 score: -raw_rank, 1046 keyword_match: true, 1047 semantic_match: false, 1048 }, 1049 }) 1050} 1051 1052fn search_row_from_semantic_row(row: &rusqlite::Row<'_>) -> rusqlite::Result<SearchRow> { 1053 let distance = row.get::<_, f64>(8)?; 1054 Ok(SearchRow { 1055 storage_key: row.get(0)?, 1056 post: PostResult { 1057 uri: row.get(1)?, 1058 cid: row.get(2)?, 1059 author_did: row.get(3)?, 1060 author_handle: row.get(4)?, 1061 text: row.get(5)?, 1062 created_at: row.get(6)?, 1063 source: row.get(7)?, 1064 score: 1.0 / (1.0 + distance), 1065 keyword_match: false, 1066 semantic_match: true, 1067 }, 1068 }) 1069} 1070 1071fn build_fts_match_query(query: &str) -> String { 1072 let tokens: Vec<String> = query 1073 .split_whitespace() 1074 .filter(|token| !token.is_empty()) 1075 .map(|token| format!("\"{}\"", token.replace('"', "\"\""))) 1076 .collect(); 1077 1078 if tokens.is_empty() { 1079 format!("\"{}\"", query.trim().replace('"', "\"\"")) 1080 } else { 1081 tokens.join(" AND ") 1082 } 1083} 1084 1085fn rrf_merge(keyword_rows: Vec<SearchRow>, semantic_rows: Vec<SearchRow>, limit: usize) -> Vec<PostResult> { 1086 let mut fused: HashMap<String, SearchRow> = HashMap::new(); 1087 let mut scores: HashMap<String, f64> = HashMap::new(); 1088 1089 for rows in [keyword_rows, semantic_rows] { 1090 for (rank, row) in rows.into_iter().enumerate() { 1091 let score = 1.0 / (DEFAULT_RRF_K + rank as f64 + 1.0); 1092 scores 1093 .entry(row.storage_key.clone()) 1094 .and_modify(|value| *value += score) 1095 .or_insert(score); 1096 fused 1097 .entry(row.storage_key.clone()) 1098 .and_modify(|existing| { 1099 existing.post.keyword_match |= row.post.keyword_match; 1100 existing.post.semantic_match |= row.post.semantic_match; 1101 }) 1102 .or_insert(row); 1103 } 1104 } 1105 1106 let mut rows: Vec<SearchRow> = fused 1107 .into_iter() 1108 .filter_map(|(key, mut row)| { 1109 scores.get(&key).map(|score| { 1110 row.post.score = *score; 1111 row 1112 }) 1113 }) 1114 .collect(); 1115 1116 rows.sort_by(|left, right| { 1117 right 1118 .post 1119 .score 1120 .total_cmp(&left.post.score) 1121 .then_with(|| right.post.created_at.cmp(&left.post.created_at)) 1122 .then_with(|| left.post.uri.cmp(&right.post.uri)) 1123 }); 1124 1125 rows.into_iter().take(limit).map(|row| row.post).collect() 1126} 1127 1128fn run_local_search( 1129 conn: &Connection, owner_did: &str, query: &str, mode: SearchMode, limit: usize, embeddings_enabled: bool, 1130 query_embedding: Option<&[f32]>, 1131) -> Result<Vec<PostResult>> { 1132 match mode { 1133 SearchMode::Keyword => { 1134 db_keyword_search(conn, owner_did, query, limit).map(|rows| rows.into_iter().map(|row| row.post).collect()) 1135 } 1136 SearchMode::Semantic => { 1137 if !embeddings_enabled { 1138 return Err(AppError::validation( 1139 "semantic search is unavailable while embeddings are disabled", 1140 )); 1141 } 1142 1143 let query_embedding = 1144 query_embedding.ok_or_else(|| AppError::validation("semantic search query embedding missing"))?; 1145 db_semantic_search(conn, owner_did, query_embedding, limit) 1146 .map(|rows| rows.into_iter().map(|row| row.post).collect()) 1147 } 1148 SearchMode::Hybrid => { 1149 let candidate_limit = limit.saturating_mul(4).min(100); 1150 let keyword_rows = db_keyword_search(conn, owner_did, query, candidate_limit)?; 1151 1152 if !embeddings_enabled { 1153 return Ok(keyword_rows.into_iter().take(limit).map(|row| row.post).collect()); 1154 } 1155 1156 let Some(query_embedding) = query_embedding else { 1157 return Err(AppError::validation("hybrid search query embedding missing")); 1158 }; 1159 1160 let semantic_rows = db_semantic_search(conn, owner_did, query_embedding, candidate_limit)?; 1161 Ok(rrf_merge(keyword_rows, semantic_rows, limit)) 1162 } 1163 } 1164} 1165 1166fn embed_query_text(query: &str, models_dir: PathBuf) -> Result<Vec<f32>> { 1167 let mut model = build_embedding_model(models_dir)?; 1168 let embeddings = model 1169 .embed(vec![query.to_owned()], Some(1)) 1170 .map_err(|error| AppError::validation(format!("embedding error: {error}")))?; 1171 1172 embeddings 1173 .into_iter() 1174 .next() 1175 .ok_or_else(|| AppError::validation("embedding model returned no query embedding")) 1176} 1177 1178/// Returns (storage_key, text) for posts that have no embedding yet. 1179fn db_posts_without_embeddings(conn: &Connection) -> Result<Vec<(String, String)>> { 1180 let mut stmt = conn.prepare( 1181 "SELECT p.storage_key, p.text 1182 FROM posts p 1183 WHERE p.text IS NOT NULL 1184 AND p.text != '' 1185 AND p.storage_key NOT IN (SELECT storage_key FROM posts_vec)", 1186 )?; 1187 1188 let rows = stmt.query_map([], |row| Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?)))?; 1189 rows.collect::<rusqlite::Result<Vec<_>>>().map_err(AppError::from) 1190} 1191 1192/// Returns (storage_key, text) for ALL posts that have non-empty text. 1193fn db_all_posts_with_text(conn: &Connection) -> Result<Vec<(String, String)>> { 1194 let mut stmt = conn.prepare("SELECT storage_key, text FROM posts WHERE text IS NOT NULL AND text != ''")?; 1195 1196 let rows = stmt.query_map([], |row| Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?)))?; 1197 rows.collect::<rusqlite::Result<Vec<_>>>().map_err(AppError::from) 1198} 1199 1200fn db_upsert_embedding(conn: &Connection, storage_key: &str, embedding: &[f32]) -> Result<()> { 1201 let bytes: Vec<u8> = embedding.iter().flat_map(|f| f.to_le_bytes()).collect(); 1202 conn.execute( 1203 "INSERT OR REPLACE INTO posts_vec(storage_key, embedding) VALUES(?1, ?2)", 1204 params![storage_key, bytes.as_slice()], 1205 )?; 1206 Ok(()) 1207} 1208 1209fn embed_posts(posts: &[(String, String)], models_dir: PathBuf, state: &AppState) -> Result<usize> { 1210 if posts.is_empty() { 1211 return Ok(0); 1212 } 1213 1214 let mut model = build_embedding_model(models_dir)?; 1215 1216 let mut total = 0usize; 1217 1218 for chunk in posts.chunks(EMBED_BATCH_SIZE) { 1219 let texts: Vec<String> = chunk.iter().map(|(_, text)| text.clone()).collect(); 1220 let embeddings = model 1221 .embed(texts, Some(EMBED_BATCH_SIZE)) 1222 .map_err(|error| AppError::validation(format!("embedding error: {error}")))?; 1223 1224 let conn = state.auth_store.lock_connection()?; 1225 for ((storage_key, _), embedding) in chunk.iter().zip(embeddings.iter()) { 1226 db_upsert_embedding(&conn, storage_key, embedding)?; 1227 } 1228 total += chunk.len(); 1229 } 1230 1231 Ok(total) 1232} 1233 1234pub fn search_posts(query: &str, mode: &str, limit: u32, app: &AppHandle, state: &AppState) -> Result<Vec<PostResult>> { 1235 validate_query(query)?; 1236 let limit = validate_limit(limit)?; 1237 let mode = validate_search_mode(mode)?; 1238 let owner_did = active_session_did(state)?.ok_or_else(|| AppError::validation("no active account"))?; 1239 1240 let embeddings_enabled = { 1241 let conn = state.auth_store.lock_connection()?; 1242 db_get_embeddings_enabled(&conn)? 1243 }; 1244 1245 let query_embedding = match mode { 1246 SearchMode::Keyword => None, 1247 SearchMode::Semantic | SearchMode::Hybrid if embeddings_enabled => { 1248 let models_dir = resolve_models_dir(app)?; 1249 Some(embed_query_text(query, models_dir)?) 1250 } 1251 SearchMode::Semantic => { 1252 return Err(AppError::validation( 1253 "semantic search is unavailable while embeddings are disabled", 1254 )); 1255 } 1256 SearchMode::Hybrid => None, 1257 }; 1258 1259 let conn = state.auth_store.lock_connection()?; 1260 run_local_search( 1261 &conn, 1262 &owner_did, 1263 query, 1264 mode, 1265 limit, 1266 embeddings_enabled, 1267 query_embedding.as_deref(), 1268 ) 1269} 1270 1271/// Embed all posts that do not yet have an embedding. Skipped when embeddings are disabled. 1272pub fn embed_pending_posts(app: &AppHandle, state: &AppState) -> Result<usize> { 1273 let enabled = { 1274 let conn = state.auth_store.lock_connection()?; 1275 db_get_embeddings_enabled(&conn)? 1276 }; 1277 if !enabled { 1278 log::info!("embeddings disabled, skipping embed_pending_posts"); 1279 return Ok(0); 1280 } 1281 1282 let posts = { 1283 let conn = state.auth_store.lock_connection()?; 1284 db_posts_without_embeddings(&conn)? 1285 }; 1286 1287 log::info!("embedding {} pending posts", posts.len()); 1288 let models_dir = resolve_models_dir(app)?; 1289 embed_posts(&posts, models_dir, state) 1290} 1291 1292/// Clear all embeddings from `posts_vec` then re-embed every post. 1293pub fn reindex_embeddings(app: &AppHandle, state: &AppState) -> Result<usize> { 1294 { 1295 let conn = state.auth_store.lock_connection()?; 1296 conn.execute("DELETE FROM posts_vec", [])?; 1297 } 1298 log::info!("cleared posts_vec for reindex"); 1299 1300 let posts = { 1301 let conn = state.auth_store.lock_connection()?; 1302 db_all_posts_with_text(&conn)? 1303 }; 1304 1305 log::info!("reindexing {} posts", posts.len()); 1306 let models_dir = resolve_models_dir(app)?; 1307 embed_posts(&posts, models_dir, state) 1308} 1309 1310/// Persist the embeddings-enabled preference. 1311pub fn set_embeddings_enabled(enabled: bool, state: &AppState) -> Result<()> { 1312 let conn = state.auth_store.lock_connection()?; 1313 db_set_embeddings_enabled(&conn, enabled) 1314} 1315 1316pub fn set_embeddings_preflight_seen(seen: bool, state: &AppState) -> Result<()> { 1317 let conn = state.auth_store.lock_connection()?; 1318 db_set_embeddings_preflight_seen(&conn, seen) 1319} 1320 1321/// Get the current embeddings-enabled preference. 1322pub fn get_embeddings_enabled(state: &AppState) -> Result<bool> { 1323 let conn = state.auth_store.lock_connection()?; 1324 db_get_embeddings_enabled(&conn) 1325} 1326 1327#[derive(Clone, Debug, Serialize)] 1328#[serde(rename_all = "camelCase")] 1329pub struct EmbeddingsConfig { 1330 pub enabled: bool, 1331 pub preflight_seen: bool, 1332 pub model_name: String, 1333 pub dimensions: i64, 1334 pub model_size_bytes: Option<u64>, 1335 pub downloaded: bool, 1336 pub download_active: bool, 1337 pub download_progress: Option<f64>, 1338 pub download_eta_seconds: Option<u64>, 1339 pub download_file: Option<String>, 1340 pub download_file_index: Option<usize>, 1341 pub download_file_total: Option<usize>, 1342 pub last_error: Option<String>, 1343} 1344 1345/// Get the embeddings configuration. 1346pub fn get_embeddings_config(app: &AppHandle, state: &AppState) -> Result<EmbeddingsConfig> { 1347 let conn = state.auth_store.lock_connection()?; 1348 let enabled = db_get_embeddings_enabled(&conn)?; 1349 let preflight_seen = db_get_embeddings_preflight_seen(&conn)?; 1350 let models_dir = resolve_models_dir(app)?; 1351 let downloaded = embeddings_downloaded(&models_dir); 1352 let model_size_bytes = directory_size(&models_dir).ok().filter(|bytes| *bytes > 0); 1353 let state = EMBEDDINGS_DOWNLOAD_STATE 1354 .lock() 1355 .map_err(|_| AppError::StatePoisoned("embeddings_download_state"))?; 1356 let download_progress = if state.active && state.current_total_bytes > 0 { 1357 Some((state.current_bytes as f64 / state.current_total_bytes as f64) * 100.0) 1358 } else if downloaded { 1359 Some(100.0) 1360 } else { 1361 None 1362 }; 1363 let download_eta_seconds = if state.active { 1364 state.started_at.and_then(|started_at| { 1365 let elapsed = started_at.elapsed().as_secs_f64(); 1366 let current = state.current_bytes as f64; 1367 let total = state.current_total_bytes as f64; 1368 if elapsed <= 0.0 || current <= 0.0 || total <= current { 1369 None 1370 } else { 1371 let bytes_per_second = current / elapsed; 1372 let remaining = total - current; 1373 Some((remaining / bytes_per_second).ceil() as u64) 1374 } 1375 }) 1376 } else { 1377 None 1378 }; 1379 1380 Ok(EmbeddingsConfig { 1381 enabled, 1382 preflight_seen, 1383 model_name: EMBEDDING_MODEL_NAME.to_string(), 1384 dimensions: EMBEDDING_DIMENSIONS, 1385 model_size_bytes, 1386 downloaded, 1387 download_active: state.active, 1388 download_progress, 1389 download_eta_seconds, 1390 download_file: state.current_file.clone(), 1391 download_file_index: state.active.then_some(state.downloaded_files + 1), 1392 download_file_total: (state.total_files > 0).then_some(state.total_files), 1393 last_error: state.last_error.clone(), 1394 }) 1395} 1396 1397pub fn prepare_embeddings_model(app: &AppHandle, state: &AppState) -> Result<EmbeddingsConfig> { 1398 let enabled = { 1399 let conn = state.auth_store.lock_connection()?; 1400 db_get_embeddings_enabled(&conn)? 1401 }; 1402 1403 if enabled { 1404 let models_dir = resolve_models_dir(app)?; 1405 ensure_model_downloaded(&models_dir)?; 1406 } 1407 1408 get_embeddings_config(app, state) 1409} 1410 1411pub fn clear_embeddings_model_cache(app: &AppHandle) -> Result<()> { 1412 let models_dir = models_dir_path(app)?; 1413 clear_embeddings_model_cache_dir(&models_dir) 1414} 1415 1416fn sync_due(active_did: Option<&str>, last_synced_did: Option<&str>, last_synced_at: Option<Instant>) -> bool { 1417 match active_did { 1418 None => false, 1419 Some(did) if Some(did) != last_synced_did => true, 1420 Some(_) => last_synced_at 1421 .map(|instant| instant.elapsed() >= SEARCH_SYNC_INTERVAL) 1422 .unwrap_or(true), 1423 } 1424} 1425 1426/// Keeps the active account's local search index warm by syncing likes on login/account switch 1427/// and then re-syncing every 15 minutes. Embeddings are refreshed for newly synced posts. 1428pub fn spawn_search_sync_task(app: AppHandle) { 1429 tauri::async_runtime::spawn(async move { 1430 let mut last_synced_did: Option<String> = None; 1431 let mut last_synced_at: Option<Instant> = None; 1432 1433 loop { 1434 let state = app.state::<AppState>(); 1435 let active_did = match active_session_did(&state) { 1436 Ok(value) => value, 1437 Err(error) => { 1438 log::warn!("search sync failed to read active session: {error}"); 1439 tokio::time::sleep(SEARCH_SYNC_CHECK_INTERVAL).await; 1440 continue; 1441 } 1442 }; 1443 1444 if active_did.is_none() { 1445 last_synced_did = None; 1446 last_synced_at = None; 1447 tokio::time::sleep(SEARCH_SYNC_CHECK_INTERVAL).await; 1448 continue; 1449 } 1450 1451 if sync_due(active_did.as_deref(), last_synced_did.as_deref(), last_synced_at) { 1452 let did = active_did.clone().unwrap_or_default(); 1453 let like_sync = sync_posts(did.clone(), "like".to_owned(), &state).await; 1454 let bookmark_sync = sync_posts(did.clone(), "bookmark".to_owned(), &state).await; 1455 match (like_sync, bookmark_sync) { 1456 (Ok(like_status), Ok(bookmark_status)) => { 1457 log::info!( 1458 "background search sync complete for {} likes/bookmarks: {}/{} post(s)", 1459 did, 1460 like_status.post_count, 1461 bookmark_status.post_count 1462 ); 1463 if let Err(error) = embed_pending_posts(&app, &state) { 1464 log::warn!("background embedding pass failed for {did}: {error}"); 1465 } 1466 last_synced_did = Some(did); 1467 last_synced_at = Some(Instant::now()); 1468 } 1469 (Err(error), _) | (_, Err(error)) => { 1470 log::warn!("background search sync failed: {error}"); 1471 } 1472 } 1473 } 1474 1475 tokio::time::sleep(SEARCH_SYNC_CHECK_INTERVAL).await; 1476 } 1477 }); 1478} 1479 1480#[cfg(test)] 1481mod tests { 1482 use super::{ 1483 build_fts_match_query, build_search_posts_request, clear_embeddings_model_cache_dir, db_get_embeddings_enabled, 1484 db_get_embeddings_preflight_seen, db_list_saved_posts, db_load_sync_cursor, db_post_count, db_save_sync_state, 1485 db_semantic_search, db_set_embeddings_enabled, db_set_embeddings_preflight_seen, db_sync_status, 1486 db_upsert_embedding, db_upsert_post, normalize_identifier_filter, normalize_tag_filter, run_local_search, 1487 storage_key, sync_due, validate_limit, validate_query, validate_search_mode, validate_source, 1488 NetworkSearchQueryParams, SearchMode, 1489 }; 1490 use rusqlite::{ffi::sqlite3_auto_extension, Connection}; 1491 use sqlite_vec::sqlite3_vec_init; 1492 use std::fs; 1493 use std::time::{SystemTime, UNIX_EPOCH}; 1494 1495 fn test_db() -> Connection { 1496 unsafe { 1497 sqlite3_auto_extension(Some(std::mem::transmute(sqlite3_vec_init as *const ()))); 1498 } 1499 1500 let conn = Connection::open_in_memory().expect("in-memory db should open"); 1501 conn.execute_batch( 1502 "CREATE TABLE posts ( 1503 storage_key TEXT PRIMARY KEY, 1504 owner_did TEXT NOT NULL, 1505 uri TEXT NOT NULL, 1506 cid TEXT NOT NULL, 1507 author_did TEXT NOT NULL, 1508 author_handle TEXT, 1509 text TEXT, 1510 created_at TEXT, 1511 indexed_at TEXT DEFAULT CURRENT_TIMESTAMP, 1512 json_record TEXT, 1513 source TEXT NOT NULL, 1514 UNIQUE(owner_did, source, uri) 1515 ); 1516 CREATE VIRTUAL TABLE posts_fts USING fts5( 1517 text, 1518 content=posts, 1519 content_rowid=rowid 1520 ); 1521 CREATE VIRTUAL TABLE posts_vec USING vec0( 1522 storage_key TEXT PRIMARY KEY, 1523 embedding float[3] 1524 ); 1525 CREATE TRIGGER posts_ai AFTER INSERT ON posts BEGIN 1526 INSERT INTO posts_fts(rowid, text) VALUES (new.rowid, new.text); 1527 END; 1528 CREATE TRIGGER posts_ad AFTER DELETE ON posts BEGIN 1529 INSERT INTO posts_fts(posts_fts, rowid, text) 1530 VALUES('delete', old.rowid, old.text); 1531 END; 1532 CREATE TRIGGER posts_au AFTER UPDATE ON posts BEGIN 1533 INSERT INTO posts_fts(posts_fts, rowid, text) 1534 VALUES('delete', old.rowid, old.text); 1535 INSERT INTO posts_fts(rowid, text) VALUES (new.rowid, new.text); 1536 END; 1537 CREATE TABLE sync_state ( 1538 did TEXT NOT NULL, 1539 source TEXT NOT NULL, 1540 cursor TEXT, 1541 last_synced_at TEXT, 1542 PRIMARY KEY (did, source) 1543 ); 1544 CREATE TABLE app_settings ( 1545 key TEXT PRIMARY KEY, 1546 value TEXT NOT NULL 1547 );", 1548 ) 1549 .expect("test schema should apply"); 1550 conn 1551 } 1552 1553 fn feed_item(uri: &str, cid: &str, did: &str, handle: &str, text: &str, created_at: &str) -> serde_json::Value { 1554 serde_json::json!({ 1555 "post": { 1556 "uri": uri, 1557 "cid": cid, 1558 "author": { "did": did, "handle": handle }, 1559 "record": { "$type": "app.bsky.feed.post", "text": text, "createdAt": created_at } 1560 } 1561 }) 1562 } 1563 1564 fn temp_models_dir() -> std::path::PathBuf { 1565 let unique = SystemTime::now() 1566 .duration_since(UNIX_EPOCH) 1567 .expect("clock should be after epoch") 1568 .as_nanos(); 1569 std::env::temp_dir().join(format!("lazurite-model-cache-{unique}")) 1570 } 1571 1572 #[test] 1573 fn clear_embeddings_model_cache_dir_removes_cached_files() { 1574 let models_dir = temp_models_dir(); 1575 let nested_dir = models_dir.join("nested"); 1576 fs::create_dir_all(&nested_dir).expect("nested models dir should be created"); 1577 fs::write(models_dir.join("model.onnx"), "model").expect("model file should be created"); 1578 fs::write(nested_dir.join("tokenizer.json"), "tokenizer").expect("tokenizer file should be created"); 1579 1580 clear_embeddings_model_cache_dir(&models_dir).expect("model cache should clear"); 1581 1582 assert!(!models_dir.exists(), "models dir should be removed after clearing"); 1583 } 1584 1585 fn insert_post(conn: &Connection, owner_did: &str, uri: &str, source: &str, text: &str, created_at: &str) { 1586 let item = feed_item(uri, "cid", "did:plc:author", "author.test", text, created_at); 1587 db_upsert_post(conn, owner_did, &item, source).expect("post should insert"); 1588 } 1589 1590 fn insert_embedding(conn: &Connection, owner_did: &str, source: &str, uri: &str, embedding: &[f32]) { 1591 let key = storage_key(owner_did, source, uri); 1592 db_upsert_embedding(conn, &key, embedding).expect("embedding should insert"); 1593 } 1594 1595 #[test] 1596 fn empty_query_is_rejected() { 1597 assert!(validate_query("").is_err()); 1598 } 1599 1600 #[test] 1601 fn whitespace_only_query_is_rejected() { 1602 assert!(validate_query(" ").is_err()); 1603 } 1604 1605 #[test] 1606 fn valid_query_is_accepted() { 1607 assert!(validate_query("rust programming").is_ok()); 1608 } 1609 1610 #[test] 1611 fn zero_limit_is_rejected() { 1612 assert!(validate_limit(0).is_err()); 1613 } 1614 1615 #[test] 1616 fn non_zero_limit_is_accepted() { 1617 assert_eq!(validate_limit(5).unwrap(), 5); 1618 } 1619 1620 #[test] 1621 fn single_char_query_is_accepted() { 1622 assert!(validate_query("a").is_ok()); 1623 } 1624 1625 #[test] 1626 fn from_handle_syntax_is_accepted() { 1627 assert!(validate_query("from:alice.bsky.social hello").is_ok()); 1628 } 1629 1630 #[test] 1631 fn valid_sources_are_accepted() { 1632 assert!(validate_source("like").is_ok()); 1633 assert!(validate_source("bookmark").is_ok()); 1634 } 1635 1636 #[test] 1637 fn valid_search_modes_are_accepted() { 1638 assert_eq!(validate_search_mode("keyword").unwrap(), SearchMode::Keyword); 1639 assert_eq!(validate_search_mode("semantic").unwrap(), SearchMode::Semantic); 1640 assert_eq!(validate_search_mode("hybrid").unwrap(), SearchMode::Hybrid); 1641 } 1642 1643 #[test] 1644 fn unknown_search_mode_is_rejected() { 1645 assert!(validate_search_mode("network").is_err()); 1646 } 1647 1648 #[test] 1649 fn unknown_source_is_rejected() { 1650 assert!(validate_source("repost").is_err()); 1651 assert!(validate_source("").is_err()); 1652 } 1653 1654 #[test] 1655 fn normalize_identifier_filter_accepts_handle() { 1656 let identifier = normalize_identifier_filter(Some("alice.test"), "Author filter").unwrap(); 1657 assert_eq!(identifier.unwrap().as_str(), "alice.test"); 1658 } 1659 1660 #[test] 1661 fn normalize_identifier_filter_strips_leading_at_sign() { 1662 let identifier = normalize_identifier_filter(Some("@alice.test"), "Author filter").unwrap(); 1663 assert_eq!(identifier.unwrap().as_str(), "alice.test"); 1664 } 1665 1666 #[test] 1667 fn normalize_identifier_filter_rejects_invalid_values() { 1668 assert!(normalize_identifier_filter(Some("not a valid handle"), "Author filter").is_err()); 1669 } 1670 1671 #[test] 1672 fn normalize_tag_filter_strips_hash_prefix() { 1673 assert_eq!(normalize_tag_filter("#rust").unwrap(), "rust"); 1674 assert_eq!(normalize_tag_filter("##solid").unwrap(), "solid"); 1675 } 1676 1677 #[test] 1678 fn normalize_tag_filter_rejects_blank_values() { 1679 assert!(normalize_tag_filter(" ").is_err()); 1680 assert!(normalize_tag_filter("###").is_err()); 1681 } 1682 1683 #[test] 1684 fn build_search_posts_request_includes_filters() { 1685 let request = build_search_posts_request(&NetworkSearchQueryParams { 1686 author: Some("@alice.test".to_owned()), 1687 cursor: Some("cursor-1".to_owned()), 1688 limit: Some(25), 1689 mentions: Some("did:plc:bob".to_owned()), 1690 query: "search text".to_owned(), 1691 since: Some("2026-04-01T05:00:00.000Z".to_owned()), 1692 sort: Some("latest".to_owned()), 1693 tags: Some(vec!["#rust".to_owned(), "solid".to_owned()]), 1694 until: Some("2026-04-02T05:00:00.000Z".to_owned()), 1695 }) 1696 .unwrap(); 1697 1698 assert_eq!(request.author.unwrap().as_str(), "alice.test"); 1699 assert_eq!(request.mentions.unwrap().as_str(), "did:plc:bob"); 1700 assert_eq!(request.cursor.unwrap().as_ref(), "cursor-1"); 1701 assert_eq!(request.limit, Some(25)); 1702 assert_eq!(request.q.as_ref(), "search text"); 1703 assert_eq!(request.since.unwrap().as_ref(), "2026-04-01T05:00:00.000Z"); 1704 assert_eq!(request.sort.unwrap().as_ref(), "latest"); 1705 assert_eq!( 1706 request 1707 .tag 1708 .unwrap() 1709 .iter() 1710 .map(|value| value.as_ref().to_owned()) 1711 .collect::<Vec<_>>(), 1712 vec!["rust".to_owned(), "solid".to_owned()] 1713 ); 1714 assert_eq!(request.until.unwrap().as_ref(), "2026-04-02T05:00:00.000Z"); 1715 } 1716 1717 #[test] 1718 fn build_search_posts_request_rejects_invalid_sort() { 1719 let result = build_search_posts_request(&NetworkSearchQueryParams { 1720 author: None, 1721 cursor: None, 1722 limit: Some(25), 1723 mentions: None, 1724 query: "search text".to_owned(), 1725 since: None, 1726 sort: Some("oldest".to_owned()), 1727 tags: None, 1728 until: None, 1729 }); 1730 1731 assert!(result.is_err()); 1732 } 1733 1734 #[test] 1735 fn build_search_posts_request_rejects_invalid_datetime() { 1736 let result = build_search_posts_request(&NetworkSearchQueryParams { 1737 author: None, 1738 cursor: None, 1739 limit: Some(25), 1740 mentions: None, 1741 query: "search text".to_owned(), 1742 since: Some("2026-04-01".to_owned()), 1743 sort: Some("latest".to_owned()), 1744 tags: None, 1745 until: None, 1746 }); 1747 1748 assert!(result.is_err()); 1749 } 1750 1751 #[test] 1752 fn build_search_posts_request_rejects_inverted_datetime_range() { 1753 let result = build_search_posts_request(&NetworkSearchQueryParams { 1754 author: None, 1755 cursor: None, 1756 limit: Some(25), 1757 mentions: None, 1758 query: "search text".to_owned(), 1759 since: Some("2026-04-02T05:00:00.000Z".to_owned()), 1760 sort: Some("latest".to_owned()), 1761 tags: None, 1762 until: Some("2026-04-01T05:00:00.000Z".to_owned()), 1763 }); 1764 1765 assert!(result.is_err()); 1766 } 1767 1768 #[test] 1769 fn cursor_is_none_when_no_sync_state_row_exists() { 1770 let conn = test_db(); 1771 let cursor = db_load_sync_cursor(&conn, "did:plc:alice", "like").unwrap(); 1772 assert!(cursor.is_none()); 1773 } 1774 1775 #[test] 1776 fn save_and_load_cursor_roundtrips() { 1777 let conn = test_db(); 1778 db_save_sync_state(&conn, "did:plc:alice", "like", Some("cursor-abc")).unwrap(); 1779 let loaded = db_load_sync_cursor(&conn, "did:plc:alice", "like").unwrap(); 1780 assert_eq!(loaded.as_deref(), Some("cursor-abc")); 1781 } 1782 1783 #[test] 1784 fn saving_none_cursor_clears_stored_cursor() { 1785 let conn = test_db(); 1786 db_save_sync_state(&conn, "did:plc:alice", "like", Some("cursor-abc")).unwrap(); 1787 db_save_sync_state(&conn, "did:plc:alice", "like", None).unwrap(); 1788 let loaded = db_load_sync_cursor(&conn, "did:plc:alice", "like").unwrap(); 1789 assert!(loaded.is_none()); 1790 } 1791 1792 #[test] 1793 fn cursor_is_per_did_and_source() { 1794 let conn = test_db(); 1795 db_save_sync_state(&conn, "did:plc:alice", "like", Some("cursor-alice-like")).unwrap(); 1796 db_save_sync_state(&conn, "did:plc:alice", "bookmark", Some("cursor-alice-bm")).unwrap(); 1797 db_save_sync_state(&conn, "did:plc:bob", "like", Some("cursor-bob-like")).unwrap(); 1798 1799 assert_eq!( 1800 db_load_sync_cursor(&conn, "did:plc:alice", "like").unwrap().as_deref(), 1801 Some("cursor-alice-like") 1802 ); 1803 assert_eq!( 1804 db_load_sync_cursor(&conn, "did:plc:alice", "bookmark") 1805 .unwrap() 1806 .as_deref(), 1807 Some("cursor-alice-bm") 1808 ); 1809 assert_eq!( 1810 db_load_sync_cursor(&conn, "did:plc:bob", "like").unwrap().as_deref(), 1811 Some("cursor-bob-like") 1812 ); 1813 } 1814 1815 #[test] 1816 fn build_fts_match_query_quotes_each_token() { 1817 assert_eq!(build_fts_match_query("rust sqlite"), "\"rust\" AND \"sqlite\""); 1818 } 1819 1820 #[test] 1821 fn upsert_inserts_new_post_for_owner_and_source() { 1822 let conn = test_db(); 1823 insert_post( 1824 &conn, 1825 "did:plc:alice", 1826 "at://did:plc:a/app.bsky.feed.post/1", 1827 "like", 1828 "hello world", 1829 "2024-01-01T00:00:00Z", 1830 ); 1831 assert_eq!(db_post_count(&conn, "did:plc:alice", "like").unwrap(), 1); 1832 } 1833 1834 #[test] 1835 fn upsert_is_scoped_by_owner_did() { 1836 let conn = test_db(); 1837 let item = feed_item( 1838 "at://did:plc:a/app.bsky.feed.post/1", 1839 "cid1", 1840 "did:plc:a", 1841 "alice", 1842 "hello", 1843 "2024-01-01T00:00:00Z", 1844 ); 1845 db_upsert_post(&conn, "did:plc:alice", &item, "like").unwrap(); 1846 db_upsert_post(&conn, "did:plc:bob", &item, "like").unwrap(); 1847 assert_eq!(db_post_count(&conn, "did:plc:alice", "like").unwrap(), 1); 1848 assert_eq!(db_post_count(&conn, "did:plc:bob", "like").unwrap(), 1); 1849 } 1850 1851 #[test] 1852 fn upsert_updates_text_on_conflict() { 1853 let conn = test_db(); 1854 let original = feed_item( 1855 "at://did:plc:a/app.bsky.feed.post/1", 1856 "cid1", 1857 "did:plc:a", 1858 "alice", 1859 "original", 1860 "2024-01-01T00:00:00Z", 1861 ); 1862 db_upsert_post(&conn, "did:plc:alice", &original, "like").unwrap(); 1863 1864 let updated = feed_item( 1865 "at://did:plc:a/app.bsky.feed.post/1", 1866 "cid2", 1867 "did:plc:a", 1868 "alice", 1869 "updated", 1870 "2024-01-02T00:00:00Z", 1871 ); 1872 db_upsert_post(&conn, "did:plc:alice", &updated, "like").unwrap(); 1873 1874 let text: String = conn 1875 .query_row( 1876 "SELECT text FROM posts WHERE storage_key = ?1", 1877 [storage_key( 1878 "did:plc:alice", 1879 "like", 1880 "at://did:plc:a/app.bsky.feed.post/1", 1881 )], 1882 |r| r.get(0), 1883 ) 1884 .unwrap(); 1885 assert_eq!(text, "updated"); 1886 } 1887 1888 #[test] 1889 fn upsert_stores_source() { 1890 let conn = test_db(); 1891 let item = feed_item( 1892 "at://did:plc:a/app.bsky.feed.post/1", 1893 "cid1", 1894 "did:plc:a", 1895 "alice", 1896 "hi", 1897 "2024-01-01T00:00:00Z", 1898 ); 1899 db_upsert_post(&conn, "did:plc:alice", &item, "bookmark").unwrap(); 1900 let source: String = conn 1901 .query_row( 1902 "SELECT source FROM posts WHERE storage_key = ?1", 1903 [storage_key( 1904 "did:plc:alice", 1905 "bookmark", 1906 "at://did:plc:a/app.bsky.feed.post/1", 1907 )], 1908 |r| r.get(0), 1909 ) 1910 .unwrap(); 1911 assert_eq!(source, "bookmark"); 1912 } 1913 1914 #[test] 1915 fn upsert_rejects_item_missing_uri() { 1916 let conn = test_db(); 1917 let bad = serde_json::json!({ "post": { "cid": "cid1", "author": { "did": "x" } } }); 1918 assert!(db_upsert_post(&conn, "did:plc:alice", &bad, "like").is_err()); 1919 } 1920 1921 #[test] 1922 fn post_count_is_per_owner_and_source() { 1923 let conn = test_db(); 1924 insert_post( 1925 &conn, 1926 "did:plc:alice", 1927 "at://a/app.bsky.feed.post/1", 1928 "like", 1929 "rust sqlite", 1930 "2024-01-01T00:00:00Z", 1931 ); 1932 insert_post( 1933 &conn, 1934 "did:plc:alice", 1935 "at://a/app.bsky.feed.post/2", 1936 "bookmark", 1937 "saved post", 1938 "2024-01-02T00:00:00Z", 1939 ); 1940 insert_post( 1941 &conn, 1942 "did:plc:bob", 1943 "at://a/app.bsky.feed.post/3", 1944 "like", 1945 "other account", 1946 "2024-01-03T00:00:00Z", 1947 ); 1948 assert_eq!(db_post_count(&conn, "did:plc:alice", "like").unwrap(), 1); 1949 assert_eq!(db_post_count(&conn, "did:plc:alice", "bookmark").unwrap(), 1); 1950 assert_eq!(db_post_count(&conn, "did:plc:bob", "like").unwrap(), 1); 1951 } 1952 1953 #[test] 1954 fn list_saved_posts_is_scoped_and_sorted_by_created_at_then_uri() { 1955 let conn = test_db(); 1956 insert_post( 1957 &conn, 1958 "did:plc:alice", 1959 "at://alice/app.bsky.feed.post/2", 1960 "bookmark", 1961 "second", 1962 "2024-01-02T00:00:00Z", 1963 ); 1964 insert_post( 1965 &conn, 1966 "did:plc:alice", 1967 "at://alice/app.bsky.feed.post/3", 1968 "bookmark", 1969 "third", 1970 "2024-01-02T00:00:00Z", 1971 ); 1972 insert_post( 1973 &conn, 1974 "did:plc:alice", 1975 "at://alice/app.bsky.feed.post/1", 1976 "bookmark", 1977 "first", 1978 "2024-01-01T00:00:00Z", 1979 ); 1980 insert_post( 1981 &conn, 1982 "did:plc:alice", 1983 "at://alice/app.bsky.feed.post/4", 1984 "like", 1985 "liked", 1986 "2024-01-03T00:00:00Z", 1987 ); 1988 insert_post( 1989 &conn, 1990 "did:plc:bob", 1991 "at://bob/app.bsky.feed.post/1", 1992 "bookmark", 1993 "bob saved", 1994 "2024-01-04T00:00:00Z", 1995 ); 1996 1997 let page = db_list_saved_posts(&conn, "did:plc:alice", "bookmark", 10, 0, None).unwrap(); 1998 let uris: Vec<&str> = page.posts.iter().map(|post| post.uri.as_str()).collect(); 1999 2000 assert_eq!( 2001 uris, 2002 vec![ 2003 "at://alice/app.bsky.feed.post/3", 2004 "at://alice/app.bsky.feed.post/2", 2005 "at://alice/app.bsky.feed.post/1", 2006 ] 2007 ); 2008 assert_eq!(page.total, 3); 2009 assert!(page.next_offset.is_none()); 2010 } 2011 2012 #[test] 2013 fn list_saved_posts_returns_next_offset_when_more_results_exist() { 2014 let conn = test_db(); 2015 insert_post( 2016 &conn, 2017 "did:plc:alice", 2018 "at://alice/app.bsky.feed.post/1", 2019 "like", 2020 "first", 2021 "2024-01-01T00:00:00Z", 2022 ); 2023 insert_post( 2024 &conn, 2025 "did:plc:alice", 2026 "at://alice/app.bsky.feed.post/2", 2027 "like", 2028 "second", 2029 "2024-01-02T00:00:00Z", 2030 ); 2031 2032 let page = db_list_saved_posts(&conn, "did:plc:alice", "like", 1, 0, None).unwrap(); 2033 2034 assert_eq!(page.posts.len(), 1); 2035 assert_eq!(page.total, 2); 2036 assert_eq!(page.next_offset, Some(1)); 2037 } 2038 2039 #[test] 2040 fn list_saved_posts_can_filter_with_query() { 2041 let conn = test_db(); 2042 insert_post( 2043 &conn, 2044 "did:plc:alice", 2045 "at://alice/app.bsky.feed.post/1", 2046 "bookmark", 2047 "rust sqlite", 2048 "2024-01-01T00:00:00Z", 2049 ); 2050 insert_post( 2051 &conn, 2052 "did:plc:alice", 2053 "at://alice/app.bsky.feed.post/2", 2054 "bookmark", 2055 "garden notes", 2056 "2024-01-02T00:00:00Z", 2057 ); 2058 2059 let page = db_list_saved_posts(&conn, "did:plc:alice", "bookmark", 10, 0, Some("rust")).unwrap(); 2060 2061 assert_eq!(page.total, 1); 2062 assert_eq!(page.posts.len(), 1); 2063 assert_eq!(page.posts[0].uri, "at://alice/app.bsky.feed.post/1"); 2064 } 2065 2066 #[test] 2067 fn embeddings_enabled_defaults_to_false_when_row_absent() { 2068 let conn = test_db(); 2069 assert!(!db_get_embeddings_enabled(&conn).unwrap()); 2070 } 2071 2072 #[test] 2073 fn set_embeddings_enabled_false_persists() { 2074 let conn = test_db(); 2075 db_set_embeddings_enabled(&conn, false).unwrap(); 2076 assert!(!db_get_embeddings_enabled(&conn).unwrap()); 2077 } 2078 2079 #[test] 2080 fn set_embeddings_enabled_true_persists() { 2081 let conn = test_db(); 2082 db_set_embeddings_enabled(&conn, false).unwrap(); 2083 db_set_embeddings_enabled(&conn, true).unwrap(); 2084 assert!(db_get_embeddings_enabled(&conn).unwrap()); 2085 } 2086 2087 #[test] 2088 fn embeddings_enabled_toggle_is_idempotent() { 2089 let conn = test_db(); 2090 conn.execute( 2091 "INSERT INTO app_settings(key, value) VALUES('embeddings_enabled', '1')", 2092 [], 2093 ) 2094 .unwrap(); 2095 db_set_embeddings_enabled(&conn, false).unwrap(); 2096 db_set_embeddings_enabled(&conn, false).unwrap(); 2097 assert!(!db_get_embeddings_enabled(&conn).unwrap()); 2098 } 2099 2100 #[test] 2101 fn embeddings_preflight_seen_defaults_to_false_when_row_absent() { 2102 let conn = test_db(); 2103 assert!(!db_get_embeddings_preflight_seen(&conn).unwrap()); 2104 } 2105 2106 #[test] 2107 fn set_embeddings_preflight_seen_persists() { 2108 let conn = test_db(); 2109 db_set_embeddings_preflight_seen(&conn, true).unwrap(); 2110 assert!(db_get_embeddings_preflight_seen(&conn).unwrap()); 2111 } 2112 2113 #[test] 2114 fn embeddings_preflight_seen_toggle_is_idempotent() { 2115 let conn = test_db(); 2116 db_set_embeddings_preflight_seen(&conn, true).unwrap(); 2117 db_set_embeddings_preflight_seen(&conn, true).unwrap(); 2118 assert!(db_get_embeddings_preflight_seen(&conn).unwrap()); 2119 db_set_embeddings_preflight_seen(&conn, false).unwrap(); 2120 assert!(!db_get_embeddings_preflight_seen(&conn).unwrap()); 2121 } 2122 2123 #[test] 2124 fn keyword_search_returns_owner_scoped_matches() { 2125 let conn = test_db(); 2126 insert_post( 2127 &conn, 2128 "did:plc:alice", 2129 "at://alice/app.bsky.feed.post/1", 2130 "like", 2131 "rust sqlite vectors", 2132 "2024-01-01T00:00:00Z", 2133 ); 2134 insert_post( 2135 &conn, 2136 "did:plc:bob", 2137 "at://bob/app.bsky.feed.post/1", 2138 "like", 2139 "rust sqlite vectors", 2140 "2024-01-02T00:00:00Z", 2141 ); 2142 2143 let results = run_local_search( 2144 &conn, 2145 "did:plc:alice", 2146 "rust sqlite", 2147 SearchMode::Keyword, 2148 10, 2149 true, 2150 None, 2151 ) 2152 .expect("keyword search should succeed"); 2153 2154 assert_eq!(results.len(), 1); 2155 assert_eq!(results[0].uri, "at://alice/app.bsky.feed.post/1"); 2156 } 2157 2158 #[test] 2159 fn semantic_search_returns_nearest_embeddings() { 2160 let conn = test_db(); 2161 insert_post( 2162 &conn, 2163 "did:plc:alice", 2164 "at://alice/app.bsky.feed.post/1", 2165 "like", 2166 "rust vectors", 2167 "2024-01-01T00:00:00Z", 2168 ); 2169 insert_post( 2170 &conn, 2171 "did:plc:alice", 2172 "at://alice/app.bsky.feed.post/2", 2173 "like", 2174 "sql joins", 2175 "2024-01-02T00:00:00Z", 2176 ); 2177 insert_embedding( 2178 &conn, 2179 "did:plc:alice", 2180 "like", 2181 "at://alice/app.bsky.feed.post/1", 2182 &[1.0, 0.0, 0.0], 2183 ); 2184 insert_embedding( 2185 &conn, 2186 "did:plc:alice", 2187 "like", 2188 "at://alice/app.bsky.feed.post/2", 2189 &[0.0, 1.0, 0.0], 2190 ); 2191 2192 let results = 2193 db_semantic_search(&conn, "did:plc:alice", &[1.0, 0.0, 0.0], 10).expect("semantic search should succeed"); 2194 2195 assert_eq!(results.len(), 2); 2196 assert_eq!(results[0].post.uri, "at://alice/app.bsky.feed.post/1"); 2197 assert!(results[0].post.score > results[1].post.score); 2198 } 2199 2200 #[test] 2201 fn semantic_search_requires_embeddings_when_disabled() { 2202 let conn = test_db(); 2203 let error = run_local_search( 2204 &conn, 2205 "did:plc:alice", 2206 "rust", 2207 SearchMode::Semantic, 2208 10, 2209 false, 2210 Some(&[1.0, 0.0, 0.0]), 2211 ) 2212 .expect_err("semantic search should fail when embeddings are disabled"); 2213 2214 assert!(error.to_string().contains("semantic search is unavailable")); 2215 } 2216 2217 #[test] 2218 fn hybrid_search_falls_back_to_keyword_when_embeddings_are_disabled() { 2219 let conn = test_db(); 2220 insert_post( 2221 &conn, 2222 "did:plc:alice", 2223 "at://alice/app.bsky.feed.post/1", 2224 "like", 2225 "rust sqlite", 2226 "2024-01-01T00:00:00Z", 2227 ); 2228 2229 let results = run_local_search(&conn, "did:plc:alice", "rust", SearchMode::Hybrid, 10, false, None) 2230 .expect("hybrid fallback should succeed"); 2231 2232 assert_eq!(results.len(), 1); 2233 assert_eq!(results[0].uri, "at://alice/app.bsky.feed.post/1"); 2234 } 2235 2236 #[test] 2237 fn hybrid_search_merges_keyword_and_semantic_results() { 2238 let conn = test_db(); 2239 insert_post( 2240 &conn, 2241 "did:plc:alice", 2242 "at://alice/app.bsky.feed.post/1", 2243 "like", 2244 "rust sqlite search", 2245 "2024-01-01T00:00:00Z", 2246 ); 2247 insert_post( 2248 &conn, 2249 "did:plc:alice", 2250 "at://alice/app.bsky.feed.post/2", 2251 "like", 2252 "semantic-only match", 2253 "2024-01-02T00:00:00Z", 2254 ); 2255 insert_embedding( 2256 &conn, 2257 "did:plc:alice", 2258 "like", 2259 "at://alice/app.bsky.feed.post/1", 2260 &[0.5, 0.5, 0.0], 2261 ); 2262 insert_embedding( 2263 &conn, 2264 "did:plc:alice", 2265 "like", 2266 "at://alice/app.bsky.feed.post/2", 2267 &[1.0, 0.0, 0.0], 2268 ); 2269 2270 let results = run_local_search( 2271 &conn, 2272 "did:plc:alice", 2273 "rust", 2274 SearchMode::Hybrid, 2275 10, 2276 true, 2277 Some(&[1.0, 0.0, 0.0]), 2278 ) 2279 .expect("hybrid search should succeed"); 2280 2281 let uris: Vec<&str> = results.iter().map(|result| result.uri.as_str()).collect(); 2282 assert!(uris.contains(&"at://alice/app.bsky.feed.post/1")); 2283 assert!(uris.contains(&"at://alice/app.bsky.feed.post/2")); 2284 } 2285 2286 #[test] 2287 fn sync_status_returns_counts_for_both_sources_per_did() { 2288 let conn = test_db(); 2289 insert_post( 2290 &conn, 2291 "did:plc:alice", 2292 "at://alice/app.bsky.feed.post/1", 2293 "like", 2294 "liked post", 2295 "2024-01-01T00:00:00Z", 2296 ); 2297 insert_post( 2298 &conn, 2299 "did:plc:alice", 2300 "at://alice/app.bsky.feed.post/2", 2301 "bookmark", 2302 "saved post", 2303 "2024-01-02T00:00:00Z", 2304 ); 2305 insert_post( 2306 &conn, 2307 "did:plc:bob", 2308 "at://bob/app.bsky.feed.post/3", 2309 "like", 2310 "bob post", 2311 "2024-01-03T00:00:00Z", 2312 ); 2313 db_save_sync_state(&conn, "did:plc:alice", "like", Some("cursor-like")).unwrap(); 2314 2315 let like_status = db_sync_status(&conn, "did:plc:alice", "like").unwrap(); 2316 let bookmark_status = db_sync_status(&conn, "did:plc:alice", "bookmark").unwrap(); 2317 2318 assert_eq!(like_status.post_count, 1); 2319 assert_eq!(like_status.cursor.as_deref(), Some("cursor-like")); 2320 assert_eq!(bookmark_status.post_count, 1); 2321 assert!(bookmark_status.cursor.is_none()); 2322 } 2323 2324 #[test] 2325 fn sync_due_is_true_for_new_active_account() { 2326 assert!(sync_due(Some("did:plc:alice"), None, None)); 2327 } 2328 2329 #[test] 2330 fn sync_due_is_false_when_recent_sync_exists() { 2331 assert!(!sync_due( 2332 Some("did:plc:alice"), 2333 Some("did:plc:alice"), 2334 Some(std::time::Instant::now()), 2335 )); 2336 } 2337}