A better Rust ATProto crate
1//! WebSocket client abstraction
2
3use crate::CowStr;
4use crate::stream::StreamError;
5use alloc::boxed::Box;
6use alloc::string::String;
7use alloc::string::ToString;
8use alloc::vec::Vec;
9use bytes::Bytes;
10use core::borrow::Borrow;
11use core::fmt::{self, Display};
12use core::future::Future;
13use core::ops::Deref;
14use core::pin::Pin;
15use n0_future::Stream;
16use url::Url;
17
18/// UTF-8 validated bytes for WebSocket text messages
19#[repr(transparent)]
20#[derive(Debug, Clone, Eq, PartialEq, Hash, PartialOrd, Ord)]
21pub struct WsText(Bytes);
22
23impl WsText {
24 /// Create from static string
25 pub const fn from_static(s: &'static str) -> Self {
26 Self(Bytes::from_static(s.as_bytes()))
27 }
28
29 /// Get as string slice
30 pub fn as_str(&self) -> &str {
31 unsafe { core::str::from_utf8_unchecked(&self.0) }
32 }
33
34 /// Create from bytes without validation (caller must ensure UTF-8)
35 ///
36 /// # Safety
37 /// Bytes must be valid UTF-8
38 pub unsafe fn from_bytes_unchecked(bytes: Bytes) -> Self {
39 Self(bytes)
40 }
41
42 /// Convert into underlying bytes
43 pub fn into_bytes(self) -> Bytes {
44 self.0
45 }
46}
47
48impl Deref for WsText {
49 type Target = str;
50 fn deref(&self) -> &str {
51 self.as_str()
52 }
53}
54
55impl AsRef<str> for WsText {
56 fn as_ref(&self) -> &str {
57 self.as_str()
58 }
59}
60
61impl AsRef<[u8]> for WsText {
62 fn as_ref(&self) -> &[u8] {
63 &self.0
64 }
65}
66
67impl AsRef<Bytes> for WsText {
68 fn as_ref(&self) -> &Bytes {
69 &self.0
70 }
71}
72
73impl Borrow<str> for WsText {
74 fn borrow(&self) -> &str {
75 self.as_str()
76 }
77}
78
79impl Display for WsText {
80 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
81 Display::fmt(self.as_str(), f)
82 }
83}
84
85impl From<String> for WsText {
86 fn from(s: String) -> Self {
87 Self(Bytes::from(s))
88 }
89}
90
91impl From<&str> for WsText {
92 fn from(s: &str) -> Self {
93 Self(Bytes::copy_from_slice(s.as_bytes()))
94 }
95}
96
97impl From<&String> for WsText {
98 fn from(s: &String) -> Self {
99 Self::from(s.as_str())
100 }
101}
102
103impl TryFrom<Bytes> for WsText {
104 type Error = core::str::Utf8Error;
105 fn try_from(bytes: Bytes) -> Result<Self, Self::Error> {
106 core::str::from_utf8(&bytes)?;
107 Ok(Self(bytes))
108 }
109}
110
111impl TryFrom<Vec<u8>> for WsText {
112 type Error = core::str::Utf8Error;
113 fn try_from(vec: Vec<u8>) -> Result<Self, Self::Error> {
114 Self::try_from(Bytes::from(vec))
115 }
116}
117
118impl From<WsText> for Bytes {
119 fn from(t: WsText) -> Bytes {
120 t.0
121 }
122}
123
124impl Default for WsText {
125 fn default() -> Self {
126 Self(Bytes::new())
127 }
128}
129
130/// WebSocket close code
131#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
132#[repr(u16)]
133pub enum CloseCode {
134 /// Normal closure
135 Normal = 1000,
136 /// Endpoint going away
137 Away = 1001,
138 /// Protocol error
139 Protocol = 1002,
140 /// Unsupported data
141 Unsupported = 1003,
142 /// Invalid frame payload data
143 Invalid = 1007,
144 /// Policy violation
145 Policy = 1008,
146 /// Message too big
147 Size = 1009,
148 /// Extension negotiation failure
149 Extension = 1010,
150 /// Unexpected condition
151 Error = 1011,
152 /// TLS handshake failure
153 Tls = 1015,
154 /// Other code
155 Other(u16),
156}
157
158impl From<u16> for CloseCode {
159 fn from(code: u16) -> Self {
160 match code {
161 1000 => CloseCode::Normal,
162 1001 => CloseCode::Away,
163 1002 => CloseCode::Protocol,
164 1003 => CloseCode::Unsupported,
165 1007 => CloseCode::Invalid,
166 1008 => CloseCode::Policy,
167 1009 => CloseCode::Size,
168 1010 => CloseCode::Extension,
169 1011 => CloseCode::Error,
170 1015 => CloseCode::Tls,
171 other => CloseCode::Other(other),
172 }
173 }
174}
175
176impl From<CloseCode> for u16 {
177 fn from(code: CloseCode) -> u16 {
178 match code {
179 CloseCode::Normal => 1000,
180 CloseCode::Away => 1001,
181 CloseCode::Protocol => 1002,
182 CloseCode::Unsupported => 1003,
183 CloseCode::Invalid => 1007,
184 CloseCode::Policy => 1008,
185 CloseCode::Size => 1009,
186 CloseCode::Extension => 1010,
187 CloseCode::Error => 1011,
188 CloseCode::Tls => 1015,
189 CloseCode::Other(code) => code,
190 }
191 }
192}
193
194/// WebSocket close frame
195#[derive(Debug, Clone, PartialEq, Eq)]
196pub struct CloseFrame<'a> {
197 /// Close code
198 pub code: CloseCode,
199 /// Close reason text
200 pub reason: CowStr<'a>,
201}
202
203impl<'a> CloseFrame<'a> {
204 /// Create a new close frame
205 pub fn new(code: CloseCode, reason: impl Into<CowStr<'a>>) -> Self {
206 Self {
207 code,
208 reason: reason.into(),
209 }
210 }
211}
212
213/// WebSocket message
214#[derive(Debug, Clone, PartialEq, Eq)]
215pub enum WsMessage {
216 /// Text message (UTF-8)
217 Text(WsText),
218 /// Binary message
219 Binary(Bytes),
220 /// Close frame
221 Close(Option<CloseFrame<'static>>),
222}
223
224impl WsMessage {
225 /// Check if this is a text message
226 pub fn is_text(&self) -> bool {
227 matches!(self, WsMessage::Text(_))
228 }
229
230 /// Check if this is a binary message
231 pub fn is_binary(&self) -> bool {
232 matches!(self, WsMessage::Binary(_))
233 }
234
235 /// Check if this is a close message
236 pub fn is_close(&self) -> bool {
237 matches!(self, WsMessage::Close(_))
238 }
239
240 /// Get as text, if this is a text message
241 pub fn as_text(&self) -> Option<&str> {
242 match self {
243 WsMessage::Text(t) => Some(t.as_str()),
244 _ => None,
245 }
246 }
247
248 /// Get as bytes
249 pub fn as_bytes(&self) -> Option<&[u8]> {
250 match self {
251 WsMessage::Text(t) => Some(t.as_ref()),
252 WsMessage::Binary(b) => Some(b),
253 WsMessage::Close(_) => None,
254 }
255 }
256}
257
258impl From<WsText> for WsMessage {
259 fn from(text: WsText) -> Self {
260 WsMessage::Text(text)
261 }
262}
263
264impl From<String> for WsMessage {
265 fn from(s: String) -> Self {
266 WsMessage::Text(WsText::from(s))
267 }
268}
269
270impl From<&str> for WsMessage {
271 fn from(s: &str) -> Self {
272 WsMessage::Text(WsText::from(s))
273 }
274}
275
276impl From<Bytes> for WsMessage {
277 fn from(bytes: Bytes) -> Self {
278 WsMessage::Binary(bytes)
279 }
280}
281
282impl From<Vec<u8>> for WsMessage {
283 fn from(vec: Vec<u8>) -> Self {
284 WsMessage::Binary(Bytes::from(vec))
285 }
286}
287
288/// WebSocket message stream
289#[cfg(not(target_arch = "wasm32"))]
290pub struct WsStream(Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>> + Send>>);
291
292/// WebSocket message stream
293#[cfg(target_arch = "wasm32")]
294pub struct WsStream(Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>>>>);
295
296impl WsStream {
297 /// Create a new message stream
298 #[cfg(not(target_arch = "wasm32"))]
299 pub fn new<S>(stream: S) -> Self
300 where
301 S: Stream<Item = Result<WsMessage, StreamError>> + Send + 'static,
302 {
303 Self(Box::pin(stream))
304 }
305
306 /// Create a new message stream
307 #[cfg(target_arch = "wasm32")]
308 pub fn new<S>(stream: S) -> Self
309 where
310 S: Stream<Item = Result<WsMessage, StreamError>> + 'static,
311 {
312 Self(Box::pin(stream))
313 }
314
315 /// Convert into the inner pinned boxed stream
316 #[cfg(not(target_arch = "wasm32"))]
317 pub fn into_inner(self) -> Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>> + Send>> {
318 self.0
319 }
320
321 /// Convert into the inner pinned boxed stream
322 #[cfg(target_arch = "wasm32")]
323 pub fn into_inner(self) -> Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>>>> {
324 self.0
325 }
326
327 /// Split this stream into two streams that both receive all messages
328 ///
329 /// Messages are cloned (cheaply via Bytes rc). Spawns a forwarder task.
330 /// Both returned streams will receive all messages from the original stream.
331 /// The forwarder continues as long as at least one stream is alive.
332 /// If the underlying stream errors, both teed streams will end.
333 pub fn tee(self) -> (WsStream, WsStream) {
334 use futures::channel::mpsc;
335 use n0_future::StreamExt as _;
336
337 let (tx1, rx1) = mpsc::unbounded();
338 let (tx2, rx2) = mpsc::unbounded();
339
340 n0_future::task::spawn(async move {
341 let mut stream = self.0;
342 while let Some(result) = stream.next().await {
343 match result {
344 Ok(msg) => {
345 // Clone message (cheap - Bytes is rc'd)
346 let msg2 = msg.clone();
347
348 // Send to both channels, continue if at least one succeeds
349 let send1 = tx1.unbounded_send(Ok(msg));
350 let send2 = tx2.unbounded_send(Ok(msg2));
351
352 // Only stop if both channels are closed
353 if send1.is_err() && send2.is_err() {
354 break;
355 }
356 }
357 Err(_e) => {
358 // Underlying stream errored, stop forwarding.
359 // Both channels will close, ending both streams.
360 break;
361 }
362 }
363 }
364 });
365
366 (WsStream::new(rx1), WsStream::new(rx2))
367 }
368}
369
370impl fmt::Debug for WsStream {
371 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
372 f.debug_struct("WsStream").finish_non_exhaustive()
373 }
374}
375
376/// WebSocket message sink
377#[cfg(not(target_arch = "wasm32"))]
378pub struct WsSink(Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + Send>>);
379
380/// WebSocket message sink
381#[cfg(target_arch = "wasm32")]
382pub struct WsSink(Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError>>>);
383
384impl WsSink {
385 /// Create a new message sink
386 #[cfg(not(target_arch = "wasm32"))]
387 pub fn new<S>(sink: S) -> Self
388 where
389 S: n0_future::Sink<WsMessage, Error = StreamError> + Send + 'static,
390 {
391 Self(Box::pin(sink))
392 }
393
394 /// Create a new message sink
395 #[cfg(target_arch = "wasm32")]
396 pub fn new<S>(sink: S) -> Self
397 where
398 S: n0_future::Sink<WsMessage, Error = StreamError> + 'static,
399 {
400 Self(Box::pin(sink))
401 }
402
403 /// Convert into the inner boxed sink
404 #[cfg(not(target_arch = "wasm32"))]
405 pub fn into_inner(
406 self,
407 ) -> Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + Send>> {
408 self.0
409 }
410
411 /// Convert into the inner boxed sink
412 #[cfg(target_arch = "wasm32")]
413 pub fn into_inner(self) -> Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError>>> {
414 self.0
415 }
416
417 /// get a mutable reference to the inner boxed sink
418 #[cfg(not(target_arch = "wasm32"))]
419 pub fn get_mut(
420 &mut self,
421 ) -> &mut Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + Send>> {
422 use core::borrow::BorrowMut;
423
424 self.0.borrow_mut()
425 }
426
427 /// get a mutable reference to the inner boxed sink
428 #[cfg(target_arch = "wasm32")]
429 pub fn get_mut(
430 &mut self,
431 ) -> &mut Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + 'static>> {
432 use core::borrow::BorrowMut;
433
434 self.0.borrow_mut()
435 }
436}
437
438impl fmt::Debug for WsSink {
439 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
440 f.debug_struct("WsSink").finish_non_exhaustive()
441 }
442}
443
444/// WebSocket client trait
445#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))]
446pub trait WebSocketClient: Sync {
447 /// Error type for WebSocket operations
448 type Error: core::error::Error + Send + Sync + 'static;
449
450 /// Connect to a WebSocket endpoint
451 fn connect(&self, url: Url) -> impl Future<Output = Result<WebSocketConnection, Self::Error>>;
452
453 /// Connect to a WebSocket endpoint with custom headers
454 ///
455 /// Default implementation ignores headers and calls `connect()`.
456 /// Override this method to support authentication headers for subscriptions.
457 fn connect_with_headers(
458 &self,
459 url: Url,
460 _headers: Vec<(CowStr<'_>, CowStr<'_>)>,
461 ) -> impl Future<Output = Result<WebSocketConnection, Self::Error>> {
462 async move { self.connect(url).await }
463 }
464}
465
466/// WebSocket connection with bidirectional streams
467pub struct WebSocketConnection {
468 tx: WsSink,
469 rx: WsStream,
470}
471
472impl WebSocketConnection {
473 /// Create a new WebSocket connection
474 pub fn new(tx: WsSink, rx: WsStream) -> Self {
475 Self { tx, rx }
476 }
477
478 /// Get mutable access to the sender
479 pub fn sender_mut(&mut self) -> &mut WsSink {
480 &mut self.tx
481 }
482
483 /// Get mutable access to the receiver
484 pub fn receiver_mut(&mut self) -> &mut WsStream {
485 &mut self.rx
486 }
487
488 /// Get a reference to the receiver
489 pub fn receiver(&self) -> &WsStream {
490 &self.rx
491 }
492
493 /// Get a reference to the sender
494 pub fn sender(&self) -> &WsSink {
495 &self.tx
496 }
497
498 /// Split into sender and receiver
499 pub fn split(self) -> (WsSink, WsStream) {
500 (self.tx, self.rx)
501 }
502
503 /// Check if connection is open (always true for this abstraction)
504 pub fn is_open(&self) -> bool {
505 true
506 }
507}
508
509impl fmt::Debug for WebSocketConnection {
510 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
511 f.debug_struct("WebSocketConnection")
512 .finish_non_exhaustive()
513 }
514}
515
516/// Concrete WebSocket client implementation using tokio-tungstenite-wasm
517pub mod tungstenite_client {
518 use super::*;
519 use crate::IntoStatic;
520 use futures::{SinkExt, StreamExt};
521
522 /// WebSocket client backed by tokio-tungstenite-wasm
523 #[derive(Debug, Clone, Default)]
524 pub struct TungsteniteClient;
525
526 impl TungsteniteClient {
527 /// Create a new tungstenite WebSocket client
528 pub fn new() -> Self {
529 Self
530 }
531 }
532
533 impl WebSocketClient for TungsteniteClient {
534 type Error = tokio_tungstenite_wasm::Error;
535
536 async fn connect(&self, url: Url) -> Result<WebSocketConnection, Self::Error> {
537 let ws_stream = tokio_tungstenite_wasm::connect(url.as_str()).await?;
538
539 let (sink, stream) = ws_stream.split();
540
541 // Convert tungstenite messages to our WsMessage
542 let rx_stream = stream.filter_map(|result| async move {
543 match result {
544 Ok(msg) => match convert_message(msg) {
545 Some(ws_msg) => Some(Ok(ws_msg)),
546 None => None, // Skip ping/pong
547 },
548 Err(e) => Some(Err(StreamError::transport(e))),
549 }
550 });
551
552 let rx = WsStream::new(rx_stream);
553
554 // Convert our WsMessage to tungstenite messages
555 let tx_sink = sink.with(|msg: WsMessage| async move {
556 Ok::<_, tokio_tungstenite_wasm::Error>(msg.into())
557 });
558
559 let tx_sink_mapped = tx_sink.sink_map_err(|e| StreamError::transport(e));
560 let tx = WsSink::new(tx_sink_mapped);
561
562 Ok(WebSocketConnection::new(tx, rx))
563 }
564 }
565
566 /// Convert tokio-tungstenite-wasm Message to our WsMessage
567 /// Returns None for Ping/Pong which we auto-handle
568 fn convert_message(msg: tokio_tungstenite_wasm::Message) -> Option<WsMessage> {
569 use tokio_tungstenite_wasm::Message;
570
571 match msg {
572 Message::Text(vec) => {
573 // tokio-tungstenite-wasm Text contains Vec<u8> (UTF-8 validated)
574 let bytes = Bytes::from(vec);
575 Some(WsMessage::Text(unsafe {
576 WsText::from_bytes_unchecked(bytes)
577 }))
578 }
579 Message::Binary(vec) => Some(WsMessage::Binary(Bytes::from(vec))),
580 Message::Close(frame) => {
581 let close_frame = frame.map(|f| {
582 let code = convert_close_code(f.code);
583 CloseFrame::new(code, CowStr::from(f.reason.into_owned()))
584 });
585 Some(WsMessage::Close(close_frame))
586 }
587 }
588 }
589
590 /// Convert tokio-tungstenite-wasm CloseCode to our CloseCode
591 fn convert_close_code(code: tokio_tungstenite_wasm::CloseCode) -> CloseCode {
592 use tokio_tungstenite_wasm::CloseCode as TungsteniteCode;
593
594 match code {
595 TungsteniteCode::Normal => CloseCode::Normal,
596 TungsteniteCode::Away => CloseCode::Away,
597 TungsteniteCode::Protocol => CloseCode::Protocol,
598 TungsteniteCode::Unsupported => CloseCode::Unsupported,
599 TungsteniteCode::Invalid => CloseCode::Invalid,
600 TungsteniteCode::Policy => CloseCode::Policy,
601 TungsteniteCode::Size => CloseCode::Size,
602 TungsteniteCode::Extension => CloseCode::Extension,
603 TungsteniteCode::Error => CloseCode::Error,
604 TungsteniteCode::Tls => CloseCode::Tls,
605 // For other variants, extract raw code
606 other => {
607 let raw: u16 = other.into();
608 CloseCode::from(raw)
609 }
610 }
611 }
612
613 impl From<WsMessage> for tokio_tungstenite_wasm::Message {
614 fn from(msg: WsMessage) -> Self {
615 use tokio_tungstenite_wasm::Message;
616
617 match msg {
618 WsMessage::Text(text) => {
619 // tokio-tungstenite-wasm Text expects String
620 let bytes = text.into_bytes();
621 // Safe: WsText is already UTF-8 validated
622 let string = unsafe { String::from_utf8_unchecked(bytes.to_vec()) };
623 Message::Text(string)
624 }
625 WsMessage::Binary(bytes) => Message::Binary(bytes.to_vec()),
626 WsMessage::Close(frame) => {
627 let close_frame = frame.map(|f| {
628 let code = u16::from(f.code).into();
629 tokio_tungstenite_wasm::CloseFrame {
630 code,
631 reason: f.reason.into_static().to_string().into(),
632 }
633 });
634 Message::Close(close_frame)
635 }
636 }
637 }
638 }
639}
640
641#[cfg(test)]
642mod tests {
643 use super::*;
644
645 #[test]
646 fn ws_text_from_string() {
647 let text = WsText::from("hello");
648 assert_eq!(text.as_str(), "hello");
649 }
650
651 #[test]
652 fn ws_text_deref() {
653 let text = WsText::from(String::from("world"));
654 assert_eq!(&*text, "world");
655 }
656
657 #[test]
658 fn ws_text_try_from_bytes() {
659 let bytes = Bytes::from("test");
660 let text = WsText::try_from(bytes).unwrap();
661 assert_eq!(text.as_str(), "test");
662 }
663
664 #[test]
665 fn ws_text_invalid_utf8() {
666 let bytes = Bytes::from(vec![0xFF, 0xFE]);
667 assert!(WsText::try_from(bytes).is_err());
668 }
669
670 #[test]
671 fn ws_message_text() {
672 let msg = WsMessage::from("hello");
673 assert!(msg.is_text());
674 assert_eq!(msg.as_text(), Some("hello"));
675 }
676
677 #[test]
678 fn ws_message_binary() {
679 let msg = WsMessage::from(vec![1, 2, 3]);
680 assert!(msg.is_binary());
681 assert_eq!(msg.as_bytes(), Some(&[1u8, 2, 3][..]));
682 }
683
684 #[test]
685 fn close_code_conversion() {
686 assert_eq!(u16::from(CloseCode::Normal), 1000);
687 assert_eq!(CloseCode::from(1000), CloseCode::Normal);
688 assert_eq!(CloseCode::from(9999), CloseCode::Other(9999));
689 }
690
691 #[test]
692 fn websocket_connection_has_tx_and_rx() {
693 use futures::sink::SinkExt;
694 use futures::stream;
695
696 let rx_stream = stream::iter(vec![Ok(WsMessage::from("test"))]);
697 let rx = WsStream::new(rx_stream);
698
699 let drain_sink = futures::sink::drain()
700 .sink_map_err(|_: std::convert::Infallible| StreamError::closed());
701 let tx = WsSink::new(drain_sink);
702
703 let conn = WebSocketConnection::new(tx, rx);
704 assert!(conn.is_open());
705 }
706}