1mod cli;
2mod query;
3
4use std::{
5 collections::{HashSet, VecDeque},
6 io, process,
7 time::Duration,
8};
9
10use futures_util::StreamExt;
11use serde_json::Value;
12use sqlx::PgPool;
13use tokio::{sync::mpsc, task::JoinSet};
14use tokio_util::sync::CancellationToken;
15use tracing::{Instrument as _, level_filters::LevelFilter};
16use tracing_subscriber::{EnvFilter, layer::SubscriberExt as _, util::SubscriberInitExt as _};
17use trap::tap::{IdentityEvent, RecordAction, RecordEvent, TapChannel, TapClient, TapEvent};
18
19/// DID provenance.
20#[derive(Debug)]
21#[allow(unused)]
22enum DidSource {
23 Seed,
24 Record(Box<str>),
25}
26
27#[tokio::main]
28async fn main() -> anyhow::Result<()> {
29 setup_tracing()?;
30
31 let arguments = cli::parse();
32 let pool = PgPool::connect(arguments.db.as_str()).await?;
33 let version = db_version(&pool).await?;
34 tracing::info!(%version, "connected to db");
35
36 sqlx::migrate!().run(&pool).await?;
37
38 let shutdown = CancellationToken::new();
39 let (did_tx, did_rx) = mpsc::unbounded_channel::<(String, DidSource)>();
40
41 let tap = TapClient::new(arguments.tap, arguments.tap_password.as_deref())?;
42 let (tap_channel, tap_task) = tap.channel();
43
44 let mut tasks = JoinSet::new();
45 tasks.spawn(async move {
46 if let Err(error) = tap_task.await {
47 tracing::error!(?error);
48 process::abort();
49 }
50 Ok(())
51 });
52
53 tasks.spawn(event_consumer(
54 tap_channel,
55 pool.clone(),
56 did_tx.clone(),
57 shutdown.child_token(),
58 ));
59
60 tasks.spawn(did_task(tap, pool, did_rx, shutdown.child_token()));
61 tasks.spawn(shutdown_task(shutdown.clone()));
62
63 // Submit seed DIDs to the Tap service.
64 for did in arguments.seed.into_iter().filter(|s| possible_did(s)) {
65 did_tx.send((did, DidSource::Seed))?;
66 }
67
68 for task in tasks.join_all().await {
69 if let Err(error) = task {
70 tracing::error!(?error, "task failed");
71 shutdown.cancel();
72 }
73 }
74
75 Ok(())
76}
77
78fn setup_tracing() -> anyhow::Result<()> {
79 tracing_subscriber::registry()
80 .with(
81 EnvFilter::builder()
82 .with_default_directive(LevelFilter::INFO.into())
83 .from_env()?,
84 )
85 .with(tracing_subscriber::fmt::layer().with_writer(io::stderr))
86 .try_init()?;
87
88 Ok(())
89}
90
91async fn db_version(pool: &PgPool) -> sqlx::Result<String> {
92 let row: (String,) = sqlx::query_as("SELECT version()").fetch_one(pool).await?;
93 Ok(row.0)
94}
95
96#[tracing::instrument(skip(channel, pool, tx, shutdown))]
97async fn event_consumer(
98 mut channel: TapChannel,
99 pool: PgPool,
100 tx: mpsc::UnboundedSender<(String, DidSource)>,
101 shutdown: CancellationToken,
102) -> anyhow::Result<()> {
103 while let Some(Some((span, event, ack))) = shutdown.run_until_cancelled(channel.recv()).await {
104 async {
105 let mut transaction = pool.begin().await?;
106 match event {
107 TapEvent::Record(record) => {
108 let (record, parsed_record) = handle_record(record, &mut transaction).await?;
109
110 // Expand the network of tracked DIDs.
111 let nsid = record.collection.into_boxed_str();
112 for did in extract_dids(&parsed_record) {
113 tx.send((did, DidSource::Record(nsid.clone())))?;
114 }
115 }
116 TapEvent::Identity(identity) => {
117 handle_identity(identity, &mut transaction).await?;
118 }
119 }
120
121 transaction.commit().await?;
122 ack.acknowledge().await?;
123 Ok::<_, anyhow::Error>(())
124 }
125 .instrument(span)
126 .await?;
127 }
128
129 tracing::info!("complete");
130 Ok(())
131}
132
133async fn handle_identity(
134 identity_event: IdentityEvent,
135 transaction: &mut sqlx::Transaction<'static, sqlx::Postgres>,
136) -> anyhow::Result<()> {
137 let IdentityEvent {
138 id: _,
139 did,
140 handle,
141 is_active,
142 status,
143 } = identity_event;
144
145 query::upsert_identity(&did, &handle, &status, is_active)
146 .execute(&mut **transaction)
147 .await?;
148
149 Ok(())
150}
151
152async fn handle_record(
153 record_event: RecordEvent,
154 transaction: &mut sqlx::Transaction<'static, sqlx::Postgres>,
155) -> anyhow::Result<(RecordEvent, Value)> {
156 let RecordEvent {
157 id: _,
158 did,
159 rev,
160 collection,
161 rkey,
162 action,
163 record,
164 cid,
165 live,
166 } = &record_event;
167
168 let parsed_record: Value = record
169 .as_ref()
170 .map(|record| serde_json::from_str(record.get()))
171 .transpose()?
172 .unwrap_or_default();
173
174 match action {
175 RecordAction::Create | RecordAction::Update => {
176 sqlx::query_file!(
177 "queries/upsert_record.sql",
178 did.as_str(),
179 collection,
180 rkey,
181 rev,
182 cid.as_deref(),
183 live,
184 parsed_record
185 )
186 .execute(&mut **transaction)
187 .await?;
188 }
189 RecordAction::Delete => {
190 query::delete_record(did, collection, rkey, rev)
191 .execute(&mut **transaction)
192 .await?;
193 }
194 }
195
196 Ok((record_event, parsed_record))
197}
198
199#[tracing::instrument(skip(tap, pool, did_rx, shutdown))]
200async fn did_task(
201 tap: TapClient,
202 pool: PgPool,
203 mut did_rx: mpsc::UnboundedReceiver<(String, DidSource)>,
204 shutdown: CancellationToken,
205) -> anyhow::Result<()> {
206 const BATCH: usize = 64;
207
208 let mut seen: HashSet<String> = HashSet::with_capacity(10_000);
209 let mut dids = Vec::new();
210
211 // Query known DIDs from the database.
212 let mut query = sqlx::query!("SELECT did FROM identity").fetch(&pool);
213 while let Some(Ok(row)) = query.next().await {
214 seen.insert(row.did);
215 }
216
217 tracing::debug!(count = seen.len(), "loaded tracked DIDs from database");
218
219 loop {
220 tokio::time::sleep(Duration::from_millis(200)).await;
221 match shutdown
222 .run_until_cancelled(did_rx.recv_many(&mut dids, BATCH))
223 .await
224 {
225 Some(0) | None => break,
226 Some(_) => {}
227 }
228
229 // Convert Vec<Box<Did>> to a Vec<&Did>.
230 let mut dedup: HashSet<&str> = HashSet::new();
231 let mut slice = Vec::with_capacity(dids.len());
232 for (did, source) in &dids {
233 if !dedup.insert(did) {
234 continue;
235 }
236
237 if !seen.contains(did) || slice.contains(&did.as_ref()) {
238 tracing::info!(?did, ?source, "tracking DID");
239 slice.push(did.as_ref());
240 }
241 }
242
243 tap.add_repos(&slice).await?;
244
245 dids.drain(..).for_each(|(did, _)| _ = seen.insert(did));
246 }
247
248 tracing::info!("complete");
249 Ok(())
250}
251
252#[tracing::instrument(skip(shutdown))]
253async fn shutdown_task(shutdown: CancellationToken) -> anyhow::Result<()> {
254 tokio::signal::ctrl_c().await?;
255 eprintln!();
256 tracing::info!("shutdown signal received");
257 shutdown.cancel();
258 Ok(())
259}
260
261/// Extract any strings that look like DIDs from a JSON document.
262fn extract_dids(value: &Value) -> HashSet<String> {
263 let mut dids = HashSet::new();
264
265 let mut queue = VecDeque::from_iter([value]);
266 while let Some(value) = queue.pop_front() {
267 match value {
268 Value::Null | Value::Bool(_) | Value::Number(_) => {}
269 Value::Array(values) => {
270 for value in values {
271 queue.push_back(value);
272 }
273 }
274 Value::Object(map) => {
275 for (_, value) in map {
276 queue.push_back(value);
277 }
278 }
279 Value::String(maybe_did) => {
280 if possible_did(maybe_did) {
281 dids.insert(maybe_did.to_string());
282 continue;
283 }
284
285 // First segment of an "at://..." URI might be a DID.
286 if let Some(uri) = maybe_did.strip_prefix("at://")
287 && let Some((maybe_did, _)) = uri.split_once('/')
288 && possible_did(maybe_did)
289 {
290 dids.insert(maybe_did.to_string());
291 continue;
292 }
293 }
294 }
295 }
296
297 dids
298}
299
300fn possible_did(s: &str) -> bool {
301 s.starts_with("did:plc") || s.starts_with("did:web")
302}