A library for ATProtocol identities.
1//! WebSocket connection management for TAP streams.
2//!
3//! This module handles the low-level WebSocket connection to a TAP service,
4//! including authentication and message sending/receiving.
5
6use crate::config::TapConfig;
7use crate::errors::TapError;
8use base64::Engine;
9use base64::engine::general_purpose::STANDARD as BASE64;
10use futures::{SinkExt, StreamExt};
11use http::Uri;
12use std::str::FromStr;
13use tokio::net::TcpStream;
14use tokio_websockets::MaybeTlsStream;
15use tokio_websockets::{ClientBuilder, Message, WebSocketStream};
16
17/// WebSocket connection to a TAP service.
18pub(crate) struct TapConnection {
19 /// The underlying WebSocket stream.
20 ws: WebSocketStream<MaybeTlsStream<TcpStream>>,
21 /// Pre-allocated buffer for acknowledgment messages.
22 ack_buffer: Vec<u8>,
23}
24
25impl TapConnection {
26 /// Establish a new WebSocket connection to the TAP service.
27 pub async fn connect(config: &TapConfig) -> Result<Self, TapError> {
28 let uri =
29 Uri::from_str(&config.ws_url()).map_err(|e| TapError::InvalidUrl(e.to_string()))?;
30
31 let mut builder = ClientBuilder::from_uri(uri);
32
33 // Add User-Agent header
34 builder = builder
35 .add_header(
36 http::header::USER_AGENT,
37 http::HeaderValue::from_str(&config.user_agent).map_err(|e| {
38 TapError::ConnectionFailed(format!("Invalid user agent: {}", e))
39 })?,
40 )
41 .map_err(|e| TapError::ConnectionFailed(format!("Failed to add header: {}", e)))?;
42
43 // Add Basic Auth header if password is configured
44 if let Some(password) = &config.admin_password {
45 let credentials = format!("admin:{}", password);
46 let encoded = BASE64.encode(credentials.as_bytes());
47 let auth_value = format!("Basic {}", encoded);
48
49 builder = builder
50 .add_header(
51 http::header::AUTHORIZATION,
52 http::HeaderValue::from_str(&auth_value).map_err(|e| {
53 TapError::ConnectionFailed(format!("Invalid auth header: {}", e))
54 })?,
55 )
56 .map_err(|e| {
57 TapError::ConnectionFailed(format!("Failed to add auth header: {}", e))
58 })?;
59 }
60
61 // Connect
62 let (ws, _response) = builder
63 .connect()
64 .await
65 .map_err(|e| TapError::ConnectionFailed(e.to_string()))?;
66
67 tracing::debug!(hostname = %config.hostname, "Connected to TAP service");
68
69 Ok(Self {
70 ws,
71 ack_buffer: Vec::with_capacity(48), // {"type":"ack","id":18446744073709551615} is 40 bytes max
72 })
73 }
74
75 /// Receive the next message from the WebSocket.
76 ///
77 /// Returns `None` if the connection was closed cleanly.
78 pub async fn recv(&mut self) -> Result<Option<String>, TapError> {
79 match self.ws.next().await {
80 Some(Ok(msg)) => {
81 if msg.is_text() {
82 msg.as_text().map(|s| Some(s.to_string())).ok_or_else(|| {
83 TapError::ParseError("Failed to get text from message".into())
84 })
85 } else if msg.is_close() {
86 tracing::debug!("Received close frame from TAP service");
87 Ok(None)
88 } else {
89 // Ignore ping/pong and binary messages
90 tracing::trace!("Received non-text message, ignoring");
91 // Recurse to get the next text message
92 Box::pin(self.recv()).await
93 }
94 }
95 Some(Err(e)) => Err(TapError::ConnectionFailed(e.to_string())),
96 None => {
97 tracing::debug!("WebSocket stream ended");
98 Ok(None)
99 }
100 }
101 }
102
103 /// Send an acknowledgment for the given event ID.
104 ///
105 /// Uses a pre-allocated buffer and itoa for allocation-free formatting.
106 /// Format: `{"type":"ack","id":12345}`
107 pub async fn send_ack(&mut self, id: u64) -> Result<(), TapError> {
108 self.ack_buffer.clear();
109 self.ack_buffer
110 .extend_from_slice(b"{\"type\":\"ack\",\"id\":");
111 let mut itoa_buf = itoa::Buffer::new();
112 self.ack_buffer
113 .extend_from_slice(itoa_buf.format(id).as_bytes());
114 self.ack_buffer.push(b'}');
115
116 // All bytes are ASCII so this is always valid UTF-8
117 let msg = std::str::from_utf8(&self.ack_buffer).expect("ack buffer contains only ASCII");
118
119 self.ws
120 .send(Message::text(msg.to_string()))
121 .await
122 .map_err(|e| TapError::AckFailed(e.to_string()))?;
123
124 // Flush to ensure the ack is sent immediately
125 self.ws
126 .flush()
127 .await
128 .map_err(|e| TapError::AckFailed(format!("Failed to flush ack: {}", e)))?;
129
130 tracing::trace!(id, "Sent ack");
131 Ok(())
132 }
133
134 /// Close the WebSocket connection gracefully.
135 pub async fn close(&mut self) -> Result<(), TapError> {
136 self.ws
137 .close()
138 .await
139 .map_err(|e| TapError::ConnectionFailed(format!("Failed to close: {}", e)))?;
140 Ok(())
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 #[test]
147 fn test_ack_buffer_format() {
148 // Test that our manual JSON formatting is correct
149 // Format: {"type":"ack","id":12345}
150 let mut buffer = Vec::with_capacity(64);
151
152 let id: u64 = 12345;
153 buffer.clear();
154 buffer.extend_from_slice(b"{\"type\":\"ack\",\"id\":");
155 let mut itoa_buf = itoa::Buffer::new();
156 buffer.extend_from_slice(itoa_buf.format(id).as_bytes());
157 buffer.push(b'}');
158
159 let result = std::str::from_utf8(&buffer).unwrap();
160 assert_eq!(result, r#"{"type":"ack","id":12345}"#);
161
162 // Test max u64
163 let id: u64 = u64::MAX;
164 buffer.clear();
165 buffer.extend_from_slice(b"{\"type\":\"ack\",\"id\":");
166 buffer.extend_from_slice(itoa_buf.format(id).as_bytes());
167 buffer.push(b'}');
168
169 let result = std::str::from_utf8(&buffer).unwrap();
170 assert_eq!(result, r#"{"type":"ack","id":18446744073709551615}"#);
171 assert!(buffer.len() <= 64); // Fits in our pre-allocated buffer
172 }
173}