1//! WebSocket client abstraction
2
3use crate::CowStr;
4use crate::stream::StreamError;
5use bytes::Bytes;
6use n0_future::Stream;
7use std::borrow::Borrow;
8use std::fmt::{self, Display};
9use std::future::Future;
10use std::ops::Deref;
11use std::pin::Pin;
12use url::Url;
13
14/// UTF-8 validated bytes for WebSocket text messages
15#[repr(transparent)]
16#[derive(Debug, Clone, Eq, PartialEq, Hash, PartialOrd, Ord)]
17pub struct WsText(Bytes);
18
19impl WsText {
20 /// Create from static string
21 pub const fn from_static(s: &'static str) -> Self {
22 Self(Bytes::from_static(s.as_bytes()))
23 }
24
25 /// Get as string slice
26 pub fn as_str(&self) -> &str {
27 unsafe { std::str::from_utf8_unchecked(&self.0) }
28 }
29
30 /// Create from bytes without validation (caller must ensure UTF-8)
31 ///
32 /// # Safety
33 /// Bytes must be valid UTF-8
34 pub unsafe fn from_bytes_unchecked(bytes: Bytes) -> Self {
35 Self(bytes)
36 }
37
38 /// Convert into underlying bytes
39 pub fn into_bytes(self) -> Bytes {
40 self.0
41 }
42}
43
44impl Deref for WsText {
45 type Target = str;
46 fn deref(&self) -> &str {
47 self.as_str()
48 }
49}
50
51impl AsRef<str> for WsText {
52 fn as_ref(&self) -> &str {
53 self.as_str()
54 }
55}
56
57impl AsRef<[u8]> for WsText {
58 fn as_ref(&self) -> &[u8] {
59 &self.0
60 }
61}
62
63impl AsRef<Bytes> for WsText {
64 fn as_ref(&self) -> &Bytes {
65 &self.0
66 }
67}
68
69impl Borrow<str> for WsText {
70 fn borrow(&self) -> &str {
71 self.as_str()
72 }
73}
74
75impl Display for WsText {
76 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
77 Display::fmt(self.as_str(), f)
78 }
79}
80
81impl From<String> for WsText {
82 fn from(s: String) -> Self {
83 Self(Bytes::from(s))
84 }
85}
86
87impl From<&str> for WsText {
88 fn from(s: &str) -> Self {
89 Self(Bytes::copy_from_slice(s.as_bytes()))
90 }
91}
92
93impl From<&String> for WsText {
94 fn from(s: &String) -> Self {
95 Self::from(s.as_str())
96 }
97}
98
99impl TryFrom<Bytes> for WsText {
100 type Error = std::str::Utf8Error;
101 fn try_from(bytes: Bytes) -> Result<Self, Self::Error> {
102 std::str::from_utf8(&bytes)?;
103 Ok(Self(bytes))
104 }
105}
106
107impl TryFrom<Vec<u8>> for WsText {
108 type Error = std::str::Utf8Error;
109 fn try_from(vec: Vec<u8>) -> Result<Self, Self::Error> {
110 Self::try_from(Bytes::from(vec))
111 }
112}
113
114impl From<WsText> for Bytes {
115 fn from(t: WsText) -> Bytes {
116 t.0
117 }
118}
119
120impl Default for WsText {
121 fn default() -> Self {
122 Self(Bytes::new())
123 }
124}
125
126/// WebSocket close code
127#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
128#[repr(u16)]
129pub enum CloseCode {
130 /// Normal closure
131 Normal = 1000,
132 /// Endpoint going away
133 Away = 1001,
134 /// Protocol error
135 Protocol = 1002,
136 /// Unsupported data
137 Unsupported = 1003,
138 /// Invalid frame payload data
139 Invalid = 1007,
140 /// Policy violation
141 Policy = 1008,
142 /// Message too big
143 Size = 1009,
144 /// Extension negotiation failure
145 Extension = 1010,
146 /// Unexpected condition
147 Error = 1011,
148 /// TLS handshake failure
149 Tls = 1015,
150 /// Other code
151 Other(u16),
152}
153
154impl From<u16> for CloseCode {
155 fn from(code: u16) -> Self {
156 match code {
157 1000 => CloseCode::Normal,
158 1001 => CloseCode::Away,
159 1002 => CloseCode::Protocol,
160 1003 => CloseCode::Unsupported,
161 1007 => CloseCode::Invalid,
162 1008 => CloseCode::Policy,
163 1009 => CloseCode::Size,
164 1010 => CloseCode::Extension,
165 1011 => CloseCode::Error,
166 1015 => CloseCode::Tls,
167 other => CloseCode::Other(other),
168 }
169 }
170}
171
172impl From<CloseCode> for u16 {
173 fn from(code: CloseCode) -> u16 {
174 match code {
175 CloseCode::Normal => 1000,
176 CloseCode::Away => 1001,
177 CloseCode::Protocol => 1002,
178 CloseCode::Unsupported => 1003,
179 CloseCode::Invalid => 1007,
180 CloseCode::Policy => 1008,
181 CloseCode::Size => 1009,
182 CloseCode::Extension => 1010,
183 CloseCode::Error => 1011,
184 CloseCode::Tls => 1015,
185 CloseCode::Other(code) => code,
186 }
187 }
188}
189
190/// WebSocket close frame
191#[derive(Debug, Clone, PartialEq, Eq)]
192pub struct CloseFrame<'a> {
193 /// Close code
194 pub code: CloseCode,
195 /// Close reason text
196 pub reason: CowStr<'a>,
197}
198
199impl<'a> CloseFrame<'a> {
200 /// Create a new close frame
201 pub fn new(code: CloseCode, reason: impl Into<CowStr<'a>>) -> Self {
202 Self {
203 code,
204 reason: reason.into(),
205 }
206 }
207}
208
209/// WebSocket message
210#[derive(Debug, Clone, PartialEq, Eq)]
211pub enum WsMessage {
212 /// Text message (UTF-8)
213 Text(WsText),
214 /// Binary message
215 Binary(Bytes),
216 /// Close frame
217 Close(Option<CloseFrame<'static>>),
218}
219
220impl WsMessage {
221 /// Check if this is a text message
222 pub fn is_text(&self) -> bool {
223 matches!(self, WsMessage::Text(_))
224 }
225
226 /// Check if this is a binary message
227 pub fn is_binary(&self) -> bool {
228 matches!(self, WsMessage::Binary(_))
229 }
230
231 /// Check if this is a close message
232 pub fn is_close(&self) -> bool {
233 matches!(self, WsMessage::Close(_))
234 }
235
236 /// Get as text, if this is a text message
237 pub fn as_text(&self) -> Option<&str> {
238 match self {
239 WsMessage::Text(t) => Some(t.as_str()),
240 _ => None,
241 }
242 }
243
244 /// Get as bytes
245 pub fn as_bytes(&self) -> Option<&[u8]> {
246 match self {
247 WsMessage::Text(t) => Some(t.as_ref()),
248 WsMessage::Binary(b) => Some(b),
249 WsMessage::Close(_) => None,
250 }
251 }
252}
253
254impl From<WsText> for WsMessage {
255 fn from(text: WsText) -> Self {
256 WsMessage::Text(text)
257 }
258}
259
260impl From<String> for WsMessage {
261 fn from(s: String) -> Self {
262 WsMessage::Text(WsText::from(s))
263 }
264}
265
266impl From<&str> for WsMessage {
267 fn from(s: &str) -> Self {
268 WsMessage::Text(WsText::from(s))
269 }
270}
271
272impl From<Bytes> for WsMessage {
273 fn from(bytes: Bytes) -> Self {
274 WsMessage::Binary(bytes)
275 }
276}
277
278impl From<Vec<u8>> for WsMessage {
279 fn from(vec: Vec<u8>) -> Self {
280 WsMessage::Binary(Bytes::from(vec))
281 }
282}
283
284/// WebSocket message stream
285#[cfg(not(target_arch = "wasm32"))]
286pub struct WsStream(Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>> + Send>>);
287
288/// WebSocket message stream
289#[cfg(target_arch = "wasm32")]
290pub struct WsStream(Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>>>>);
291
292impl WsStream {
293 /// Create a new message stream
294 #[cfg(not(target_arch = "wasm32"))]
295 pub fn new<S>(stream: S) -> Self
296 where
297 S: Stream<Item = Result<WsMessage, StreamError>> + Send + 'static,
298 {
299 Self(Box::pin(stream))
300 }
301
302 /// Create a new message stream
303 #[cfg(target_arch = "wasm32")]
304 pub fn new<S>(stream: S) -> Self
305 where
306 S: Stream<Item = Result<WsMessage, StreamError>> + 'static,
307 {
308 Self(Box::pin(stream))
309 }
310
311 /// Convert into the inner pinned boxed stream
312 #[cfg(not(target_arch = "wasm32"))]
313 pub fn into_inner(self) -> Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>> + Send>> {
314 self.0
315 }
316
317 /// Convert into the inner pinned boxed stream
318 #[cfg(target_arch = "wasm32")]
319 pub fn into_inner(self) -> Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>>>> {
320 self.0
321 }
322
323 /// Split this stream into two streams that both receive all messages
324 ///
325 /// Messages are cloned (cheaply via Bytes rc). Spawns a forwarder task.
326 /// Both returned streams will receive all messages from the original stream.
327 /// The forwarder continues as long as at least one stream is alive.
328 /// If the underlying stream errors, both teed streams will end.
329 pub fn tee(self) -> (WsStream, WsStream) {
330 use futures::channel::mpsc;
331 use n0_future::StreamExt as _;
332
333 let (tx1, rx1) = mpsc::unbounded();
334 let (tx2, rx2) = mpsc::unbounded();
335
336 n0_future::task::spawn(async move {
337 let mut stream = self.0;
338 while let Some(result) = stream.next().await {
339 match result {
340 Ok(msg) => {
341 // Clone message (cheap - Bytes is rc'd)
342 let msg2 = msg.clone();
343
344 // Send to both channels, continue if at least one succeeds
345 let send1 = tx1.unbounded_send(Ok(msg));
346 let send2 = tx2.unbounded_send(Ok(msg2));
347
348 // Only stop if both channels are closed
349 if send1.is_err() && send2.is_err() {
350 break;
351 }
352 }
353 Err(_e) => {
354 // Underlying stream errored, stop forwarding.
355 // Both channels will close, ending both streams.
356 break;
357 }
358 }
359 }
360 });
361
362 (WsStream::new(rx1), WsStream::new(rx2))
363 }
364}
365
366impl fmt::Debug for WsStream {
367 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
368 f.debug_struct("WsStream").finish_non_exhaustive()
369 }
370}
371
372/// WebSocket message sink
373#[cfg(not(target_arch = "wasm32"))]
374pub struct WsSink(Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + Send>>);
375
376/// WebSocket message sink
377#[cfg(target_arch = "wasm32")]
378pub struct WsSink(Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError>>>);
379
380impl WsSink {
381 /// Create a new message sink
382 #[cfg(not(target_arch = "wasm32"))]
383 pub fn new<S>(sink: S) -> Self
384 where
385 S: n0_future::Sink<WsMessage, Error = StreamError> + Send + 'static,
386 {
387 Self(Box::pin(sink))
388 }
389
390 /// Create a new message sink
391 #[cfg(target_arch = "wasm32")]
392 pub fn new<S>(sink: S) -> Self
393 where
394 S: n0_future::Sink<WsMessage, Error = StreamError> + 'static,
395 {
396 Self(Box::pin(sink))
397 }
398
399 /// Convert into the inner boxed sink
400 #[cfg(not(target_arch = "wasm32"))]
401 pub fn into_inner(
402 self,
403 ) -> Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + Send>> {
404 self.0
405 }
406
407 /// Convert into the inner boxed sink
408 #[cfg(target_arch = "wasm32")]
409 pub fn into_inner(self) -> Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError>>> {
410 self.0
411 }
412
413 /// get a mutable reference to the inner boxed sink
414 #[cfg(not(target_arch = "wasm32"))]
415 pub fn get_mut(
416 &mut self,
417 ) -> &mut Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + Send>> {
418 use std::borrow::BorrowMut;
419
420 self.0.borrow_mut()
421 }
422
423 /// get a mutable reference to the inner boxed sink
424 #[cfg(target_arch = "wasm32")]
425 pub fn get_mut(
426 &mut self,
427 ) -> &mut Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + 'static>> {
428 use std::borrow::BorrowMut;
429
430 self.0.borrow_mut()
431 }
432}
433
434impl fmt::Debug for WsSink {
435 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
436 f.debug_struct("WsSink").finish_non_exhaustive()
437 }
438}
439
440/// WebSocket client trait
441#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))]
442pub trait WebSocketClient: Sync {
443 /// Error type for WebSocket operations
444 type Error: std::error::Error + Send + Sync + 'static;
445
446 /// Connect to a WebSocket endpoint
447 fn connect(&self, url: Url) -> impl Future<Output = Result<WebSocketConnection, Self::Error>>;
448
449 /// Connect to a WebSocket endpoint with custom headers
450 ///
451 /// Default implementation ignores headers and calls `connect()`.
452 /// Override this method to support authentication headers for subscriptions.
453 fn connect_with_headers(
454 &self,
455 url: Url,
456 _headers: Vec<(CowStr<'_>, CowStr<'_>)>,
457 ) -> impl Future<Output = Result<WebSocketConnection, Self::Error>> {
458 async move { self.connect(url).await }
459 }
460}
461
462/// WebSocket connection with bidirectional streams
463pub struct WebSocketConnection {
464 tx: WsSink,
465 rx: WsStream,
466}
467
468impl WebSocketConnection {
469 /// Create a new WebSocket connection
470 pub fn new(tx: WsSink, rx: WsStream) -> Self {
471 Self { tx, rx }
472 }
473
474 /// Get mutable access to the sender
475 pub fn sender_mut(&mut self) -> &mut WsSink {
476 &mut self.tx
477 }
478
479 /// Get mutable access to the receiver
480 pub fn receiver_mut(&mut self) -> &mut WsStream {
481 &mut self.rx
482 }
483
484 /// Get a reference to the receiver
485 pub fn receiver(&self) -> &WsStream {
486 &self.rx
487 }
488
489 /// Get a reference to the sender
490 pub fn sender(&self) -> &WsSink {
491 &self.tx
492 }
493
494 /// Split into sender and receiver
495 pub fn split(self) -> (WsSink, WsStream) {
496 (self.tx, self.rx)
497 }
498
499 /// Check if connection is open (always true for this abstraction)
500 pub fn is_open(&self) -> bool {
501 true
502 }
503}
504
505impl fmt::Debug for WebSocketConnection {
506 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
507 f.debug_struct("WebSocketConnection")
508 .finish_non_exhaustive()
509 }
510}
511
512/// Concrete WebSocket client implementation using tokio-tungstenite-wasm
513pub mod tungstenite_client {
514 use super::*;
515 use crate::IntoStatic;
516 use futures::{SinkExt, StreamExt};
517
518 /// WebSocket client backed by tokio-tungstenite-wasm
519 #[derive(Debug, Clone, Default)]
520 pub struct TungsteniteClient;
521
522 impl TungsteniteClient {
523 /// Create a new tungstenite WebSocket client
524 pub fn new() -> Self {
525 Self
526 }
527 }
528
529 impl WebSocketClient for TungsteniteClient {
530 type Error = tokio_tungstenite_wasm::Error;
531
532 async fn connect(&self, url: Url) -> Result<WebSocketConnection, Self::Error> {
533 let ws_stream = tokio_tungstenite_wasm::connect(url.as_str()).await?;
534
535 let (sink, stream) = ws_stream.split();
536
537 // Convert tungstenite messages to our WsMessage
538 let rx_stream = stream.filter_map(|result| async move {
539 match result {
540 Ok(msg) => match convert_message(msg) {
541 Some(ws_msg) => Some(Ok(ws_msg)),
542 None => None, // Skip ping/pong
543 },
544 Err(e) => Some(Err(StreamError::transport(e))),
545 }
546 });
547
548 let rx = WsStream::new(rx_stream);
549
550 // Convert our WsMessage to tungstenite messages
551 let tx_sink = sink.with(|msg: WsMessage| async move {
552 Ok::<_, tokio_tungstenite_wasm::Error>(msg.into())
553 });
554
555 let tx_sink_mapped = tx_sink.sink_map_err(|e| StreamError::transport(e));
556 let tx = WsSink::new(tx_sink_mapped);
557
558 Ok(WebSocketConnection::new(tx, rx))
559 }
560 }
561
562 /// Convert tokio-tungstenite-wasm Message to our WsMessage
563 /// Returns None for Ping/Pong which we auto-handle
564 fn convert_message(msg: tokio_tungstenite_wasm::Message) -> Option<WsMessage> {
565 use tokio_tungstenite_wasm::Message;
566
567 match msg {
568 Message::Text(vec) => {
569 // tokio-tungstenite-wasm Text contains Vec<u8> (UTF-8 validated)
570 let bytes = Bytes::from(vec);
571 Some(WsMessage::Text(unsafe {
572 WsText::from_bytes_unchecked(bytes)
573 }))
574 }
575 Message::Binary(vec) => Some(WsMessage::Binary(Bytes::from(vec))),
576 Message::Close(frame) => {
577 let close_frame = frame.map(|f| {
578 let code = convert_close_code(f.code);
579 CloseFrame::new(code, CowStr::from(f.reason.into_owned()))
580 });
581 Some(WsMessage::Close(close_frame))
582 }
583 }
584 }
585
586 /// Convert tokio-tungstenite-wasm CloseCode to our CloseCode
587 fn convert_close_code(code: tokio_tungstenite_wasm::CloseCode) -> CloseCode {
588 use tokio_tungstenite_wasm::CloseCode as TungsteniteCode;
589
590 match code {
591 TungsteniteCode::Normal => CloseCode::Normal,
592 TungsteniteCode::Away => CloseCode::Away,
593 TungsteniteCode::Protocol => CloseCode::Protocol,
594 TungsteniteCode::Unsupported => CloseCode::Unsupported,
595 TungsteniteCode::Invalid => CloseCode::Invalid,
596 TungsteniteCode::Policy => CloseCode::Policy,
597 TungsteniteCode::Size => CloseCode::Size,
598 TungsteniteCode::Extension => CloseCode::Extension,
599 TungsteniteCode::Error => CloseCode::Error,
600 TungsteniteCode::Tls => CloseCode::Tls,
601 // For other variants, extract raw code
602 other => {
603 let raw: u16 = other.into();
604 CloseCode::from(raw)
605 }
606 }
607 }
608
609 impl From<WsMessage> for tokio_tungstenite_wasm::Message {
610 fn from(msg: WsMessage) -> Self {
611 use tokio_tungstenite_wasm::Message;
612
613 match msg {
614 WsMessage::Text(text) => {
615 // tokio-tungstenite-wasm Text expects String
616 let bytes = text.into_bytes();
617 // Safe: WsText is already UTF-8 validated
618 let string = unsafe { String::from_utf8_unchecked(bytes.to_vec()) };
619 Message::Text(string)
620 }
621 WsMessage::Binary(bytes) => Message::Binary(bytes.to_vec()),
622 WsMessage::Close(frame) => {
623 let close_frame = frame.map(|f| {
624 let code = u16::from(f.code).into();
625 tokio_tungstenite_wasm::CloseFrame {
626 code,
627 reason: f.reason.into_static().to_string().into(),
628 }
629 });
630 Message::Close(close_frame)
631 }
632 }
633 }
634 }
635}
636
637#[cfg(test)]
638mod tests {
639 use super::*;
640
641 #[test]
642 fn ws_text_from_string() {
643 let text = WsText::from("hello");
644 assert_eq!(text.as_str(), "hello");
645 }
646
647 #[test]
648 fn ws_text_deref() {
649 let text = WsText::from(String::from("world"));
650 assert_eq!(&*text, "world");
651 }
652
653 #[test]
654 fn ws_text_try_from_bytes() {
655 let bytes = Bytes::from("test");
656 let text = WsText::try_from(bytes).unwrap();
657 assert_eq!(text.as_str(), "test");
658 }
659
660 #[test]
661 fn ws_text_invalid_utf8() {
662 let bytes = Bytes::from(vec![0xFF, 0xFE]);
663 assert!(WsText::try_from(bytes).is_err());
664 }
665
666 #[test]
667 fn ws_message_text() {
668 let msg = WsMessage::from("hello");
669 assert!(msg.is_text());
670 assert_eq!(msg.as_text(), Some("hello"));
671 }
672
673 #[test]
674 fn ws_message_binary() {
675 let msg = WsMessage::from(vec![1, 2, 3]);
676 assert!(msg.is_binary());
677 assert_eq!(msg.as_bytes(), Some(&[1u8, 2, 3][..]));
678 }
679
680 #[test]
681 fn close_code_conversion() {
682 assert_eq!(u16::from(CloseCode::Normal), 1000);
683 assert_eq!(CloseCode::from(1000), CloseCode::Normal);
684 assert_eq!(CloseCode::from(9999), CloseCode::Other(9999));
685 }
686
687 #[test]
688 fn websocket_connection_has_tx_and_rx() {
689 use futures::sink::SinkExt;
690 use futures::stream;
691
692 let rx_stream = stream::iter(vec![Ok(WsMessage::from("test"))]);
693 let rx = WsStream::new(rx_stream);
694
695 let drain_sink = futures::sink::drain()
696 .sink_map_err(|_: std::convert::Infallible| StreamError::closed());
697 let tx = WsSink::new(drain_sink);
698
699 let conn = WebSocketConnection::new(tx, rx);
700 assert!(conn.is_open());
701 }
702}