Bevy+Ratutui powered Monitoring of Pico-Strike devices
1use std::{
2 io::{Read, Write},
3 net::{SocketAddr, TcpStream},
4 sync::{Arc, LazyLock, Mutex, OnceLock},
5 time::Duration,
6};
7
8use async_channel::{Receiver, Sender};
9use async_io::{Async, Timer};
10use bevy::{
11 app::{Plugin, Startup},
12 ecs::{resource::Resource, system::Commands},
13 tasks::IoTaskPool,
14};
15use futures_concurrency::future::Race;
16use snow::{Builder, TransportState, params::NoiseParams};
17
18use crate::constants::NOISE_PSK;
19
20static PARAMS: LazyLock<NoiseParams> =
21 LazyLock::new(|| "Noise_XXpsk3_25519_ChaChaPoly_BLAKE2s".parse().unwrap());
22static LOCAL_PRIVATE_KEY: OnceLock<snow::Keypair> = OnceLock::new();
23
24pub enum StrikeUpdateState {
25 Disconnected,
26 Connecting,
27 Connected,
28 Updating(striker_proto::StrikerResponse),
29}
30
31#[derive(Debug, Resource)]
32pub struct StrikeUpdates(pub Receiver<StrikeUpdateState>);
33
34#[derive(Debug, Resource)]
35pub struct StrikeRequests(pub Sender<striker_proto::StrikerRequest>);
36
37#[derive(Debug, Resource)]
38pub struct StrikeActions(pub Sender<StrikeAction>);
39
40pub enum StrikeAction {
41 Connect(SocketAddr),
42 Disconnect,
43}
44
45pub fn setup_strike_connection(mut commands: Commands) {
46 let io = IoTaskPool::get();
47
48 let (signal_tx, signal_rx) = async_channel::bounded(2);
49 let (req_tx, req_rx) = async_channel::bounded(1);
50 let (resp_tx, resp_rx) = async_channel::bounded(64);
51
52 io.spawn(async move {
53 let mut read_buf = vec![0u8; 4096];
54 let mut enc_buf = vec![0u8; 4096];
55 let mut write_buf = vec![0u8; 4096];
56
57 while let Ok(StrikeAction::Connect(addr)) = signal_rx.recv().await {
58 let net_fut = async {
59 loop {
60 resp_tx.send(StrikeUpdateState::Connecting).await.ok();
61 let Ok(stream) = Async::<TcpStream>::connect(addr).await else {
62 Timer::after(Duration::from_secs(1)).await;
63 continue;
64 };
65
66 resp_tx.send(StrikeUpdateState::Connected).await.ok();
67 stream.write_with(|s| s.set_nodelay(true)).await.ok();
68
69 let Ok(transport) = noise_handshake(&stream).await else {
70 resp_tx.send(StrikeUpdateState::Disconnected).await.ok();
71 continue;
72 };
73
74 let transport_access = Arc::new(Mutex::new(transport));
75
76 let read_fut = async {
77 while let Ok(buh) = recv(&stream).await {
78 if let Ok(decrypted) = {
79 transport_access
80 .lock()
81 .unwrap()
82 .read_message(&buh, &mut read_buf)
83 } && let Ok(data) =
84 striker_proto::receive_response(&mut read_buf[..decrypted])
85 && resp_tx
86 .send(StrikeUpdateState::Updating(data))
87 .await
88 .is_err()
89 {
90 break;
91 }
92 }
93 };
94
95 let write_fut = async {
96 while let Ok(req) = req_rx.recv().await {
97 if let Ok(payload) = striker_proto::send_request(req, &mut write_buf)
98 && let Ok(encrypted) = {
99 transport_access
100 .lock()
101 .unwrap()
102 .write_message(payload, &mut enc_buf)
103 }
104 && send(&stream, &enc_buf[..encrypted]).await.is_err()
105 {
106 break;
107 };
108 }
109 };
110
111 (read_fut, write_fut).race().await;
112
113 stream
114 .write_with(|s| s.shutdown(std::net::Shutdown::Both))
115 .await
116 .ok();
117
118 break;
119 }
120 };
121
122 let cancel_fut = async {
123 while signal_rx
124 .recv()
125 .await
126 .is_ok_and(|strike| !matches!(strike, StrikeAction::Disconnect))
127 {
128 }
129 };
130
131 (net_fut, cancel_fut).race().await;
132 resp_tx.send(StrikeUpdateState::Disconnected).await.ok();
133 }
134 })
135 .detach();
136
137 commands.insert_resource(StrikeActions(signal_tx));
138 commands.insert_resource(StrikeUpdates(resp_rx));
139 commands.insert_resource(StrikeRequests(req_tx));
140}
141
142pub struct NetPlugin;
143
144impl Plugin for NetPlugin {
145 fn build(&self, app: &mut bevy::app::App) {
146 app.add_systems(Startup, setup_strike_connection);
147 }
148}
149
150async fn noise_handshake(tcp: &Async<TcpStream>) -> color_eyre::Result<TransportState> {
151 let builder = Builder::new(PARAMS.clone());
152 let static_key = LOCAL_PRIVATE_KEY.get_or_init(|| builder.generate_keypair().unwrap());
153
154 let mut noise = builder
155 .local_private_key(&static_key.private)?
156 .psk(3, &NOISE_PSK)?
157 .build_initiator()?;
158
159 let mut payload = vec![0u8; 2048];
160
161 // -> e
162 let len = noise.write_message(&[], &mut payload)?;
163
164 send(tcp, &payload[..len]).await?;
165
166 // <- e, ee, s, es
167 noise.read_message(&recv(tcp).await?, &mut payload)?;
168
169 // -> s, se
170 let len = noise.write_message(&[], &mut payload)?;
171
172 send(tcp, &payload[..len]).await?;
173
174 let transport = noise.into_transport_mode()?;
175
176 Ok(transport)
177}
178
179async fn recv(stream: &Async<TcpStream>) -> std::io::Result<Vec<u8>> {
180 let mut msg_len_buf = [0_u8; 2];
181
182 stream
183 .read_with(|mut stream| {
184 stream.read_exact(&mut msg_len_buf)?;
185 let msg_len = usize::from(u16::from_be_bytes(msg_len_buf));
186 let mut msg = vec![0_u8; msg_len];
187 stream.read_exact(&mut msg[..])?;
188
189 Ok(msg)
190 })
191 .await
192}
193
194/// Hyper-basic stream transport sender. 16-bit BE size followed by payload.
195async fn send(stream: &Async<TcpStream>, buf: &[u8]) -> std::io::Result<()> {
196 stream
197 .write_with(|mut stream| {
198 let len = u16::try_from(buf.len()).expect("message too large");
199 stream.write_all(&len.to_be_bytes())?;
200 stream.write_all(buf)
201 })
202 .await
203}