this repo has no description
at main 14 kB view raw
1use std::str::FromStr; 2 3use anyhow::{anyhow, Context, Result}; 4use chrono::DateTime; 5use futures_util::SinkExt; 6use futures_util::StreamExt; 7use http::HeaderValue; 8use http::Uri; 9use tokio::time::{sleep, Instant}; 10use tokio_util::sync::CancellationToken; 11use tokio_websockets::{ClientBuilder, Message}; 12 13use crate::config; 14use crate::matcher::FeedMatchers; 15use crate::matcher::Match; 16use crate::matcher::MatchOperation; 17use crate::storage; 18use crate::storage::consumer_control_get; 19use crate::storage::consumer_control_insert; 20use crate::storage::denylist_exists; 21use crate::storage::feed_content_update; 22use crate::storage::feed_content_upsert; 23use crate::storage::StoragePool; 24 25const MAX_MESSAGE_SIZE: usize = 25000; 26 27#[derive(Clone)] 28pub struct ConsumerTaskConfig { 29 pub user_agent: String, 30 pub compression: bool, 31 pub zstd_dictionary_location: String, 32 pub jetstream_hostname: String, 33 pub feeds: config::Feeds, 34 pub collections: Vec<String>, 35} 36 37pub struct ConsumerTask { 38 cancellation_token: CancellationToken, 39 pool: StoragePool, 40 config: ConsumerTaskConfig, 41 feed_matchers: FeedMatchers, 42} 43 44impl ConsumerTask { 45 pub fn new( 46 pool: StoragePool, 47 config: ConsumerTaskConfig, 48 cancellation_token: CancellationToken, 49 ) -> Result<Self> { 50 let feed_matchers = FeedMatchers::from_config(&config.feeds)?; 51 52 Ok(Self { 53 pool, 54 cancellation_token, 55 config, 56 feed_matchers, 57 }) 58 } 59 60 pub async fn run_background(&self) -> Result<()> { 61 tracing::debug!("ConsumerTask started"); 62 63 let last_time_us = 64 consumer_control_get(&self.pool, &self.config.jetstream_hostname).await?; 65 66 tracing::info!(cursor = ?last_time_us, "loaded cursor from database"); 67 68 let cursor_param = if let Some(cursor) = last_time_us { 69 format!("&cursor={}", cursor) 70 } else { 71 String::new() 72 }; 73 74 let uri = Uri::from_str(&format!( 75 "wss://{}/subscribe?compress={}&requireHello=true{}", 76 self.config.jetstream_hostname, self.config.compression, cursor_param 77 )) 78 .context("invalid jetstream URL")?; 79 80 tracing::info!(uri = %uri, "connecting to jetstream"); 81 82 let (mut client, _) = ClientBuilder::from_uri(uri) 83 .add_header( 84 http::header::USER_AGENT, 85 HeaderValue::from_str(&self.config.user_agent)?, 86 ) 87 .connect() 88 .await 89 .map_err(|err| anyhow::Error::new(err).context("cannot connect to jetstream"))?; 90 91 let update = model::SubscriberSourcedMessage::Update { 92 wanted_collections: self.config.collections.clone(), 93 wanted_dids: vec![], 94 max_message_size_bytes: MAX_MESSAGE_SIZE as u64, 95 cursor: None, 96 }; 97 let serialized_update = serde_json::to_string(&update) 98 .map_err(|err| anyhow::Error::msg(err).context("cannot serialize update"))?; 99 100 tracing::info!(message = %serialized_update, "sending subscription update to jetstream"); 101 102 client 103 .send(Message::text(serialized_update)) 104 .await 105 .map_err(|err| anyhow::Error::msg(err).context("cannot send update"))?; 106 107 let mut decompressor = if self.config.compression { 108 // mkdir -p data/ && curl -o data/zstd_dictionary https://github.com/bluesky-social/jetstream/raw/refs/heads/main/pkg/models/zstd_dictionary 109 let data: Vec<u8> = std::fs::read(self.config.zstd_dictionary_location.clone()) 110 .context("unable to load zstd dictionary")?; 111 zstd::bulk::Decompressor::with_dictionary(&data) 112 .map_err(|err| anyhow::Error::msg(err).context("cannot create decompressor"))? 113 } else { 114 zstd::bulk::Decompressor::new() 115 .map_err(|err| anyhow::Error::msg(err).context("cannot create decompressor"))? 116 }; 117 118 let interval = std::time::Duration::from_secs(120); 119 let sleeper = sleep(interval); 120 tokio::pin!(sleeper); 121 122 let heartbeat_interval = std::time::Duration::from_secs(15); 123 let heartbeat_sleeper = sleep(heartbeat_interval); 124 tokio::pin!(heartbeat_sleeper); 125 126 let mut time_usec = 0i64; 127 128 loop { 129 tokio::select! { 130 () = self.cancellation_token.cancelled() => { 131 break; 132 }, 133 () = &mut sleeper => { 134 consumer_control_insert(&self.pool, &self.config.jetstream_hostname, time_usec).await?; 135 sleeper.as_mut().reset(Instant::now() + interval); 136 }, 137 () = &mut heartbeat_sleeper => { 138 if time_usec > 0 { 139 let datetime = DateTime::from_timestamp_micros(time_usec) 140 .map(|dt| dt.to_rfc3339()) 141 .unwrap_or_else(|| format!("{} microseconds", time_usec)); 142 tracing::info!(time_us = time_usec, timestamp = %datetime, "consumer heartbeat"); 143 } 144 heartbeat_sleeper.as_mut().reset(Instant::now() + heartbeat_interval); 145 }, 146 item = client.next() => { 147 if item.is_none() { 148 tracing::warn!("jetstream connection closed"); 149 break; 150 } 151 let item = item.unwrap(); 152 153 if let Err(err) = item { 154 tracing::error!(error = ?err, "error processing jetstream message"); 155 continue; 156 } 157 let item = item.unwrap(); 158 159 let event = if self.config.compression { 160 if !item.is_binary() { 161 // Skip WebSocket control frames (ping, pong, close) 162 if item.is_ping() || item.is_pong() || item.is_close() { 163 continue; 164 } 165 // Log unexpected non-binary message types 166 tracing::warn!("received unexpected non-binary message from jetstream (not ping/pong/close)"); 167 continue; 168 } 169 let payload = item.into_payload(); 170 171 let decoded = decompressor.decompress(&payload, MAX_MESSAGE_SIZE * 3); 172 if let Err(err) = decoded { 173 tracing::debug!(err = ?err, "cannot decompress message"); 174 continue; 175 } 176 let decoded = decoded.unwrap(); 177 serde_json::from_slice::<model::Event>(&decoded) 178 .context(anyhow!("cannot deserialize message")) 179 } else { 180 if !item.is_text() { 181 // Skip WebSocket control frames (ping, pong, close) 182 if item.is_ping() || item.is_pong() || item.is_close() { 183 continue; 184 } 185 // Log unexpected non-text message types 186 tracing::warn!("received unexpected non-text message from jetstream (not ping/pong/close)"); 187 continue; 188 } 189 item.as_text() 190 .ok_or(anyhow!("cannot convert message to text")) 191 .and_then(|value| { 192 serde_json::from_str::<model::Event>(value) 193 .context(anyhow!("cannot deserialize message")) 194 }) 195 }; 196 if let Err(err) = event { 197 tracing::error!(error = ?err, "error processing jetstream message"); 198 199 continue; 200 } 201 let event = event.unwrap(); 202 203 let previous_time_usec = time_usec; 204 time_usec = std::cmp::max(time_usec, event.time_us); 205 206 if previous_time_usec == 0 { 207 let datetime = DateTime::from_timestamp_micros(event.time_us) 208 .map(|dt| dt.to_rfc3339()) 209 .unwrap_or_else(|| format!("{} microseconds", event.time_us)); 210 tracing::info!(time_us = event.time_us, timestamp = %datetime, "received first event from jetstream"); 211 } 212 213 if event.clone().kind != "commit" { 214 continue; 215 } 216 217 let event_value = serde_json::to_value(event.clone()); 218 if let Err(err) = event_value { 219 tracing::error!(error = ?err, "error processing jetstream message"); 220 continue; 221 } 222 let event_value = event_value.unwrap(); 223 224 // Assumption: Performing a query for each event will cost more in the 225 // long-term than evaluating each event against all matchers and if there's a 226 // match, then checking both the event DID and the AT-URI DID. 227 'matchers_loop: for feed_matcher in self.feed_matchers.0.iter() { 228 if let Some(Match(op, aturi)) = feed_matcher.matches(&event_value) { 229 tracing::debug!(feed_id = ?feed_matcher.feed, "matched event"); 230 231 let aturi_did = did_from_aturi(&aturi); 232 let dids = vec![event.did.as_str(), aturi_did.as_str()]; 233 if denylist_exists(&self.pool, &dids).await? { 234 break 'matchers_loop; 235 } 236 237 let feed_content = storage::model::FeedContent{ 238 feed_id: feed_matcher.feed.clone(), 239 uri: aturi, 240 indexed_at: event.clone().time_us, 241 score: 1, 242 }; 243 match op { 244 MatchOperation::Upsert => { 245 feed_content_upsert(&self.pool, &feed_content).await?; 246 }, 247 MatchOperation::Update => { 248 feed_content_update(&self.pool, &feed_content).await?; 249 }, 250 } 251 252 } 253 } 254 } 255 } 256 } 257 258 tracing::debug!("ConsumerTask stopped"); 259 260 Ok(()) 261 } 262} 263 264fn did_from_aturi(aturi: &str) -> String { 265 let aturi_len = aturi.len(); 266 if aturi_len < 6 { 267 return "".to_string(); 268 } 269 let collection_start = aturi[5..] 270 .find("/") 271 .map(|value| value + 5) 272 .unwrap_or(aturi_len); 273 aturi[5..collection_start].to_string() 274} 275 276pub(crate) mod model { 277 278 use std::collections::HashMap; 279 280 use serde::{Deserialize, Serialize}; 281 282 #[derive(Debug, Clone, Serialize, Deserialize)] 283 #[serde(tag = "type", content = "payload")] 284 pub(crate) enum SubscriberSourcedMessage { 285 #[serde(rename = "options_update")] 286 Update { 287 #[serde(rename = "wantedCollections")] 288 wanted_collections: Vec<String>, 289 290 #[serde(rename = "wantedDids", skip_serializing_if = "Vec::is_empty", default)] 291 wanted_dids: Vec<String>, 292 293 #[serde(rename = "maxMessageSizeBytes")] 294 max_message_size_bytes: u64, 295 296 #[serde(skip_serializing_if = "Option::is_none")] 297 cursor: Option<i64>, 298 }, 299 } 300 301 #[derive(Debug, Clone, Serialize, Deserialize)] 302 pub(crate) struct Facet { 303 pub(crate) features: Vec<HashMap<String, String>>, 304 } 305 306 #[derive(Debug, Clone, Serialize, Deserialize)] 307 pub(crate) struct StrongRef { 308 pub(crate) uri: String, 309 } 310 311 #[derive(Debug, Clone, Serialize, Deserialize)] 312 pub(crate) struct Reply { 313 pub(crate) root: Option<StrongRef>, 314 pub(crate) parent: Option<StrongRef>, 315 } 316 317 #[derive(Debug, Clone, Serialize, Deserialize)] 318 #[serde(tag = "$type")] 319 pub(crate) enum Record { 320 #[serde(rename = "app.bsky.feed.post")] 321 Post { 322 #[serde(flatten)] 323 extra: HashMap<String, serde_json::Value>, 324 }, 325 #[serde(rename = "app.bsky.feed.like")] 326 Like { 327 #[serde(flatten)] 328 extra: HashMap<String, serde_json::Value>, 329 }, 330 331 #[serde(untagged)] 332 Other { 333 #[serde(flatten)] 334 extra: HashMap<String, serde_json::Value>, 335 }, 336 } 337 338 #[derive(Debug, Clone, Serialize, Deserialize)] 339 #[serde(tag = "operation")] 340 pub(crate) enum CommitOp { 341 #[serde(rename = "create")] 342 Create { 343 rev: String, 344 collection: String, 345 rkey: String, 346 record: Record, 347 cid: String, 348 }, 349 #[serde(rename = "update")] 350 Update { 351 rev: String, 352 collection: String, 353 rkey: String, 354 record: Record, 355 cid: String, 356 }, 357 #[serde(rename = "delete")] 358 Delete { 359 rev: String, 360 collection: String, 361 rkey: String, 362 }, 363 } 364 365 #[derive(Debug, Clone, Serialize, Deserialize)] 366 pub(crate) struct Event { 367 pub(crate) did: String, 368 pub(crate) kind: String, 369 pub(crate) time_us: i64, 370 pub(crate) commit: Option<CommitOp>, 371 } 372}