A better Rust ATProto crate
at main 706 lines 20 kB view raw
1//! WebSocket client abstraction 2 3use crate::CowStr; 4use crate::stream::StreamError; 5use alloc::boxed::Box; 6use alloc::string::String; 7use alloc::string::ToString; 8use alloc::vec::Vec; 9use bytes::Bytes; 10use core::borrow::Borrow; 11use core::fmt::{self, Display}; 12use core::future::Future; 13use core::ops::Deref; 14use core::pin::Pin; 15use n0_future::Stream; 16use url::Url; 17 18/// UTF-8 validated bytes for WebSocket text messages 19#[repr(transparent)] 20#[derive(Debug, Clone, Eq, PartialEq, Hash, PartialOrd, Ord)] 21pub struct WsText(Bytes); 22 23impl WsText { 24 /// Create from static string 25 pub const fn from_static(s: &'static str) -> Self { 26 Self(Bytes::from_static(s.as_bytes())) 27 } 28 29 /// Get as string slice 30 pub fn as_str(&self) -> &str { 31 unsafe { core::str::from_utf8_unchecked(&self.0) } 32 } 33 34 /// Create from bytes without validation (caller must ensure UTF-8) 35 /// 36 /// # Safety 37 /// Bytes must be valid UTF-8 38 pub unsafe fn from_bytes_unchecked(bytes: Bytes) -> Self { 39 Self(bytes) 40 } 41 42 /// Convert into underlying bytes 43 pub fn into_bytes(self) -> Bytes { 44 self.0 45 } 46} 47 48impl Deref for WsText { 49 type Target = str; 50 fn deref(&self) -> &str { 51 self.as_str() 52 } 53} 54 55impl AsRef<str> for WsText { 56 fn as_ref(&self) -> &str { 57 self.as_str() 58 } 59} 60 61impl AsRef<[u8]> for WsText { 62 fn as_ref(&self) -> &[u8] { 63 &self.0 64 } 65} 66 67impl AsRef<Bytes> for WsText { 68 fn as_ref(&self) -> &Bytes { 69 &self.0 70 } 71} 72 73impl Borrow<str> for WsText { 74 fn borrow(&self) -> &str { 75 self.as_str() 76 } 77} 78 79impl Display for WsText { 80 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 81 Display::fmt(self.as_str(), f) 82 } 83} 84 85impl From<String> for WsText { 86 fn from(s: String) -> Self { 87 Self(Bytes::from(s)) 88 } 89} 90 91impl From<&str> for WsText { 92 fn from(s: &str) -> Self { 93 Self(Bytes::copy_from_slice(s.as_bytes())) 94 } 95} 96 97impl From<&String> for WsText { 98 fn from(s: &String) -> Self { 99 Self::from(s.as_str()) 100 } 101} 102 103impl TryFrom<Bytes> for WsText { 104 type Error = core::str::Utf8Error; 105 fn try_from(bytes: Bytes) -> Result<Self, Self::Error> { 106 core::str::from_utf8(&bytes)?; 107 Ok(Self(bytes)) 108 } 109} 110 111impl TryFrom<Vec<u8>> for WsText { 112 type Error = core::str::Utf8Error; 113 fn try_from(vec: Vec<u8>) -> Result<Self, Self::Error> { 114 Self::try_from(Bytes::from(vec)) 115 } 116} 117 118impl From<WsText> for Bytes { 119 fn from(t: WsText) -> Bytes { 120 t.0 121 } 122} 123 124impl Default for WsText { 125 fn default() -> Self { 126 Self(Bytes::new()) 127 } 128} 129 130/// WebSocket close code 131#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] 132#[repr(u16)] 133pub enum CloseCode { 134 /// Normal closure 135 Normal = 1000, 136 /// Endpoint going away 137 Away = 1001, 138 /// Protocol error 139 Protocol = 1002, 140 /// Unsupported data 141 Unsupported = 1003, 142 /// Invalid frame payload data 143 Invalid = 1007, 144 /// Policy violation 145 Policy = 1008, 146 /// Message too big 147 Size = 1009, 148 /// Extension negotiation failure 149 Extension = 1010, 150 /// Unexpected condition 151 Error = 1011, 152 /// TLS handshake failure 153 Tls = 1015, 154 /// Other code 155 Other(u16), 156} 157 158impl From<u16> for CloseCode { 159 fn from(code: u16) -> Self { 160 match code { 161 1000 => CloseCode::Normal, 162 1001 => CloseCode::Away, 163 1002 => CloseCode::Protocol, 164 1003 => CloseCode::Unsupported, 165 1007 => CloseCode::Invalid, 166 1008 => CloseCode::Policy, 167 1009 => CloseCode::Size, 168 1010 => CloseCode::Extension, 169 1011 => CloseCode::Error, 170 1015 => CloseCode::Tls, 171 other => CloseCode::Other(other), 172 } 173 } 174} 175 176impl From<CloseCode> for u16 { 177 fn from(code: CloseCode) -> u16 { 178 match code { 179 CloseCode::Normal => 1000, 180 CloseCode::Away => 1001, 181 CloseCode::Protocol => 1002, 182 CloseCode::Unsupported => 1003, 183 CloseCode::Invalid => 1007, 184 CloseCode::Policy => 1008, 185 CloseCode::Size => 1009, 186 CloseCode::Extension => 1010, 187 CloseCode::Error => 1011, 188 CloseCode::Tls => 1015, 189 CloseCode::Other(code) => code, 190 } 191 } 192} 193 194/// WebSocket close frame 195#[derive(Debug, Clone, PartialEq, Eq)] 196pub struct CloseFrame<'a> { 197 /// Close code 198 pub code: CloseCode, 199 /// Close reason text 200 pub reason: CowStr<'a>, 201} 202 203impl<'a> CloseFrame<'a> { 204 /// Create a new close frame 205 pub fn new(code: CloseCode, reason: impl Into<CowStr<'a>>) -> Self { 206 Self { 207 code, 208 reason: reason.into(), 209 } 210 } 211} 212 213/// WebSocket message 214#[derive(Debug, Clone, PartialEq, Eq)] 215pub enum WsMessage { 216 /// Text message (UTF-8) 217 Text(WsText), 218 /// Binary message 219 Binary(Bytes), 220 /// Close frame 221 Close(Option<CloseFrame<'static>>), 222} 223 224impl WsMessage { 225 /// Check if this is a text message 226 pub fn is_text(&self) -> bool { 227 matches!(self, WsMessage::Text(_)) 228 } 229 230 /// Check if this is a binary message 231 pub fn is_binary(&self) -> bool { 232 matches!(self, WsMessage::Binary(_)) 233 } 234 235 /// Check if this is a close message 236 pub fn is_close(&self) -> bool { 237 matches!(self, WsMessage::Close(_)) 238 } 239 240 /// Get as text, if this is a text message 241 pub fn as_text(&self) -> Option<&str> { 242 match self { 243 WsMessage::Text(t) => Some(t.as_str()), 244 _ => None, 245 } 246 } 247 248 /// Get as bytes 249 pub fn as_bytes(&self) -> Option<&[u8]> { 250 match self { 251 WsMessage::Text(t) => Some(t.as_ref()), 252 WsMessage::Binary(b) => Some(b), 253 WsMessage::Close(_) => None, 254 } 255 } 256} 257 258impl From<WsText> for WsMessage { 259 fn from(text: WsText) -> Self { 260 WsMessage::Text(text) 261 } 262} 263 264impl From<String> for WsMessage { 265 fn from(s: String) -> Self { 266 WsMessage::Text(WsText::from(s)) 267 } 268} 269 270impl From<&str> for WsMessage { 271 fn from(s: &str) -> Self { 272 WsMessage::Text(WsText::from(s)) 273 } 274} 275 276impl From<Bytes> for WsMessage { 277 fn from(bytes: Bytes) -> Self { 278 WsMessage::Binary(bytes) 279 } 280} 281 282impl From<Vec<u8>> for WsMessage { 283 fn from(vec: Vec<u8>) -> Self { 284 WsMessage::Binary(Bytes::from(vec)) 285 } 286} 287 288/// WebSocket message stream 289#[cfg(not(target_arch = "wasm32"))] 290pub struct WsStream(Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>> + Send>>); 291 292/// WebSocket message stream 293#[cfg(target_arch = "wasm32")] 294pub struct WsStream(Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>>>>); 295 296impl WsStream { 297 /// Create a new message stream 298 #[cfg(not(target_arch = "wasm32"))] 299 pub fn new<S>(stream: S) -> Self 300 where 301 S: Stream<Item = Result<WsMessage, StreamError>> + Send + 'static, 302 { 303 Self(Box::pin(stream)) 304 } 305 306 /// Create a new message stream 307 #[cfg(target_arch = "wasm32")] 308 pub fn new<S>(stream: S) -> Self 309 where 310 S: Stream<Item = Result<WsMessage, StreamError>> + 'static, 311 { 312 Self(Box::pin(stream)) 313 } 314 315 /// Convert into the inner pinned boxed stream 316 #[cfg(not(target_arch = "wasm32"))] 317 pub fn into_inner(self) -> Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>> + Send>> { 318 self.0 319 } 320 321 /// Convert into the inner pinned boxed stream 322 #[cfg(target_arch = "wasm32")] 323 pub fn into_inner(self) -> Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>>>> { 324 self.0 325 } 326 327 /// Split this stream into two streams that both receive all messages 328 /// 329 /// Messages are cloned (cheaply via Bytes rc). Spawns a forwarder task. 330 /// Both returned streams will receive all messages from the original stream. 331 /// The forwarder continues as long as at least one stream is alive. 332 /// If the underlying stream errors, both teed streams will end. 333 pub fn tee(self) -> (WsStream, WsStream) { 334 use futures::channel::mpsc; 335 use n0_future::StreamExt as _; 336 337 let (tx1, rx1) = mpsc::unbounded(); 338 let (tx2, rx2) = mpsc::unbounded(); 339 340 n0_future::task::spawn(async move { 341 let mut stream = self.0; 342 while let Some(result) = stream.next().await { 343 match result { 344 Ok(msg) => { 345 // Clone message (cheap - Bytes is rc'd) 346 let msg2 = msg.clone(); 347 348 // Send to both channels, continue if at least one succeeds 349 let send1 = tx1.unbounded_send(Ok(msg)); 350 let send2 = tx2.unbounded_send(Ok(msg2)); 351 352 // Only stop if both channels are closed 353 if send1.is_err() && send2.is_err() { 354 break; 355 } 356 } 357 Err(_e) => { 358 // Underlying stream errored, stop forwarding. 359 // Both channels will close, ending both streams. 360 break; 361 } 362 } 363 } 364 }); 365 366 (WsStream::new(rx1), WsStream::new(rx2)) 367 } 368} 369 370impl fmt::Debug for WsStream { 371 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 372 f.debug_struct("WsStream").finish_non_exhaustive() 373 } 374} 375 376/// WebSocket message sink 377#[cfg(not(target_arch = "wasm32"))] 378pub struct WsSink(Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + Send>>); 379 380/// WebSocket message sink 381#[cfg(target_arch = "wasm32")] 382pub struct WsSink(Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError>>>); 383 384impl WsSink { 385 /// Create a new message sink 386 #[cfg(not(target_arch = "wasm32"))] 387 pub fn new<S>(sink: S) -> Self 388 where 389 S: n0_future::Sink<WsMessage, Error = StreamError> + Send + 'static, 390 { 391 Self(Box::pin(sink)) 392 } 393 394 /// Create a new message sink 395 #[cfg(target_arch = "wasm32")] 396 pub fn new<S>(sink: S) -> Self 397 where 398 S: n0_future::Sink<WsMessage, Error = StreamError> + 'static, 399 { 400 Self(Box::pin(sink)) 401 } 402 403 /// Convert into the inner boxed sink 404 #[cfg(not(target_arch = "wasm32"))] 405 pub fn into_inner( 406 self, 407 ) -> Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + Send>> { 408 self.0 409 } 410 411 /// Convert into the inner boxed sink 412 #[cfg(target_arch = "wasm32")] 413 pub fn into_inner(self) -> Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError>>> { 414 self.0 415 } 416 417 /// get a mutable reference to the inner boxed sink 418 #[cfg(not(target_arch = "wasm32"))] 419 pub fn get_mut( 420 &mut self, 421 ) -> &mut Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + Send>> { 422 use core::borrow::BorrowMut; 423 424 self.0.borrow_mut() 425 } 426 427 /// get a mutable reference to the inner boxed sink 428 #[cfg(target_arch = "wasm32")] 429 pub fn get_mut( 430 &mut self, 431 ) -> &mut Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + 'static>> { 432 use core::borrow::BorrowMut; 433 434 self.0.borrow_mut() 435 } 436} 437 438impl fmt::Debug for WsSink { 439 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 440 f.debug_struct("WsSink").finish_non_exhaustive() 441 } 442} 443 444/// WebSocket client trait 445#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))] 446pub trait WebSocketClient: Sync { 447 /// Error type for WebSocket operations 448 type Error: core::error::Error + Send + Sync + 'static; 449 450 /// Connect to a WebSocket endpoint 451 fn connect(&self, url: Url) -> impl Future<Output = Result<WebSocketConnection, Self::Error>>; 452 453 /// Connect to a WebSocket endpoint with custom headers 454 /// 455 /// Default implementation ignores headers and calls `connect()`. 456 /// Override this method to support authentication headers for subscriptions. 457 fn connect_with_headers( 458 &self, 459 url: Url, 460 _headers: Vec<(CowStr<'_>, CowStr<'_>)>, 461 ) -> impl Future<Output = Result<WebSocketConnection, Self::Error>> { 462 async move { self.connect(url).await } 463 } 464} 465 466/// WebSocket connection with bidirectional streams 467pub struct WebSocketConnection { 468 tx: WsSink, 469 rx: WsStream, 470} 471 472impl WebSocketConnection { 473 /// Create a new WebSocket connection 474 pub fn new(tx: WsSink, rx: WsStream) -> Self { 475 Self { tx, rx } 476 } 477 478 /// Get mutable access to the sender 479 pub fn sender_mut(&mut self) -> &mut WsSink { 480 &mut self.tx 481 } 482 483 /// Get mutable access to the receiver 484 pub fn receiver_mut(&mut self) -> &mut WsStream { 485 &mut self.rx 486 } 487 488 /// Get a reference to the receiver 489 pub fn receiver(&self) -> &WsStream { 490 &self.rx 491 } 492 493 /// Get a reference to the sender 494 pub fn sender(&self) -> &WsSink { 495 &self.tx 496 } 497 498 /// Split into sender and receiver 499 pub fn split(self) -> (WsSink, WsStream) { 500 (self.tx, self.rx) 501 } 502 503 /// Check if connection is open (always true for this abstraction) 504 pub fn is_open(&self) -> bool { 505 true 506 } 507} 508 509impl fmt::Debug for WebSocketConnection { 510 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 511 f.debug_struct("WebSocketConnection") 512 .finish_non_exhaustive() 513 } 514} 515 516/// Concrete WebSocket client implementation using tokio-tungstenite-wasm 517pub mod tungstenite_client { 518 use super::*; 519 use crate::IntoStatic; 520 use futures::{SinkExt, StreamExt}; 521 522 /// WebSocket client backed by tokio-tungstenite-wasm 523 #[derive(Debug, Clone, Default)] 524 pub struct TungsteniteClient; 525 526 impl TungsteniteClient { 527 /// Create a new tungstenite WebSocket client 528 pub fn new() -> Self { 529 Self 530 } 531 } 532 533 impl WebSocketClient for TungsteniteClient { 534 type Error = tokio_tungstenite_wasm::Error; 535 536 async fn connect(&self, url: Url) -> Result<WebSocketConnection, Self::Error> { 537 let ws_stream = tokio_tungstenite_wasm::connect(url.as_str()).await?; 538 539 let (sink, stream) = ws_stream.split(); 540 541 // Convert tungstenite messages to our WsMessage 542 let rx_stream = stream.filter_map(|result| async move { 543 match result { 544 Ok(msg) => match convert_message(msg) { 545 Some(ws_msg) => Some(Ok(ws_msg)), 546 None => None, // Skip ping/pong 547 }, 548 Err(e) => Some(Err(StreamError::transport(e))), 549 } 550 }); 551 552 let rx = WsStream::new(rx_stream); 553 554 // Convert our WsMessage to tungstenite messages 555 let tx_sink = sink.with(|msg: WsMessage| async move { 556 Ok::<_, tokio_tungstenite_wasm::Error>(msg.into()) 557 }); 558 559 let tx_sink_mapped = tx_sink.sink_map_err(|e| StreamError::transport(e)); 560 let tx = WsSink::new(tx_sink_mapped); 561 562 Ok(WebSocketConnection::new(tx, rx)) 563 } 564 } 565 566 /// Convert tokio-tungstenite-wasm Message to our WsMessage 567 /// Returns None for Ping/Pong which we auto-handle 568 fn convert_message(msg: tokio_tungstenite_wasm::Message) -> Option<WsMessage> { 569 use tokio_tungstenite_wasm::Message; 570 571 match msg { 572 Message::Text(vec) => { 573 // tokio-tungstenite-wasm Text contains Vec<u8> (UTF-8 validated) 574 let bytes = Bytes::from(vec); 575 Some(WsMessage::Text(unsafe { 576 WsText::from_bytes_unchecked(bytes) 577 })) 578 } 579 Message::Binary(vec) => Some(WsMessage::Binary(Bytes::from(vec))), 580 Message::Close(frame) => { 581 let close_frame = frame.map(|f| { 582 let code = convert_close_code(f.code); 583 CloseFrame::new(code, CowStr::from(f.reason.into_owned())) 584 }); 585 Some(WsMessage::Close(close_frame)) 586 } 587 } 588 } 589 590 /// Convert tokio-tungstenite-wasm CloseCode to our CloseCode 591 fn convert_close_code(code: tokio_tungstenite_wasm::CloseCode) -> CloseCode { 592 use tokio_tungstenite_wasm::CloseCode as TungsteniteCode; 593 594 match code { 595 TungsteniteCode::Normal => CloseCode::Normal, 596 TungsteniteCode::Away => CloseCode::Away, 597 TungsteniteCode::Protocol => CloseCode::Protocol, 598 TungsteniteCode::Unsupported => CloseCode::Unsupported, 599 TungsteniteCode::Invalid => CloseCode::Invalid, 600 TungsteniteCode::Policy => CloseCode::Policy, 601 TungsteniteCode::Size => CloseCode::Size, 602 TungsteniteCode::Extension => CloseCode::Extension, 603 TungsteniteCode::Error => CloseCode::Error, 604 TungsteniteCode::Tls => CloseCode::Tls, 605 // For other variants, extract raw code 606 other => { 607 let raw: u16 = other.into(); 608 CloseCode::from(raw) 609 } 610 } 611 } 612 613 impl From<WsMessage> for tokio_tungstenite_wasm::Message { 614 fn from(msg: WsMessage) -> Self { 615 use tokio_tungstenite_wasm::Message; 616 617 match msg { 618 WsMessage::Text(text) => { 619 // tokio-tungstenite-wasm Text expects String 620 let bytes = text.into_bytes(); 621 // Safe: WsText is already UTF-8 validated 622 let string = unsafe { String::from_utf8_unchecked(bytes.to_vec()) }; 623 Message::Text(string) 624 } 625 WsMessage::Binary(bytes) => Message::Binary(bytes.to_vec()), 626 WsMessage::Close(frame) => { 627 let close_frame = frame.map(|f| { 628 let code = u16::from(f.code).into(); 629 tokio_tungstenite_wasm::CloseFrame { 630 code, 631 reason: f.reason.into_static().to_string().into(), 632 } 633 }); 634 Message::Close(close_frame) 635 } 636 } 637 } 638 } 639} 640 641#[cfg(test)] 642mod tests { 643 use super::*; 644 645 #[test] 646 fn ws_text_from_string() { 647 let text = WsText::from("hello"); 648 assert_eq!(text.as_str(), "hello"); 649 } 650 651 #[test] 652 fn ws_text_deref() { 653 let text = WsText::from(String::from("world")); 654 assert_eq!(&*text, "world"); 655 } 656 657 #[test] 658 fn ws_text_try_from_bytes() { 659 let bytes = Bytes::from("test"); 660 let text = WsText::try_from(bytes).unwrap(); 661 assert_eq!(text.as_str(), "test"); 662 } 663 664 #[test] 665 fn ws_text_invalid_utf8() { 666 let bytes = Bytes::from(vec![0xFF, 0xFE]); 667 assert!(WsText::try_from(bytes).is_err()); 668 } 669 670 #[test] 671 fn ws_message_text() { 672 let msg = WsMessage::from("hello"); 673 assert!(msg.is_text()); 674 assert_eq!(msg.as_text(), Some("hello")); 675 } 676 677 #[test] 678 fn ws_message_binary() { 679 let msg = WsMessage::from(vec![1, 2, 3]); 680 assert!(msg.is_binary()); 681 assert_eq!(msg.as_bytes(), Some(&[1u8, 2, 3][..])); 682 } 683 684 #[test] 685 fn close_code_conversion() { 686 assert_eq!(u16::from(CloseCode::Normal), 1000); 687 assert_eq!(CloseCode::from(1000), CloseCode::Normal); 688 assert_eq!(CloseCode::from(9999), CloseCode::Other(9999)); 689 } 690 691 #[test] 692 fn websocket_connection_has_tx_and_rx() { 693 use futures::sink::SinkExt; 694 use futures::stream; 695 696 let rx_stream = stream::iter(vec![Ok(WsMessage::from("test"))]); 697 let rx = WsStream::new(rx_stream); 698 699 let drain_sink = futures::sink::drain() 700 .sink_map_err(|_: std::convert::Infallible| StreamError::closed()); 701 let tx = WsSink::new(drain_sink); 702 703 let conn = WebSocketConnection::new(tx, rx); 704 assert!(conn.is_open()); 705 } 706}