this repo has no description
1use std::str::FromStr;
2
3use anyhow::{anyhow, Context, Result};
4use chrono::DateTime;
5use futures_util::SinkExt;
6use futures_util::StreamExt;
7use http::HeaderValue;
8use http::Uri;
9use tokio::time::{sleep, Instant};
10use tokio_util::sync::CancellationToken;
11use tokio_websockets::{ClientBuilder, Message};
12
13use crate::config;
14use crate::matcher::FeedMatchers;
15use crate::matcher::Match;
16use crate::matcher::MatchOperation;
17use crate::storage;
18use crate::storage::consumer_control_get;
19use crate::storage::consumer_control_insert;
20use crate::storage::denylist_exists;
21use crate::storage::feed_content_update;
22use crate::storage::feed_content_upsert;
23use crate::storage::StoragePool;
24
25const MAX_MESSAGE_SIZE: usize = 25000;
26
27#[derive(Clone)]
28pub struct ConsumerTaskConfig {
29 pub user_agent: String,
30 pub compression: bool,
31 pub zstd_dictionary_location: String,
32 pub jetstream_hostname: String,
33 pub feeds: config::Feeds,
34 pub collections: Vec<String>,
35}
36
37pub struct ConsumerTask {
38 cancellation_token: CancellationToken,
39 pool: StoragePool,
40 config: ConsumerTaskConfig,
41 feed_matchers: FeedMatchers,
42}
43
44impl ConsumerTask {
45 pub fn new(
46 pool: StoragePool,
47 config: ConsumerTaskConfig,
48 cancellation_token: CancellationToken,
49 ) -> Result<Self> {
50 let feed_matchers = FeedMatchers::from_config(&config.feeds)?;
51
52 Ok(Self {
53 pool,
54 cancellation_token,
55 config,
56 feed_matchers,
57 })
58 }
59
60 pub async fn run_background(&self) -> Result<()> {
61 tracing::debug!("ConsumerTask started");
62
63 let last_time_us =
64 consumer_control_get(&self.pool, &self.config.jetstream_hostname).await?;
65
66 tracing::info!(cursor = ?last_time_us, "loaded cursor from database");
67
68 let cursor_param = if let Some(cursor) = last_time_us {
69 format!("&cursor={}", cursor)
70 } else {
71 String::new()
72 };
73
74 let uri = Uri::from_str(&format!(
75 "wss://{}/subscribe?compress={}&requireHello=true{}",
76 self.config.jetstream_hostname, self.config.compression, cursor_param
77 ))
78 .context("invalid jetstream URL")?;
79
80 tracing::info!(uri = %uri, "connecting to jetstream");
81
82 let (mut client, _) = ClientBuilder::from_uri(uri)
83 .add_header(
84 http::header::USER_AGENT,
85 HeaderValue::from_str(&self.config.user_agent)?,
86 )
87 .connect()
88 .await
89 .map_err(|err| anyhow::Error::new(err).context("cannot connect to jetstream"))?;
90
91 let update = model::SubscriberSourcedMessage::Update {
92 wanted_collections: self.config.collections.clone(),
93 wanted_dids: vec![],
94 max_message_size_bytes: MAX_MESSAGE_SIZE as u64,
95 cursor: None,
96 };
97 let serialized_update = serde_json::to_string(&update)
98 .map_err(|err| anyhow::Error::msg(err).context("cannot serialize update"))?;
99
100 tracing::info!(message = %serialized_update, "sending subscription update to jetstream");
101
102 client
103 .send(Message::text(serialized_update))
104 .await
105 .map_err(|err| anyhow::Error::msg(err).context("cannot send update"))?;
106
107 let mut decompressor = if self.config.compression {
108 // mkdir -p data/ && curl -o data/zstd_dictionary https://github.com/bluesky-social/jetstream/raw/refs/heads/main/pkg/models/zstd_dictionary
109 let data: Vec<u8> = std::fs::read(self.config.zstd_dictionary_location.clone())
110 .context("unable to load zstd dictionary")?;
111 zstd::bulk::Decompressor::with_dictionary(&data)
112 .map_err(|err| anyhow::Error::msg(err).context("cannot create decompressor"))?
113 } else {
114 zstd::bulk::Decompressor::new()
115 .map_err(|err| anyhow::Error::msg(err).context("cannot create decompressor"))?
116 };
117
118 let interval = std::time::Duration::from_secs(120);
119 let sleeper = sleep(interval);
120 tokio::pin!(sleeper);
121
122 let heartbeat_interval = std::time::Duration::from_secs(15);
123 let heartbeat_sleeper = sleep(heartbeat_interval);
124 tokio::pin!(heartbeat_sleeper);
125
126 let mut time_usec = 0i64;
127
128 loop {
129 tokio::select! {
130 () = self.cancellation_token.cancelled() => {
131 break;
132 },
133 () = &mut sleeper => {
134 consumer_control_insert(&self.pool, &self.config.jetstream_hostname, time_usec).await?;
135 sleeper.as_mut().reset(Instant::now() + interval);
136 },
137 () = &mut heartbeat_sleeper => {
138 if time_usec > 0 {
139 let datetime = DateTime::from_timestamp_micros(time_usec)
140 .map(|dt| dt.to_rfc3339())
141 .unwrap_or_else(|| format!("{} microseconds", time_usec));
142 tracing::info!(time_us = time_usec, timestamp = %datetime, "consumer heartbeat");
143 }
144 heartbeat_sleeper.as_mut().reset(Instant::now() + heartbeat_interval);
145 },
146 item = client.next() => {
147 if item.is_none() {
148 tracing::warn!("jetstream connection closed");
149 break;
150 }
151 let item = item.unwrap();
152
153 if let Err(err) = item {
154 tracing::error!(error = ?err, "error processing jetstream message");
155 continue;
156 }
157 let item = item.unwrap();
158
159 let event = if self.config.compression {
160 if !item.is_binary() {
161 // Skip WebSocket control frames (ping, pong, close)
162 if item.is_ping() || item.is_pong() || item.is_close() {
163 continue;
164 }
165 // Log unexpected non-binary message types
166 tracing::warn!("received unexpected non-binary message from jetstream (not ping/pong/close)");
167 continue;
168 }
169 let payload = item.into_payload();
170
171 let decoded = decompressor.decompress(&payload, MAX_MESSAGE_SIZE * 3);
172 if let Err(err) = decoded {
173 tracing::debug!(err = ?err, "cannot decompress message");
174 continue;
175 }
176 let decoded = decoded.unwrap();
177 serde_json::from_slice::<model::Event>(&decoded)
178 .context(anyhow!("cannot deserialize message"))
179 } else {
180 if !item.is_text() {
181 // Skip WebSocket control frames (ping, pong, close)
182 if item.is_ping() || item.is_pong() || item.is_close() {
183 continue;
184 }
185 // Log unexpected non-text message types
186 tracing::warn!("received unexpected non-text message from jetstream (not ping/pong/close)");
187 continue;
188 }
189 item.as_text()
190 .ok_or(anyhow!("cannot convert message to text"))
191 .and_then(|value| {
192 serde_json::from_str::<model::Event>(value)
193 .context(anyhow!("cannot deserialize message"))
194 })
195 };
196 if let Err(err) = event {
197 tracing::error!(error = ?err, "error processing jetstream message");
198
199 continue;
200 }
201 let event = event.unwrap();
202
203 let previous_time_usec = time_usec;
204 time_usec = std::cmp::max(time_usec, event.time_us);
205
206 if previous_time_usec == 0 {
207 let datetime = DateTime::from_timestamp_micros(event.time_us)
208 .map(|dt| dt.to_rfc3339())
209 .unwrap_or_else(|| format!("{} microseconds", event.time_us));
210 tracing::info!(time_us = event.time_us, timestamp = %datetime, "received first event from jetstream");
211 }
212
213 if event.clone().kind != "commit" {
214 continue;
215 }
216
217 let event_value = serde_json::to_value(event.clone());
218 if let Err(err) = event_value {
219 tracing::error!(error = ?err, "error processing jetstream message");
220 continue;
221 }
222 let event_value = event_value.unwrap();
223
224 // Assumption: Performing a query for each event will cost more in the
225 // long-term than evaluating each event against all matchers and if there's a
226 // match, then checking both the event DID and the AT-URI DID.
227 'matchers_loop: for feed_matcher in self.feed_matchers.0.iter() {
228 if let Some(Match(op, aturi)) = feed_matcher.matches(&event_value) {
229 tracing::debug!(feed_id = ?feed_matcher.feed, "matched event");
230
231 let aturi_did = did_from_aturi(&aturi);
232 let dids = vec![event.did.as_str(), aturi_did.as_str()];
233 if denylist_exists(&self.pool, &dids).await? {
234 break 'matchers_loop;
235 }
236
237 let feed_content = storage::model::FeedContent{
238 feed_id: feed_matcher.feed.clone(),
239 uri: aturi,
240 indexed_at: event.clone().time_us,
241 score: 1,
242 };
243 match op {
244 MatchOperation::Upsert => {
245 feed_content_upsert(&self.pool, &feed_content).await?;
246 },
247 MatchOperation::Update => {
248 feed_content_update(&self.pool, &feed_content).await?;
249 },
250 }
251
252 }
253 }
254 }
255 }
256 }
257
258 tracing::debug!("ConsumerTask stopped");
259
260 Ok(())
261 }
262}
263
264fn did_from_aturi(aturi: &str) -> String {
265 let aturi_len = aturi.len();
266 if aturi_len < 6 {
267 return "".to_string();
268 }
269 let collection_start = aturi[5..]
270 .find("/")
271 .map(|value| value + 5)
272 .unwrap_or(aturi_len);
273 aturi[5..collection_start].to_string()
274}
275
276pub(crate) mod model {
277
278 use std::collections::HashMap;
279
280 use serde::{Deserialize, Serialize};
281
282 #[derive(Debug, Clone, Serialize, Deserialize)]
283 #[serde(tag = "type", content = "payload")]
284 pub(crate) enum SubscriberSourcedMessage {
285 #[serde(rename = "options_update")]
286 Update {
287 #[serde(rename = "wantedCollections")]
288 wanted_collections: Vec<String>,
289
290 #[serde(rename = "wantedDids", skip_serializing_if = "Vec::is_empty", default)]
291 wanted_dids: Vec<String>,
292
293 #[serde(rename = "maxMessageSizeBytes")]
294 max_message_size_bytes: u64,
295
296 #[serde(skip_serializing_if = "Option::is_none")]
297 cursor: Option<i64>,
298 },
299 }
300
301 #[derive(Debug, Clone, Serialize, Deserialize)]
302 pub(crate) struct Facet {
303 pub(crate) features: Vec<HashMap<String, String>>,
304 }
305
306 #[derive(Debug, Clone, Serialize, Deserialize)]
307 pub(crate) struct StrongRef {
308 pub(crate) uri: String,
309 }
310
311 #[derive(Debug, Clone, Serialize, Deserialize)]
312 pub(crate) struct Reply {
313 pub(crate) root: Option<StrongRef>,
314 pub(crate) parent: Option<StrongRef>,
315 }
316
317 #[derive(Debug, Clone, Serialize, Deserialize)]
318 #[serde(tag = "$type")]
319 pub(crate) enum Record {
320 #[serde(rename = "app.bsky.feed.post")]
321 Post {
322 #[serde(flatten)]
323 extra: HashMap<String, serde_json::Value>,
324 },
325 #[serde(rename = "app.bsky.feed.like")]
326 Like {
327 #[serde(flatten)]
328 extra: HashMap<String, serde_json::Value>,
329 },
330
331 #[serde(untagged)]
332 Other {
333 #[serde(flatten)]
334 extra: HashMap<String, serde_json::Value>,
335 },
336 }
337
338 #[derive(Debug, Clone, Serialize, Deserialize)]
339 #[serde(tag = "operation")]
340 pub(crate) enum CommitOp {
341 #[serde(rename = "create")]
342 Create {
343 rev: String,
344 collection: String,
345 rkey: String,
346 record: Record,
347 cid: String,
348 },
349 #[serde(rename = "update")]
350 Update {
351 rev: String,
352 collection: String,
353 rkey: String,
354 record: Record,
355 cid: String,
356 },
357 #[serde(rename = "delete")]
358 Delete {
359 rev: String,
360 collection: String,
361 rkey: String,
362 },
363 }
364
365 #[derive(Debug, Clone, Serialize, Deserialize)]
366 pub(crate) struct Event {
367 pub(crate) did: String,
368 pub(crate) kind: String,
369 pub(crate) time_us: i64,
370 pub(crate) commit: Option<CommitOp>,
371 }
372}