at main 9.8 kB view raw
1use core::fmt; 2use std::{ 3 collections::HashSet, 4 time::{Duration, SystemTime}, 5}; 6 7use futures_util::{SinkExt as _, StreamExt}; 8use serde::Serialize; 9use tokio::{ 10 net::TcpStream, 11 sync::{mpsc, watch}, 12}; 13use tokio_tungstenite::{ 14 MaybeTlsStream, WebSocketStream, 15 tungstenite::{Bytes, ClientRequestBuilder, Message}, 16}; 17 18use crate::tap::{TapClient, TapEvent}; 19 20/// Maximum number of unanswered Ping messages to allow before the connection is 21/// considered broken. 22const MAX_INFLIGHT_PINGS: usize = 2; 23 24const TIMEOUT: Duration = Duration::from_secs(30); 25 26#[derive(Debug, thiserror::Error)] 27#[error("Failed to enqueue acknowledgement for event #{0}")] 28pub struct AckError(u64); 29 30impl From<mpsc::error::SendError<Ack>> for AckError { 31 fn from(error: mpsc::error::SendError<Ack>) -> Self { 32 Self(error.0.0) 33 } 34} 35 36pub struct AckHandle { 37 id: u64, 38 tx: mpsc::Sender<Ack>, 39} 40 41impl AckHandle { 42 /// Acknowledge receipt of the associated event. 43 /// 44 /// Success does *not* mean the Tap server has successfully received the 45 /// acknowledgement, only that the ack has be queued by the client. 46 pub async fn acknowledge(self) -> Result<(), AckError> { 47 self.tx.send(Ack::new(self.id)).await?; 48 Ok(()) 49 } 50} 51 52#[derive(Debug)] 53pub struct Ack(u64); 54 55impl fmt::Display for Ack { 56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 57 fmt::Display::fmt(&self.0, f) 58 } 59} 60 61impl Ack { 62 fn new(id: u64) -> Self { 63 Self(id) 64 } 65} 66 67/// Messages that are serialized and sent to the Tap server. 68#[derive(Debug, Serialize)] 69#[serde(tag = "type", rename_all = "snake_case")] 70enum ClientMessage { 71 Ack { id: u64 }, 72} 73 74impl From<&Ack> for ClientMessage { 75 fn from(&Ack(id): &Ack) -> Self { 76 Self::Ack { id } 77 } 78} 79 80#[derive(Debug)] 81pub struct TapChannel { 82 rx: mpsc::Receiver<(TapEvent, AckHandle)>, 83 shutdown: watch::Sender<bool>, 84} 85 86impl TapChannel { 87 pub const DEFAULT_CAPACITY: usize = 128; 88 89 pub fn new( 90 tap: &TapClient, 91 capacity: usize, 92 ) -> ( 93 Self, 94 impl Future<Output = Result<Vec<Ack>, ChannelError>> + Send + 'static, 95 ) { 96 let mut url = tap.url().clone(); 97 url.set_path("/channel"); 98 url.set_scheme(match url.scheme() { 99 "https" => "wss", 100 "http" => "ws", 101 _ => unreachable!("Tap::new should reject unknown schemes"), 102 }) 103 .expect("'http' or 'https' is a valid URL scheme"); 104 105 let uri = url 106 .as_str() 107 .parse() 108 .expect("Url has already been validated"); 109 110 let mut builder = ClientRequestBuilder::new(uri); 111 for (header_name, header_value) in &tap.headers { 112 builder = builder.with_header( 113 header_name.to_string(), 114 header_value 115 .to_str() 116 .expect("Header value has already been validated"), 117 ); 118 } 119 120 let (tx, rx) = mpsc::channel(capacity); 121 let (shutdown, shutdown_rx) = watch::channel(false); 122 let handle = channel_task(builder, tx, shutdown_rx, capacity); 123 124 (Self { rx, shutdown }, handle) 125 } 126} 127 128impl TapChannel { 129 pub async fn recv(&mut self) -> Option<(TapEvent, AckHandle)> { 130 self.rx.recv().await 131 } 132} 133 134impl Drop for TapChannel { 135 fn drop(&mut self) { 136 self.rx.close(); 137 let _ = self.shutdown.send(true); 138 } 139} 140 141#[derive(Debug, thiserror::Error)] 142pub enum ChannelError { 143 #[error("Client authorization failed")] 144 Authorization, 145 #[error("Failed to send pending Acks: {0:?}: {1}")] 146 FailedAcks(Vec<Ack>, tokio_tungstenite::tungstenite::Error), 147} 148 149async fn channel_task( 150 request_builder: ClientRequestBuilder, 151 event_tx: mpsc::Sender<(TapEvent, AckHandle)>, 152 mut shutdown_rx: watch::Receiver<bool>, 153 capacity: usize, 154) -> Result<Vec<Ack>, ChannelError> { 155 #[derive(Debug)] 156 enum Action { 157 Message(Message), 158 Timeout, 159 Ack, 160 ClearAcks, 161 } 162 163 let mut pings: HashSet<Bytes> = HashSet::with_capacity(MAX_INFLIGHT_PINGS + 1); 164 let mut timeout = tokio::time::interval(TIMEOUT); 165 timeout.tick().await; 166 167 let (ack_tx, mut ack_rx) = mpsc::channel(capacity); 168 let mut acks: Vec<_> = Default::default(); 169 170 'outer: loop { 171 let request = request_builder.clone(); 172 let (mut socket, _) = match tokio_tungstenite::connect_async(request).await { 173 Ok(result) => result, 174 Err(tokio_tungstenite::tungstenite::Error::Http(error)) 175 if error.status().is_client_error() => 176 { 177 tracing::error!(?error, "failed to connect to Tap channel"); 178 return Err(ChannelError::Authorization); 179 } 180 Err(error) => { 181 tracing::error!(?error); 182 183 // @TODO Reconnect delay 184 185 continue 'outer; 186 } 187 }; 188 189 // Send any pending Acks. 190 if let Err(error) = send_acknowledgements(&mut acks, &mut socket).await { 191 tracing::error!(?error, "failed to send Ack"); 192 continue; 193 } 194 195 loop { 196 let action = tokio::select! { 197 Some(Ok(message)) = socket.next() => Action::Message(message), 198 count = ack_rx.recv_many(&mut acks, 64) => match count { 199 0 => Action::ClearAcks, 200 _ => Action::Ack, 201 }, 202 _ = timeout.tick() => Action::Timeout, 203 _ = shutdown_rx.wait_for(|v| *v) => Action::ClearAcks, 204 else => Action::ClearAcks, 205 }; 206 207 match action { 208 Action::Message(message) => match message { 209 Message::Text(bytes) => { 210 let event = match serde_json::from_str::<TapEvent>(bytes.as_str()) { 211 Ok(event) => event, 212 Err(error) => { 213 tracing::error!(?error, bytes = %bytes.as_str(), "failed to deserialize event"); 214 continue; 215 } 216 }; 217 218 let ack = AckHandle { 219 id: event.id(), 220 tx: ack_tx.clone(), 221 }; 222 223 if event_tx.send((event, ack)).await.is_err() { 224 tracing::error!("failed to dispatch event"); 225 break 'outer; 226 } 227 } 228 Message::Binary(_) | Message::Frame(_) => { 229 tracing::error!("unexpected Binary or Frame message from server"); 230 break; 231 } 232 Message::Ping(bytes) => { 233 if let Err(error) = socket.send(Message::Pong(bytes)).await { 234 tracing::error!(?error, "failed to send Pong"); 235 continue 'outer; 236 } 237 } 238 Message::Pong(bytes) => { 239 tracing::trace!(?bytes, "received Pong from server"); 240 if !pings.remove(&bytes) { 241 tracing::error!("unsolicited Pong"); 242 break; 243 } 244 } 245 Message::Close(close_frame) => { 246 tracing::debug!(?close_frame, "received close frame"); 247 break; 248 } 249 }, 250 Action::Ack => { 251 if let Err(error) = send_acknowledgements(&mut acks, &mut socket).await { 252 tracing::error!(?error, "failed to send Ack"); 253 break; 254 } 255 } 256 Action::Timeout => { 257 if pings.len() > MAX_INFLIGHT_PINGS { 258 tracing::error!("too many missed pings"); 259 break; 260 } 261 262 let timestamp = SystemTime::now() 263 .duration_since(SystemTime::UNIX_EPOCH) 264 .expect("system time precedes UNIX epoch") 265 .as_micros(); 266 267 let payload = format!("{timestamp}"); 268 let payload: Bytes = payload.into(); 269 pings.insert(payload.clone()); 270 271 if socket.send(Message::Ping(payload)).await.is_err() { 272 tracing::error!("failed to send Ping to server"); 273 break; 274 } 275 } 276 Action::ClearAcks => { 277 drop(ack_tx); 278 while let Some(ack) = ack_rx.recv().await { 279 acks.push(ack); 280 } 281 282 if let Err(error) = send_acknowledgements(&mut acks, &mut socket).await { 283 return Err(ChannelError::FailedAcks(acks, error)); 284 } 285 286 break 'outer; 287 } 288 } 289 } 290 291 tracing::warn!("disconnected"); 292 } 293 294 tracing::info!("complete"); 295 Ok(acks) 296} 297 298async fn send_acknowledgements( 299 acks: &mut Vec<Ack>, 300 socket: &mut WebSocketStream<MaybeTlsStream<TcpStream>>, 301) -> Result<usize, tokio_tungstenite::tungstenite::Error> { 302 let mut count = 0; 303 for ack in acks.drain(..) { 304 tracing::debug!(?ack, "sending ack"); 305 let message = serde_json::to_string(&ClientMessage::from(&ack)) 306 .expect("ClientMessage should be serializable"); 307 308 socket.send(Message::text(message)).await?; 309 count += 1; 310 } 311 312 Ok(count) 313}