A library for ATProtocol identities.
22
fork

Configure Feed

Select the types of activity you want to include in your feed.

at main 173 lines 6.3 kB view raw
1//! WebSocket connection management for TAP streams. 2//! 3//! This module handles the low-level WebSocket connection to a TAP service, 4//! including authentication and message sending/receiving. 5 6use crate::config::TapConfig; 7use crate::errors::TapError; 8use base64::Engine; 9use base64::engine::general_purpose::STANDARD as BASE64; 10use futures::{SinkExt, StreamExt}; 11use http::Uri; 12use std::str::FromStr; 13use tokio::net::TcpStream; 14use tokio_websockets::MaybeTlsStream; 15use tokio_websockets::{ClientBuilder, Message, WebSocketStream}; 16 17/// WebSocket connection to a TAP service. 18pub(crate) struct TapConnection { 19 /// The underlying WebSocket stream. 20 ws: WebSocketStream<MaybeTlsStream<TcpStream>>, 21 /// Pre-allocated buffer for acknowledgment messages. 22 ack_buffer: Vec<u8>, 23} 24 25impl TapConnection { 26 /// Establish a new WebSocket connection to the TAP service. 27 pub async fn connect(config: &TapConfig) -> Result<Self, TapError> { 28 let uri = 29 Uri::from_str(&config.ws_url()).map_err(|e| TapError::InvalidUrl(e.to_string()))?; 30 31 let mut builder = ClientBuilder::from_uri(uri); 32 33 // Add User-Agent header 34 builder = builder 35 .add_header( 36 http::header::USER_AGENT, 37 http::HeaderValue::from_str(&config.user_agent).map_err(|e| { 38 TapError::ConnectionFailed(format!("Invalid user agent: {}", e)) 39 })?, 40 ) 41 .map_err(|e| TapError::ConnectionFailed(format!("Failed to add header: {}", e)))?; 42 43 // Add Basic Auth header if password is configured 44 if let Some(password) = &config.admin_password { 45 let credentials = format!("admin:{}", password); 46 let encoded = BASE64.encode(credentials.as_bytes()); 47 let auth_value = format!("Basic {}", encoded); 48 49 builder = builder 50 .add_header( 51 http::header::AUTHORIZATION, 52 http::HeaderValue::from_str(&auth_value).map_err(|e| { 53 TapError::ConnectionFailed(format!("Invalid auth header: {}", e)) 54 })?, 55 ) 56 .map_err(|e| { 57 TapError::ConnectionFailed(format!("Failed to add auth header: {}", e)) 58 })?; 59 } 60 61 // Connect 62 let (ws, _response) = builder 63 .connect() 64 .await 65 .map_err(|e| TapError::ConnectionFailed(e.to_string()))?; 66 67 tracing::debug!(hostname = %config.hostname, "Connected to TAP service"); 68 69 Ok(Self { 70 ws, 71 ack_buffer: Vec::with_capacity(48), // {"type":"ack","id":18446744073709551615} is 40 bytes max 72 }) 73 } 74 75 /// Receive the next message from the WebSocket. 76 /// 77 /// Returns `None` if the connection was closed cleanly. 78 pub async fn recv(&mut self) -> Result<Option<String>, TapError> { 79 match self.ws.next().await { 80 Some(Ok(msg)) => { 81 if msg.is_text() { 82 msg.as_text().map(|s| Some(s.to_string())).ok_or_else(|| { 83 TapError::ParseError("Failed to get text from message".into()) 84 }) 85 } else if msg.is_close() { 86 tracing::debug!("Received close frame from TAP service"); 87 Ok(None) 88 } else { 89 // Ignore ping/pong and binary messages 90 tracing::trace!("Received non-text message, ignoring"); 91 // Recurse to get the next text message 92 Box::pin(self.recv()).await 93 } 94 } 95 Some(Err(e)) => Err(TapError::ConnectionFailed(e.to_string())), 96 None => { 97 tracing::debug!("WebSocket stream ended"); 98 Ok(None) 99 } 100 } 101 } 102 103 /// Send an acknowledgment for the given event ID. 104 /// 105 /// Uses a pre-allocated buffer and itoa for allocation-free formatting. 106 /// Format: `{"type":"ack","id":12345}` 107 pub async fn send_ack(&mut self, id: u64) -> Result<(), TapError> { 108 self.ack_buffer.clear(); 109 self.ack_buffer 110 .extend_from_slice(b"{\"type\":\"ack\",\"id\":"); 111 let mut itoa_buf = itoa::Buffer::new(); 112 self.ack_buffer 113 .extend_from_slice(itoa_buf.format(id).as_bytes()); 114 self.ack_buffer.push(b'}'); 115 116 // All bytes are ASCII so this is always valid UTF-8 117 let msg = std::str::from_utf8(&self.ack_buffer).expect("ack buffer contains only ASCII"); 118 119 self.ws 120 .send(Message::text(msg.to_string())) 121 .await 122 .map_err(|e| TapError::AckFailed(e.to_string()))?; 123 124 // Flush to ensure the ack is sent immediately 125 self.ws 126 .flush() 127 .await 128 .map_err(|e| TapError::AckFailed(format!("Failed to flush ack: {}", e)))?; 129 130 tracing::trace!(id, "Sent ack"); 131 Ok(()) 132 } 133 134 /// Close the WebSocket connection gracefully. 135 pub async fn close(&mut self) -> Result<(), TapError> { 136 self.ws 137 .close() 138 .await 139 .map_err(|e| TapError::ConnectionFailed(format!("Failed to close: {}", e)))?; 140 Ok(()) 141 } 142} 143 144#[cfg(test)] 145mod tests { 146 #[test] 147 fn test_ack_buffer_format() { 148 // Test that our manual JSON formatting is correct 149 // Format: {"type":"ack","id":12345} 150 let mut buffer = Vec::with_capacity(64); 151 152 let id: u64 = 12345; 153 buffer.clear(); 154 buffer.extend_from_slice(b"{\"type\":\"ack\",\"id\":"); 155 let mut itoa_buf = itoa::Buffer::new(); 156 buffer.extend_from_slice(itoa_buf.format(id).as_bytes()); 157 buffer.push(b'}'); 158 159 let result = std::str::from_utf8(&buffer).unwrap(); 160 assert_eq!(result, r#"{"type":"ack","id":12345}"#); 161 162 // Test max u64 163 let id: u64 = u64::MAX; 164 buffer.clear(); 165 buffer.extend_from_slice(b"{\"type\":\"ack\",\"id\":"); 166 buffer.extend_from_slice(itoa_buf.format(id).as_bytes()); 167 buffer.push(b'}'); 168 169 let result = std::str::from_utf8(&buffer).unwrap(); 170 assert_eq!(result, r#"{"type":"ack","id":18446744073709551615}"#); 171 assert!(buffer.len() <= 64); // Fits in our pre-allocated buffer 172 } 173}