this repo has no description
1use std::sync::Arc;
2
3use async_tungstenite::tungstenite::protocol::frame;
4use async_tungstenite::tungstenite::{self, Utf8Bytes};
5use async_tungstenite::{async_tls::client_async_tls, tungstenite::Message};
6use bytes::Bytes;
7use futures_util::{SinkExt, StreamExt};
8use jacquard::{
9 CloseCode, CloseFrame, StreamError, WebSocketClient, WebSocketConnection, WsMessage, WsSink,
10 WsStream, WsText, client::BasicClient,
11};
12use smol::net::TcpStream;
13
14#[derive(Debug, Clone, Default)]
15pub struct AsyncTungsteniteClient;
16
17impl WebSocketClient for AsyncTungsteniteClient {
18 type Error = tungstenite::Error;
19
20 async fn connect(
21 &self,
22 uri: jacquard::deps::fluent_uri::Uri<&str>,
23 ) -> Result<WebSocketConnection, Self::Error> {
24 let uri_str = uri.as_str();
25
26 let authority = uri
27 .authority()
28 .ok_or_else(|| tungstenite::Error::Url(tungstenite::error::UrlError::NoHostName))?;
29
30 let domain = authority.host();
31 let default_port = if uri_str.starts_with("wss:") { 443 } else { 80 };
32 let port = authority
33 .port_to_u16()
34 .ok()
35 .flatten()
36 .unwrap_or(default_port);
37
38 let addr = format!("{domain}:{port}");
39
40 let tcp_stream = TcpStream::connect(&addr)
41 .await
42 .map_err(tungstenite::Error::Io)?;
43
44 let (ws_stream, _response) = client_async_tls(uri_str, tcp_stream).await?;
45
46 let (sink, stream) = ws_stream.split();
47
48 let rx_stream = stream.filter_map(|result| async move {
49 match result {
50 Ok(msg) => convert_to_ws_message(msg).map(Ok),
51 Err(e) => Some(Err(StreamError::transport(e))),
52 }
53 });
54 let rx = WsStream::new(rx_stream);
55
56 let tx_sink = sink.with(|msg: WsMessage| async move {
57 Ok::<_, tungstenite::Error>(convert_from_ws_message(msg))
58 });
59 let tx_sink_mapped = tx_sink.sink_map_err(StreamError::transport);
60 let tx = WsSink::new(tx_sink_mapped);
61
62 Ok(WebSocketConnection::new(tx, rx))
63 }
64}
65
66fn convert_to_ws_message(msg: Message) -> Option<WsMessage> {
67 match msg {
68 Message::Text(utf8_bytes) => {
69 // Both Utf8Bytes and WsText wrap Bytes with UTF-8 invariant — zero-copy
70 let bytes: Bytes = utf8_bytes.into();
71 // Safety: tungstenite already validated UTF-8
72 Some(WsMessage::Text(unsafe {
73 WsText::from_bytes_unchecked(bytes)
74 }))
75 }
76 Message::Binary(data) => Some(WsMessage::Binary(data)),
77 Message::Close(frame) => {
78 let close_frame = frame.map(|f| {
79 let code = convert_close_code(f.code);
80 let reason_bytes: Bytes = f.reason.into();
81 // Safety: CloseFrame reason was already UTF-8 validated by tungstenite
82 let reason_str = unsafe { core::str::from_utf8_unchecked(&reason_bytes) };
83 CloseFrame::new(code, reason_str.to_owned())
84 });
85 Some(WsMessage::Close(close_frame))
86 }
87 _ => None,
88 }
89}
90
91fn convert_from_ws_message(msg: WsMessage) -> Message {
92 match msg {
93 WsMessage::Text(text) => {
94 // WsText → Bytes → Utf8Bytes, both already UTF-8 validated
95 let bytes = text.into_bytes();
96 // Safety: WsText guarantees UTF-8
97 Message::Text(unsafe { Utf8Bytes::from_bytes_unchecked(bytes) })
98 }
99 WsMessage::Binary(bytes) => Message::Binary(bytes),
100 WsMessage::Close(frame) => {
101 let close_frame = frame.map(|f| {
102 let code: u16 = f.code.into();
103 frame::CloseFrame {
104 code: code.into(),
105 reason: Utf8Bytes::from(f.reason.to_string()),
106 }
107 });
108 Message::Close(close_frame)
109 }
110 }
111}
112
113fn convert_close_code(code: frame::coding::CloseCode) -> CloseCode {
114 use frame::coding::CloseCode as TC;
115 match code {
116 TC::Normal => CloseCode::Normal,
117 TC::Away => CloseCode::Away,
118 TC::Protocol => CloseCode::Protocol,
119 TC::Unsupported => CloseCode::Unsupported,
120 TC::Invalid => CloseCode::Invalid,
121 TC::Policy => CloseCode::Policy,
122 TC::Size => CloseCode::Size,
123 TC::Extension => CloseCode::Extension,
124 TC::Error => CloseCode::Error,
125 TC::Tls => CloseCode::Tls,
126 other => {
127 let raw: u16 = other.into();
128 CloseCode::from(raw)
129 }
130 }
131}
132
133use bevy::prelude::*;
134
135/// Shared AT Protocol client for XRPC calls.
136#[derive(Resource, Clone)]
137pub struct AtpClient(pub Arc<BasicClient>);
138
139pub fn setup_atp_client(mut commands: Commands) {
140 commands.insert_resource(AtpClient(Arc::new(BasicClient::unauthenticated())));
141}