use async_stream::try_stream; use cid::Cid; use futures_util::{ TryStreamExt, stream::{Stream, StreamExt}, }; use serde::{Deserialize, Serialize}; use serde_json::Value as JsonValue; use std::collections::HashMap; use std::time::Duration; use thiserror::Error; use tokio_tungstenite::{connect_async, tungstenite::protocol::Message}; #[derive(Error, Debug)] pub enum FirehoseError { #[error("WebSocket connection error: {0}")] WebSocket(#[from] tokio_tungstenite::tungstenite::Error), #[error("URL parsing error: {0}")] Url(#[from] url::ParseError), #[error("Failed to deserialize CBOR data: {0}")] Cbor(#[from] std::io::Error), #[error("Failed to read CAR file: {0}")] CarRead(#[from] rs_car::CarDecodeError), #[error("Invalid message frame received from relay")] InvalidFrame, #[error("Received an error message from the relay: {name} - {message}")] RelayError { name: String, message: String }, #[error("Unknown message type received: {0}")] UnknownMessageType(String), #[error("Connection timed out after {0:?} without receiving a message")] Timeout(Duration), } #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "action", rename_all = "camelCase")] pub enum RepoOp { Create { path: String, cid: Cid, record: JsonValue, }, Update { path: String, cid: Cid, record: JsonValue, }, Delete { path: String, }, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CommitEvent { #[serde(rename = "seq")] pub sequence: i64, pub repo: String, pub commit: Cid, pub rev: String, pub since: Option, pub ops: Vec, #[serde(rename = "time")] pub timestamp: String, } #[derive(Debug, Clone, Serialize)] #[serde(untagged)] pub enum FirehoseEvent { Commit(CommitEvent), Info(JsonValue), Account(JsonValue), Identity(JsonValue), Unknown(JsonValue), } #[derive(Debug, Clone)] pub struct FirehoseOptions { pub relay_url: String, pub cursor: Option, pub auto_reconnect: bool, } impl Default for FirehoseOptions { fn default() -> Self { Self { relay_url: "wss://relay.upcloud.world".to_string(), cursor: None, auto_reconnect: true, } } } #[derive(Deserialize, Debug)] pub struct FrameHeader { #[serde(rename = "op")] operation: i64, #[serde(rename = "t")] message_type: String, } #[derive(Deserialize, Debug)] struct ErrorBody { error: String, message: String, } #[derive(Deserialize, Debug, Clone)] struct RawRepoOp { action: String, path: String, cid: Option, } #[derive(Deserialize, Debug)] struct RawCommitEvent { seq: i64, repo: String, commit: Cid, rev: String, since: Option, blocks: serde_bytes::ByteBuf, ops: Vec, time: String, } pub fn subscribe_repos( options: FirehoseOptions, ) -> impl Stream> { let mut last_cursor = options.cursor; try_stream! { loop { let mut url_str = format!( "{}/xrpc/com.atproto.sync.subscribeRepos", options.relay_url ); if let Some(cursor_val) = last_cursor { url_str.push_str(&format!("?cursor={}", cursor_val)); } tracing::info!("Connecting to {}...", url_str); let (mut ws_stream, _) = connect_async(&url_str).await?; tracing::info!("Successfully connected to firehose."); loop { let next_message = tokio::time::timeout( Duration::from_secs(15), ws_stream.next() ); match next_message.await { Err(_) => { tracing::info!("Connection timed out. Attempting to reconnect..."); ws_stream.close(None).await?; break; }, Ok(Some(Ok(msg))) => { if let Message::Binary(data) = msg { let mut deserializer = serde_ipld_dagcbor::de::Deserializer::from_slice(&data); let map_cbor_err = |e: serde_ipld_dagcbor::DecodeError| { std::io::Error::other(e.to_string()) }; let header: FrameHeader = match serde::Deserialize::deserialize(&mut deserializer).map_err(map_cbor_err) { Ok(h) => h, Err(e) => { tracing::error!("Failed to deserialize frame header: {}. Skipping message.", e); continue; } }; if header.operation == -1 { let body: ErrorBody = match serde::Deserialize::deserialize(&mut deserializer).map_err(map_cbor_err) { Ok(b) => b, Err(e) => { tracing::error!("Failed to deserialize relay error body: {}. Skipping message.", e); continue; } }; Err(FirehoseError::RelayError { name: body.error, message: body.message })?; } let event = match header.message_type.as_str() { "#commit" => { let raw_commit: RawCommitEvent = match serde::Deserialize::deserialize(&mut deserializer).map_err(map_cbor_err) { Ok(c) => c, Err(e) => { tracing::error!("Failed to deserialize commit body: {}. Skipping message.", e); continue; } }; last_cursor = Some(raw_commit.seq); let commit_event = process_commit_event(raw_commit).await?; FirehoseEvent::Commit(commit_event) } t => { let body: JsonValue = match serde::Deserialize::deserialize(&mut deserializer).map_err(map_cbor_err) { Ok(b) => b, Err(e) => { tracing::error!("Failed to deserialize event body for type '{}': {}. Skipping message.", t, e); continue; } }; if let Some(seq) = body.get("seq").and_then(|v| v.as_i64()) { last_cursor = Some(seq); } match t { "#info" => FirehoseEvent::Info(body), "#account" => FirehoseEvent::Account(body), "#identity" => FirehoseEvent::Identity(body), _ => FirehoseEvent::Unknown(body) } } }; yield event; } } Ok(Some(Err(e))) => { tracing::info!("WebSocket error: {}", e); break; } Ok(None) => { tracing::info!("WebSocket stream closed by server."); break; } } } if !options.auto_reconnect { tracing::info!("Auto-reconnect is disabled. Exiting."); break; } tokio::time::sleep(Duration::from_secs(1)).await; } } } async fn process_commit_event(raw: RawCommitEvent) -> Result { let mut blocks_reader = raw.blocks.as_ref(); let car_reader = rs_car::CarReader::new(&mut blocks_reader, false).await?; let records_map: HashMap> = car_reader.try_collect::>().await?; let mut ops = Vec::new(); for op in raw.ops { let repo_op = match op.action.as_str() { "create" | "update" => { let cid = op.cid.expect("Create/Update operation must have a CID"); let record_bytes = records_map.get(&cid).cloned().unwrap_or_default(); let record: JsonValue = serde_ipld_dagcbor::from_slice(&record_bytes).unwrap_or(JsonValue::Null); if op.action == "create" { RepoOp::Create { path: op.path, cid, record, } } else { RepoOp::Update { path: op.path, cid, record, } } } "delete" => RepoOp::Delete { path: op.path }, _ => continue, }; ops.push(repo_op); } Ok(CommitEvent { sequence: raw.seq, repo: raw.repo, commit: raw.commit, rev: raw.rev, since: raw.since, ops, timestamp: raw.time, }) }