1use core::fmt;
2use std::{
3 collections::HashSet,
4 time::{Duration, SystemTime},
5};
6
7use futures_util::{SinkExt as _, StreamExt};
8use serde::Serialize;
9use tokio::{
10 net::TcpStream,
11 sync::{mpsc, watch},
12};
13use tokio_tungstenite::{
14 MaybeTlsStream, WebSocketStream,
15 tungstenite::{Bytes, ClientRequestBuilder, Message},
16};
17
18use crate::tap::{TapClient, TapEvent};
19
20/// Maximum number of unanswered Ping messages to allow before the connection is
21/// considered broken.
22const MAX_INFLIGHT_PINGS: usize = 2;
23
24const TIMEOUT: Duration = Duration::from_secs(30);
25
26#[derive(Debug, thiserror::Error)]
27#[error("Failed to enqueue acknowledgement for event #{0}")]
28pub struct AckError(u64);
29
30impl From<mpsc::error::SendError<Ack>> for AckError {
31 fn from(error: mpsc::error::SendError<Ack>) -> Self {
32 Self(error.0.0)
33 }
34}
35
36pub struct AckHandle {
37 id: u64,
38 tx: mpsc::Sender<Ack>,
39}
40
41impl AckHandle {
42 /// Acknowledge receipt of the associated event.
43 ///
44 /// Success does *not* mean the Tap server has successfully received the
45 /// acknowledgement, only that the ack has be queued by the client.
46 pub async fn acknowledge(self) -> Result<(), AckError> {
47 self.tx.send(Ack::new(self.id)).await?;
48 Ok(())
49 }
50}
51
52#[derive(Debug)]
53pub struct Ack(u64);
54
55impl fmt::Display for Ack {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 fmt::Display::fmt(&self.0, f)
58 }
59}
60
61impl Ack {
62 fn new(id: u64) -> Self {
63 Self(id)
64 }
65}
66
67/// Messages that are serialized and sent to the Tap server.
68#[derive(Debug, Serialize)]
69#[serde(tag = "type", rename_all = "snake_case")]
70enum ClientMessage {
71 Ack { id: u64 },
72}
73
74impl From<&Ack> for ClientMessage {
75 fn from(&Ack(id): &Ack) -> Self {
76 Self::Ack { id }
77 }
78}
79
80#[derive(Debug)]
81pub struct TapChannel {
82 rx: mpsc::Receiver<(TapEvent, AckHandle)>,
83 shutdown: watch::Sender<bool>,
84}
85
86impl TapChannel {
87 pub const DEFAULT_CAPACITY: usize = 128;
88
89 pub fn new(
90 tap: &TapClient,
91 capacity: usize,
92 ) -> (
93 Self,
94 impl Future<Output = Result<Vec<Ack>, ChannelError>> + Send + 'static,
95 ) {
96 let mut url = tap.url().clone();
97 url.set_path("/channel");
98 url.set_scheme(match url.scheme() {
99 "https" => "wss",
100 "http" => "ws",
101 _ => unreachable!("Tap::new should reject unknown schemes"),
102 })
103 .expect("'http' or 'https' is a valid URL scheme");
104
105 let uri = url
106 .as_str()
107 .parse()
108 .expect("Url has already been validated");
109
110 let mut builder = ClientRequestBuilder::new(uri);
111 for (header_name, header_value) in &tap.headers {
112 builder = builder.with_header(
113 header_name.to_string(),
114 header_value
115 .to_str()
116 .expect("Header value has already been validated"),
117 );
118 }
119
120 let (tx, rx) = mpsc::channel(capacity);
121 let (shutdown, shutdown_rx) = watch::channel(false);
122 let handle = channel_task(builder, tx, shutdown_rx, capacity);
123
124 (Self { rx, shutdown }, handle)
125 }
126}
127
128impl TapChannel {
129 pub async fn recv(&mut self) -> Option<(TapEvent, AckHandle)> {
130 self.rx.recv().await
131 }
132}
133
134impl Drop for TapChannel {
135 fn drop(&mut self) {
136 self.rx.close();
137 let _ = self.shutdown.send(true);
138 }
139}
140
141#[derive(Debug, thiserror::Error)]
142pub enum ChannelError {
143 #[error("Client authorization failed")]
144 Authorization,
145 #[error("Failed to send pending Acks: {0:?}: {1}")]
146 FailedAcks(Vec<Ack>, tokio_tungstenite::tungstenite::Error),
147}
148
149async fn channel_task(
150 request_builder: ClientRequestBuilder,
151 event_tx: mpsc::Sender<(TapEvent, AckHandle)>,
152 mut shutdown_rx: watch::Receiver<bool>,
153 capacity: usize,
154) -> Result<Vec<Ack>, ChannelError> {
155 #[derive(Debug)]
156 enum Action {
157 Message(Message),
158 Timeout,
159 Ack,
160 ClearAcks,
161 }
162
163 let mut pings: HashSet<Bytes> = HashSet::with_capacity(MAX_INFLIGHT_PINGS + 1);
164 let mut timeout = tokio::time::interval(TIMEOUT);
165 timeout.tick().await;
166
167 let (ack_tx, mut ack_rx) = mpsc::channel(capacity);
168 let mut acks: Vec<_> = Default::default();
169
170 'outer: loop {
171 let request = request_builder.clone();
172 let (mut socket, _) = match tokio_tungstenite::connect_async(request).await {
173 Ok(result) => result,
174 Err(tokio_tungstenite::tungstenite::Error::Http(error))
175 if error.status().is_client_error() =>
176 {
177 tracing::error!(?error, "failed to connect to Tap channel");
178 return Err(ChannelError::Authorization);
179 }
180 Err(error) => {
181 tracing::error!(?error);
182
183 // @TODO Reconnect delay
184
185 continue 'outer;
186 }
187 };
188
189 // Send any pending Acks.
190 if let Err(error) = send_acknowledgements(&mut acks, &mut socket).await {
191 tracing::error!(?error, "failed to send Ack");
192 continue;
193 }
194
195 loop {
196 let action = tokio::select! {
197 Some(Ok(message)) = socket.next() => Action::Message(message),
198 count = ack_rx.recv_many(&mut acks, 64) => match count {
199 0 => Action::ClearAcks,
200 _ => Action::Ack,
201 },
202 _ = timeout.tick() => Action::Timeout,
203 _ = shutdown_rx.wait_for(|v| *v) => Action::ClearAcks,
204 else => Action::ClearAcks,
205 };
206
207 match action {
208 Action::Message(message) => match message {
209 Message::Text(bytes) => {
210 let event = match serde_json::from_str::<TapEvent>(bytes.as_str()) {
211 Ok(event) => event,
212 Err(error) => {
213 tracing::error!(?error, bytes = %bytes.as_str(), "failed to deserialize event");
214 continue;
215 }
216 };
217
218 let ack = AckHandle {
219 id: event.id(),
220 tx: ack_tx.clone(),
221 };
222
223 if event_tx.send((event, ack)).await.is_err() {
224 tracing::error!("failed to dispatch event");
225 break 'outer;
226 }
227 }
228 Message::Binary(_) | Message::Frame(_) => {
229 tracing::error!("unexpected Binary or Frame message from server");
230 break;
231 }
232 Message::Ping(bytes) => {
233 if let Err(error) = socket.send(Message::Pong(bytes)).await {
234 tracing::error!(?error, "failed to send Pong");
235 continue 'outer;
236 }
237 }
238 Message::Pong(bytes) => {
239 tracing::trace!(?bytes, "received Pong from server");
240 if !pings.remove(&bytes) {
241 tracing::error!("unsolicited Pong");
242 break;
243 }
244 }
245 Message::Close(close_frame) => {
246 tracing::debug!(?close_frame, "received close frame");
247 break;
248 }
249 },
250 Action::Ack => {
251 if let Err(error) = send_acknowledgements(&mut acks, &mut socket).await {
252 tracing::error!(?error, "failed to send Ack");
253 break;
254 }
255 }
256 Action::Timeout => {
257 if pings.len() > MAX_INFLIGHT_PINGS {
258 tracing::error!("too many missed pings");
259 break;
260 }
261
262 let timestamp = SystemTime::now()
263 .duration_since(SystemTime::UNIX_EPOCH)
264 .expect("system time precedes UNIX epoch")
265 .as_micros();
266
267 let payload = format!("{timestamp}");
268 let payload: Bytes = payload.into();
269 pings.insert(payload.clone());
270
271 if socket.send(Message::Ping(payload)).await.is_err() {
272 tracing::error!("failed to send Ping to server");
273 break;
274 }
275 }
276 Action::ClearAcks => {
277 drop(ack_tx);
278 while let Some(ack) = ack_rx.recv().await {
279 acks.push(ack);
280 }
281
282 if let Err(error) = send_acknowledgements(&mut acks, &mut socket).await {
283 return Err(ChannelError::FailedAcks(acks, error));
284 }
285
286 break 'outer;
287 }
288 }
289 }
290
291 tracing::warn!("disconnected");
292 }
293
294 tracing::info!("complete");
295 Ok(acks)
296}
297
298async fn send_acknowledgements(
299 acks: &mut Vec<Ack>,
300 socket: &mut WebSocketStream<MaybeTlsStream<TcpStream>>,
301) -> Result<usize, tokio_tungstenite::tungstenite::Error> {
302 let mut count = 0;
303 for ack in acks.drain(..) {
304 tracing::debug!(?ack, "sending ack");
305 let message = serde_json::to_string(&ClientMessage::from(&ack))
306 .expect("ClientMessage should be serializable");
307
308 socket.send(Message::text(message)).await?;
309 count += 1;
310 }
311
312 Ok(count)
313}