A better Rust ATProto crate
at main 20 kB view raw
1//! WebSocket client abstraction 2 3use crate::CowStr; 4use crate::stream::StreamError; 5use bytes::Bytes; 6use n0_future::Stream; 7use std::borrow::Borrow; 8use std::fmt::{self, Display}; 9use std::future::Future; 10use std::ops::Deref; 11use std::pin::Pin; 12use url::Url; 13 14/// UTF-8 validated bytes for WebSocket text messages 15#[repr(transparent)] 16#[derive(Debug, Clone, Eq, PartialEq, Hash, PartialOrd, Ord)] 17pub struct WsText(Bytes); 18 19impl WsText { 20 /// Create from static string 21 pub const fn from_static(s: &'static str) -> Self { 22 Self(Bytes::from_static(s.as_bytes())) 23 } 24 25 /// Get as string slice 26 pub fn as_str(&self) -> &str { 27 unsafe { std::str::from_utf8_unchecked(&self.0) } 28 } 29 30 /// Create from bytes without validation (caller must ensure UTF-8) 31 /// 32 /// # Safety 33 /// Bytes must be valid UTF-8 34 pub unsafe fn from_bytes_unchecked(bytes: Bytes) -> Self { 35 Self(bytes) 36 } 37 38 /// Convert into underlying bytes 39 pub fn into_bytes(self) -> Bytes { 40 self.0 41 } 42} 43 44impl Deref for WsText { 45 type Target = str; 46 fn deref(&self) -> &str { 47 self.as_str() 48 } 49} 50 51impl AsRef<str> for WsText { 52 fn as_ref(&self) -> &str { 53 self.as_str() 54 } 55} 56 57impl AsRef<[u8]> for WsText { 58 fn as_ref(&self) -> &[u8] { 59 &self.0 60 } 61} 62 63impl AsRef<Bytes> for WsText { 64 fn as_ref(&self) -> &Bytes { 65 &self.0 66 } 67} 68 69impl Borrow<str> for WsText { 70 fn borrow(&self) -> &str { 71 self.as_str() 72 } 73} 74 75impl Display for WsText { 76 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 77 Display::fmt(self.as_str(), f) 78 } 79} 80 81impl From<String> for WsText { 82 fn from(s: String) -> Self { 83 Self(Bytes::from(s)) 84 } 85} 86 87impl From<&str> for WsText { 88 fn from(s: &str) -> Self { 89 Self(Bytes::copy_from_slice(s.as_bytes())) 90 } 91} 92 93impl From<&String> for WsText { 94 fn from(s: &String) -> Self { 95 Self::from(s.as_str()) 96 } 97} 98 99impl TryFrom<Bytes> for WsText { 100 type Error = std::str::Utf8Error; 101 fn try_from(bytes: Bytes) -> Result<Self, Self::Error> { 102 std::str::from_utf8(&bytes)?; 103 Ok(Self(bytes)) 104 } 105} 106 107impl TryFrom<Vec<u8>> for WsText { 108 type Error = std::str::Utf8Error; 109 fn try_from(vec: Vec<u8>) -> Result<Self, Self::Error> { 110 Self::try_from(Bytes::from(vec)) 111 } 112} 113 114impl From<WsText> for Bytes { 115 fn from(t: WsText) -> Bytes { 116 t.0 117 } 118} 119 120impl Default for WsText { 121 fn default() -> Self { 122 Self(Bytes::new()) 123 } 124} 125 126/// WebSocket close code 127#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] 128#[repr(u16)] 129pub enum CloseCode { 130 /// Normal closure 131 Normal = 1000, 132 /// Endpoint going away 133 Away = 1001, 134 /// Protocol error 135 Protocol = 1002, 136 /// Unsupported data 137 Unsupported = 1003, 138 /// Invalid frame payload data 139 Invalid = 1007, 140 /// Policy violation 141 Policy = 1008, 142 /// Message too big 143 Size = 1009, 144 /// Extension negotiation failure 145 Extension = 1010, 146 /// Unexpected condition 147 Error = 1011, 148 /// TLS handshake failure 149 Tls = 1015, 150 /// Other code 151 Other(u16), 152} 153 154impl From<u16> for CloseCode { 155 fn from(code: u16) -> Self { 156 match code { 157 1000 => CloseCode::Normal, 158 1001 => CloseCode::Away, 159 1002 => CloseCode::Protocol, 160 1003 => CloseCode::Unsupported, 161 1007 => CloseCode::Invalid, 162 1008 => CloseCode::Policy, 163 1009 => CloseCode::Size, 164 1010 => CloseCode::Extension, 165 1011 => CloseCode::Error, 166 1015 => CloseCode::Tls, 167 other => CloseCode::Other(other), 168 } 169 } 170} 171 172impl From<CloseCode> for u16 { 173 fn from(code: CloseCode) -> u16 { 174 match code { 175 CloseCode::Normal => 1000, 176 CloseCode::Away => 1001, 177 CloseCode::Protocol => 1002, 178 CloseCode::Unsupported => 1003, 179 CloseCode::Invalid => 1007, 180 CloseCode::Policy => 1008, 181 CloseCode::Size => 1009, 182 CloseCode::Extension => 1010, 183 CloseCode::Error => 1011, 184 CloseCode::Tls => 1015, 185 CloseCode::Other(code) => code, 186 } 187 } 188} 189 190/// WebSocket close frame 191#[derive(Debug, Clone, PartialEq, Eq)] 192pub struct CloseFrame<'a> { 193 /// Close code 194 pub code: CloseCode, 195 /// Close reason text 196 pub reason: CowStr<'a>, 197} 198 199impl<'a> CloseFrame<'a> { 200 /// Create a new close frame 201 pub fn new(code: CloseCode, reason: impl Into<CowStr<'a>>) -> Self { 202 Self { 203 code, 204 reason: reason.into(), 205 } 206 } 207} 208 209/// WebSocket message 210#[derive(Debug, Clone, PartialEq, Eq)] 211pub enum WsMessage { 212 /// Text message (UTF-8) 213 Text(WsText), 214 /// Binary message 215 Binary(Bytes), 216 /// Close frame 217 Close(Option<CloseFrame<'static>>), 218} 219 220impl WsMessage { 221 /// Check if this is a text message 222 pub fn is_text(&self) -> bool { 223 matches!(self, WsMessage::Text(_)) 224 } 225 226 /// Check if this is a binary message 227 pub fn is_binary(&self) -> bool { 228 matches!(self, WsMessage::Binary(_)) 229 } 230 231 /// Check if this is a close message 232 pub fn is_close(&self) -> bool { 233 matches!(self, WsMessage::Close(_)) 234 } 235 236 /// Get as text, if this is a text message 237 pub fn as_text(&self) -> Option<&str> { 238 match self { 239 WsMessage::Text(t) => Some(t.as_str()), 240 _ => None, 241 } 242 } 243 244 /// Get as bytes 245 pub fn as_bytes(&self) -> Option<&[u8]> { 246 match self { 247 WsMessage::Text(t) => Some(t.as_ref()), 248 WsMessage::Binary(b) => Some(b), 249 WsMessage::Close(_) => None, 250 } 251 } 252} 253 254impl From<WsText> for WsMessage { 255 fn from(text: WsText) -> Self { 256 WsMessage::Text(text) 257 } 258} 259 260impl From<String> for WsMessage { 261 fn from(s: String) -> Self { 262 WsMessage::Text(WsText::from(s)) 263 } 264} 265 266impl From<&str> for WsMessage { 267 fn from(s: &str) -> Self { 268 WsMessage::Text(WsText::from(s)) 269 } 270} 271 272impl From<Bytes> for WsMessage { 273 fn from(bytes: Bytes) -> Self { 274 WsMessage::Binary(bytes) 275 } 276} 277 278impl From<Vec<u8>> for WsMessage { 279 fn from(vec: Vec<u8>) -> Self { 280 WsMessage::Binary(Bytes::from(vec)) 281 } 282} 283 284/// WebSocket message stream 285#[cfg(not(target_arch = "wasm32"))] 286pub struct WsStream(Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>> + Send>>); 287 288/// WebSocket message stream 289#[cfg(target_arch = "wasm32")] 290pub struct WsStream(Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>>>>); 291 292impl WsStream { 293 /// Create a new message stream 294 #[cfg(not(target_arch = "wasm32"))] 295 pub fn new<S>(stream: S) -> Self 296 where 297 S: Stream<Item = Result<WsMessage, StreamError>> + Send + 'static, 298 { 299 Self(Box::pin(stream)) 300 } 301 302 /// Create a new message stream 303 #[cfg(target_arch = "wasm32")] 304 pub fn new<S>(stream: S) -> Self 305 where 306 S: Stream<Item = Result<WsMessage, StreamError>> + 'static, 307 { 308 Self(Box::pin(stream)) 309 } 310 311 /// Convert into the inner pinned boxed stream 312 #[cfg(not(target_arch = "wasm32"))] 313 pub fn into_inner(self) -> Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>> + Send>> { 314 self.0 315 } 316 317 /// Convert into the inner pinned boxed stream 318 #[cfg(target_arch = "wasm32")] 319 pub fn into_inner(self) -> Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>>>> { 320 self.0 321 } 322 323 /// Split this stream into two streams that both receive all messages 324 /// 325 /// Messages are cloned (cheaply via Bytes rc). Spawns a forwarder task. 326 /// Both returned streams will receive all messages from the original stream. 327 /// The forwarder continues as long as at least one stream is alive. 328 /// If the underlying stream errors, both teed streams will end. 329 pub fn tee(self) -> (WsStream, WsStream) { 330 use futures::channel::mpsc; 331 use n0_future::StreamExt as _; 332 333 let (tx1, rx1) = mpsc::unbounded(); 334 let (tx2, rx2) = mpsc::unbounded(); 335 336 n0_future::task::spawn(async move { 337 let mut stream = self.0; 338 while let Some(result) = stream.next().await { 339 match result { 340 Ok(msg) => { 341 // Clone message (cheap - Bytes is rc'd) 342 let msg2 = msg.clone(); 343 344 // Send to both channels, continue if at least one succeeds 345 let send1 = tx1.unbounded_send(Ok(msg)); 346 let send2 = tx2.unbounded_send(Ok(msg2)); 347 348 // Only stop if both channels are closed 349 if send1.is_err() && send2.is_err() { 350 break; 351 } 352 } 353 Err(_e) => { 354 // Underlying stream errored, stop forwarding. 355 // Both channels will close, ending both streams. 356 break; 357 } 358 } 359 } 360 }); 361 362 (WsStream::new(rx1), WsStream::new(rx2)) 363 } 364} 365 366impl fmt::Debug for WsStream { 367 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 368 f.debug_struct("WsStream").finish_non_exhaustive() 369 } 370} 371 372/// WebSocket message sink 373#[cfg(not(target_arch = "wasm32"))] 374pub struct WsSink(Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + Send>>); 375 376/// WebSocket message sink 377#[cfg(target_arch = "wasm32")] 378pub struct WsSink(Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError>>>); 379 380impl WsSink { 381 /// Create a new message sink 382 #[cfg(not(target_arch = "wasm32"))] 383 pub fn new<S>(sink: S) -> Self 384 where 385 S: n0_future::Sink<WsMessage, Error = StreamError> + Send + 'static, 386 { 387 Self(Box::pin(sink)) 388 } 389 390 /// Create a new message sink 391 #[cfg(target_arch = "wasm32")] 392 pub fn new<S>(sink: S) -> Self 393 where 394 S: n0_future::Sink<WsMessage, Error = StreamError> + 'static, 395 { 396 Self(Box::pin(sink)) 397 } 398 399 /// Convert into the inner boxed sink 400 #[cfg(not(target_arch = "wasm32"))] 401 pub fn into_inner( 402 self, 403 ) -> Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + Send>> { 404 self.0 405 } 406 407 /// Convert into the inner boxed sink 408 #[cfg(target_arch = "wasm32")] 409 pub fn into_inner(self) -> Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError>>> { 410 self.0 411 } 412 413 /// get a mutable reference to the inner boxed sink 414 #[cfg(not(target_arch = "wasm32"))] 415 pub fn get_mut( 416 &mut self, 417 ) -> &mut Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + Send>> { 418 use std::borrow::BorrowMut; 419 420 self.0.borrow_mut() 421 } 422 423 /// get a mutable reference to the inner boxed sink 424 #[cfg(target_arch = "wasm32")] 425 pub fn get_mut( 426 &mut self, 427 ) -> &mut Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + 'static>> { 428 use std::borrow::BorrowMut; 429 430 self.0.borrow_mut() 431 } 432} 433 434impl fmt::Debug for WsSink { 435 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 436 f.debug_struct("WsSink").finish_non_exhaustive() 437 } 438} 439 440/// WebSocket client trait 441#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))] 442pub trait WebSocketClient: Sync { 443 /// Error type for WebSocket operations 444 type Error: std::error::Error + Send + Sync + 'static; 445 446 /// Connect to a WebSocket endpoint 447 fn connect(&self, url: Url) -> impl Future<Output = Result<WebSocketConnection, Self::Error>>; 448 449 /// Connect to a WebSocket endpoint with custom headers 450 /// 451 /// Default implementation ignores headers and calls `connect()`. 452 /// Override this method to support authentication headers for subscriptions. 453 fn connect_with_headers( 454 &self, 455 url: Url, 456 _headers: Vec<(CowStr<'_>, CowStr<'_>)>, 457 ) -> impl Future<Output = Result<WebSocketConnection, Self::Error>> { 458 async move { self.connect(url).await } 459 } 460} 461 462/// WebSocket connection with bidirectional streams 463pub struct WebSocketConnection { 464 tx: WsSink, 465 rx: WsStream, 466} 467 468impl WebSocketConnection { 469 /// Create a new WebSocket connection 470 pub fn new(tx: WsSink, rx: WsStream) -> Self { 471 Self { tx, rx } 472 } 473 474 /// Get mutable access to the sender 475 pub fn sender_mut(&mut self) -> &mut WsSink { 476 &mut self.tx 477 } 478 479 /// Get mutable access to the receiver 480 pub fn receiver_mut(&mut self) -> &mut WsStream { 481 &mut self.rx 482 } 483 484 /// Get a reference to the receiver 485 pub fn receiver(&self) -> &WsStream { 486 &self.rx 487 } 488 489 /// Get a reference to the sender 490 pub fn sender(&self) -> &WsSink { 491 &self.tx 492 } 493 494 /// Split into sender and receiver 495 pub fn split(self) -> (WsSink, WsStream) { 496 (self.tx, self.rx) 497 } 498 499 /// Check if connection is open (always true for this abstraction) 500 pub fn is_open(&self) -> bool { 501 true 502 } 503} 504 505impl fmt::Debug for WebSocketConnection { 506 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 507 f.debug_struct("WebSocketConnection") 508 .finish_non_exhaustive() 509 } 510} 511 512/// Concrete WebSocket client implementation using tokio-tungstenite-wasm 513pub mod tungstenite_client { 514 use super::*; 515 use crate::IntoStatic; 516 use futures::{SinkExt, StreamExt}; 517 518 /// WebSocket client backed by tokio-tungstenite-wasm 519 #[derive(Debug, Clone, Default)] 520 pub struct TungsteniteClient; 521 522 impl TungsteniteClient { 523 /// Create a new tungstenite WebSocket client 524 pub fn new() -> Self { 525 Self 526 } 527 } 528 529 impl WebSocketClient for TungsteniteClient { 530 type Error = tokio_tungstenite_wasm::Error; 531 532 async fn connect(&self, url: Url) -> Result<WebSocketConnection, Self::Error> { 533 let ws_stream = tokio_tungstenite_wasm::connect(url.as_str()).await?; 534 535 let (sink, stream) = ws_stream.split(); 536 537 // Convert tungstenite messages to our WsMessage 538 let rx_stream = stream.filter_map(|result| async move { 539 match result { 540 Ok(msg) => match convert_message(msg) { 541 Some(ws_msg) => Some(Ok(ws_msg)), 542 None => None, // Skip ping/pong 543 }, 544 Err(e) => Some(Err(StreamError::transport(e))), 545 } 546 }); 547 548 let rx = WsStream::new(rx_stream); 549 550 // Convert our WsMessage to tungstenite messages 551 let tx_sink = sink.with(|msg: WsMessage| async move { 552 Ok::<_, tokio_tungstenite_wasm::Error>(msg.into()) 553 }); 554 555 let tx_sink_mapped = tx_sink.sink_map_err(|e| StreamError::transport(e)); 556 let tx = WsSink::new(tx_sink_mapped); 557 558 Ok(WebSocketConnection::new(tx, rx)) 559 } 560 } 561 562 /// Convert tokio-tungstenite-wasm Message to our WsMessage 563 /// Returns None for Ping/Pong which we auto-handle 564 fn convert_message(msg: tokio_tungstenite_wasm::Message) -> Option<WsMessage> { 565 use tokio_tungstenite_wasm::Message; 566 567 match msg { 568 Message::Text(vec) => { 569 // tokio-tungstenite-wasm Text contains Vec<u8> (UTF-8 validated) 570 let bytes = Bytes::from(vec); 571 Some(WsMessage::Text(unsafe { 572 WsText::from_bytes_unchecked(bytes) 573 })) 574 } 575 Message::Binary(vec) => Some(WsMessage::Binary(Bytes::from(vec))), 576 Message::Close(frame) => { 577 let close_frame = frame.map(|f| { 578 let code = convert_close_code(f.code); 579 CloseFrame::new(code, CowStr::from(f.reason.into_owned())) 580 }); 581 Some(WsMessage::Close(close_frame)) 582 } 583 } 584 } 585 586 /// Convert tokio-tungstenite-wasm CloseCode to our CloseCode 587 fn convert_close_code(code: tokio_tungstenite_wasm::CloseCode) -> CloseCode { 588 use tokio_tungstenite_wasm::CloseCode as TungsteniteCode; 589 590 match code { 591 TungsteniteCode::Normal => CloseCode::Normal, 592 TungsteniteCode::Away => CloseCode::Away, 593 TungsteniteCode::Protocol => CloseCode::Protocol, 594 TungsteniteCode::Unsupported => CloseCode::Unsupported, 595 TungsteniteCode::Invalid => CloseCode::Invalid, 596 TungsteniteCode::Policy => CloseCode::Policy, 597 TungsteniteCode::Size => CloseCode::Size, 598 TungsteniteCode::Extension => CloseCode::Extension, 599 TungsteniteCode::Error => CloseCode::Error, 600 TungsteniteCode::Tls => CloseCode::Tls, 601 // For other variants, extract raw code 602 other => { 603 let raw: u16 = other.into(); 604 CloseCode::from(raw) 605 } 606 } 607 } 608 609 impl From<WsMessage> for tokio_tungstenite_wasm::Message { 610 fn from(msg: WsMessage) -> Self { 611 use tokio_tungstenite_wasm::Message; 612 613 match msg { 614 WsMessage::Text(text) => { 615 // tokio-tungstenite-wasm Text expects String 616 let bytes = text.into_bytes(); 617 // Safe: WsText is already UTF-8 validated 618 let string = unsafe { String::from_utf8_unchecked(bytes.to_vec()) }; 619 Message::Text(string) 620 } 621 WsMessage::Binary(bytes) => Message::Binary(bytes.to_vec()), 622 WsMessage::Close(frame) => { 623 let close_frame = frame.map(|f| { 624 let code = u16::from(f.code).into(); 625 tokio_tungstenite_wasm::CloseFrame { 626 code, 627 reason: f.reason.into_static().to_string().into(), 628 } 629 }); 630 Message::Close(close_frame) 631 } 632 } 633 } 634 } 635} 636 637#[cfg(test)] 638mod tests { 639 use super::*; 640 641 #[test] 642 fn ws_text_from_string() { 643 let text = WsText::from("hello"); 644 assert_eq!(text.as_str(), "hello"); 645 } 646 647 #[test] 648 fn ws_text_deref() { 649 let text = WsText::from(String::from("world")); 650 assert_eq!(&*text, "world"); 651 } 652 653 #[test] 654 fn ws_text_try_from_bytes() { 655 let bytes = Bytes::from("test"); 656 let text = WsText::try_from(bytes).unwrap(); 657 assert_eq!(text.as_str(), "test"); 658 } 659 660 #[test] 661 fn ws_text_invalid_utf8() { 662 let bytes = Bytes::from(vec![0xFF, 0xFE]); 663 assert!(WsText::try_from(bytes).is_err()); 664 } 665 666 #[test] 667 fn ws_message_text() { 668 let msg = WsMessage::from("hello"); 669 assert!(msg.is_text()); 670 assert_eq!(msg.as_text(), Some("hello")); 671 } 672 673 #[test] 674 fn ws_message_binary() { 675 let msg = WsMessage::from(vec![1, 2, 3]); 676 assert!(msg.is_binary()); 677 assert_eq!(msg.as_bytes(), Some(&[1u8, 2, 3][..])); 678 } 679 680 #[test] 681 fn close_code_conversion() { 682 assert_eq!(u16::from(CloseCode::Normal), 1000); 683 assert_eq!(CloseCode::from(1000), CloseCode::Normal); 684 assert_eq!(CloseCode::from(9999), CloseCode::Other(9999)); 685 } 686 687 #[test] 688 fn websocket_connection_has_tx_and_rx() { 689 use futures::sink::SinkExt; 690 use futures::stream; 691 692 let rx_stream = stream::iter(vec![Ok(WsMessage::from("test"))]); 693 let rx = WsStream::new(rx_stream); 694 695 let drain_sink = futures::sink::drain() 696 .sink_map_err(|_: std::convert::Infallible| StreamError::closed()); 697 let tx = WsSink::new(drain_sink); 698 699 let conn = WebSocketConnection::new(tx, rx); 700 assert!(conn.is_open()); 701 } 702}