Bevy+Ratutui powered Monitoring of Pico-Strike devices
at noise-proto 203 lines 6.6 kB view raw
1use std::{ 2 io::{Read, Write}, 3 net::{SocketAddr, TcpStream}, 4 sync::{Arc, LazyLock, Mutex, OnceLock}, 5 time::Duration, 6}; 7 8use async_channel::{Receiver, Sender}; 9use async_io::{Async, Timer}; 10use bevy::{ 11 app::{Plugin, Startup}, 12 ecs::{resource::Resource, system::Commands}, 13 tasks::IoTaskPool, 14}; 15use futures_concurrency::future::Race; 16use snow::{Builder, TransportState, params::NoiseParams}; 17 18use crate::constants::NOISE_PSK; 19 20static PARAMS: LazyLock<NoiseParams> = 21 LazyLock::new(|| "Noise_XXpsk3_25519_ChaChaPoly_BLAKE2s".parse().unwrap()); 22static LOCAL_PRIVATE_KEY: OnceLock<snow::Keypair> = OnceLock::new(); 23 24pub enum StrikeUpdateState { 25 Disconnected, 26 Connecting, 27 Connected, 28 Updating(striker_proto::StrikerResponse), 29} 30 31#[derive(Debug, Resource)] 32pub struct StrikeUpdates(pub Receiver<StrikeUpdateState>); 33 34#[derive(Debug, Resource)] 35pub struct StrikeRequests(pub Sender<striker_proto::StrikerRequest>); 36 37#[derive(Debug, Resource)] 38pub struct StrikeActions(pub Sender<StrikeAction>); 39 40pub enum StrikeAction { 41 Connect(SocketAddr), 42 Disconnect, 43} 44 45pub fn setup_strike_connection(mut commands: Commands) { 46 let io = IoTaskPool::get(); 47 48 let (signal_tx, signal_rx) = async_channel::bounded(2); 49 let (req_tx, req_rx) = async_channel::bounded(1); 50 let (resp_tx, resp_rx) = async_channel::bounded(64); 51 52 io.spawn(async move { 53 let mut read_buf = vec![0u8; 4096]; 54 let mut enc_buf = vec![0u8; 4096]; 55 let mut write_buf = vec![0u8; 4096]; 56 57 while let Ok(StrikeAction::Connect(addr)) = signal_rx.recv().await { 58 let net_fut = async { 59 loop { 60 resp_tx.send(StrikeUpdateState::Connecting).await.ok(); 61 let Ok(stream) = Async::<TcpStream>::connect(addr).await else { 62 Timer::after(Duration::from_secs(1)).await; 63 continue; 64 }; 65 66 resp_tx.send(StrikeUpdateState::Connected).await.ok(); 67 stream.write_with(|s| s.set_nodelay(true)).await.ok(); 68 69 let Ok(transport) = noise_handshake(&stream).await else { 70 resp_tx.send(StrikeUpdateState::Disconnected).await.ok(); 71 continue; 72 }; 73 74 let transport_access = Arc::new(Mutex::new(transport)); 75 76 let read_fut = async { 77 while let Ok(buh) = recv(&stream).await { 78 if let Ok(decrypted) = { 79 transport_access 80 .lock() 81 .unwrap() 82 .read_message(&buh, &mut read_buf) 83 } && let Ok(data) = 84 striker_proto::receive_response(&mut read_buf[..decrypted]) 85 && resp_tx 86 .send(StrikeUpdateState::Updating(data)) 87 .await 88 .is_err() 89 { 90 break; 91 } 92 } 93 }; 94 95 let write_fut = async { 96 while let Ok(req) = req_rx.recv().await { 97 if let Ok(payload) = striker_proto::send_request(req, &mut write_buf) 98 && let Ok(encrypted) = { 99 transport_access 100 .lock() 101 .unwrap() 102 .write_message(payload, &mut enc_buf) 103 } 104 && send(&stream, &enc_buf[..encrypted]).await.is_err() 105 { 106 break; 107 }; 108 } 109 }; 110 111 (read_fut, write_fut).race().await; 112 113 stream 114 .write_with(|s| s.shutdown(std::net::Shutdown::Both)) 115 .await 116 .ok(); 117 118 break; 119 } 120 }; 121 122 let cancel_fut = async { 123 while signal_rx 124 .recv() 125 .await 126 .is_ok_and(|strike| !matches!(strike, StrikeAction::Disconnect)) 127 { 128 } 129 }; 130 131 (net_fut, cancel_fut).race().await; 132 resp_tx.send(StrikeUpdateState::Disconnected).await.ok(); 133 } 134 }) 135 .detach(); 136 137 commands.insert_resource(StrikeActions(signal_tx)); 138 commands.insert_resource(StrikeUpdates(resp_rx)); 139 commands.insert_resource(StrikeRequests(req_tx)); 140} 141 142pub struct NetPlugin; 143 144impl Plugin for NetPlugin { 145 fn build(&self, app: &mut bevy::app::App) { 146 app.add_systems(Startup, setup_strike_connection); 147 } 148} 149 150async fn noise_handshake(tcp: &Async<TcpStream>) -> color_eyre::Result<TransportState> { 151 let builder = Builder::new(PARAMS.clone()); 152 let static_key = LOCAL_PRIVATE_KEY.get_or_init(|| builder.generate_keypair().unwrap()); 153 154 let mut noise = builder 155 .local_private_key(&static_key.private)? 156 .psk(3, &NOISE_PSK)? 157 .build_initiator()?; 158 159 let mut payload = vec![0u8; 2048]; 160 161 // -> e 162 let len = noise.write_message(&[], &mut payload)?; 163 164 send(tcp, &payload[..len]).await?; 165 166 // <- e, ee, s, es 167 noise.read_message(&recv(tcp).await?, &mut payload)?; 168 169 // -> s, se 170 let len = noise.write_message(&[], &mut payload)?; 171 172 send(tcp, &payload[..len]).await?; 173 174 let transport = noise.into_transport_mode()?; 175 176 Ok(transport) 177} 178 179async fn recv(stream: &Async<TcpStream>) -> std::io::Result<Vec<u8>> { 180 let mut msg_len_buf = [0_u8; 2]; 181 182 stream 183 .read_with(|mut stream| { 184 stream.read_exact(&mut msg_len_buf)?; 185 let msg_len = usize::from(u16::from_be_bytes(msg_len_buf)); 186 let mut msg = vec![0_u8; msg_len]; 187 stream.read_exact(&mut msg[..])?; 188 189 Ok(msg) 190 }) 191 .await 192} 193 194/// Hyper-basic stream transport sender. 16-bit BE size followed by payload. 195async fn send(stream: &Async<TcpStream>, buf: &[u8]) -> std::io::Result<()> { 196 stream 197 .write_with(|mut stream| { 198 let len = u16::try_from(buf.len()).expect("message too large"); 199 stream.write_all(&len.to_be_bytes())?; 200 stream.write_all(buf) 201 }) 202 .await 203}