A better Rust ATProto crate
1//! WebSocket subscription support for XRPC
2//!
3//! This module defines traits and types for typed WebSocket subscriptions,
4//! mirroring the request/response pattern used for HTTP XRPC endpoints.
5
6use alloc::borrow::ToOwned;
7use alloc::string::String;
8use alloc::string::ToString;
9use alloc::vec::Vec;
10use core::error::Error;
11use core::future::Future;
12use core::marker::PhantomData;
13#[cfg(not(target_arch = "wasm32"))]
14use n0_future::stream::Boxed;
15#[cfg(target_arch = "wasm32")]
16use n0_future::stream::BoxedLocal as Boxed;
17use serde::{Deserialize, Serialize};
18use url::Url;
19
20use crate::cowstr::ToCowStr;
21use crate::error::DecodeError;
22use crate::stream::StreamError;
23use crate::websocket::{WebSocketClient, WebSocketConnection, WsSink, WsStream};
24use crate::{CowStr, Data, IntoStatic, RawData, WsMessage};
25
26/// Encoding format for subscription messages
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum MessageEncoding {
29 /// JSON text frames
30 Json,
31 /// DAG-CBOR binary frames
32 DagCbor,
33}
34
35/// XRPC subscription stream response trait
36///
37/// Analogous to `XrpcResp` but for WebSocket subscriptions.
38/// Defines the message and error types for a subscription stream.
39///
40/// This trait is implemented on a marker struct to keep it lifetime-free
41/// while using GATs for the message/error types.
42pub trait SubscriptionResp {
43 /// The NSID for this subscription
44 const NSID: &'static str;
45
46 /// Message encoding (JSON or DAG-CBOR)
47 const ENCODING: MessageEncoding;
48
49 /// Message union type
50 type Message<'de>: Deserialize<'de> + IntoStatic;
51
52 /// Error union type
53 type Error<'de>: Error + Deserialize<'de> + IntoStatic;
54
55 /// Decode a message from bytes.
56 ///
57 /// Default implementation uses simple deserialization via serde.
58 /// Subscriptions that use framed encoding (header + body) can override
59 /// this to do two-stage deserialization.
60 fn decode_message<'de>(bytes: &'de [u8]) -> Result<Self::Message<'de>, DecodeError> {
61 match Self::ENCODING {
62 MessageEncoding::Json => serde_json::from_slice(bytes).map_err(DecodeError::from),
63 MessageEncoding::DagCbor => {
64 serde_ipld_dagcbor::from_slice(bytes).map_err(DecodeError::from)
65 }
66 }
67 }
68}
69
70/// XRPC subscription (WebSocket)
71///
72/// This trait is analogous to `XrpcRequest` but for WebSocket subscriptions.
73/// It defines the NSID and associated stream response type.
74///
75/// The trait is implemented on the subscription parameters type.
76pub trait XrpcSubscription: Serialize {
77 /// The NSID for this XRPC subscription
78 const NSID: &'static str;
79
80 /// Message encoding (JSON or DAG-CBOR)
81 const ENCODING: MessageEncoding;
82
83 /// Custom path override (e.g., "/subscribe" for Jetstream).
84 /// If None, defaults to "/xrpc/{NSID}"
85 const CUSTOM_PATH: Option<&'static str> = None;
86
87 /// Stream response type (marker struct)
88 type Stream: SubscriptionResp;
89
90 /// Encode query params for WebSocket URL
91 ///
92 /// Default implementation uses serde_html_form to encode the struct as query parameters.
93 fn query_params(&self) -> Vec<(String, String)> {
94 // Default: use serde_html_form to encode self
95 serde_html_form::to_string(self)
96 .ok()
97 .map(|s| {
98 s.split('&')
99 .filter_map(|pair| {
100 let mut parts = pair.splitn(2, '=');
101 Some((parts.next()?.to_string(), parts.next()?.to_string()))
102 })
103 .collect()
104 })
105 .unwrap_or_default()
106 }
107}
108
109/// Header for framed DAG-CBOR subscription messages.
110///
111/// Used in ATProto subscription streams where each message has a CBOR-encoded header
112/// followed by the message body.
113#[derive(Debug, serde::Deserialize)]
114pub struct EventHeader {
115 /// Operation code
116 pub op: i64,
117 /// Event type discriminator (e.g., "#commit", "#identity")
118 pub t: smol_str::SmolStr,
119}
120
121/// A minimal cursor for no_std that tracks read position.
122///
123/// Implements `ciborium_io::Read` to work with ciborium's CBOR parser.
124#[cfg(not(feature = "std"))]
125struct SliceCursor<'a> {
126 slice: &'a [u8],
127 position: usize,
128}
129
130#[cfg(not(feature = "std"))]
131impl<'a> SliceCursor<'a> {
132 fn new(slice: &'a [u8]) -> Self {
133 Self { slice, position: 0 }
134 }
135
136 fn position(&self) -> usize {
137 self.position
138 }
139}
140
141#[cfg(not(feature = "std"))]
142impl ciborium_io::Read for SliceCursor<'_> {
143 type Error = core::convert::Infallible;
144
145 fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), Self::Error> {
146 let end = self.position + buf.len();
147 buf.copy_from_slice(&self.slice[self.position..end]);
148 self.position = end;
149 Ok(())
150 }
151}
152
153/// Parse a framed DAG-CBOR message header and return the header plus remaining body bytes.
154///
155/// Used for two-stage deserialization of subscription messages in formats like
156/// `com.atproto.sync.subscribeRepos`.
157#[cfg(feature = "std")]
158pub fn parse_event_header<'a>(bytes: &'a [u8]) -> Result<(EventHeader, &'a [u8]), DecodeError> {
159 let mut cursor = std::io::Cursor::new(bytes);
160 let header: EventHeader = ciborium::de::from_reader(&mut cursor)?;
161 let position = cursor.position() as usize;
162 drop(cursor); // explicit drop before reborrowing bytes
163
164 Ok((header, &bytes[position..]))
165}
166
167/// Parse a framed DAG-CBOR message header and return the header plus remaining body bytes.
168///
169/// Used for two-stage deserialization of subscription messages in formats like
170/// `com.atproto.sync.subscribeRepos`.
171#[cfg(not(feature = "std"))]
172pub fn parse_event_header<'a>(bytes: &'a [u8]) -> Result<(EventHeader, &'a [u8]), DecodeError> {
173 let mut cursor = SliceCursor::new(bytes);
174 let header: EventHeader = ciborium::de::from_reader(&mut cursor)?;
175 let position = cursor.position();
176
177 Ok((header, &bytes[position..]))
178}
179
180/// Decode JSON messages from a WebSocket stream
181pub fn decode_json_msg<S: SubscriptionResp>(
182 msg_result: Result<crate::websocket::WsMessage, StreamError>,
183) -> Option<Result<StreamMessage<'static, S>, StreamError>>
184where
185 for<'a> StreamMessage<'a, S>: IntoStatic<Output = StreamMessage<'static, S>>,
186{
187 use crate::websocket::WsMessage;
188
189 match msg_result {
190 Ok(WsMessage::Text(text)) => Some(
191 S::decode_message(text.as_ref())
192 .map(|v| v.into_static())
193 .map_err(StreamError::decode),
194 ),
195 Ok(WsMessage::Binary(bytes)) => {
196 #[cfg(feature = "zstd")]
197 {
198 // Try to decompress with zstd first (Jetstream uses zstd compression)
199 match decompress_zstd(&bytes) {
200 Ok(decompressed) => Some(
201 S::decode_message(&decompressed)
202 .map(|v| v.into_static())
203 .map_err(StreamError::decode),
204 ),
205 Err(_) => {
206 // Not zstd-compressed, try direct decode
207 Some(
208 S::decode_message(&bytes)
209 .map(|v| v.into_static())
210 .map_err(StreamError::decode),
211 )
212 }
213 }
214 }
215 #[cfg(not(feature = "zstd"))]
216 {
217 Some(
218 S::decode_message(&bytes)
219 .map(|v| v.into_static())
220 .map_err(StreamError::decode),
221 )
222 }
223 }
224 Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
225 Err(e) => Some(Err(e)),
226 }
227}
228
229#[cfg(feature = "zstd")]
230fn decompress_zstd(bytes: &[u8]) -> Result<Vec<u8>, std::io::Error> {
231 use std::sync::OnceLock;
232 use zstd::stream::decode_all;
233
234 static DICTIONARY: OnceLock<Vec<u8>> = OnceLock::new();
235
236 let dict = DICTIONARY.get_or_init(|| include_bytes!("../../zstd_dictionary").to_vec());
237
238 decode_all(std::io::Cursor::new(bytes)).or_else(|_| {
239 // Try with dictionary
240 let mut decoder = zstd::Decoder::with_dictionary(std::io::Cursor::new(bytes), dict)?;
241 let mut result = Vec::new();
242 std::io::Read::read_to_end(&mut decoder, &mut result)?;
243 Ok(result)
244 })
245}
246
247/// Decode CBOR messages from a WebSocket stream
248pub fn decode_cbor_msg<S: SubscriptionResp>(
249 msg_result: Result<crate::websocket::WsMessage, StreamError>,
250) -> Option<Result<StreamMessage<'static, S>, StreamError>>
251where
252 for<'a> StreamMessage<'a, S>: IntoStatic<Output = StreamMessage<'static, S>>,
253{
254 use crate::websocket::WsMessage;
255
256 match msg_result {
257 Ok(WsMessage::Binary(bytes)) => Some(
258 S::decode_message(&bytes)
259 .map(|v| v.into_static())
260 .map_err(StreamError::decode),
261 ),
262 Ok(WsMessage::Text(_)) => Some(Err(StreamError::wrong_message_format(
263 "expected binary frame for CBOR, got text",
264 ))),
265 Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
266 Err(e) => Some(Err(e)),
267 }
268}
269
270/// Websocket subscriber-sent control message
271///
272/// Note: this is not meaningful for atproto event stream endpoints as
273/// those do not support control after the fact. Jetstream does, however.
274///
275/// If you wish to control an ongoing Jetstream connection, wrap the [`WsSink`]
276/// returned from one of the `into_*` methods of the [`SubscriptionStream`]
277/// in a [`SubscriptionController`] with the corresponding message implementing
278/// this trait as a generic parameter.
279pub trait SubscriptionControlMessage: Serialize {
280 /// The subscription this is associated with
281 type Subscription: XrpcSubscription;
282
283 /// Encode the control message for transmission
284 ///
285 /// Defaults to json text (matches Jetstream)
286 fn encode(&self) -> Result<WsMessage, StreamError> {
287 Ok(WsMessage::from(
288 serde_json::to_string(&self).map_err(StreamError::encode)?,
289 ))
290 }
291
292 /// Decode the control message
293 fn decode<'de>(frame: &'de [u8]) -> Result<Self, StreamError>
294 where
295 Self: Deserialize<'de>,
296 {
297 Ok(serde_json::from_slice(frame).map_err(StreamError::decode)?)
298 }
299}
300
301/// Control a websocket stream with a given subscription control message
302pub struct SubscriptionController<S: SubscriptionControlMessage> {
303 controller: WsSink,
304 _marker: PhantomData<fn() -> S>,
305}
306
307impl<S: SubscriptionControlMessage> SubscriptionController<S> {
308 /// Create a new subscription controller from a WebSocket sink.
309 pub fn new(controller: WsSink) -> Self {
310 Self {
311 controller,
312 _marker: PhantomData,
313 }
314 }
315
316 /// Configure the upstream connection via the websocket
317 pub async fn configure(&mut self, params: &S) -> Result<(), StreamError> {
318 let message = params.encode()?;
319
320 n0_future::SinkExt::send(self.controller.get_mut(), message)
321 .await
322 .map_err(StreamError::transport)
323 }
324}
325
326/// Typed subscription stream wrapping a WebSocket connection.
327///
328/// Analogous to `Response<R>` for XRPC but for subscription streams.
329/// Automatically decodes messages based on the subscription's encoding format.
330pub struct SubscriptionStream<S: SubscriptionResp> {
331 _marker: PhantomData<fn() -> S>,
332 connection: WebSocketConnection,
333}
334
335impl<S: SubscriptionResp> SubscriptionStream<S> {
336 /// Create a new subscription stream from a WebSocket connection.
337 pub fn new(connection: WebSocketConnection) -> Self {
338 Self {
339 _marker: PhantomData,
340 connection,
341 }
342 }
343
344 /// Get a reference to the underlying WebSocket connection.
345 pub fn connection(&self) -> &WebSocketConnection {
346 &self.connection
347 }
348
349 /// Get a mutable reference to the underlying WebSocket connection.
350 pub fn connection_mut(&mut self) -> &mut WebSocketConnection {
351 &mut self.connection
352 }
353
354 /// Split the connection and decode messages into a typed stream.
355 ///
356 /// Returns a tuple of (sender, typed message stream).
357 /// Messages are decoded according to the subscription's ENCODING.
358 pub fn into_stream(
359 self,
360 ) -> (
361 WsSink,
362 Boxed<Result<StreamMessage<'static, S>, StreamError>>,
363 )
364 where
365 for<'a> StreamMessage<'a, S>: IntoStatic<Output = StreamMessage<'static, S>>,
366 {
367 use n0_future::StreamExt as _;
368
369 let (tx, rx) = self.connection.split();
370
371 #[cfg(not(target_arch = "wasm32"))]
372 let stream = match S::ENCODING {
373 MessageEncoding::Json => rx
374 .into_inner()
375 .filter_map(|msg| decode_json_msg::<S>(msg))
376 .boxed(),
377 MessageEncoding::DagCbor => rx
378 .into_inner()
379 .filter_map(|msg| decode_cbor_msg::<S>(msg))
380 .boxed(),
381 };
382
383 #[cfg(target_arch = "wasm32")]
384 let stream = match S::ENCODING {
385 MessageEncoding::Json => rx
386 .into_inner()
387 .filter_map(|msg| decode_json_msg::<S>(msg))
388 .boxed_local(),
389 MessageEncoding::DagCbor => rx
390 .into_inner()
391 .filter_map(|msg| decode_cbor_msg::<S>(msg))
392 .boxed_local(),
393 };
394
395 (tx, stream)
396 }
397
398 /// Converts the subscription into a stream of raw atproto data.
399 pub fn into_raw_data_stream(self) -> (WsSink, Boxed<Result<RawData<'static>, StreamError>>) {
400 use n0_future::StreamExt as _;
401
402 let (tx, rx) = self.connection.split();
403
404 fn parse_msg<'a>(bytes: &'a [u8]) -> Result<RawData<'a>, serde_json::Error> {
405 serde_json::from_slice(bytes)
406 }
407 fn parse_cbor<'a>(
408 bytes: &'a [u8],
409 ) -> Result<RawData<'a>, serde_ipld_dagcbor::DecodeError<core::convert::Infallible>>
410 {
411 serde_ipld_dagcbor::from_slice(bytes)
412 }
413
414 #[cfg(not(target_arch = "wasm32"))]
415 let stream = match S::ENCODING {
416 MessageEncoding::Json => rx
417 .into_inner()
418 .filter_map(|msg_result| match msg_result {
419 Ok(WsMessage::Text(text)) => Some(
420 parse_msg(text.as_ref())
421 .map(|v| v.into_static())
422 .map_err(StreamError::decode),
423 ),
424 Ok(WsMessage::Binary(bytes)) => {
425 #[cfg(feature = "zstd")]
426 {
427 match decompress_zstd(&bytes) {
428 Ok(decompressed) => Some(
429 parse_msg(&decompressed)
430 .map(|v| v.into_static())
431 .map_err(StreamError::decode),
432 ),
433 Err(_) => Some(
434 parse_msg(&bytes)
435 .map(|v| v.into_static())
436 .map_err(StreamError::decode),
437 ),
438 }
439 }
440 #[cfg(not(feature = "zstd"))]
441 {
442 Some(
443 parse_msg(&bytes)
444 .map(|v| v.into_static())
445 .map_err(StreamError::decode),
446 )
447 }
448 }
449 Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
450 Err(e) => Some(Err(e)),
451 })
452 .boxed(),
453 MessageEncoding::DagCbor => rx
454 .into_inner()
455 .filter_map(|msg_result| match msg_result {
456 Ok(WsMessage::Binary(bytes)) => Some(
457 parse_cbor(&bytes)
458 .map(|v| v.into_static())
459 .map_err(|e| StreamError::decode(crate::error::DecodeError::from(e))),
460 ),
461 Ok(WsMessage::Text(_)) => Some(Err(StreamError::wrong_message_format(
462 "expected binary frame for CBOR, got text",
463 ))),
464 Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
465 Err(e) => Some(Err(e)),
466 })
467 .boxed(),
468 };
469
470 #[cfg(target_arch = "wasm32")]
471 let stream = match S::ENCODING {
472 MessageEncoding::Json => rx
473 .into_inner()
474 .filter_map(|msg_result| match msg_result {
475 Ok(WsMessage::Text(text)) => Some(
476 parse_msg(text.as_ref())
477 .map(|v| v.into_static())
478 .map_err(StreamError::decode),
479 ),
480 Ok(WsMessage::Binary(bytes)) => {
481 #[cfg(feature = "zstd")]
482 {
483 match decompress_zstd(&bytes) {
484 Ok(decompressed) => Some(
485 parse_msg(&decompressed)
486 .map(|v| v.into_static())
487 .map_err(StreamError::decode),
488 ),
489 Err(_) => Some(
490 parse_msg(&bytes)
491 .map(|v| v.into_static())
492 .map_err(StreamError::decode),
493 ),
494 }
495 }
496 #[cfg(not(feature = "zstd"))]
497 {
498 Some(
499 parse_msg(&bytes)
500 .map(|v| v.into_static())
501 .map_err(StreamError::decode),
502 )
503 }
504 }
505 Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
506 Err(e) => Some(Err(e)),
507 })
508 .boxed_local(),
509 MessageEncoding::DagCbor => rx
510 .into_inner()
511 .filter_map(|msg_result| match msg_result {
512 Ok(WsMessage::Binary(bytes)) => Some(
513 parse_cbor(&bytes)
514 .map(|v| v.into_static())
515 .map_err(|e| StreamError::decode(crate::error::DecodeError::from(e))),
516 ),
517 Ok(WsMessage::Text(_)) => Some(Err(StreamError::wrong_message_format(
518 "expected binary frame for CBOR, got text",
519 ))),
520 Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
521 Err(e) => Some(Err(e)),
522 })
523 .boxed_local(),
524 };
525
526 (tx, stream)
527 }
528
529 /// Converts the subscription into a stream of loosely-typed atproto data.
530 pub fn into_data_stream(self) -> (WsSink, Boxed<Result<Data<'static>, StreamError>>) {
531 use n0_future::StreamExt as _;
532
533 let (tx, rx) = self.connection.split();
534
535 fn parse_msg<'a>(bytes: &'a [u8]) -> Result<Data<'a>, serde_json::Error> {
536 serde_json::from_slice(bytes)
537 }
538 fn parse_cbor<'a>(
539 bytes: &'a [u8],
540 ) -> Result<Data<'a>, serde_ipld_dagcbor::DecodeError<core::convert::Infallible>> {
541 serde_ipld_dagcbor::from_slice(bytes)
542 }
543
544 #[cfg(not(target_arch = "wasm32"))]
545 let stream = match S::ENCODING {
546 MessageEncoding::Json => rx
547 .into_inner()
548 .filter_map(|msg_result| match msg_result {
549 Ok(WsMessage::Text(text)) => Some(
550 parse_msg(text.as_ref())
551 .map(|v| v.into_static())
552 .map_err(StreamError::decode),
553 ),
554 Ok(WsMessage::Binary(bytes)) => {
555 #[cfg(feature = "zstd")]
556 {
557 match decompress_zstd(&bytes) {
558 Ok(decompressed) => Some(
559 parse_msg(&decompressed)
560 .map(|v| v.into_static())
561 .map_err(StreamError::decode),
562 ),
563 Err(_) => Some(
564 parse_msg(&bytes)
565 .map(|v| v.into_static())
566 .map_err(StreamError::decode),
567 ),
568 }
569 }
570 #[cfg(not(feature = "zstd"))]
571 {
572 Some(
573 parse_msg(&bytes)
574 .map(|v| v.into_static())
575 .map_err(StreamError::decode),
576 )
577 }
578 }
579 Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
580 Err(e) => Some(Err(e)),
581 })
582 .boxed(),
583 MessageEncoding::DagCbor => rx
584 .into_inner()
585 .filter_map(|msg_result| match msg_result {
586 Ok(WsMessage::Binary(bytes)) => Some(
587 parse_cbor(&bytes)
588 .map(|v| v.into_static())
589 .map_err(|e| StreamError::decode(crate::error::DecodeError::from(e))),
590 ),
591 Ok(WsMessage::Text(_)) => Some(Err(StreamError::wrong_message_format(
592 "expected binary frame for CBOR, got text",
593 ))),
594 Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
595 Err(e) => Some(Err(e)),
596 })
597 .boxed(),
598 };
599
600 #[cfg(target_arch = "wasm32")]
601 let stream = match S::ENCODING {
602 MessageEncoding::Json => rx
603 .into_inner()
604 .filter_map(|msg_result| match msg_result {
605 Ok(WsMessage::Text(text)) => Some(
606 parse_msg(text.as_ref())
607 .map(|v| v.into_static())
608 .map_err(StreamError::decode),
609 ),
610 Ok(WsMessage::Binary(bytes)) => {
611 #[cfg(feature = "zstd")]
612 {
613 match decompress_zstd(&bytes) {
614 Ok(decompressed) => Some(
615 parse_msg(&decompressed)
616 .map(|v| v.into_static())
617 .map_err(StreamError::decode),
618 ),
619 Err(_) => Some(
620 parse_msg(&bytes)
621 .map(|v| v.into_static())
622 .map_err(StreamError::decode),
623 ),
624 }
625 }
626 #[cfg(not(feature = "zstd"))]
627 {
628 Some(
629 parse_msg(&bytes)
630 .map(|v| v.into_static())
631 .map_err(StreamError::decode),
632 )
633 }
634 }
635 Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
636 Err(e) => Some(Err(e)),
637 })
638 .boxed_local(),
639 MessageEncoding::DagCbor => rx
640 .into_inner()
641 .filter_map(|msg_result| match msg_result {
642 Ok(WsMessage::Binary(bytes)) => Some(
643 parse_cbor(&bytes)
644 .map(|v| v.into_static())
645 .map_err(|e| StreamError::decode(crate::error::DecodeError::from(e))),
646 ),
647 Ok(WsMessage::Text(_)) => Some(Err(StreamError::wrong_message_format(
648 "expected binary frame for CBOR, got text",
649 ))),
650 Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
651 Err(e) => Some(Err(e)),
652 })
653 .boxed_local(),
654 };
655
656 (tx, stream)
657 }
658
659 /// Consume the stream and return the underlying connection.
660 pub fn into_connection(self) -> WebSocketConnection {
661 self.connection
662 }
663
664 /// Tee the stream, keeping the raw stream in self and returning a typed stream.
665 ///
666 /// Replaces the internal WebSocket stream with one copy and returns a typed decoded
667 /// stream. Both streams receive all messages. Useful for observing raw messages
668 /// while also processing typed messages.
669 pub fn tee(&mut self) -> Boxed<Result<StreamMessage<'static, S>, StreamError>>
670 where
671 for<'a> StreamMessage<'a, S>: IntoStatic<Output = StreamMessage<'static, S>>,
672 {
673 use n0_future::StreamExt as _;
674
675 let rx = self.connection.receiver_mut();
676 let (raw_rx, typed_rx_source) =
677 core::mem::replace(rx, WsStream::new(n0_future::stream::empty())).tee();
678
679 // Put the raw stream back
680 *rx = raw_rx;
681
682 #[cfg(not(target_arch = "wasm32"))]
683 let stream = match S::ENCODING {
684 MessageEncoding::Json => typed_rx_source
685 .into_inner()
686 .filter_map(|msg| decode_json_msg::<S>(msg))
687 .boxed(),
688 MessageEncoding::DagCbor => typed_rx_source
689 .into_inner()
690 .filter_map(|msg| decode_cbor_msg::<S>(msg))
691 .boxed(),
692 };
693
694 #[cfg(target_arch = "wasm32")]
695 let stream = match S::ENCODING {
696 MessageEncoding::Json => typed_rx_source
697 .into_inner()
698 .filter_map(|msg| decode_json_msg::<S>(msg))
699 .boxed_local(),
700 MessageEncoding::DagCbor => typed_rx_source
701 .into_inner()
702 .filter_map(|msg| decode_cbor_msg::<S>(msg))
703 .boxed_local(),
704 };
705 stream
706 }
707}
708
709type StreamMessage<'a, R> = <R as SubscriptionResp>::Message<'a>;
710
711/// XRPC subscription endpoint trait (server-side)
712///
713/// Analogous to `XrpcEndpoint` but for WebSocket subscriptions.
714/// Defines the fully-qualified path and associated parameter/stream types.
715///
716/// This exists primarily for server-side frameworks (like Axum) to extract
717/// typed subscription parameters without lifetime issues.
718pub trait SubscriptionEndpoint {
719 /// Fully-qualified path ('/xrpc/[nsid]') where this subscription endpoint lives
720 const PATH: &'static str;
721
722 /// Message encoding (JSON or DAG-CBOR)
723 const ENCODING: MessageEncoding;
724
725 /// Subscription parameters type
726 type Params<'de>: XrpcSubscription + Deserialize<'de> + IntoStatic;
727
728 /// Stream response type
729 type Stream: SubscriptionResp;
730}
731
732/// Per-subscription options for WebSocket subscriptions.
733#[derive(Debug, Default, Clone)]
734pub struct SubscriptionOptions<'a> {
735 /// Extra headers to attach to this subscription (e.g., Authorization).
736 pub headers: Vec<(CowStr<'a>, CowStr<'a>)>,
737}
738
739impl IntoStatic for SubscriptionOptions<'_> {
740 type Output = SubscriptionOptions<'static>;
741
742 fn into_static(self) -> Self::Output {
743 SubscriptionOptions {
744 headers: self
745 .headers
746 .into_iter()
747 .map(|(k, v)| (k.into_static(), v.into_static()))
748 .collect(),
749 }
750 }
751}
752
753/// Extension for stateless subscription calls on any `WebSocketClient`.
754///
755/// Provides a builder pattern for establishing WebSocket subscriptions with custom options.
756pub trait SubscriptionExt: WebSocketClient {
757 /// Start building a subscription call for the given base URL.
758 fn subscription<'a>(&'a self, base: Url) -> SubscriptionCall<'a, Self>
759 where
760 Self: Sized,
761 {
762 SubscriptionCall {
763 client: self,
764 base,
765 opts: SubscriptionOptions::default(),
766 }
767 }
768}
769
770impl<T: WebSocketClient> SubscriptionExt for T {}
771
772/// Stateless subscription call builder.
773///
774/// Provides methods for adding headers and establishing typed subscriptions.
775pub struct SubscriptionCall<'a, C: WebSocketClient> {
776 pub(crate) client: &'a C,
777 pub(crate) base: Url,
778 pub(crate) opts: SubscriptionOptions<'a>,
779}
780
781impl<'a, C: WebSocketClient> SubscriptionCall<'a, C> {
782 /// Add an extra header.
783 pub fn header(mut self, name: impl Into<CowStr<'a>>, value: impl Into<CowStr<'a>>) -> Self {
784 self.opts.headers.push((name.into(), value.into()));
785 self
786 }
787
788 /// Replace the builder's options entirely.
789 pub fn with_options(mut self, opts: SubscriptionOptions<'a>) -> Self {
790 self.opts = opts;
791 self
792 }
793
794 /// Subscribe to the given XRPC subscription endpoint.
795 ///
796 /// Builds a WebSocket URL from the base, appends the NSID path,
797 /// encodes query parameters from the subscription type, and connects.
798 /// Returns a typed SubscriptionStream that automatically decodes messages.
799 pub async fn subscribe<Sub>(
800 self,
801 params: &Sub,
802 ) -> Result<SubscriptionStream<Sub::Stream>, C::Error>
803 where
804 Sub: XrpcSubscription,
805 {
806 let mut url = self.base.clone();
807
808 // Use custom path if provided, otherwise construct from NSID
809 let mut path = url.path().trim_end_matches('/').to_owned();
810 if let Some(custom_path) = Sub::CUSTOM_PATH {
811 path.push_str(custom_path);
812 } else {
813 path.push_str("/xrpc/");
814 path.push_str(Sub::NSID);
815 }
816 url.set_path(&path);
817
818 let query_params = params.query_params();
819 if !query_params.is_empty() {
820 let qs = query_params
821 .iter()
822 .map(|(k, v)| format!("{}={}", k, v))
823 .collect::<Vec<_>>()
824 .join("&");
825 url.set_query(Some(&qs));
826 } else {
827 url.set_query(None);
828 }
829
830 let connection = self
831 .client
832 .connect_with_headers(url, self.opts.headers)
833 .await?;
834
835 Ok(SubscriptionStream::new(connection))
836 }
837}
838
839/// Stateful subscription client trait.
840///
841/// Analogous to `XrpcClient` but for WebSocket subscriptions.
842/// Provides a stateful interface for subscribing with configured base URI and options.
843#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))]
844pub trait SubscriptionClient: WebSocketClient {
845 /// Get the base URI for the client.
846 fn base_uri(&self) -> impl Future<Output = CowStr<'static>>;
847
848 /// Get the subscription options for the client.
849 fn subscription_opts(&self) -> impl Future<Output = SubscriptionOptions<'_>> {
850 async { SubscriptionOptions::default() }
851 }
852
853 /// Subscribe to an XRPC subscription endpoint using the client's base URI and options.
854 #[cfg(not(target_arch = "wasm32"))]
855 fn subscribe<Sub>(
856 &self,
857 params: &Sub,
858 ) -> impl Future<Output = Result<SubscriptionStream<Sub::Stream>, Self::Error>>
859 where
860 Sub: XrpcSubscription + Send + Sync,
861 Self: Sync;
862
863 /// Subscribe to an XRPC subscription endpoint using the client's base URI and options.
864 #[cfg(target_arch = "wasm32")]
865 fn subscribe<Sub>(
866 &self,
867 params: &Sub,
868 ) -> impl Future<Output = Result<SubscriptionStream<Sub::Stream>, Self::Error>>
869 where
870 Sub: XrpcSubscription + Send + Sync;
871
872 /// Subscribe with custom options.
873 #[cfg(not(target_arch = "wasm32"))]
874 fn subscribe_with_opts<Sub>(
875 &self,
876 params: &Sub,
877 opts: SubscriptionOptions<'_>,
878 ) -> impl Future<Output = Result<SubscriptionStream<Sub::Stream>, Self::Error>>
879 where
880 Sub: XrpcSubscription + Send + Sync,
881 Self: Sync;
882
883 /// Subscribe with custom options.
884 #[cfg(target_arch = "wasm32")]
885 fn subscribe_with_opts<Sub>(
886 &self,
887 params: &Sub,
888 opts: SubscriptionOptions<'_>,
889 ) -> impl Future<Output = Result<SubscriptionStream<Sub::Stream>, Self::Error>>
890 where
891 Sub: XrpcSubscription + Send + Sync;
892}
893
894/// Simple stateless subscription client wrapping a WebSocketClient.
895///
896/// Analogous to a basic HTTP client but for WebSocket subscriptions.
897/// Does not manage sessions or authentication - useful for public subscriptions
898/// or when you want to handle auth manually via headers.
899pub struct BasicSubscriptionClient<W: WebSocketClient> {
900 client: W,
901 base_uri: CowStr<'static>,
902 opts: SubscriptionOptions<'static>,
903}
904
905impl<W: WebSocketClient> BasicSubscriptionClient<W> {
906 /// Create a new basic subscription client with the given WebSocket client and base URI.
907 pub fn new(client: W, base_uri: Url) -> Self {
908 let base_uri = base_uri.as_str().trim_end_matches("/");
909 Self {
910 client,
911 base_uri: base_uri.to_cowstr().into_static(),
912 opts: SubscriptionOptions::default(),
913 }
914 }
915
916 /// Create with default options.
917 pub fn with_options(mut self, opts: SubscriptionOptions<'_>) -> Self {
918 self.opts = opts.into_static();
919 self
920 }
921
922 /// Get a reference to the inner WebSocket client.
923 pub fn inner(&self) -> &W {
924 &self.client
925 }
926}
927
928impl<W: WebSocketClient> WebSocketClient for BasicSubscriptionClient<W> {
929 type Error = W::Error;
930
931 async fn connect(&self, url: Url) -> Result<WebSocketConnection, Self::Error> {
932 self.client.connect(url).await
933 }
934
935 async fn connect_with_headers(
936 &self,
937 url: Url,
938 headers: Vec<(CowStr<'_>, CowStr<'_>)>,
939 ) -> Result<WebSocketConnection, Self::Error> {
940 self.client.connect_with_headers(url, headers).await
941 }
942}
943
944impl<W: WebSocketClient> SubscriptionClient for BasicSubscriptionClient<W> {
945 async fn base_uri(&self) -> CowStr<'static> {
946 self.base_uri.clone()
947 }
948
949 async fn subscription_opts(&self) -> SubscriptionOptions<'_> {
950 self.opts.clone()
951 }
952
953 #[cfg(not(target_arch = "wasm32"))]
954 async fn subscribe<Sub>(
955 &self,
956 params: &Sub,
957 ) -> Result<SubscriptionStream<Sub::Stream>, Self::Error>
958 where
959 Sub: XrpcSubscription + Send + Sync,
960 Self: Sync,
961 {
962 let opts = self.subscription_opts().await;
963 self.subscribe_with_opts(params, opts).await
964 }
965
966 #[cfg(target_arch = "wasm32")]
967 async fn subscribe<Sub>(
968 &self,
969 params: &Sub,
970 ) -> Result<SubscriptionStream<Sub::Stream>, Self::Error>
971 where
972 Sub: XrpcSubscription + Send + Sync,
973 {
974 let opts = self.subscription_opts().await;
975 self.subscribe_with_opts(params, opts).await
976 }
977
978 #[cfg(not(target_arch = "wasm32"))]
979 async fn subscribe_with_opts<Sub>(
980 &self,
981 params: &Sub,
982 opts: SubscriptionOptions<'_>,
983 ) -> Result<SubscriptionStream<Sub::Stream>, Self::Error>
984 where
985 Sub: XrpcSubscription + Send + Sync,
986 Self: Sync,
987 {
988 let base = self.base_uri().await;
989 let base = Url::parse(&base).expect("Failed to parse base URL");
990 self.subscription(base)
991 .with_options(opts)
992 .subscribe(params)
993 .await
994 }
995
996 #[cfg(target_arch = "wasm32")]
997 async fn subscribe_with_opts<Sub>(
998 &self,
999 params: &Sub,
1000 opts: SubscriptionOptions<'_>,
1001 ) -> Result<SubscriptionStream<Sub::Stream>, Self::Error>
1002 where
1003 Sub: XrpcSubscription + Send + Sync,
1004 {
1005 let base = self.base_uri().await;
1006 let base = Url::parse(&base).expect("Failed to parse base URL");
1007 self.subscription(base)
1008 .with_options(opts)
1009 .subscribe(params)
1010 .await
1011 }
1012}
1013
1014/// Type alias for a basic subscription client using the default TungsteniteClient.
1015///
1016/// Provides a simple, stateless WebSocket subscription client without session management.
1017/// Useful for public subscriptions or when handling authentication manually.
1018///
1019/// # Example
1020///
1021/// ```no_run
1022/// # use jacquard_common::xrpc::{TungsteniteSubscriptionClient, SubscriptionClient};
1023/// # use url::Url;
1024/// # #[tokio::main]
1025/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
1026/// let base = Url::parse("wss://bsky.network")?;
1027/// let client = TungsteniteSubscriptionClient::from_base_uri(base);
1028/// // let conn = client.subscribe(¶ms).await?;
1029/// # Ok(())
1030/// # }
1031/// ```
1032pub type TungsteniteSubscriptionClient =
1033 BasicSubscriptionClient<crate::websocket::tungstenite_client::TungsteniteClient>;
1034
1035impl TungsteniteSubscriptionClient {
1036 /// Create a new Tungstenite-backed subscription client with the given base URI.
1037 pub fn from_base_uri(base_uri: Url) -> Self {
1038 let client = crate::websocket::tungstenite_client::TungsteniteClient::new();
1039 BasicSubscriptionClient::new(client, base_uri)
1040 }
1041}