mod cli; mod query; use std::{ collections::{HashSet, VecDeque}, io, process, time::Duration, }; use futures_util::StreamExt; use serde_json::Value; use sqlx::PgPool; use tokio::{sync::mpsc, task::JoinSet}; use tokio_util::sync::CancellationToken; use tracing::{Instrument as _, level_filters::LevelFilter}; use tracing_subscriber::{EnvFilter, layer::SubscriberExt as _, util::SubscriberInitExt as _}; use trap::tap::{IdentityEvent, RecordAction, RecordEvent, TapChannel, TapClient, TapEvent}; /// DID provenance. #[derive(Debug)] #[allow(unused)] enum DidSource { Seed, Record(Box), } #[tokio::main] async fn main() -> anyhow::Result<()> { setup_tracing()?; let arguments = cli::parse(); let pool = PgPool::connect(arguments.db.as_str()).await?; let version = db_version(&pool).await?; tracing::info!(%version, "connected to db"); sqlx::migrate!().run(&pool).await?; let shutdown = CancellationToken::new(); let (did_tx, did_rx) = mpsc::unbounded_channel::<(String, DidSource)>(); let tap = TapClient::new(arguments.tap, arguments.tap_password.as_deref())?; let (tap_channel, tap_task) = tap.channel(); let mut tasks = JoinSet::new(); tasks.spawn(async move { if let Err(error) = tap_task.await { tracing::error!(?error); process::abort(); } Ok(()) }); tasks.spawn(event_consumer( tap_channel, pool.clone(), did_tx.clone(), shutdown.child_token(), )); tasks.spawn(did_task(tap, pool, did_rx, shutdown.child_token())); tasks.spawn(shutdown_task(shutdown.clone())); // Submit seed DIDs to the Tap service. for did in arguments.seed.into_iter().filter(|s| possible_did(s)) { did_tx.send((did, DidSource::Seed))?; } for task in tasks.join_all().await { if let Err(error) = task { tracing::error!(?error, "task failed"); shutdown.cancel(); } } Ok(()) } fn setup_tracing() -> anyhow::Result<()> { tracing_subscriber::registry() .with( EnvFilter::builder() .with_default_directive(LevelFilter::INFO.into()) .from_env()?, ) .with(tracing_subscriber::fmt::layer().with_writer(io::stderr)) .try_init()?; Ok(()) } async fn db_version(pool: &PgPool) -> sqlx::Result { let row: (String,) = sqlx::query_as("SELECT version()").fetch_one(pool).await?; Ok(row.0) } #[tracing::instrument(skip(channel, pool, tx, shutdown))] async fn event_consumer( mut channel: TapChannel, pool: PgPool, tx: mpsc::UnboundedSender<(String, DidSource)>, shutdown: CancellationToken, ) -> anyhow::Result<()> { while let Some(Some((span, event, ack))) = shutdown.run_until_cancelled(channel.recv()).await { async { let mut transaction = pool.begin().await?; match event { TapEvent::Record(record) => { let (record, parsed_record) = handle_record(record, &mut transaction).await?; // Expand the network of tracked DIDs. let nsid = record.collection.into_boxed_str(); for did in extract_dids(&parsed_record) { tx.send((did, DidSource::Record(nsid.clone())))?; } } TapEvent::Identity(identity) => { handle_identity(identity, &mut transaction).await?; } } transaction.commit().await?; ack.acknowledge().await?; Ok::<_, anyhow::Error>(()) } .instrument(span) .await?; } tracing::info!("complete"); Ok(()) } async fn handle_identity( identity_event: IdentityEvent, transaction: &mut sqlx::Transaction<'static, sqlx::Postgres>, ) -> anyhow::Result<()> { let IdentityEvent { id: _, did, handle, is_active, status, } = identity_event; query::upsert_identity(&did, &handle, &status, is_active) .execute(&mut **transaction) .await?; Ok(()) } async fn handle_record( record_event: RecordEvent, transaction: &mut sqlx::Transaction<'static, sqlx::Postgres>, ) -> anyhow::Result<(RecordEvent, Value)> { let RecordEvent { id: _, did, rev, collection, rkey, action, record, cid, live, } = &record_event; let parsed_record: Value = record .as_ref() .map(|record| serde_json::from_str(record.get())) .transpose()? .unwrap_or_default(); match action { RecordAction::Create | RecordAction::Update => { sqlx::query_file!( "queries/upsert_record.sql", did.as_str(), collection, rkey, rev, cid.as_deref(), live, parsed_record ) .execute(&mut **transaction) .await?; } RecordAction::Delete => { query::delete_record(did, collection, rkey, rev) .execute(&mut **transaction) .await?; } } Ok((record_event, parsed_record)) } #[tracing::instrument(skip(tap, pool, did_rx, shutdown))] async fn did_task( tap: TapClient, pool: PgPool, mut did_rx: mpsc::UnboundedReceiver<(String, DidSource)>, shutdown: CancellationToken, ) -> anyhow::Result<()> { const BATCH: usize = 64; let mut seen: HashSet = HashSet::with_capacity(10_000); let mut dids = Vec::new(); // Query known DIDs from the database. let mut query = sqlx::query!("SELECT did FROM identity").fetch(&pool); while let Some(Ok(row)) = query.next().await { seen.insert(row.did); } tracing::debug!(count = seen.len(), "loaded tracked DIDs from database"); loop { tokio::time::sleep(Duration::from_millis(200)).await; match shutdown .run_until_cancelled(did_rx.recv_many(&mut dids, BATCH)) .await { Some(0) | None => break, Some(_) => {} } // Convert Vec> to a Vec<&Did>. let mut dedup: HashSet<&str> = HashSet::new(); let mut slice = Vec::with_capacity(dids.len()); for (did, source) in &dids { if !dedup.insert(did) { continue; } if !seen.contains(did) || slice.contains(&did.as_ref()) { tracing::info!(?did, ?source, "tracking DID"); slice.push(did.as_ref()); } } tap.add_repos(&slice).await?; dids.drain(..).for_each(|(did, _)| _ = seen.insert(did)); } tracing::info!("complete"); Ok(()) } #[tracing::instrument(skip(shutdown))] async fn shutdown_task(shutdown: CancellationToken) -> anyhow::Result<()> { tokio::signal::ctrl_c().await?; eprintln!(); tracing::info!("shutdown signal received"); shutdown.cancel(); Ok(()) } /// Extract any strings that look like DIDs from a JSON document. fn extract_dids(value: &Value) -> HashSet { let mut dids = HashSet::new(); let mut queue = VecDeque::from_iter([value]); while let Some(value) = queue.pop_front() { match value { Value::Null | Value::Bool(_) | Value::Number(_) => {} Value::Array(values) => { for value in values { queue.push_back(value); } } Value::Object(map) => { for (_, value) in map { queue.push_back(value); } } Value::String(maybe_did) => { if possible_did(maybe_did) { dids.insert(maybe_did.to_string()); continue; } // First segment of an "at://..." URI might be a DID. if let Some(uri) = maybe_did.strip_prefix("at://") && let Some((maybe_did, _)) = uri.split_once('/') && possible_did(maybe_did) { dids.insert(maybe_did.to_string()); continue; } } } } dids } fn possible_did(s: &str) -> bool { s.starts_with("did:plc") || s.starts_with("did:web") }