1use anyhow::{anyhow, Result};
2use atrium_api::com::atproto::sync::subscribe_repos::{Commit, NSID};
3use atrium_api::client::AtpServiceClient;
4use atrium_api::com;
5use atrium_api::types;
6use atrium_xrpc_client::isahc::IsahcClient;
7use futures::StreamExt;
8use ipld_core::ipld::Ipld;
9use std::io::Cursor;
10use std::sync::{Arc, Mutex};
11use tokio_tungstenite::connect_async;
12use tokio_tungstenite::tungstenite::Message;
13
14use crate::fs::{PdsFsCollection, PdsFsEntry, PdsFsRecord};
15use indexmap::{IndexMap, IndexSet};
16
17/// Frame header types for WebSocket messages
18#[derive(Debug, Clone, PartialEq, Eq)]
19enum FrameHeader {
20 Message(Option<String>),
21 Error,
22}
23
24impl TryFrom<Ipld> for FrameHeader {
25 type Error = anyhow::Error;
26
27 fn try_from(value: Ipld) -> Result<Self> {
28 if let Ipld::Map(map) = value {
29 if let Some(Ipld::Integer(i)) = map.get("op") {
30 match i {
31 1 => {
32 let t = if let Some(Ipld::String(s)) = map.get("t") {
33 Some(s.clone())
34 } else {
35 None
36 };
37 return Ok(FrameHeader::Message(t));
38 }
39 -1 => return Ok(FrameHeader::Error),
40 _ => {}
41 }
42 }
43 }
44 Err(anyhow!("invalid frame type"))
45 }
46}
47
48/// Frame types for parsed WebSocket messages
49#[derive(Debug, Clone, PartialEq, Eq)]
50pub enum Frame {
51 Message(Option<String>, MessageFrame),
52 Error(ErrorFrame),
53}
54
55#[derive(Debug, Clone, PartialEq, Eq)]
56pub struct MessageFrame {
57 pub body: Vec<u8>,
58}
59
60#[derive(Debug, Clone, PartialEq, Eq)]
61pub struct ErrorFrame {}
62
63impl TryFrom<&[u8]> for Frame {
64 type Error = anyhow::Error;
65
66 fn try_from(value: &[u8]) -> Result<Self> {
67 let mut cursor = Cursor::new(value);
68 let (left, right) = match serde_ipld_dagcbor::from_reader::<Ipld, _>(&mut cursor) {
69 Err(serde_ipld_dagcbor::DecodeError::TrailingData) => {
70 value.split_at(cursor.position() as usize)
71 }
72 _ => {
73 return Err(anyhow!("invalid frame type"));
74 }
75 };
76 let header = FrameHeader::try_from(serde_ipld_dagcbor::from_slice::<Ipld>(left)?)?;
77 if let FrameHeader::Message(t) = &header {
78 Ok(Frame::Message(t.clone(), MessageFrame { body: right.to_vec() }))
79 } else {
80 Ok(Frame::Error(ErrorFrame {}))
81 }
82 }
83}
84
85/// Subscribe to a repo's firehose and update inodes on changes
86pub async fn subscribe_to_repo<R>(
87 did: String,
88 pds: String,
89 inodes: Arc<Mutex<IndexSet<PdsFsEntry>>>,
90 sizes: Arc<Mutex<IndexMap<usize, u64>>>,
91 content_cache: Arc<Mutex<IndexMap<String, String>>>,
92 notifier: fuser::Notifier,
93) -> Result<()>
94where
95 R: atrium_repo::blockstore::AsyncBlockStoreRead,
96{
97 // Strip https:// or http:// prefix from PDS URL if present
98 let pds_host = pds.trim_start_matches("https://").trim_start_matches("http://");
99 let url = format!("wss://{}/xrpc/{}", pds_host, NSID);
100 println!("Connecting to firehose: {}", url);
101
102 let (mut stream, _) = connect_async(url).await?;
103 println!("Connected to firehose for {}", did);
104
105 loop {
106 match stream.next().await {
107 Some(Ok(Message::Binary(data))) => {
108 if let Ok(Frame::Message(Some(t), msg)) = Frame::try_from(data.as_slice()) {
109 if t.as_str() == "#commit" {
110 if let Ok(commit) = serde_ipld_dagcbor::from_reader::<Commit, _>(msg.body.as_slice()) {
111 // Only process commits for our DID
112 if commit.repo.as_str() == did {
113 if let Err(e) = handle_commit(&commit, &inodes, &sizes, &content_cache, &did, &pds, ¬ifier).await {
114 eprintln!("Error handling commit: {:?}", e);
115 }
116 }
117 }
118 }
119 }
120 }
121 Some(Ok(_)) => {} // Ignore other message types
122 Some(Err(e)) => {
123 eprintln!("WebSocket error: {}", e);
124 break;
125 }
126 None => {
127 eprintln!("WebSocket closed");
128 break;
129 }
130 }
131 }
132
133 Ok(())
134}
135
136/// Handle a commit by updating the inode tree and notifying Finder
137async fn handle_commit(
138 commit: &Commit,
139 inodes: &Arc<Mutex<IndexSet<PdsFsEntry>>>,
140 sizes: &Arc<Mutex<IndexMap<usize, u64>>>,
141 content_cache: &Arc<Mutex<IndexMap<String, String>>>,
142 did: &str,
143 pds: &str,
144 notifier: &fuser::Notifier,
145) -> Result<()> {
146 // Find the DID inode
147 let did_entry = PdsFsEntry::Did(did.to_string());
148 let did_inode = {
149 let inodes_lock = inodes.lock().unwrap();
150 inodes_lock.get_index_of(&did_entry)
151 };
152
153 let Some(did_inode) = did_inode else {
154 return Err(anyhow!("DID not found in inodes"));
155 };
156
157 for op in &commit.ops {
158 let Some((collection, rkey)) = op.path.split_once('/') else {
159 continue;
160 };
161
162 match op.action.as_str() {
163 "create" => {
164 // Fetch the record from PDS
165 let record_key = format!("{}/{}", collection, rkey);
166 let cache_key = format!("{}/{}", did, record_key);
167
168 // Fetch record content from PDS
169 match fetch_record(pds, did, collection, rkey).await {
170 Ok(content) => {
171 let content_len = content.len() as u64;
172
173 // Add the record to inodes
174 let (collection_inode, record_inode) = {
175 let mut inodes_lock = inodes.lock().unwrap();
176
177 // Ensure collection exists
178 let collection_entry = PdsFsEntry::Collection(PdsFsCollection {
179 parent: did_inode,
180 nsid: collection.to_string(),
181 });
182 let (collection_inode, _) = inodes_lock.insert_full(collection_entry);
183
184 // Add the record
185 let record_entry = PdsFsEntry::Record(PdsFsRecord {
186 parent: collection_inode,
187 rkey: rkey.to_string(),
188 });
189 let (record_inode, _) = inodes_lock.insert_full(record_entry);
190 (collection_inode, record_inode)
191 };
192
193 // Cache the content and size
194 content_cache.lock().unwrap().insert(cache_key, content);
195 sizes.lock().unwrap().insert(record_inode, content_len);
196
197 // Notify Finder about the new file (release lock first)
198 let filename = format!("{}.json", rkey);
199 if let Err(e) = notifier.inval_entry(collection_inode as u64, filename.as_ref()) {
200 eprintln!("Failed to invalidate entry for {}: {}", filename, e);
201 }
202
203 println!("Created: {}/{}", collection, rkey);
204 }
205 Err(e) => {
206 eprintln!("Failed to fetch record {}/{}: {}", collection, rkey, e);
207 }
208 }
209 }
210 "delete" => {
211 // Get inodes before removing
212 let (collection_inode_opt, child_inode_opt) = {
213 let mut inodes_lock = inodes.lock().unwrap();
214
215 // Find the collection
216 let collection_entry = PdsFsEntry::Collection(PdsFsCollection {
217 parent: did_inode,
218 nsid: collection.to_string(),
219 });
220 let collection_inode = inodes_lock.get_index_of(&collection_entry);
221
222 // Find and remove the record
223 let child_inode = if let Some(coll_ino) = collection_inode {
224 let record_entry = PdsFsEntry::Record(PdsFsRecord {
225 parent: coll_ino,
226 rkey: rkey.to_string(),
227 });
228 let child_ino = inodes_lock.get_index_of(&record_entry);
229 inodes_lock.shift_remove(&record_entry);
230 child_ino
231 } else {
232 None
233 };
234
235 (collection_inode, child_inode)
236 };
237
238 // Notify Finder about the deletion (release lock first)
239 if let (Some(coll_ino), Some(child_ino)) = (collection_inode_opt, child_inode_opt) {
240 // Remove from caches
241 sizes.lock().unwrap().shift_remove(&child_ino);
242 let cache_key = format!("{}/{}/{}", did, collection, rkey);
243 content_cache.lock().unwrap().shift_remove(&cache_key);
244
245 let filename = format!("{}.json", rkey);
246 if let Err(e) = notifier.delete(coll_ino as u64, child_ino as u64, filename.as_ref()) {
247 eprintln!("Failed to notify deletion for {}: {}", filename, e);
248 }
249 }
250
251 println!("Deleted: {}/{}", collection, rkey);
252 }
253 "update" => {
254 // For updates, invalidate the inode so content is re-fetched
255 let record_inode_opt = {
256 let inodes_lock = inodes.lock().unwrap();
257 let collection_entry = PdsFsEntry::Collection(PdsFsCollection {
258 parent: did_inode,
259 nsid: collection.to_string(),
260 });
261
262 if let Some(collection_inode) = inodes_lock.get_index_of(&collection_entry) {
263 let record_entry = PdsFsEntry::Record(PdsFsRecord {
264 parent: collection_inode,
265 rkey: rkey.to_string(),
266 });
267 inodes_lock.get_index_of(&record_entry)
268 } else {
269 None
270 }
271 };
272
273 // Notify Finder to invalidate the inode (release lock first)
274 if let Some(record_ino) = record_inode_opt {
275 // Clear caches so content is recalculated
276 sizes.lock().unwrap().shift_remove(&record_ino);
277 let cache_key = format!("{}/{}/{}", did, collection, rkey);
278 content_cache.lock().unwrap().shift_remove(&cache_key);
279
280 // Invalidate the entire inode (metadata and all data)
281 if let Err(e) = notifier.inval_inode(record_ino as u64, 0, 0) {
282 eprintln!("Failed to invalidate inode for {}/{}: {}", collection, rkey, e);
283 }
284 }
285
286 println!("Updated: {}/{}", collection, rkey);
287 }
288 _ => {}
289 }
290 }
291
292 Ok(())
293}
294
295/// Fetch a record from the PDS
296async fn fetch_record(pds: &str, did: &str, collection: &str, rkey: &str) -> Result<String> {
297 let client = AtpServiceClient::new(IsahcClient::new(pds));
298 let did = types::string::Did::new(did.to_string()).map_err(|e| anyhow!(e))?;
299 let collection_nsid = types::string::Nsid::new(collection.to_string()).map_err(|e| anyhow!(e))?;
300 let record_key = types::string::RecordKey::new(rkey.to_string()).map_err(|e| anyhow!(e))?;
301
302 let response = client
303 .service
304 .com
305 .atproto
306 .repo
307 .get_record(com::atproto::repo::get_record::Parameters::from(
308 com::atproto::repo::get_record::ParametersData {
309 cid: None,
310 collection: collection_nsid,
311 repo: types::string::AtIdentifier::Did(did),
312 rkey: record_key,
313 }
314 ))
315 .await?;
316
317 Ok(serde_json::to_string_pretty(&response.value)?)
318}