P2P support library for the beaver compute environment
1/* SPDX Id: AGPL-3.0-or-later */
2
3mod message_protocol;
4mod packet;
5mod pairing_hook;
6mod pairing_protocol;
7mod state;
8
9use packet::PostcardPacket;
10use std::sync::Arc;
11use tokio::task::AbortHandle;
12
13use crate::packet::BasePacket;
14use iroh::address_lookup::UserData;
15use iroh::address_lookup::mdns::MdnsAddressLookup;
16use iroh::endpoint::{ClosedStream, ConnectError, ConnectionError, WriteError};
17use iroh::{Endpoint, EndpointId, RelayMode, protocol::Router};
18use log::{error, info};
19use n0_future::StreamExt;
20use std::sync::mpsc::Sender;
21use thiserror::Error;
22use tokio::sync::Mutex;
23
24use crate::pairing_hook::{MESSAGE_ALPN, PAIRING_ALPN};
25pub use crate::state::EndpointStatus;
26pub use crate::state::PeerEvent;
27use crate::state::{EndpointDescription, PairingCommand, SharedState, State};
28
29#[derive(Debug, Error)]
30pub enum PairingError {
31 #[error("Unknown remote endpoint")]
32 UnknownRemote,
33 #[error("Endpoint pairing already requested")]
34 AlreadyRequested,
35 #[error("Invalid state during pairing")]
36 InvalidState,
37 #[error("Failure to receive command Ack")]
38 FailedAck,
39 #[error("Pairing rejected")]
40 Rejected,
41 #[error("Failed to connect")]
42 Connect(#[from] ConnectError),
43 #[error("Connection error")]
44 Connection(#[from] ConnectionError),
45 #[error("Write error")]
46 Write(#[from] WriteError),
47 #[error("Closed stream")]
48 ClosedStream(#[from] ClosedStream),
49 #[error("Postcard error")]
50 Postcard(#[from] postcard::Error),
51}
52
53#[derive(Debug, Error)]
54pub enum MessageError {
55 #[error("Invalid state for messaging")]
56 InvalidState,
57 #[error("Unknown remote endpoint")]
58 UnknownRemote,
59 #[error("Peer is not paired")]
60 Unpaired,
61 #[error("Packet sending error")]
62 Packet(#[from] packet::PacketError),
63 #[error("Failed to connect")]
64 Connect(#[from] ConnectError),
65 #[error("Connection error")]
66 Connection(#[from] ConnectionError),
67}
68
69#[derive(Clone)]
70pub struct PairingManagerInner {
71 state: SharedState,
72
73 router: Router,
74 mdns_handle: AbortHandle,
75}
76
77#[derive(Clone)]
78pub struct PairingManager {
79 inner: Option<PairingManagerInner>,
80}
81
82impl PairingManager {
83 pub async fn create(sender: Sender<PeerEvent>, name: &str) -> Self {
84 let state = Arc::new(Mutex::new(State::new(sender.clone())));
85
86 let user_data: UserData = name.to_owned().try_into().unwrap();
87 let endpoint = Endpoint::empty_builder(RelayMode::Disabled)
88 .user_data_for_address_lookup(user_data)
89 .hooks(pairing_hook::PairingHook::new(Arc::clone(&state)))
90 .bind()
91 .await
92 .unwrap();
93
94 println!("Endpoint {name}, {}", endpoint.id());
95
96 let mdns = MdnsAddressLookup::builder().build(endpoint.id()).unwrap();
97 endpoint.address_lookup().add(mdns.clone());
98
99 let disco_state = Arc::clone(&state);
100 let mdns_handle = tokio::spawn(async move {
101 let mut events = mdns.subscribe().await;
102 while let Some(event) = events.next().await {
103 // Update the state.
104 disco_state.lock().await.on_discovery(&event);
105 }
106 });
107
108 let router = Router::builder(endpoint)
109 .accept(
110 PAIRING_ALPN,
111 pairing_protocol::PairingProtocol::new(Arc::clone(&state)),
112 )
113 .accept(
114 MESSAGE_ALPN,
115 message_protocol::MessageProtocol::new(Arc::clone(&state)),
116 )
117 .spawn();
118
119 let inner = PairingManagerInner {
120 state,
121 router,
122 mdns_handle: mdns_handle.abort_handle(),
123 };
124 Self { inner: Some(inner) }
125 }
126
127 // Send a pairing request to an endpoint.
128 pub async fn request_pairing(&self, id: &EndpointId) -> Result<(), PairingError> {
129 info!("Pairing with {id}");
130 let Some(ref inner) = self.inner else {
131 error!("Not initialized");
132 return Err(PairingError::InvalidState);
133 };
134
135 let addr = {
136 let state = inner.state.lock().await;
137 let Some(remote) = state.by_id(id) else {
138 error!("Can't request pairing from unknown remote {id}");
139 return Err(PairingError::UnknownRemote);
140 };
141 remote.addr()
142 };
143
144 // Don't request pairing twice with the same endpoint.
145 {
146 let mut state = inner.state.lock().await;
147 if !state.has_requested(id) {
148 state.set_pairing_requested(id);
149 } else {
150 error!("Already requested pairing with {id}");
151 return Err(PairingError::AlreadyRequested);
152 }
153 }
154
155 let connection = inner.router.endpoint().connect(addr, PAIRING_ALPN).await?;
156
157 // Send the the request.
158 let (mut sender, mut receiver) = connection.open_bi().await?;
159
160 let command = PairingCommand::Request;
161 PostcardPacket::send(command, &mut sender)
162 .await
163 .expect("Failed to send");
164
165 // Wait for accept or reject
166 let command: PairingCommand = PostcardPacket::recv(&mut receiver)
167 .await
168 .expect("Failed to receive");
169
170 let accepted = match command {
171 PairingCommand::Accept => true,
172 PairingCommand::Reject => false,
173 PairingCommand::Request => {
174 error!("Unexpected Request in response to a pairing request");
175 return Err(PairingError::InvalidState);
176 }
177 PairingCommand::Ack => {
178 error!("Unexpected Ack in response to a pairing request");
179 return Err(PairingError::InvalidState);
180 }
181 };
182
183 // Send Ack
184 PostcardPacket::send(PairingCommand::Ack, &mut sender)
185 .await
186 .expect("Failed to send Ack");
187
188 let mut state = inner.state.lock().await;
189 state.remove_pairing_requested(id);
190 sender.finish()?;
191
192 if accepted {
193 state.notify(PeerEvent::PairingAccepted(*id));
194 state.set_status(id, EndpointStatus::PairedConnected);
195 Ok(())
196 } else {
197 state.remove_pairing_requested(id);
198 state.notify(PeerEvent::PairingRejected(*id));
199 Err(PairingError::Rejected)
200 }
201 }
202
203 // Common code for accept/reject of a pairing request
204 async fn send_pairing_response(
205 &self,
206 from: &EndpointId,
207 command: PairingCommand,
208 ) -> Result<(), PairingError> {
209 info!("Sending pairing response: {command:?}");
210 let Some(ref inner) = self.inner else {
211 error!("Not initialized");
212 return Err(PairingError::InvalidState);
213 };
214
215 if inner.state.lock().await.by_id(from).is_none() {
216 error!("Can't send pairing response to unknown remote {from}");
217 return Err(PairingError::UnknownRemote);
218 };
219
220 let (sender, mut ack_receiver) = {
221 let mut state = inner.state.lock().await;
222 let Some((sender, ack_receiver)) = state.take_pairing_responder(from) else {
223 error!("No responder for {from}");
224 return Err(PairingError::UnknownRemote);
225 };
226 (sender, ack_receiver)
227 };
228
229 sender
230 .send(command)
231 .await
232 .expect("Failed to send to the responder");
233
234 // Wait for the Ack to be received.
235 let ack = ack_receiver.recv().await.unwrap_or(false);
236
237 if ack {
238 Ok(())
239 } else {
240 Err(PairingError::FailedAck)
241 }
242 }
243
244 // Accept a pairing request
245 pub async fn accept_pairing(&self, from: &EndpointId) -> Result<(), PairingError> {
246 let Some(ref inner) = self.inner else {
247 error!("Not initialized");
248 return Err(PairingError::InvalidState);
249 };
250
251 let res = self
252 .send_pairing_response(from, PairingCommand::Accept)
253 .await;
254 println!("accept_pairing ok 3");
255 let mut state = inner.state.lock().await;
256 let res = res.map(|_| {
257 // Add the endpoint to the set of pending acks if successfully sending.
258 state.set_pending_ack(from);
259 });
260 res
261 }
262
263 // Reject a pairing request
264 pub async fn reject_pairing(&self, from: &EndpointId) -> Result<(), PairingError> {
265 self.send_pairing_response(from, PairingCommand::Reject)
266 .await
267 }
268
269 // Get the list of current known peers.
270 pub async fn peers(&self) -> Vec<EndpointDescription> {
271 let Some(ref inner) = self.inner else {
272 error!("Not initialized");
273 return vec![];
274 };
275
276 inner
277 .state
278 .lock()
279 .await
280 .endpoints()
281 .values()
282 .map(|e| e.into())
283 .collect()
284 }
285
286 // Send a message to a paired peer.
287 pub async fn send_message(&self, to: &EndpointId, message: &[u8]) -> Result<(), MessageError> {
288 let Some(ref inner) = self.inner else {
289 error!("Not initialized");
290 return Err(MessageError::InvalidState);
291 };
292
293 let addr = {
294 let state = inner.state.lock().await;
295 let Some(remote) = state.by_id(to) else {
296 error!("Can't send messages to unknown remote {to}");
297 return Err(MessageError::UnknownRemote);
298 };
299 if !remote.is_paired() {
300 return Err(MessageError::Unpaired);
301 }
302 remote.addr()
303 };
304
305 if inner.state.lock().await.by_id(to).is_none() {
306 error!("Can't send message to unknown remote {to}");
307 return Err(MessageError::UnknownRemote);
308 };
309
310 // If we have an existing sender for that endpoint, use it.
311 if let Some(ref mut sender) = inner.state.lock().await.get_message_sender(to) {
312 BasePacket::send(message, sender).await?;
313 return Ok(());
314 }
315
316 // Otherwise, create a new connection
317 let connection = inner.router.endpoint().connect(addr, MESSAGE_ALPN).await?;
318 let (mut sender, mut receiver) = connection.open_bi().await?;
319 BasePacket::send(message, &mut sender).await?;
320 inner.state.lock().await.set_message_sender(to, sender);
321
322 // create the relaying task
323 // TODO: share with message_protocol
324 let state = Arc::clone(&inner.state);
325 let remote_id = *to;
326 tokio::spawn(async move {
327 loop {
328 match BasePacket::recv(&mut receiver).await {
329 Ok(payload) => {
330 state
331 .lock()
332 .await
333 .notify(PeerEvent::Message(remote_id, payload));
334 }
335 Err(err) => {
336 error!("Error reading message packet: {err}");
337 break;
338 }
339 }
340 }
341 });
342
343 Ok(())
344 }
345
346 pub async fn stop(&mut self) {
347 let Some(ref inner) = self.inner else {
348 error!("Not initialized");
349 return;
350 };
351
352 if !inner.router.is_shutdown() {
353 inner.mdns_handle.abort();
354 let _ = inner.router.shutdown().await;
355 }
356
357 // Force dropping the mdns advertising tasks.
358 self.inner = None;
359 }
360
361 pub async fn set_status(&self, endpoint: &EndpointId, status: EndpointStatus) {
362 let Some(ref inner) = self.inner else {
363 error!("Not initialized");
364 return;
365 };
366
367 inner.state.lock().await.set_status(endpoint, status);
368 }
369}