Mirror of https://git.olaren.dev/Olaren/moot-graph
1use async_stream::try_stream;
2use cid::Cid;
3use futures_util::{
4 TryStreamExt,
5 stream::{Stream, StreamExt},
6};
7use serde::{Deserialize, Serialize};
8use serde_json::Value as JsonValue;
9use std::collections::HashMap;
10use std::time::Duration;
11use thiserror::Error;
12use tokio_tungstenite::{connect_async, tungstenite::protocol::Message};
13
14#[derive(Error, Debug)]
15pub enum FirehoseError {
16 #[error("WebSocket connection error: {0}")]
17 WebSocket(#[from] tokio_tungstenite::tungstenite::Error),
18 #[error("URL parsing error: {0}")]
19 Url(#[from] url::ParseError),
20 #[error("Failed to deserialize CBOR data: {0}")]
21 Cbor(#[from] std::io::Error),
22 #[error("Failed to read CAR file: {0}")]
23 CarRead(#[from] rs_car::CarDecodeError),
24 #[error("Invalid message frame received from relay")]
25 InvalidFrame,
26 #[error("Received an error message from the relay: {name} - {message}")]
27 RelayError { name: String, message: String },
28 #[error("Unknown message type received: {0}")]
29 UnknownMessageType(String),
30 #[error("Connection timed out after {0:?} without receiving a message")]
31 Timeout(Duration),
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35#[serde(tag = "action", rename_all = "camelCase")]
36pub enum RepoOp {
37 Create {
38 path: String,
39 cid: Cid,
40 record: JsonValue,
41 },
42 Update {
43 path: String,
44 cid: Cid,
45 record: JsonValue,
46 },
47 Delete {
48 path: String,
49 },
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct CommitEvent {
54 #[serde(rename = "seq")]
55 pub sequence: i64,
56 pub repo: String,
57 pub commit: Cid,
58 pub rev: String,
59 pub since: Option<String>,
60 pub ops: Vec<RepoOp>,
61 #[serde(rename = "time")]
62 pub timestamp: String,
63}
64
65#[derive(Debug, Clone, Serialize)]
66#[serde(untagged)]
67pub enum FirehoseEvent {
68 Commit(CommitEvent),
69 Info(JsonValue),
70 Account(JsonValue),
71 Identity(JsonValue),
72 Unknown(JsonValue),
73}
74
75#[derive(Debug, Clone)]
76pub struct FirehoseOptions {
77 pub relay_url: String,
78 pub cursor: Option<i64>,
79 pub auto_reconnect: bool,
80}
81
82impl Default for FirehoseOptions {
83 fn default() -> Self {
84 Self {
85 relay_url: "wss://relay.upcloud.world".to_string(),
86 cursor: None,
87 auto_reconnect: true,
88 }
89 }
90}
91
92#[derive(Deserialize, Debug)]
93pub struct FrameHeader {
94 #[serde(rename = "op")]
95 operation: i64,
96 #[serde(rename = "t")]
97 message_type: String,
98}
99
100#[derive(Deserialize, Debug)]
101struct ErrorBody {
102 error: String,
103 message: String,
104}
105
106#[derive(Deserialize, Debug, Clone)]
107struct RawRepoOp {
108 action: String,
109 path: String,
110 cid: Option<Cid>,
111}
112
113#[derive(Deserialize, Debug)]
114struct RawCommitEvent {
115 seq: i64,
116 repo: String,
117 commit: Cid,
118 rev: String,
119 since: Option<String>,
120 blocks: serde_bytes::ByteBuf,
121 ops: Vec<RawRepoOp>,
122 time: String,
123}
124
125pub fn subscribe_repos(
126 options: FirehoseOptions,
127) -> impl Stream<Item = Result<FirehoseEvent, FirehoseError>> {
128 let mut last_cursor = options.cursor;
129
130 try_stream! {
131 loop {
132 let mut url_str = format!(
133 "{}/xrpc/com.atproto.sync.subscribeRepos",
134 options.relay_url
135 );
136 if let Some(cursor_val) = last_cursor {
137 url_str.push_str(&format!("?cursor={}", cursor_val));
138 }
139
140 tracing::info!("Connecting to {}...", url_str);
141
142 let (mut ws_stream, _) = connect_async(&url_str).await?;
143
144 tracing::info!("Successfully connected to firehose.");
145
146 loop {
147 let next_message = tokio::time::timeout(
148 Duration::from_secs(15),
149 ws_stream.next()
150 );
151
152 match next_message.await {
153 Err(_) => {
154 tracing::info!("Connection timed out. Attempting to reconnect...");
155 ws_stream.close(None).await?;
156 break;
157 },
158 Ok(Some(Ok(msg))) => {
159 if let Message::Binary(data) = msg {
160 let mut deserializer = serde_ipld_dagcbor::de::Deserializer::from_slice(&data);
161 let map_cbor_err = |e: serde_ipld_dagcbor::DecodeError<std::convert::Infallible>| {
162 std::io::Error::other(e.to_string())
163 };
164
165 let header: FrameHeader = match serde::Deserialize::deserialize(&mut deserializer).map_err(map_cbor_err) {
166 Ok(h) => h,
167 Err(e) => {
168 tracing::error!("Failed to deserialize frame header: {}. Skipping message.", e);
169 continue;
170 }
171 };
172
173 if header.operation == -1 {
174 let body: ErrorBody = match serde::Deserialize::deserialize(&mut deserializer).map_err(map_cbor_err) {
175 Ok(b) => b,
176 Err(e) => {
177 tracing::error!("Failed to deserialize relay error body: {}. Skipping message.", e);
178 continue;
179 }
180 };
181 Err(FirehoseError::RelayError { name: body.error, message: body.message })?;
182 }
183
184 let event = match header.message_type.as_str() {
185 "#commit" => {
186 let raw_commit: RawCommitEvent = match serde::Deserialize::deserialize(&mut deserializer).map_err(map_cbor_err) {
187 Ok(c) => c,
188 Err(e) => {
189 tracing::error!("Failed to deserialize commit body: {}. Skipping message.", e);
190 continue;
191 }
192 };
193 last_cursor = Some(raw_commit.seq);
194 let commit_event = process_commit_event(raw_commit).await?;
195 FirehoseEvent::Commit(commit_event)
196 }
197 t => {
198 let body: JsonValue = match serde::Deserialize::deserialize(&mut deserializer).map_err(map_cbor_err) {
199 Ok(b) => b,
200 Err(e) => {
201 tracing::error!("Failed to deserialize event body for type '{}': {}. Skipping message.", t, e);
202 continue;
203 }
204 };
205 if let Some(seq) = body.get("seq").and_then(|v| v.as_i64()) {
206 last_cursor = Some(seq);
207 }
208 match t {
209 "#info" => FirehoseEvent::Info(body),
210 "#account" => FirehoseEvent::Account(body),
211 "#identity" => FirehoseEvent::Identity(body),
212 _ => FirehoseEvent::Unknown(body)
213 }
214 }
215 };
216 yield event;
217 }
218 }
219 Ok(Some(Err(e))) => {
220 tracing::info!("WebSocket error: {}", e);
221 break;
222 }
223 Ok(None) => {
224 tracing::info!("WebSocket stream closed by server.");
225 break;
226 }
227 }
228 }
229
230 if !options.auto_reconnect {
231 tracing::info!("Auto-reconnect is disabled. Exiting.");
232 break;
233 }
234 tokio::time::sleep(Duration::from_secs(1)).await;
235 }
236 }
237}
238
239async fn process_commit_event(raw: RawCommitEvent) -> Result<CommitEvent, FirehoseError> {
240 let mut blocks_reader = raw.blocks.as_ref();
241 let car_reader = rs_car::CarReader::new(&mut blocks_reader, false).await?;
242
243 let records_map: HashMap<Cid, Vec<u8>> = car_reader.try_collect::<HashMap<_, _>>().await?;
244
245 let mut ops = Vec::new();
246 for op in raw.ops {
247 let repo_op = match op.action.as_str() {
248 "create" | "update" => {
249 let cid = op.cid.expect("Create/Update operation must have a CID");
250 let record_bytes = records_map.get(&cid).cloned().unwrap_or_default();
251 let record: JsonValue =
252 serde_ipld_dagcbor::from_slice(&record_bytes).unwrap_or(JsonValue::Null);
253
254 if op.action == "create" {
255 RepoOp::Create {
256 path: op.path,
257 cid,
258 record,
259 }
260 } else {
261 RepoOp::Update {
262 path: op.path,
263 cid,
264 record,
265 }
266 }
267 }
268 "delete" => RepoOp::Delete { path: op.path },
269 _ => continue,
270 };
271 ops.push(repo_op);
272 }
273
274 Ok(CommitEvent {
275 sequence: raw.seq,
276 repo: raw.repo,
277 commit: raw.commit,
278 rev: raw.rev,
279 since: raw.since,
280 ops,
281 timestamp: raw.time,
282 })
283}