1use core::net::IpAddr;
2use std::{
3 io::{Read, Write},
4 net::{Ipv4Addr, SocketAddr, SocketAddrV4, TcpStream, UdpSocket},
5 time::Duration,
6};
7
8use async_channel::{Receiver, Sender};
9use async_io::{Async, Timer};
10use bevy::{
11 app::{Plugin, Startup},
12 ecs::{error::Result, resource::Resource, system::Commands},
13 tasks::IoTaskPool,
14};
15use futures_concurrency::future::Race;
16use sachy_mdns::{
17 GROUP_SOCK_V4, MDNS_PORT,
18 client::query_service,
19 dns::{
20 records::{QType, Record},
21 reqres::Response,
22 traits::DnsParse,
23 },
24};
25use socket2::{Domain, Protocol, Socket, Type};
26
27#[derive(Debug, Resource)]
28pub struct DiscoverResponse(pub Receiver<InstanceDetails>);
29
30pub enum StrikeUpdateState {
31 Disconnected,
32 Connecting,
33 Connected,
34 Updating(striker_proto::StrikerResponse),
35}
36
37#[derive(Debug, Resource)]
38pub struct StrikeUpdates(pub Receiver<StrikeUpdateState>);
39
40#[derive(Debug, Resource)]
41pub struct StrikeRequests(pub Sender<striker_proto::StrikerRequest>);
42
43#[derive(Debug, Resource)]
44pub struct StrikeActions(pub Sender<StrikeAction>);
45
46pub enum StrikeAction {
47 Connect(SocketAddr),
48 Disconnect,
49}
50
51pub struct InstanceDetails {
52 pub host: String,
53 pub address: String,
54 pub port: u16,
55 pub ip: IpAddr,
56}
57
58#[derive(Debug, Resource)]
59pub struct MdnsSignaler(pub Sender<()>);
60
61fn create_mdns_socket() -> std::io::Result<Async<UdpSocket>> {
62 let sock = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
63 sock.set_reuse_address(true)?;
64 sock.set_nonblocking(true)?;
65 sock.bind(&SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, MDNS_PORT)).into())?;
66 let udp_socket = UdpSocket::from(sock);
67 Async::new_nonblocking(udp_socket)
68}
69
70pub fn setup_mdns_task(mut commands: Commands) -> Result {
71 let io = IoTaskPool::get();
72
73 let (signal_tx, signal_rx) = async_channel::bounded(1);
74 let (resp_tx, resp_rx) = async_channel::bounded(64);
75
76 let udp_socket = create_mdns_socket()?;
77
78 io.spawn(async move {
79 let mut buf = vec![0u8; 1028];
80 let mut query_buf = vec![0u8; 128];
81 let query = query_service("_picostrike._tcp.local", &mut query_buf).unwrap();
82
83 while signal_rx.recv().await.is_ok() {
84 let send_fut = async {
85 // Retry three times in case packets get lost coz UDP things
86 for _ in 0..3 {
87 udp_socket.send_to(query, GROUP_SOCK_V4).await.ok();
88 Timer::after(Duration::from_millis(250)).await;
89 }
90 };
91
92 let recv_fut = async {
93 while let Ok((read, _)) = udp_socket.recv_from(&mut buf).await {
94 let input = &buf[..read];
95 let Ok(resp) = Response::parse(&mut &*input, input) else {
96 continue;
97 };
98
99 if resp.answers.iter().any(|answer| {
100 answer.atype == QType::PTR && answer.name == "_picostrike._tcp.local"
101 }) && let Some(ip) = resp.additional.iter().find_map(|answer| match &answer
102 .record
103 {
104 Record::A(a) => Some(IpAddr::V4(a.address)),
105 Record::AAAA(aaaa) => Some(IpAddr::V6(aaaa.address)),
106 _ => None,
107 }) && let Some((name, srv)) = resp.additional.iter().find_map(|answer| {
108 if let Record::SRV(srv) = &answer.record {
109 Some((answer.name, srv))
110 } else {
111 None
112 }
113 }) {
114 resp_tx
115 .send(InstanceDetails {
116 host: name.to_string(),
117 address: srv.target.to_string(),
118 port: srv.port,
119 ip,
120 })
121 .await
122 .ok();
123 }
124 }
125 };
126
127 let timer = async {
128 Timer::after(Duration::from_millis(1000)).await;
129 };
130
131 let cancel = async {
132 signal_rx.recv().await.ok();
133 };
134
135 (send_fut, recv_fut, timer, cancel).race().await;
136 }
137 })
138 .detach();
139
140 commands.insert_resource(DiscoverResponse(resp_rx));
141 commands.insert_resource(MdnsSignaler(signal_tx));
142
143 Ok(())
144}
145
146pub fn setup_strike_connection(mut commands: Commands) {
147 let io = IoTaskPool::get();
148
149 let (signal_tx, signal_rx) = async_channel::bounded(2);
150 let (req_tx, req_rx) = async_channel::bounded(1);
151 let (resp_tx, resp_rx) = async_channel::bounded(64);
152
153 io.spawn(async move {
154 let mut read_buf = vec![0u8; 4096];
155 let mut write_buf = vec![0u8; 4096];
156
157 while let Ok(StrikeAction::Connect(addr)) = signal_rx.recv().await {
158 let net_fut = async {
159 loop {
160 resp_tx.send(StrikeUpdateState::Connecting).await.ok();
161 let Ok(stream) = Async::<TcpStream>::connect(addr).await else {
162 Timer::after(Duration::from_secs(1)).await;
163 continue;
164 };
165
166 resp_tx.send(StrikeUpdateState::Connected).await.ok();
167 stream.write_with(|s| s.set_nodelay(true)).await.ok();
168
169 let read_fut = async {
170 while let Ok(read) = stream.read_with(|mut a| a.read(&mut read_buf)).await {
171 let Ok(data) = striker_proto::receive_response(&mut read_buf[..read])
172 else {
173 continue;
174 };
175
176 if resp_tx.send(StrikeUpdateState::Updating(data)).await.is_err() {
177 break;
178 }
179 }
180 };
181
182 let write_fut = async {
183 while let Ok(req) = req_rx.recv().await {
184 let Ok(payload) = striker_proto::send_request(req, &mut write_buf)
185 else {
186 continue;
187 };
188
189 if stream.write_with(|mut s| s.write(payload)).await.is_err() {
190 break;
191 }
192 }
193 };
194
195 (read_fut, write_fut).race().await;
196
197 stream
198 .write_with(|s| s.shutdown(std::net::Shutdown::Both))
199 .await
200 .ok();
201
202 break;
203 }
204 };
205
206 let cancel_fut = async {
207 while signal_rx
208 .recv()
209 .await
210 .is_ok_and(|strike| !matches!(strike, StrikeAction::Disconnect))
211 {
212 }
213 };
214
215 (net_fut, cancel_fut).race().await;
216 resp_tx.send(StrikeUpdateState::Disconnected).await.ok();
217 }
218 })
219 .detach();
220
221 commands.insert_resource(StrikeActions(signal_tx));
222 commands.insert_resource(StrikeUpdates(resp_rx));
223 commands.insert_resource(StrikeRequests(req_tx));
224}
225
226pub struct NetPlugin;
227
228impl Plugin for NetPlugin {
229 fn build(&self, app: &mut bevy::app::App) {
230 app.add_systems(Startup, (setup_mdns_task, setup_strike_connection));
231 }
232}