Tap drinker

refactor: use cancellation tokens everywhere

Signed-off-by: tjh <x@tjh.dev>

tjh.dev 5e4b30d5 6d9456ef

verified
+23 -48
-1
src/lib.rs
··· 1 1 pub mod tap; 2 - pub mod util;
+22 -27
src/main.rs
··· 4 4 use std::{ 5 5 collections::{HashSet, VecDeque}, 6 6 io, 7 - ops::ControlFlow::{self, *}, 8 7 time::Duration, 9 8 }; 10 9 11 10 use futures_util::StreamExt; 12 11 use serde_json::Value; 13 12 use sqlx::PgPool; 14 - use tokio::{ 15 - sync::{mpsc, watch}, 16 - task::JoinSet, 17 - }; 13 + use tokio::{sync::mpsc, task::JoinSet}; 14 + use tokio_util::sync::CancellationToken; 18 15 use tracing::level_filters::LevelFilter; 19 16 use tracing_subscriber::{EnvFilter, layer::SubscriberExt as _, util::SubscriberInitExt as _}; 20 - use trap::{ 21 - tap::{IdentityEvent, RecordAction, RecordEvent, TapChannel, TapClient, TapEvent}, 22 - util::with_shutdown, 23 - }; 17 + use trap::tap::{IdentityEvent, RecordAction, RecordEvent, TapChannel, TapClient, TapEvent}; 24 18 25 19 /// DID provenance. 26 20 #[derive(Debug)] ··· 41 35 42 36 sqlx::migrate!().run(&pool).await?; 43 37 44 - let (shutdown_tx, shutdown_rx) = watch::channel(false); 38 + let shutdown = CancellationToken::new(); 45 39 let (did_tx, did_rx) = mpsc::unbounded_channel::<(String, DidSource)>(); 46 40 47 41 let tap = TapClient::new(arguments.tap, arguments.tap_password.as_deref())?; ··· 60 54 tap_channel, 61 55 pool.clone(), 62 56 did_tx.clone(), 63 - shutdown_rx.clone(), 57 + shutdown.child_token(), 64 58 )); 65 59 66 - tasks.spawn(did_task(tap, pool, did_rx, shutdown_rx.clone())); 67 - tasks.spawn(shutdown_task(shutdown_tx.clone())); 60 + tasks.spawn(did_task(tap, pool, did_rx, shutdown.child_token())); 61 + tasks.spawn(shutdown_task(shutdown.clone())); 68 62 69 63 // Submit seed DIDs to the Tap service. 70 64 for did in arguments.seed.into_iter().filter(|s| possible_did(s)) { ··· 74 68 for task in tasks.join_all().await { 75 69 if let Err(error) = task { 76 70 tracing::error!(?error, "task failed"); 77 - shutdown_tx.send(true)?; 71 + shutdown.cancel(); 78 72 } 79 73 } 80 74 ··· 99 93 Ok(row.0) 100 94 } 101 95 102 - #[tracing::instrument(skip(channel, pool, tx, rx))] 96 + #[tracing::instrument(skip(channel, pool, tx, shutdown))] 103 97 async fn event_consumer( 104 98 mut channel: TapChannel, 105 99 pool: PgPool, 106 100 tx: mpsc::UnboundedSender<(String, DidSource)>, 107 - mut rx: watch::Receiver<bool>, 101 + shutdown: CancellationToken, 108 102 ) -> anyhow::Result<()> { 109 - use ControlFlow::Continue; 110 - 111 - while let Continue(Some((event, ack))) = with_shutdown(channel.recv(), &mut rx).await { 103 + while let Some(Some((event, ack))) = shutdown.run_until_cancelled(channel.recv()).await { 112 104 let mut transaction = pool.begin().await?; 113 105 match event { 114 106 TapEvent::Record(record) => { ··· 199 191 Ok((record_event, parsed_record)) 200 192 } 201 193 202 - #[tracing::instrument(skip(tap, pool, did_rx, shutdown_rx))] 194 + #[tracing::instrument(skip(tap, pool, did_rx, shutdown))] 203 195 async fn did_task( 204 196 tap: TapClient, 205 197 pool: PgPool, 206 198 mut did_rx: mpsc::UnboundedReceiver<(String, DidSource)>, 207 - mut shutdown_rx: watch::Receiver<bool>, 199 + shutdown: CancellationToken, 208 200 ) -> anyhow::Result<()> { 209 201 const BATCH: usize = 64; 210 202 ··· 221 213 222 214 loop { 223 215 tokio::time::sleep(Duration::from_millis(200)).await; 224 - match with_shutdown(did_rx.recv_many(&mut dids, BATCH), &mut shutdown_rx).await { 225 - Continue(0) | Break(_) => break, 226 - Continue(_) => {} 216 + match shutdown 217 + .run_until_cancelled(did_rx.recv_many(&mut dids, BATCH)) 218 + .await 219 + { 220 + Some(0) | None => break, 221 + Some(_) => {} 227 222 } 228 223 229 224 // Convert Vec<Box<Did>> to a Vec<&Did>. ··· 249 244 Ok(()) 250 245 } 251 246 252 - #[tracing::instrument(skip(tx))] 253 - async fn shutdown_task(tx: watch::Sender<bool>) -> anyhow::Result<()> { 247 + #[tracing::instrument(skip(shutdown))] 248 + async fn shutdown_task(shutdown: CancellationToken) -> anyhow::Result<()> { 254 249 tokio::signal::ctrl_c().await?; 255 250 eprintln!(); 256 251 tracing::info!("shutdown signal received"); 257 - tx.send(true)?; 252 + shutdown.cancel(); 258 253 Ok(()) 259 254 } 260 255
+1 -1
src/tap/channel.rs
··· 138 138 let (ack_tx, mut ack_rx) = mpsc::channel(capacity); 139 139 let mut acks: Vec<_> = Default::default(); 140 140 141 - 'outer: loop { 141 + 'outer: while !shutdown.is_cancelled() { 142 142 let mut ping_inflight = false; 143 143 let mut timeout = tokio::time::interval(TIMEOUT); 144 144 timeout.tick().await;
-19
src/util.rs
··· 1 - use std::ops; 2 - 3 - use tokio::sync::watch; 4 - 5 - #[derive(Debug, thiserror::Error)] 6 - #[error("Shutdown signal received")] 7 - pub struct Shutdown; 8 - 9 - pub async fn with_shutdown<F: Future>( 10 - f: F, 11 - rx: &mut watch::Receiver<bool>, 12 - ) -> ops::ControlFlow<Shutdown, F::Output> { 13 - use ops::ControlFlow::{Break, Continue}; 14 - 15 - tokio::select! { 16 - result = f => Continue(result), 17 - Ok(_) = rx.wait_for(|&v| v) => Break(Shutdown), 18 - } 19 - }