at main 8.5 kB view raw
1mod cli; 2mod query; 3 4use std::{ 5 collections::{HashSet, VecDeque}, 6 io, process, 7 time::Duration, 8}; 9 10use futures_util::StreamExt; 11use serde_json::Value; 12use sqlx::PgPool; 13use tokio::{sync::mpsc, task::JoinSet}; 14use tokio_util::sync::CancellationToken; 15use tracing::{Instrument as _, level_filters::LevelFilter}; 16use tracing_subscriber::{EnvFilter, layer::SubscriberExt as _, util::SubscriberInitExt as _}; 17use trap::tap::{IdentityEvent, RecordAction, RecordEvent, TapChannel, TapClient, TapEvent}; 18 19/// DID provenance. 20#[derive(Debug)] 21#[allow(unused)] 22enum DidSource { 23 Seed, 24 Record(Box<str>), 25} 26 27#[tokio::main] 28async fn main() -> anyhow::Result<()> { 29 setup_tracing()?; 30 31 let arguments = cli::parse(); 32 let pool = PgPool::connect(arguments.db.as_str()).await?; 33 let version = db_version(&pool).await?; 34 tracing::info!(%version, "connected to db"); 35 36 sqlx::migrate!().run(&pool).await?; 37 38 let shutdown = CancellationToken::new(); 39 let (did_tx, did_rx) = mpsc::unbounded_channel::<(String, DidSource)>(); 40 41 let tap = TapClient::new(arguments.tap, arguments.tap_password.as_deref())?; 42 let (tap_channel, tap_task) = tap.channel(); 43 44 let mut tasks = JoinSet::new(); 45 tasks.spawn(async move { 46 if let Err(error) = tap_task.await { 47 tracing::error!(?error); 48 process::abort(); 49 } 50 Ok(()) 51 }); 52 53 tasks.spawn(event_consumer( 54 tap_channel, 55 pool.clone(), 56 did_tx.clone(), 57 shutdown.child_token(), 58 )); 59 60 tasks.spawn(did_task(tap, pool, did_rx, shutdown.child_token())); 61 tasks.spawn(shutdown_task(shutdown.clone())); 62 63 // Submit seed DIDs to the Tap service. 64 for did in arguments.seed.into_iter().filter(|s| possible_did(s)) { 65 did_tx.send((did, DidSource::Seed))?; 66 } 67 68 for task in tasks.join_all().await { 69 if let Err(error) = task { 70 tracing::error!(?error, "task failed"); 71 shutdown.cancel(); 72 } 73 } 74 75 Ok(()) 76} 77 78fn setup_tracing() -> anyhow::Result<()> { 79 tracing_subscriber::registry() 80 .with( 81 EnvFilter::builder() 82 .with_default_directive(LevelFilter::INFO.into()) 83 .from_env()?, 84 ) 85 .with(tracing_subscriber::fmt::layer().with_writer(io::stderr)) 86 .try_init()?; 87 88 Ok(()) 89} 90 91async fn db_version(pool: &PgPool) -> sqlx::Result<String> { 92 let row: (String,) = sqlx::query_as("SELECT version()").fetch_one(pool).await?; 93 Ok(row.0) 94} 95 96#[tracing::instrument(skip(channel, pool, tx, shutdown))] 97async fn event_consumer( 98 mut channel: TapChannel, 99 pool: PgPool, 100 tx: mpsc::UnboundedSender<(String, DidSource)>, 101 shutdown: CancellationToken, 102) -> anyhow::Result<()> { 103 while let Some(Some((span, event, ack))) = shutdown.run_until_cancelled(channel.recv()).await { 104 async { 105 let mut transaction = pool.begin().await?; 106 match event { 107 TapEvent::Record(record) => { 108 let (record, parsed_record) = handle_record(record, &mut transaction).await?; 109 110 // Expand the network of tracked DIDs. 111 let nsid = record.collection.into_boxed_str(); 112 for did in extract_dids(&parsed_record) { 113 tx.send((did, DidSource::Record(nsid.clone())))?; 114 } 115 } 116 TapEvent::Identity(identity) => { 117 handle_identity(identity, &mut transaction).await?; 118 } 119 } 120 121 transaction.commit().await?; 122 ack.acknowledge().await?; 123 Ok::<_, anyhow::Error>(()) 124 } 125 .instrument(span) 126 .await?; 127 } 128 129 tracing::info!("complete"); 130 Ok(()) 131} 132 133async fn handle_identity( 134 identity_event: IdentityEvent, 135 transaction: &mut sqlx::Transaction<'static, sqlx::Postgres>, 136) -> anyhow::Result<()> { 137 let IdentityEvent { 138 id: _, 139 did, 140 handle, 141 is_active, 142 status, 143 } = identity_event; 144 145 query::upsert_identity(&did, &handle, &status, is_active) 146 .execute(&mut **transaction) 147 .await?; 148 149 Ok(()) 150} 151 152async fn handle_record( 153 record_event: RecordEvent, 154 transaction: &mut sqlx::Transaction<'static, sqlx::Postgres>, 155) -> anyhow::Result<(RecordEvent, Value)> { 156 let RecordEvent { 157 id: _, 158 did, 159 rev, 160 collection, 161 rkey, 162 action, 163 record, 164 cid, 165 live, 166 } = &record_event; 167 168 let parsed_record: Value = record 169 .as_ref() 170 .map(|record| serde_json::from_str(record.get())) 171 .transpose()? 172 .unwrap_or_default(); 173 174 match action { 175 RecordAction::Create | RecordAction::Update => { 176 sqlx::query_file!( 177 "queries/upsert_record.sql", 178 did.as_str(), 179 collection, 180 rkey, 181 rev, 182 cid.as_deref(), 183 live, 184 parsed_record 185 ) 186 .execute(&mut **transaction) 187 .await?; 188 } 189 RecordAction::Delete => { 190 query::delete_record(did, collection, rkey, rev) 191 .execute(&mut **transaction) 192 .await?; 193 } 194 } 195 196 Ok((record_event, parsed_record)) 197} 198 199#[tracing::instrument(skip(tap, pool, did_rx, shutdown))] 200async fn did_task( 201 tap: TapClient, 202 pool: PgPool, 203 mut did_rx: mpsc::UnboundedReceiver<(String, DidSource)>, 204 shutdown: CancellationToken, 205) -> anyhow::Result<()> { 206 const BATCH: usize = 64; 207 208 let mut seen: HashSet<String> = HashSet::with_capacity(10_000); 209 let mut dids = Vec::new(); 210 211 // Query known DIDs from the database. 212 let mut query = sqlx::query!("SELECT did FROM identity").fetch(&pool); 213 while let Some(Ok(row)) = query.next().await { 214 seen.insert(row.did); 215 } 216 217 tracing::debug!(count = seen.len(), "loaded tracked DIDs from database"); 218 219 loop { 220 tokio::time::sleep(Duration::from_millis(200)).await; 221 match shutdown 222 .run_until_cancelled(did_rx.recv_many(&mut dids, BATCH)) 223 .await 224 { 225 Some(0) | None => break, 226 Some(_) => {} 227 } 228 229 // Convert Vec<Box<Did>> to a Vec<&Did>. 230 let mut dedup: HashSet<&str> = HashSet::new(); 231 let mut slice = Vec::with_capacity(dids.len()); 232 for (did, source) in &dids { 233 if !dedup.insert(did) { 234 continue; 235 } 236 237 if !seen.contains(did) || slice.contains(&did.as_ref()) { 238 tracing::info!(?did, ?source, "tracking DID"); 239 slice.push(did.as_ref()); 240 } 241 } 242 243 tap.add_repos(&slice).await?; 244 245 dids.drain(..).for_each(|(did, _)| _ = seen.insert(did)); 246 } 247 248 tracing::info!("complete"); 249 Ok(()) 250} 251 252#[tracing::instrument(skip(shutdown))] 253async fn shutdown_task(shutdown: CancellationToken) -> anyhow::Result<()> { 254 tokio::signal::ctrl_c().await?; 255 eprintln!(); 256 tracing::info!("shutdown signal received"); 257 shutdown.cancel(); 258 Ok(()) 259} 260 261/// Extract any strings that look like DIDs from a JSON document. 262fn extract_dids(value: &Value) -> HashSet<String> { 263 let mut dids = HashSet::new(); 264 265 let mut queue = VecDeque::from_iter([value]); 266 while let Some(value) = queue.pop_front() { 267 match value { 268 Value::Null | Value::Bool(_) | Value::Number(_) => {} 269 Value::Array(values) => { 270 for value in values { 271 queue.push_back(value); 272 } 273 } 274 Value::Object(map) => { 275 for (_, value) in map { 276 queue.push_back(value); 277 } 278 } 279 Value::String(maybe_did) => { 280 if possible_did(maybe_did) { 281 dids.insert(maybe_did.to_string()); 282 continue; 283 } 284 285 // First segment of an "at://..." URI might be a DID. 286 if let Some(uri) = maybe_did.strip_prefix("at://") 287 && let Some((maybe_did, _)) = uri.split_once('/') 288 && possible_did(maybe_did) 289 { 290 dids.insert(maybe_did.to_string()); 291 continue; 292 } 293 } 294 } 295 } 296 297 dids 298} 299 300fn possible_did(s: &str) -> bool { 301 s.starts_with("did:plc") || s.starts_with("did:web") 302}