Mirror of https://git.olaren.dev/Olaren/moot-graph
at main 10 kB view raw
1use async_stream::try_stream; 2use cid::Cid; 3use futures_util::{ 4 TryStreamExt, 5 stream::{Stream, StreamExt}, 6}; 7use serde::{Deserialize, Serialize}; 8use serde_json::Value as JsonValue; 9use std::collections::HashMap; 10use std::time::Duration; 11use thiserror::Error; 12use tokio_tungstenite::{connect_async, tungstenite::protocol::Message}; 13 14#[derive(Error, Debug)] 15pub enum FirehoseError { 16 #[error("WebSocket connection error: {0}")] 17 WebSocket(#[from] tokio_tungstenite::tungstenite::Error), 18 #[error("URL parsing error: {0}")] 19 Url(#[from] url::ParseError), 20 #[error("Failed to deserialize CBOR data: {0}")] 21 Cbor(#[from] std::io::Error), 22 #[error("Failed to read CAR file: {0}")] 23 CarRead(#[from] rs_car::CarDecodeError), 24 #[error("Invalid message frame received from relay")] 25 InvalidFrame, 26 #[error("Received an error message from the relay: {name} - {message}")] 27 RelayError { name: String, message: String }, 28 #[error("Unknown message type received: {0}")] 29 UnknownMessageType(String), 30 #[error("Connection timed out after {0:?} without receiving a message")] 31 Timeout(Duration), 32} 33 34#[derive(Debug, Clone, Serialize, Deserialize)] 35#[serde(tag = "action", rename_all = "camelCase")] 36pub enum RepoOp { 37 Create { 38 path: String, 39 cid: Cid, 40 record: JsonValue, 41 }, 42 Update { 43 path: String, 44 cid: Cid, 45 record: JsonValue, 46 }, 47 Delete { 48 path: String, 49 }, 50} 51 52#[derive(Debug, Clone, Serialize, Deserialize)] 53pub struct CommitEvent { 54 #[serde(rename = "seq")] 55 pub sequence: i64, 56 pub repo: String, 57 pub commit: Cid, 58 pub rev: String, 59 pub since: Option<String>, 60 pub ops: Vec<RepoOp>, 61 #[serde(rename = "time")] 62 pub timestamp: String, 63} 64 65#[derive(Debug, Clone, Serialize)] 66#[serde(untagged)] 67pub enum FirehoseEvent { 68 Commit(CommitEvent), 69 Info(JsonValue), 70 Account(JsonValue), 71 Identity(JsonValue), 72 Unknown(JsonValue), 73} 74 75#[derive(Debug, Clone)] 76pub struct FirehoseOptions { 77 pub relay_url: String, 78 pub cursor: Option<i64>, 79 pub auto_reconnect: bool, 80} 81 82impl Default for FirehoseOptions { 83 fn default() -> Self { 84 Self { 85 relay_url: "wss://relay.upcloud.world".to_string(), 86 cursor: None, 87 auto_reconnect: true, 88 } 89 } 90} 91 92#[derive(Deserialize, Debug)] 93pub struct FrameHeader { 94 #[serde(rename = "op")] 95 operation: i64, 96 #[serde(rename = "t")] 97 message_type: String, 98} 99 100#[derive(Deserialize, Debug)] 101struct ErrorBody { 102 error: String, 103 message: String, 104} 105 106#[derive(Deserialize, Debug, Clone)] 107struct RawRepoOp { 108 action: String, 109 path: String, 110 cid: Option<Cid>, 111} 112 113#[derive(Deserialize, Debug)] 114struct RawCommitEvent { 115 seq: i64, 116 repo: String, 117 commit: Cid, 118 rev: String, 119 since: Option<String>, 120 blocks: serde_bytes::ByteBuf, 121 ops: Vec<RawRepoOp>, 122 time: String, 123} 124 125pub fn subscribe_repos( 126 options: FirehoseOptions, 127) -> impl Stream<Item = Result<FirehoseEvent, FirehoseError>> { 128 let mut last_cursor = options.cursor; 129 130 try_stream! { 131 loop { 132 let mut url_str = format!( 133 "{}/xrpc/com.atproto.sync.subscribeRepos", 134 options.relay_url 135 ); 136 if let Some(cursor_val) = last_cursor { 137 url_str.push_str(&format!("?cursor={}", cursor_val)); 138 } 139 140 tracing::info!("Connecting to {}...", url_str); 141 142 let (mut ws_stream, _) = connect_async(&url_str).await?; 143 144 tracing::info!("Successfully connected to firehose."); 145 146 loop { 147 let next_message = tokio::time::timeout( 148 Duration::from_secs(15), 149 ws_stream.next() 150 ); 151 152 match next_message.await { 153 Err(_) => { 154 tracing::info!("Connection timed out. Attempting to reconnect..."); 155 ws_stream.close(None).await?; 156 break; 157 }, 158 Ok(Some(Ok(msg))) => { 159 if let Message::Binary(data) = msg { 160 let mut deserializer = serde_ipld_dagcbor::de::Deserializer::from_slice(&data); 161 let map_cbor_err = |e: serde_ipld_dagcbor::DecodeError<std::convert::Infallible>| { 162 std::io::Error::other(e.to_string()) 163 }; 164 165 let header: FrameHeader = match serde::Deserialize::deserialize(&mut deserializer).map_err(map_cbor_err) { 166 Ok(h) => h, 167 Err(e) => { 168 tracing::error!("Failed to deserialize frame header: {}. Skipping message.", e); 169 continue; 170 } 171 }; 172 173 if header.operation == -1 { 174 let body: ErrorBody = match serde::Deserialize::deserialize(&mut deserializer).map_err(map_cbor_err) { 175 Ok(b) => b, 176 Err(e) => { 177 tracing::error!("Failed to deserialize relay error body: {}. Skipping message.", e); 178 continue; 179 } 180 }; 181 Err(FirehoseError::RelayError { name: body.error, message: body.message })?; 182 } 183 184 let event = match header.message_type.as_str() { 185 "#commit" => { 186 let raw_commit: RawCommitEvent = match serde::Deserialize::deserialize(&mut deserializer).map_err(map_cbor_err) { 187 Ok(c) => c, 188 Err(e) => { 189 tracing::error!("Failed to deserialize commit body: {}. Skipping message.", e); 190 continue; 191 } 192 }; 193 last_cursor = Some(raw_commit.seq); 194 let commit_event = process_commit_event(raw_commit).await?; 195 FirehoseEvent::Commit(commit_event) 196 } 197 t => { 198 let body: JsonValue = match serde::Deserialize::deserialize(&mut deserializer).map_err(map_cbor_err) { 199 Ok(b) => b, 200 Err(e) => { 201 tracing::error!("Failed to deserialize event body for type '{}': {}. Skipping message.", t, e); 202 continue; 203 } 204 }; 205 if let Some(seq) = body.get("seq").and_then(|v| v.as_i64()) { 206 last_cursor = Some(seq); 207 } 208 match t { 209 "#info" => FirehoseEvent::Info(body), 210 "#account" => FirehoseEvent::Account(body), 211 "#identity" => FirehoseEvent::Identity(body), 212 _ => FirehoseEvent::Unknown(body) 213 } 214 } 215 }; 216 yield event; 217 } 218 } 219 Ok(Some(Err(e))) => { 220 tracing::info!("WebSocket error: {}", e); 221 break; 222 } 223 Ok(None) => { 224 tracing::info!("WebSocket stream closed by server."); 225 break; 226 } 227 } 228 } 229 230 if !options.auto_reconnect { 231 tracing::info!("Auto-reconnect is disabled. Exiting."); 232 break; 233 } 234 tokio::time::sleep(Duration::from_secs(1)).await; 235 } 236 } 237} 238 239async fn process_commit_event(raw: RawCommitEvent) -> Result<CommitEvent, FirehoseError> { 240 let mut blocks_reader = raw.blocks.as_ref(); 241 let car_reader = rs_car::CarReader::new(&mut blocks_reader, false).await?; 242 243 let records_map: HashMap<Cid, Vec<u8>> = car_reader.try_collect::<HashMap<_, _>>().await?; 244 245 let mut ops = Vec::new(); 246 for op in raw.ops { 247 let repo_op = match op.action.as_str() { 248 "create" | "update" => { 249 let cid = op.cid.expect("Create/Update operation must have a CID"); 250 let record_bytes = records_map.get(&cid).cloned().unwrap_or_default(); 251 let record: JsonValue = 252 serde_ipld_dagcbor::from_slice(&record_bytes).unwrap_or(JsonValue::Null); 253 254 if op.action == "create" { 255 RepoOp::Create { 256 path: op.path, 257 cid, 258 record, 259 } 260 } else { 261 RepoOp::Update { 262 path: op.path, 263 cid, 264 record, 265 } 266 } 267 } 268 "delete" => RepoOp::Delete { path: op.path }, 269 _ => continue, 270 }; 271 ops.push(repo_op); 272 } 273 274 Ok(CommitEvent { 275 sequence: raw.seq, 276 repo: raw.repo, 277 commit: raw.commit, 278 rev: raw.rev, 279 since: raw.since, 280 ops, 281 timestamp: raw.time, 282 }) 283}