A better Rust ATProto crate
at main 38 kB view raw
1//! WebSocket subscription support for XRPC 2//! 3//! This module defines traits and types for typed WebSocket subscriptions, 4//! mirroring the request/response pattern used for HTTP XRPC endpoints. 5 6use alloc::borrow::ToOwned; 7use alloc::string::String; 8use alloc::string::ToString; 9use alloc::vec::Vec; 10use core::error::Error; 11use core::future::Future; 12use core::marker::PhantomData; 13#[cfg(not(target_arch = "wasm32"))] 14use n0_future::stream::Boxed; 15#[cfg(target_arch = "wasm32")] 16use n0_future::stream::BoxedLocal as Boxed; 17use serde::{Deserialize, Serialize}; 18use url::Url; 19 20use crate::cowstr::ToCowStr; 21use crate::error::DecodeError; 22use crate::stream::StreamError; 23use crate::websocket::{WebSocketClient, WebSocketConnection, WsSink, WsStream}; 24use crate::{CowStr, Data, IntoStatic, RawData, WsMessage}; 25 26/// Encoding format for subscription messages 27#[derive(Debug, Clone, Copy, PartialEq, Eq)] 28pub enum MessageEncoding { 29 /// JSON text frames 30 Json, 31 /// DAG-CBOR binary frames 32 DagCbor, 33} 34 35/// XRPC subscription stream response trait 36/// 37/// Analogous to `XrpcResp` but for WebSocket subscriptions. 38/// Defines the message and error types for a subscription stream. 39/// 40/// This trait is implemented on a marker struct to keep it lifetime-free 41/// while using GATs for the message/error types. 42pub trait SubscriptionResp { 43 /// The NSID for this subscription 44 const NSID: &'static str; 45 46 /// Message encoding (JSON or DAG-CBOR) 47 const ENCODING: MessageEncoding; 48 49 /// Message union type 50 type Message<'de>: Deserialize<'de> + IntoStatic; 51 52 /// Error union type 53 type Error<'de>: Error + Deserialize<'de> + IntoStatic; 54 55 /// Decode a message from bytes. 56 /// 57 /// Default implementation uses simple deserialization via serde. 58 /// Subscriptions that use framed encoding (header + body) can override 59 /// this to do two-stage deserialization. 60 fn decode_message<'de>(bytes: &'de [u8]) -> Result<Self::Message<'de>, DecodeError> { 61 match Self::ENCODING { 62 MessageEncoding::Json => serde_json::from_slice(bytes).map_err(DecodeError::from), 63 MessageEncoding::DagCbor => { 64 serde_ipld_dagcbor::from_slice(bytes).map_err(DecodeError::from) 65 } 66 } 67 } 68} 69 70/// XRPC subscription (WebSocket) 71/// 72/// This trait is analogous to `XrpcRequest` but for WebSocket subscriptions. 73/// It defines the NSID and associated stream response type. 74/// 75/// The trait is implemented on the subscription parameters type. 76pub trait XrpcSubscription: Serialize { 77 /// The NSID for this XRPC subscription 78 const NSID: &'static str; 79 80 /// Message encoding (JSON or DAG-CBOR) 81 const ENCODING: MessageEncoding; 82 83 /// Custom path override (e.g., "/subscribe" for Jetstream). 84 /// If None, defaults to "/xrpc/{NSID}" 85 const CUSTOM_PATH: Option<&'static str> = None; 86 87 /// Stream response type (marker struct) 88 type Stream: SubscriptionResp; 89 90 /// Encode query params for WebSocket URL 91 /// 92 /// Default implementation uses serde_html_form to encode the struct as query parameters. 93 fn query_params(&self) -> Vec<(String, String)> { 94 // Default: use serde_html_form to encode self 95 serde_html_form::to_string(self) 96 .ok() 97 .map(|s| { 98 s.split('&') 99 .filter_map(|pair| { 100 let mut parts = pair.splitn(2, '='); 101 Some((parts.next()?.to_string(), parts.next()?.to_string())) 102 }) 103 .collect() 104 }) 105 .unwrap_or_default() 106 } 107} 108 109/// Header for framed DAG-CBOR subscription messages. 110/// 111/// Used in ATProto subscription streams where each message has a CBOR-encoded header 112/// followed by the message body. 113#[derive(Debug, serde::Deserialize)] 114pub struct EventHeader { 115 /// Operation code 116 pub op: i64, 117 /// Event type discriminator (e.g., "#commit", "#identity") 118 pub t: smol_str::SmolStr, 119} 120 121/// A minimal cursor for no_std that tracks read position. 122/// 123/// Implements `ciborium_io::Read` to work with ciborium's CBOR parser. 124#[cfg(not(feature = "std"))] 125struct SliceCursor<'a> { 126 slice: &'a [u8], 127 position: usize, 128} 129 130#[cfg(not(feature = "std"))] 131impl<'a> SliceCursor<'a> { 132 fn new(slice: &'a [u8]) -> Self { 133 Self { slice, position: 0 } 134 } 135 136 fn position(&self) -> usize { 137 self.position 138 } 139} 140 141#[cfg(not(feature = "std"))] 142impl ciborium_io::Read for SliceCursor<'_> { 143 type Error = core::convert::Infallible; 144 145 fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), Self::Error> { 146 let end = self.position + buf.len(); 147 buf.copy_from_slice(&self.slice[self.position..end]); 148 self.position = end; 149 Ok(()) 150 } 151} 152 153/// Parse a framed DAG-CBOR message header and return the header plus remaining body bytes. 154/// 155/// Used for two-stage deserialization of subscription messages in formats like 156/// `com.atproto.sync.subscribeRepos`. 157#[cfg(feature = "std")] 158pub fn parse_event_header<'a>(bytes: &'a [u8]) -> Result<(EventHeader, &'a [u8]), DecodeError> { 159 let mut cursor = std::io::Cursor::new(bytes); 160 let header: EventHeader = ciborium::de::from_reader(&mut cursor)?; 161 let position = cursor.position() as usize; 162 drop(cursor); // explicit drop before reborrowing bytes 163 164 Ok((header, &bytes[position..])) 165} 166 167/// Parse a framed DAG-CBOR message header and return the header plus remaining body bytes. 168/// 169/// Used for two-stage deserialization of subscription messages in formats like 170/// `com.atproto.sync.subscribeRepos`. 171#[cfg(not(feature = "std"))] 172pub fn parse_event_header<'a>(bytes: &'a [u8]) -> Result<(EventHeader, &'a [u8]), DecodeError> { 173 let mut cursor = SliceCursor::new(bytes); 174 let header: EventHeader = ciborium::de::from_reader(&mut cursor)?; 175 let position = cursor.position(); 176 177 Ok((header, &bytes[position..])) 178} 179 180/// Decode JSON messages from a WebSocket stream 181pub fn decode_json_msg<S: SubscriptionResp>( 182 msg_result: Result<crate::websocket::WsMessage, StreamError>, 183) -> Option<Result<StreamMessage<'static, S>, StreamError>> 184where 185 for<'a> StreamMessage<'a, S>: IntoStatic<Output = StreamMessage<'static, S>>, 186{ 187 use crate::websocket::WsMessage; 188 189 match msg_result { 190 Ok(WsMessage::Text(text)) => Some( 191 S::decode_message(text.as_ref()) 192 .map(|v| v.into_static()) 193 .map_err(StreamError::decode), 194 ), 195 Ok(WsMessage::Binary(bytes)) => { 196 #[cfg(feature = "zstd")] 197 { 198 // Try to decompress with zstd first (Jetstream uses zstd compression) 199 match decompress_zstd(&bytes) { 200 Ok(decompressed) => Some( 201 S::decode_message(&decompressed) 202 .map(|v| v.into_static()) 203 .map_err(StreamError::decode), 204 ), 205 Err(_) => { 206 // Not zstd-compressed, try direct decode 207 Some( 208 S::decode_message(&bytes) 209 .map(|v| v.into_static()) 210 .map_err(StreamError::decode), 211 ) 212 } 213 } 214 } 215 #[cfg(not(feature = "zstd"))] 216 { 217 Some( 218 S::decode_message(&bytes) 219 .map(|v| v.into_static()) 220 .map_err(StreamError::decode), 221 ) 222 } 223 } 224 Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())), 225 Err(e) => Some(Err(e)), 226 } 227} 228 229#[cfg(feature = "zstd")] 230fn decompress_zstd(bytes: &[u8]) -> Result<Vec<u8>, std::io::Error> { 231 use std::sync::OnceLock; 232 use zstd::stream::decode_all; 233 234 static DICTIONARY: OnceLock<Vec<u8>> = OnceLock::new(); 235 236 let dict = DICTIONARY.get_or_init(|| include_bytes!("../../zstd_dictionary").to_vec()); 237 238 decode_all(std::io::Cursor::new(bytes)).or_else(|_| { 239 // Try with dictionary 240 let mut decoder = zstd::Decoder::with_dictionary(std::io::Cursor::new(bytes), dict)?; 241 let mut result = Vec::new(); 242 std::io::Read::read_to_end(&mut decoder, &mut result)?; 243 Ok(result) 244 }) 245} 246 247/// Decode CBOR messages from a WebSocket stream 248pub fn decode_cbor_msg<S: SubscriptionResp>( 249 msg_result: Result<crate::websocket::WsMessage, StreamError>, 250) -> Option<Result<StreamMessage<'static, S>, StreamError>> 251where 252 for<'a> StreamMessage<'a, S>: IntoStatic<Output = StreamMessage<'static, S>>, 253{ 254 use crate::websocket::WsMessage; 255 256 match msg_result { 257 Ok(WsMessage::Binary(bytes)) => Some( 258 S::decode_message(&bytes) 259 .map(|v| v.into_static()) 260 .map_err(StreamError::decode), 261 ), 262 Ok(WsMessage::Text(_)) => Some(Err(StreamError::wrong_message_format( 263 "expected binary frame for CBOR, got text", 264 ))), 265 Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())), 266 Err(e) => Some(Err(e)), 267 } 268} 269 270/// Websocket subscriber-sent control message 271/// 272/// Note: this is not meaningful for atproto event stream endpoints as 273/// those do not support control after the fact. Jetstream does, however. 274/// 275/// If you wish to control an ongoing Jetstream connection, wrap the [`WsSink`] 276/// returned from one of the `into_*` methods of the [`SubscriptionStream`] 277/// in a [`SubscriptionController`] with the corresponding message implementing 278/// this trait as a generic parameter. 279pub trait SubscriptionControlMessage: Serialize { 280 /// The subscription this is associated with 281 type Subscription: XrpcSubscription; 282 283 /// Encode the control message for transmission 284 /// 285 /// Defaults to json text (matches Jetstream) 286 fn encode(&self) -> Result<WsMessage, StreamError> { 287 Ok(WsMessage::from( 288 serde_json::to_string(&self).map_err(StreamError::encode)?, 289 )) 290 } 291 292 /// Decode the control message 293 fn decode<'de>(frame: &'de [u8]) -> Result<Self, StreamError> 294 where 295 Self: Deserialize<'de>, 296 { 297 Ok(serde_json::from_slice(frame).map_err(StreamError::decode)?) 298 } 299} 300 301/// Control a websocket stream with a given subscription control message 302pub struct SubscriptionController<S: SubscriptionControlMessage> { 303 controller: WsSink, 304 _marker: PhantomData<fn() -> S>, 305} 306 307impl<S: SubscriptionControlMessage> SubscriptionController<S> { 308 /// Create a new subscription controller from a WebSocket sink. 309 pub fn new(controller: WsSink) -> Self { 310 Self { 311 controller, 312 _marker: PhantomData, 313 } 314 } 315 316 /// Configure the upstream connection via the websocket 317 pub async fn configure(&mut self, params: &S) -> Result<(), StreamError> { 318 let message = params.encode()?; 319 320 n0_future::SinkExt::send(self.controller.get_mut(), message) 321 .await 322 .map_err(StreamError::transport) 323 } 324} 325 326/// Typed subscription stream wrapping a WebSocket connection. 327/// 328/// Analogous to `Response<R>` for XRPC but for subscription streams. 329/// Automatically decodes messages based on the subscription's encoding format. 330pub struct SubscriptionStream<S: SubscriptionResp> { 331 _marker: PhantomData<fn() -> S>, 332 connection: WebSocketConnection, 333} 334 335impl<S: SubscriptionResp> SubscriptionStream<S> { 336 /// Create a new subscription stream from a WebSocket connection. 337 pub fn new(connection: WebSocketConnection) -> Self { 338 Self { 339 _marker: PhantomData, 340 connection, 341 } 342 } 343 344 /// Get a reference to the underlying WebSocket connection. 345 pub fn connection(&self) -> &WebSocketConnection { 346 &self.connection 347 } 348 349 /// Get a mutable reference to the underlying WebSocket connection. 350 pub fn connection_mut(&mut self) -> &mut WebSocketConnection { 351 &mut self.connection 352 } 353 354 /// Split the connection and decode messages into a typed stream. 355 /// 356 /// Returns a tuple of (sender, typed message stream). 357 /// Messages are decoded according to the subscription's ENCODING. 358 pub fn into_stream( 359 self, 360 ) -> ( 361 WsSink, 362 Boxed<Result<StreamMessage<'static, S>, StreamError>>, 363 ) 364 where 365 for<'a> StreamMessage<'a, S>: IntoStatic<Output = StreamMessage<'static, S>>, 366 { 367 use n0_future::StreamExt as _; 368 369 let (tx, rx) = self.connection.split(); 370 371 #[cfg(not(target_arch = "wasm32"))] 372 let stream = match S::ENCODING { 373 MessageEncoding::Json => rx 374 .into_inner() 375 .filter_map(|msg| decode_json_msg::<S>(msg)) 376 .boxed(), 377 MessageEncoding::DagCbor => rx 378 .into_inner() 379 .filter_map(|msg| decode_cbor_msg::<S>(msg)) 380 .boxed(), 381 }; 382 383 #[cfg(target_arch = "wasm32")] 384 let stream = match S::ENCODING { 385 MessageEncoding::Json => rx 386 .into_inner() 387 .filter_map(|msg| decode_json_msg::<S>(msg)) 388 .boxed_local(), 389 MessageEncoding::DagCbor => rx 390 .into_inner() 391 .filter_map(|msg| decode_cbor_msg::<S>(msg)) 392 .boxed_local(), 393 }; 394 395 (tx, stream) 396 } 397 398 /// Converts the subscription into a stream of raw atproto data. 399 pub fn into_raw_data_stream(self) -> (WsSink, Boxed<Result<RawData<'static>, StreamError>>) { 400 use n0_future::StreamExt as _; 401 402 let (tx, rx) = self.connection.split(); 403 404 fn parse_msg<'a>(bytes: &'a [u8]) -> Result<RawData<'a>, serde_json::Error> { 405 serde_json::from_slice(bytes) 406 } 407 fn parse_cbor<'a>( 408 bytes: &'a [u8], 409 ) -> Result<RawData<'a>, serde_ipld_dagcbor::DecodeError<core::convert::Infallible>> 410 { 411 serde_ipld_dagcbor::from_slice(bytes) 412 } 413 414 #[cfg(not(target_arch = "wasm32"))] 415 let stream = match S::ENCODING { 416 MessageEncoding::Json => rx 417 .into_inner() 418 .filter_map(|msg_result| match msg_result { 419 Ok(WsMessage::Text(text)) => Some( 420 parse_msg(text.as_ref()) 421 .map(|v| v.into_static()) 422 .map_err(StreamError::decode), 423 ), 424 Ok(WsMessage::Binary(bytes)) => { 425 #[cfg(feature = "zstd")] 426 { 427 match decompress_zstd(&bytes) { 428 Ok(decompressed) => Some( 429 parse_msg(&decompressed) 430 .map(|v| v.into_static()) 431 .map_err(StreamError::decode), 432 ), 433 Err(_) => Some( 434 parse_msg(&bytes) 435 .map(|v| v.into_static()) 436 .map_err(StreamError::decode), 437 ), 438 } 439 } 440 #[cfg(not(feature = "zstd"))] 441 { 442 Some( 443 parse_msg(&bytes) 444 .map(|v| v.into_static()) 445 .map_err(StreamError::decode), 446 ) 447 } 448 } 449 Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())), 450 Err(e) => Some(Err(e)), 451 }) 452 .boxed(), 453 MessageEncoding::DagCbor => rx 454 .into_inner() 455 .filter_map(|msg_result| match msg_result { 456 Ok(WsMessage::Binary(bytes)) => Some( 457 parse_cbor(&bytes) 458 .map(|v| v.into_static()) 459 .map_err(|e| StreamError::decode(crate::error::DecodeError::from(e))), 460 ), 461 Ok(WsMessage::Text(_)) => Some(Err(StreamError::wrong_message_format( 462 "expected binary frame for CBOR, got text", 463 ))), 464 Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())), 465 Err(e) => Some(Err(e)), 466 }) 467 .boxed(), 468 }; 469 470 #[cfg(target_arch = "wasm32")] 471 let stream = match S::ENCODING { 472 MessageEncoding::Json => rx 473 .into_inner() 474 .filter_map(|msg_result| match msg_result { 475 Ok(WsMessage::Text(text)) => Some( 476 parse_msg(text.as_ref()) 477 .map(|v| v.into_static()) 478 .map_err(StreamError::decode), 479 ), 480 Ok(WsMessage::Binary(bytes)) => { 481 #[cfg(feature = "zstd")] 482 { 483 match decompress_zstd(&bytes) { 484 Ok(decompressed) => Some( 485 parse_msg(&decompressed) 486 .map(|v| v.into_static()) 487 .map_err(StreamError::decode), 488 ), 489 Err(_) => Some( 490 parse_msg(&bytes) 491 .map(|v| v.into_static()) 492 .map_err(StreamError::decode), 493 ), 494 } 495 } 496 #[cfg(not(feature = "zstd"))] 497 { 498 Some( 499 parse_msg(&bytes) 500 .map(|v| v.into_static()) 501 .map_err(StreamError::decode), 502 ) 503 } 504 } 505 Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())), 506 Err(e) => Some(Err(e)), 507 }) 508 .boxed_local(), 509 MessageEncoding::DagCbor => rx 510 .into_inner() 511 .filter_map(|msg_result| match msg_result { 512 Ok(WsMessage::Binary(bytes)) => Some( 513 parse_cbor(&bytes) 514 .map(|v| v.into_static()) 515 .map_err(|e| StreamError::decode(crate::error::DecodeError::from(e))), 516 ), 517 Ok(WsMessage::Text(_)) => Some(Err(StreamError::wrong_message_format( 518 "expected binary frame for CBOR, got text", 519 ))), 520 Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())), 521 Err(e) => Some(Err(e)), 522 }) 523 .boxed_local(), 524 }; 525 526 (tx, stream) 527 } 528 529 /// Converts the subscription into a stream of loosely-typed atproto data. 530 pub fn into_data_stream(self) -> (WsSink, Boxed<Result<Data<'static>, StreamError>>) { 531 use n0_future::StreamExt as _; 532 533 let (tx, rx) = self.connection.split(); 534 535 fn parse_msg<'a>(bytes: &'a [u8]) -> Result<Data<'a>, serde_json::Error> { 536 serde_json::from_slice(bytes) 537 } 538 fn parse_cbor<'a>( 539 bytes: &'a [u8], 540 ) -> Result<Data<'a>, serde_ipld_dagcbor::DecodeError<core::convert::Infallible>> { 541 serde_ipld_dagcbor::from_slice(bytes) 542 } 543 544 #[cfg(not(target_arch = "wasm32"))] 545 let stream = match S::ENCODING { 546 MessageEncoding::Json => rx 547 .into_inner() 548 .filter_map(|msg_result| match msg_result { 549 Ok(WsMessage::Text(text)) => Some( 550 parse_msg(text.as_ref()) 551 .map(|v| v.into_static()) 552 .map_err(StreamError::decode), 553 ), 554 Ok(WsMessage::Binary(bytes)) => { 555 #[cfg(feature = "zstd")] 556 { 557 match decompress_zstd(&bytes) { 558 Ok(decompressed) => Some( 559 parse_msg(&decompressed) 560 .map(|v| v.into_static()) 561 .map_err(StreamError::decode), 562 ), 563 Err(_) => Some( 564 parse_msg(&bytes) 565 .map(|v| v.into_static()) 566 .map_err(StreamError::decode), 567 ), 568 } 569 } 570 #[cfg(not(feature = "zstd"))] 571 { 572 Some( 573 parse_msg(&bytes) 574 .map(|v| v.into_static()) 575 .map_err(StreamError::decode), 576 ) 577 } 578 } 579 Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())), 580 Err(e) => Some(Err(e)), 581 }) 582 .boxed(), 583 MessageEncoding::DagCbor => rx 584 .into_inner() 585 .filter_map(|msg_result| match msg_result { 586 Ok(WsMessage::Binary(bytes)) => Some( 587 parse_cbor(&bytes) 588 .map(|v| v.into_static()) 589 .map_err(|e| StreamError::decode(crate::error::DecodeError::from(e))), 590 ), 591 Ok(WsMessage::Text(_)) => Some(Err(StreamError::wrong_message_format( 592 "expected binary frame for CBOR, got text", 593 ))), 594 Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())), 595 Err(e) => Some(Err(e)), 596 }) 597 .boxed(), 598 }; 599 600 #[cfg(target_arch = "wasm32")] 601 let stream = match S::ENCODING { 602 MessageEncoding::Json => rx 603 .into_inner() 604 .filter_map(|msg_result| match msg_result { 605 Ok(WsMessage::Text(text)) => Some( 606 parse_msg(text.as_ref()) 607 .map(|v| v.into_static()) 608 .map_err(StreamError::decode), 609 ), 610 Ok(WsMessage::Binary(bytes)) => { 611 #[cfg(feature = "zstd")] 612 { 613 match decompress_zstd(&bytes) { 614 Ok(decompressed) => Some( 615 parse_msg(&decompressed) 616 .map(|v| v.into_static()) 617 .map_err(StreamError::decode), 618 ), 619 Err(_) => Some( 620 parse_msg(&bytes) 621 .map(|v| v.into_static()) 622 .map_err(StreamError::decode), 623 ), 624 } 625 } 626 #[cfg(not(feature = "zstd"))] 627 { 628 Some( 629 parse_msg(&bytes) 630 .map(|v| v.into_static()) 631 .map_err(StreamError::decode), 632 ) 633 } 634 } 635 Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())), 636 Err(e) => Some(Err(e)), 637 }) 638 .boxed_local(), 639 MessageEncoding::DagCbor => rx 640 .into_inner() 641 .filter_map(|msg_result| match msg_result { 642 Ok(WsMessage::Binary(bytes)) => Some( 643 parse_cbor(&bytes) 644 .map(|v| v.into_static()) 645 .map_err(|e| StreamError::decode(crate::error::DecodeError::from(e))), 646 ), 647 Ok(WsMessage::Text(_)) => Some(Err(StreamError::wrong_message_format( 648 "expected binary frame for CBOR, got text", 649 ))), 650 Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())), 651 Err(e) => Some(Err(e)), 652 }) 653 .boxed_local(), 654 }; 655 656 (tx, stream) 657 } 658 659 /// Consume the stream and return the underlying connection. 660 pub fn into_connection(self) -> WebSocketConnection { 661 self.connection 662 } 663 664 /// Tee the stream, keeping the raw stream in self and returning a typed stream. 665 /// 666 /// Replaces the internal WebSocket stream with one copy and returns a typed decoded 667 /// stream. Both streams receive all messages. Useful for observing raw messages 668 /// while also processing typed messages. 669 pub fn tee(&mut self) -> Boxed<Result<StreamMessage<'static, S>, StreamError>> 670 where 671 for<'a> StreamMessage<'a, S>: IntoStatic<Output = StreamMessage<'static, S>>, 672 { 673 use n0_future::StreamExt as _; 674 675 let rx = self.connection.receiver_mut(); 676 let (raw_rx, typed_rx_source) = 677 core::mem::replace(rx, WsStream::new(n0_future::stream::empty())).tee(); 678 679 // Put the raw stream back 680 *rx = raw_rx; 681 682 #[cfg(not(target_arch = "wasm32"))] 683 let stream = match S::ENCODING { 684 MessageEncoding::Json => typed_rx_source 685 .into_inner() 686 .filter_map(|msg| decode_json_msg::<S>(msg)) 687 .boxed(), 688 MessageEncoding::DagCbor => typed_rx_source 689 .into_inner() 690 .filter_map(|msg| decode_cbor_msg::<S>(msg)) 691 .boxed(), 692 }; 693 694 #[cfg(target_arch = "wasm32")] 695 let stream = match S::ENCODING { 696 MessageEncoding::Json => typed_rx_source 697 .into_inner() 698 .filter_map(|msg| decode_json_msg::<S>(msg)) 699 .boxed_local(), 700 MessageEncoding::DagCbor => typed_rx_source 701 .into_inner() 702 .filter_map(|msg| decode_cbor_msg::<S>(msg)) 703 .boxed_local(), 704 }; 705 stream 706 } 707} 708 709type StreamMessage<'a, R> = <R as SubscriptionResp>::Message<'a>; 710 711/// XRPC subscription endpoint trait (server-side) 712/// 713/// Analogous to `XrpcEndpoint` but for WebSocket subscriptions. 714/// Defines the fully-qualified path and associated parameter/stream types. 715/// 716/// This exists primarily for server-side frameworks (like Axum) to extract 717/// typed subscription parameters without lifetime issues. 718pub trait SubscriptionEndpoint { 719 /// Fully-qualified path ('/xrpc/[nsid]') where this subscription endpoint lives 720 const PATH: &'static str; 721 722 /// Message encoding (JSON or DAG-CBOR) 723 const ENCODING: MessageEncoding; 724 725 /// Subscription parameters type 726 type Params<'de>: XrpcSubscription + Deserialize<'de> + IntoStatic; 727 728 /// Stream response type 729 type Stream: SubscriptionResp; 730} 731 732/// Per-subscription options for WebSocket subscriptions. 733#[derive(Debug, Default, Clone)] 734pub struct SubscriptionOptions<'a> { 735 /// Extra headers to attach to this subscription (e.g., Authorization). 736 pub headers: Vec<(CowStr<'a>, CowStr<'a>)>, 737} 738 739impl IntoStatic for SubscriptionOptions<'_> { 740 type Output = SubscriptionOptions<'static>; 741 742 fn into_static(self) -> Self::Output { 743 SubscriptionOptions { 744 headers: self 745 .headers 746 .into_iter() 747 .map(|(k, v)| (k.into_static(), v.into_static())) 748 .collect(), 749 } 750 } 751} 752 753/// Extension for stateless subscription calls on any `WebSocketClient`. 754/// 755/// Provides a builder pattern for establishing WebSocket subscriptions with custom options. 756pub trait SubscriptionExt: WebSocketClient { 757 /// Start building a subscription call for the given base URL. 758 fn subscription<'a>(&'a self, base: Url) -> SubscriptionCall<'a, Self> 759 where 760 Self: Sized, 761 { 762 SubscriptionCall { 763 client: self, 764 base, 765 opts: SubscriptionOptions::default(), 766 } 767 } 768} 769 770impl<T: WebSocketClient> SubscriptionExt for T {} 771 772/// Stateless subscription call builder. 773/// 774/// Provides methods for adding headers and establishing typed subscriptions. 775pub struct SubscriptionCall<'a, C: WebSocketClient> { 776 pub(crate) client: &'a C, 777 pub(crate) base: Url, 778 pub(crate) opts: SubscriptionOptions<'a>, 779} 780 781impl<'a, C: WebSocketClient> SubscriptionCall<'a, C> { 782 /// Add an extra header. 783 pub fn header(mut self, name: impl Into<CowStr<'a>>, value: impl Into<CowStr<'a>>) -> Self { 784 self.opts.headers.push((name.into(), value.into())); 785 self 786 } 787 788 /// Replace the builder's options entirely. 789 pub fn with_options(mut self, opts: SubscriptionOptions<'a>) -> Self { 790 self.opts = opts; 791 self 792 } 793 794 /// Subscribe to the given XRPC subscription endpoint. 795 /// 796 /// Builds a WebSocket URL from the base, appends the NSID path, 797 /// encodes query parameters from the subscription type, and connects. 798 /// Returns a typed SubscriptionStream that automatically decodes messages. 799 pub async fn subscribe<Sub>( 800 self, 801 params: &Sub, 802 ) -> Result<SubscriptionStream<Sub::Stream>, C::Error> 803 where 804 Sub: XrpcSubscription, 805 { 806 let mut url = self.base.clone(); 807 808 // Use custom path if provided, otherwise construct from NSID 809 let mut path = url.path().trim_end_matches('/').to_owned(); 810 if let Some(custom_path) = Sub::CUSTOM_PATH { 811 path.push_str(custom_path); 812 } else { 813 path.push_str("/xrpc/"); 814 path.push_str(Sub::NSID); 815 } 816 url.set_path(&path); 817 818 let query_params = params.query_params(); 819 if !query_params.is_empty() { 820 let qs = query_params 821 .iter() 822 .map(|(k, v)| format!("{}={}", k, v)) 823 .collect::<Vec<_>>() 824 .join("&"); 825 url.set_query(Some(&qs)); 826 } else { 827 url.set_query(None); 828 } 829 830 let connection = self 831 .client 832 .connect_with_headers(url, self.opts.headers) 833 .await?; 834 835 Ok(SubscriptionStream::new(connection)) 836 } 837} 838 839/// Stateful subscription client trait. 840/// 841/// Analogous to `XrpcClient` but for WebSocket subscriptions. 842/// Provides a stateful interface for subscribing with configured base URI and options. 843#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))] 844pub trait SubscriptionClient: WebSocketClient { 845 /// Get the base URI for the client. 846 fn base_uri(&self) -> impl Future<Output = CowStr<'static>>; 847 848 /// Get the subscription options for the client. 849 fn subscription_opts(&self) -> impl Future<Output = SubscriptionOptions<'_>> { 850 async { SubscriptionOptions::default() } 851 } 852 853 /// Subscribe to an XRPC subscription endpoint using the client's base URI and options. 854 #[cfg(not(target_arch = "wasm32"))] 855 fn subscribe<Sub>( 856 &self, 857 params: &Sub, 858 ) -> impl Future<Output = Result<SubscriptionStream<Sub::Stream>, Self::Error>> 859 where 860 Sub: XrpcSubscription + Send + Sync, 861 Self: Sync; 862 863 /// Subscribe to an XRPC subscription endpoint using the client's base URI and options. 864 #[cfg(target_arch = "wasm32")] 865 fn subscribe<Sub>( 866 &self, 867 params: &Sub, 868 ) -> impl Future<Output = Result<SubscriptionStream<Sub::Stream>, Self::Error>> 869 where 870 Sub: XrpcSubscription + Send + Sync; 871 872 /// Subscribe with custom options. 873 #[cfg(not(target_arch = "wasm32"))] 874 fn subscribe_with_opts<Sub>( 875 &self, 876 params: &Sub, 877 opts: SubscriptionOptions<'_>, 878 ) -> impl Future<Output = Result<SubscriptionStream<Sub::Stream>, Self::Error>> 879 where 880 Sub: XrpcSubscription + Send + Sync, 881 Self: Sync; 882 883 /// Subscribe with custom options. 884 #[cfg(target_arch = "wasm32")] 885 fn subscribe_with_opts<Sub>( 886 &self, 887 params: &Sub, 888 opts: SubscriptionOptions<'_>, 889 ) -> impl Future<Output = Result<SubscriptionStream<Sub::Stream>, Self::Error>> 890 where 891 Sub: XrpcSubscription + Send + Sync; 892} 893 894/// Simple stateless subscription client wrapping a WebSocketClient. 895/// 896/// Analogous to a basic HTTP client but for WebSocket subscriptions. 897/// Does not manage sessions or authentication - useful for public subscriptions 898/// or when you want to handle auth manually via headers. 899pub struct BasicSubscriptionClient<W: WebSocketClient> { 900 client: W, 901 base_uri: CowStr<'static>, 902 opts: SubscriptionOptions<'static>, 903} 904 905impl<W: WebSocketClient> BasicSubscriptionClient<W> { 906 /// Create a new basic subscription client with the given WebSocket client and base URI. 907 pub fn new(client: W, base_uri: Url) -> Self { 908 let base_uri = base_uri.as_str().trim_end_matches("/"); 909 Self { 910 client, 911 base_uri: base_uri.to_cowstr().into_static(), 912 opts: SubscriptionOptions::default(), 913 } 914 } 915 916 /// Create with default options. 917 pub fn with_options(mut self, opts: SubscriptionOptions<'_>) -> Self { 918 self.opts = opts.into_static(); 919 self 920 } 921 922 /// Get a reference to the inner WebSocket client. 923 pub fn inner(&self) -> &W { 924 &self.client 925 } 926} 927 928impl<W: WebSocketClient> WebSocketClient for BasicSubscriptionClient<W> { 929 type Error = W::Error; 930 931 async fn connect(&self, url: Url) -> Result<WebSocketConnection, Self::Error> { 932 self.client.connect(url).await 933 } 934 935 async fn connect_with_headers( 936 &self, 937 url: Url, 938 headers: Vec<(CowStr<'_>, CowStr<'_>)>, 939 ) -> Result<WebSocketConnection, Self::Error> { 940 self.client.connect_with_headers(url, headers).await 941 } 942} 943 944impl<W: WebSocketClient> SubscriptionClient for BasicSubscriptionClient<W> { 945 async fn base_uri(&self) -> CowStr<'static> { 946 self.base_uri.clone() 947 } 948 949 async fn subscription_opts(&self) -> SubscriptionOptions<'_> { 950 self.opts.clone() 951 } 952 953 #[cfg(not(target_arch = "wasm32"))] 954 async fn subscribe<Sub>( 955 &self, 956 params: &Sub, 957 ) -> Result<SubscriptionStream<Sub::Stream>, Self::Error> 958 where 959 Sub: XrpcSubscription + Send + Sync, 960 Self: Sync, 961 { 962 let opts = self.subscription_opts().await; 963 self.subscribe_with_opts(params, opts).await 964 } 965 966 #[cfg(target_arch = "wasm32")] 967 async fn subscribe<Sub>( 968 &self, 969 params: &Sub, 970 ) -> Result<SubscriptionStream<Sub::Stream>, Self::Error> 971 where 972 Sub: XrpcSubscription + Send + Sync, 973 { 974 let opts = self.subscription_opts().await; 975 self.subscribe_with_opts(params, opts).await 976 } 977 978 #[cfg(not(target_arch = "wasm32"))] 979 async fn subscribe_with_opts<Sub>( 980 &self, 981 params: &Sub, 982 opts: SubscriptionOptions<'_>, 983 ) -> Result<SubscriptionStream<Sub::Stream>, Self::Error> 984 where 985 Sub: XrpcSubscription + Send + Sync, 986 Self: Sync, 987 { 988 let base = self.base_uri().await; 989 let base = Url::parse(&base).expect("Failed to parse base URL"); 990 self.subscription(base) 991 .with_options(opts) 992 .subscribe(params) 993 .await 994 } 995 996 #[cfg(target_arch = "wasm32")] 997 async fn subscribe_with_opts<Sub>( 998 &self, 999 params: &Sub, 1000 opts: SubscriptionOptions<'_>, 1001 ) -> Result<SubscriptionStream<Sub::Stream>, Self::Error> 1002 where 1003 Sub: XrpcSubscription + Send + Sync, 1004 { 1005 let base = self.base_uri().await; 1006 let base = Url::parse(&base).expect("Failed to parse base URL"); 1007 self.subscription(base) 1008 .with_options(opts) 1009 .subscribe(params) 1010 .await 1011 } 1012} 1013 1014/// Type alias for a basic subscription client using the default TungsteniteClient. 1015/// 1016/// Provides a simple, stateless WebSocket subscription client without session management. 1017/// Useful for public subscriptions or when handling authentication manually. 1018/// 1019/// # Example 1020/// 1021/// ```no_run 1022/// # use jacquard_common::xrpc::{TungsteniteSubscriptionClient, SubscriptionClient}; 1023/// # use url::Url; 1024/// # #[tokio::main] 1025/// # async fn main() -> Result<(), Box<dyn std::error::Error>> { 1026/// let base = Url::parse("wss://bsky.network")?; 1027/// let client = TungsteniteSubscriptionClient::from_base_uri(base); 1028/// // let conn = client.subscribe(&params).await?; 1029/// # Ok(()) 1030/// # } 1031/// ``` 1032pub type TungsteniteSubscriptionClient = 1033 BasicSubscriptionClient<crate::websocket::tungstenite_client::TungsteniteClient>; 1034 1035impl TungsteniteSubscriptionClient { 1036 /// Create a new Tungstenite-backed subscription client with the given base URI. 1037 pub fn from_base_uri(base_uri: Url) -> Self { 1038 let client = crate::websocket::tungstenite_client::TungsteniteClient::new(); 1039 BasicSubscriptionClient::new(client, base_uri) 1040 } 1041}