P2P support library for the beaver compute environment
at main 369 lines 12 kB view raw
1/* SPDX Id: AGPL-3.0-or-later */ 2 3mod message_protocol; 4mod packet; 5mod pairing_hook; 6mod pairing_protocol; 7mod state; 8 9use packet::PostcardPacket; 10use std::sync::Arc; 11use tokio::task::AbortHandle; 12 13use crate::packet::BasePacket; 14use iroh::address_lookup::UserData; 15use iroh::address_lookup::mdns::MdnsAddressLookup; 16use iroh::endpoint::{ClosedStream, ConnectError, ConnectionError, WriteError}; 17use iroh::{Endpoint, EndpointId, RelayMode, protocol::Router}; 18use log::{error, info}; 19use n0_future::StreamExt; 20use std::sync::mpsc::Sender; 21use thiserror::Error; 22use tokio::sync::Mutex; 23 24use crate::pairing_hook::{MESSAGE_ALPN, PAIRING_ALPN}; 25pub use crate::state::EndpointStatus; 26pub use crate::state::PeerEvent; 27use crate::state::{EndpointDescription, PairingCommand, SharedState, State}; 28 29#[derive(Debug, Error)] 30pub enum PairingError { 31 #[error("Unknown remote endpoint")] 32 UnknownRemote, 33 #[error("Endpoint pairing already requested")] 34 AlreadyRequested, 35 #[error("Invalid state during pairing")] 36 InvalidState, 37 #[error("Failure to receive command Ack")] 38 FailedAck, 39 #[error("Pairing rejected")] 40 Rejected, 41 #[error("Failed to connect")] 42 Connect(#[from] ConnectError), 43 #[error("Connection error")] 44 Connection(#[from] ConnectionError), 45 #[error("Write error")] 46 Write(#[from] WriteError), 47 #[error("Closed stream")] 48 ClosedStream(#[from] ClosedStream), 49 #[error("Postcard error")] 50 Postcard(#[from] postcard::Error), 51} 52 53#[derive(Debug, Error)] 54pub enum MessageError { 55 #[error("Invalid state for messaging")] 56 InvalidState, 57 #[error("Unknown remote endpoint")] 58 UnknownRemote, 59 #[error("Peer is not paired")] 60 Unpaired, 61 #[error("Packet sending error")] 62 Packet(#[from] packet::PacketError), 63 #[error("Failed to connect")] 64 Connect(#[from] ConnectError), 65 #[error("Connection error")] 66 Connection(#[from] ConnectionError), 67} 68 69#[derive(Clone)] 70pub struct PairingManagerInner { 71 state: SharedState, 72 73 router: Router, 74 mdns_handle: AbortHandle, 75} 76 77#[derive(Clone)] 78pub struct PairingManager { 79 inner: Option<PairingManagerInner>, 80} 81 82impl PairingManager { 83 pub async fn create(sender: Sender<PeerEvent>, name: &str) -> Self { 84 let state = Arc::new(Mutex::new(State::new(sender.clone()))); 85 86 let user_data: UserData = name.to_owned().try_into().unwrap(); 87 let endpoint = Endpoint::empty_builder(RelayMode::Disabled) 88 .user_data_for_address_lookup(user_data) 89 .hooks(pairing_hook::PairingHook::new(Arc::clone(&state))) 90 .bind() 91 .await 92 .unwrap(); 93 94 println!("Endpoint {name}, {}", endpoint.id()); 95 96 let mdns = MdnsAddressLookup::builder().build(endpoint.id()).unwrap(); 97 endpoint.address_lookup().add(mdns.clone()); 98 99 let disco_state = Arc::clone(&state); 100 let mdns_handle = tokio::spawn(async move { 101 let mut events = mdns.subscribe().await; 102 while let Some(event) = events.next().await { 103 // Update the state. 104 disco_state.lock().await.on_discovery(&event); 105 } 106 }); 107 108 let router = Router::builder(endpoint) 109 .accept( 110 PAIRING_ALPN, 111 pairing_protocol::PairingProtocol::new(Arc::clone(&state)), 112 ) 113 .accept( 114 MESSAGE_ALPN, 115 message_protocol::MessageProtocol::new(Arc::clone(&state)), 116 ) 117 .spawn(); 118 119 let inner = PairingManagerInner { 120 state, 121 router, 122 mdns_handle: mdns_handle.abort_handle(), 123 }; 124 Self { inner: Some(inner) } 125 } 126 127 // Send a pairing request to an endpoint. 128 pub async fn request_pairing(&self, id: &EndpointId) -> Result<(), PairingError> { 129 info!("Pairing with {id}"); 130 let Some(ref inner) = self.inner else { 131 error!("Not initialized"); 132 return Err(PairingError::InvalidState); 133 }; 134 135 let addr = { 136 let state = inner.state.lock().await; 137 let Some(remote) = state.by_id(id) else { 138 error!("Can't request pairing from unknown remote {id}"); 139 return Err(PairingError::UnknownRemote); 140 }; 141 remote.addr() 142 }; 143 144 // Don't request pairing twice with the same endpoint. 145 { 146 let mut state = inner.state.lock().await; 147 if !state.has_requested(id) { 148 state.set_pairing_requested(id); 149 } else { 150 error!("Already requested pairing with {id}"); 151 return Err(PairingError::AlreadyRequested); 152 } 153 } 154 155 let connection = inner.router.endpoint().connect(addr, PAIRING_ALPN).await?; 156 157 // Send the the request. 158 let (mut sender, mut receiver) = connection.open_bi().await?; 159 160 let command = PairingCommand::Request; 161 PostcardPacket::send(command, &mut sender) 162 .await 163 .expect("Failed to send"); 164 165 // Wait for accept or reject 166 let command: PairingCommand = PostcardPacket::recv(&mut receiver) 167 .await 168 .expect("Failed to receive"); 169 170 let accepted = match command { 171 PairingCommand::Accept => true, 172 PairingCommand::Reject => false, 173 PairingCommand::Request => { 174 error!("Unexpected Request in response to a pairing request"); 175 return Err(PairingError::InvalidState); 176 } 177 PairingCommand::Ack => { 178 error!("Unexpected Ack in response to a pairing request"); 179 return Err(PairingError::InvalidState); 180 } 181 }; 182 183 // Send Ack 184 PostcardPacket::send(PairingCommand::Ack, &mut sender) 185 .await 186 .expect("Failed to send Ack"); 187 188 let mut state = inner.state.lock().await; 189 state.remove_pairing_requested(id); 190 sender.finish()?; 191 192 if accepted { 193 state.notify(PeerEvent::PairingAccepted(*id)); 194 state.set_status(id, EndpointStatus::PairedConnected); 195 Ok(()) 196 } else { 197 state.remove_pairing_requested(id); 198 state.notify(PeerEvent::PairingRejected(*id)); 199 Err(PairingError::Rejected) 200 } 201 } 202 203 // Common code for accept/reject of a pairing request 204 async fn send_pairing_response( 205 &self, 206 from: &EndpointId, 207 command: PairingCommand, 208 ) -> Result<(), PairingError> { 209 info!("Sending pairing response: {command:?}"); 210 let Some(ref inner) = self.inner else { 211 error!("Not initialized"); 212 return Err(PairingError::InvalidState); 213 }; 214 215 if inner.state.lock().await.by_id(from).is_none() { 216 error!("Can't send pairing response to unknown remote {from}"); 217 return Err(PairingError::UnknownRemote); 218 }; 219 220 let (sender, mut ack_receiver) = { 221 let mut state = inner.state.lock().await; 222 let Some((sender, ack_receiver)) = state.take_pairing_responder(from) else { 223 error!("No responder for {from}"); 224 return Err(PairingError::UnknownRemote); 225 }; 226 (sender, ack_receiver) 227 }; 228 229 sender 230 .send(command) 231 .await 232 .expect("Failed to send to the responder"); 233 234 // Wait for the Ack to be received. 235 let ack = ack_receiver.recv().await.unwrap_or(false); 236 237 if ack { 238 Ok(()) 239 } else { 240 Err(PairingError::FailedAck) 241 } 242 } 243 244 // Accept a pairing request 245 pub async fn accept_pairing(&self, from: &EndpointId) -> Result<(), PairingError> { 246 let Some(ref inner) = self.inner else { 247 error!("Not initialized"); 248 return Err(PairingError::InvalidState); 249 }; 250 251 let res = self 252 .send_pairing_response(from, PairingCommand::Accept) 253 .await; 254 println!("accept_pairing ok 3"); 255 let mut state = inner.state.lock().await; 256 let res = res.map(|_| { 257 // Add the endpoint to the set of pending acks if successfully sending. 258 state.set_pending_ack(from); 259 }); 260 res 261 } 262 263 // Reject a pairing request 264 pub async fn reject_pairing(&self, from: &EndpointId) -> Result<(), PairingError> { 265 self.send_pairing_response(from, PairingCommand::Reject) 266 .await 267 } 268 269 // Get the list of current known peers. 270 pub async fn peers(&self) -> Vec<EndpointDescription> { 271 let Some(ref inner) = self.inner else { 272 error!("Not initialized"); 273 return vec![]; 274 }; 275 276 inner 277 .state 278 .lock() 279 .await 280 .endpoints() 281 .values() 282 .map(|e| e.into()) 283 .collect() 284 } 285 286 // Send a message to a paired peer. 287 pub async fn send_message(&self, to: &EndpointId, message: &[u8]) -> Result<(), MessageError> { 288 let Some(ref inner) = self.inner else { 289 error!("Not initialized"); 290 return Err(MessageError::InvalidState); 291 }; 292 293 let addr = { 294 let state = inner.state.lock().await; 295 let Some(remote) = state.by_id(to) else { 296 error!("Can't send messages to unknown remote {to}"); 297 return Err(MessageError::UnknownRemote); 298 }; 299 if !remote.is_paired() { 300 return Err(MessageError::Unpaired); 301 } 302 remote.addr() 303 }; 304 305 if inner.state.lock().await.by_id(to).is_none() { 306 error!("Can't send message to unknown remote {to}"); 307 return Err(MessageError::UnknownRemote); 308 }; 309 310 // If we have an existing sender for that endpoint, use it. 311 if let Some(ref mut sender) = inner.state.lock().await.get_message_sender(to) { 312 BasePacket::send(message, sender).await?; 313 return Ok(()); 314 } 315 316 // Otherwise, create a new connection 317 let connection = inner.router.endpoint().connect(addr, MESSAGE_ALPN).await?; 318 let (mut sender, mut receiver) = connection.open_bi().await?; 319 BasePacket::send(message, &mut sender).await?; 320 inner.state.lock().await.set_message_sender(to, sender); 321 322 // create the relaying task 323 // TODO: share with message_protocol 324 let state = Arc::clone(&inner.state); 325 let remote_id = *to; 326 tokio::spawn(async move { 327 loop { 328 match BasePacket::recv(&mut receiver).await { 329 Ok(payload) => { 330 state 331 .lock() 332 .await 333 .notify(PeerEvent::Message(remote_id, payload)); 334 } 335 Err(err) => { 336 error!("Error reading message packet: {err}"); 337 break; 338 } 339 } 340 } 341 }); 342 343 Ok(()) 344 } 345 346 pub async fn stop(&mut self) { 347 let Some(ref inner) = self.inner else { 348 error!("Not initialized"); 349 return; 350 }; 351 352 if !inner.router.is_shutdown() { 353 inner.mdns_handle.abort(); 354 let _ = inner.router.shutdown().await; 355 } 356 357 // Force dropping the mdns advertising tasks. 358 self.inner = None; 359 } 360 361 pub async fn set_status(&self, endpoint: &EndpointId, status: EndpointStatus) { 362 let Some(ref inner) = self.inner else { 363 error!("Not initialized"); 364 return; 365 }; 366 367 inner.state.lock().await.set_status(endpoint, status); 368 } 369}