/* SPDX Id: AGPL-3.0-or-later */ mod message_protocol; mod packet; mod pairing_hook; mod pairing_protocol; mod state; use packet::PostcardPacket; use std::sync::Arc; use tokio::task::AbortHandle; use crate::packet::BasePacket; use iroh::address_lookup::UserData; use iroh::address_lookup::mdns::MdnsAddressLookup; use iroh::endpoint::{ClosedStream, ConnectError, ConnectionError, WriteError}; use iroh::{Endpoint, EndpointId, RelayMode, protocol::Router}; use log::{error, info}; use n0_future::StreamExt; use std::sync::mpsc::Sender; use thiserror::Error; use tokio::sync::Mutex; use crate::pairing_hook::{MESSAGE_ALPN, PAIRING_ALPN}; pub use crate::state::EndpointStatus; pub use crate::state::PeerEvent; use crate::state::{EndpointDescription, PairingCommand, SharedState, State}; #[derive(Debug, Error)] pub enum PairingError { #[error("Unknown remote endpoint")] UnknownRemote, #[error("Endpoint pairing already requested")] AlreadyRequested, #[error("Invalid state during pairing")] InvalidState, #[error("Failure to receive command Ack")] FailedAck, #[error("Pairing rejected")] Rejected, #[error("Failed to connect")] Connect(#[from] ConnectError), #[error("Connection error")] Connection(#[from] ConnectionError), #[error("Write error")] Write(#[from] WriteError), #[error("Closed stream")] ClosedStream(#[from] ClosedStream), #[error("Postcard error")] Postcard(#[from] postcard::Error), } #[derive(Debug, Error)] pub enum MessageError { #[error("Invalid state for messaging")] InvalidState, #[error("Unknown remote endpoint")] UnknownRemote, #[error("Peer is not paired")] Unpaired, #[error("Packet sending error")] Packet(#[from] packet::PacketError), #[error("Failed to connect")] Connect(#[from] ConnectError), #[error("Connection error")] Connection(#[from] ConnectionError), } #[derive(Clone)] pub struct PairingManagerInner { state: SharedState, router: Router, mdns_handle: AbortHandle, } #[derive(Clone)] pub struct PairingManager { inner: Option, } impl PairingManager { pub async fn create(sender: Sender, name: &str) -> Self { let state = Arc::new(Mutex::new(State::new(sender.clone()))); let user_data: UserData = name.to_owned().try_into().unwrap(); let endpoint = Endpoint::empty_builder(RelayMode::Disabled) .user_data_for_address_lookup(user_data) .hooks(pairing_hook::PairingHook::new(Arc::clone(&state))) .bind() .await .unwrap(); println!("Endpoint {name}, {}", endpoint.id()); let mdns = MdnsAddressLookup::builder().build(endpoint.id()).unwrap(); endpoint.address_lookup().add(mdns.clone()); let disco_state = Arc::clone(&state); let mdns_handle = tokio::spawn(async move { let mut events = mdns.subscribe().await; while let Some(event) = events.next().await { // Update the state. disco_state.lock().await.on_discovery(&event); } }); let router = Router::builder(endpoint) .accept( PAIRING_ALPN, pairing_protocol::PairingProtocol::new(Arc::clone(&state)), ) .accept( MESSAGE_ALPN, message_protocol::MessageProtocol::new(Arc::clone(&state)), ) .spawn(); let inner = PairingManagerInner { state, router, mdns_handle: mdns_handle.abort_handle(), }; Self { inner: Some(inner) } } // Send a pairing request to an endpoint. pub async fn request_pairing(&self, id: &EndpointId) -> Result<(), PairingError> { info!("Pairing with {id}"); let Some(ref inner) = self.inner else { error!("Not initialized"); return Err(PairingError::InvalidState); }; let addr = { let state = inner.state.lock().await; let Some(remote) = state.by_id(id) else { error!("Can't request pairing from unknown remote {id}"); return Err(PairingError::UnknownRemote); }; remote.addr() }; // Don't request pairing twice with the same endpoint. { let mut state = inner.state.lock().await; if !state.has_requested(id) { state.set_pairing_requested(id); } else { error!("Already requested pairing with {id}"); return Err(PairingError::AlreadyRequested); } } let connection = inner.router.endpoint().connect(addr, PAIRING_ALPN).await?; // Send the the request. let (mut sender, mut receiver) = connection.open_bi().await?; let command = PairingCommand::Request; PostcardPacket::send(command, &mut sender) .await .expect("Failed to send"); // Wait for accept or reject let command: PairingCommand = PostcardPacket::recv(&mut receiver) .await .expect("Failed to receive"); let accepted = match command { PairingCommand::Accept => true, PairingCommand::Reject => false, PairingCommand::Request => { error!("Unexpected Request in response to a pairing request"); return Err(PairingError::InvalidState); } PairingCommand::Ack => { error!("Unexpected Ack in response to a pairing request"); return Err(PairingError::InvalidState); } }; // Send Ack PostcardPacket::send(PairingCommand::Ack, &mut sender) .await .expect("Failed to send Ack"); let mut state = inner.state.lock().await; state.remove_pairing_requested(id); sender.finish()?; if accepted { state.notify(PeerEvent::PairingAccepted(*id)); state.set_status(id, EndpointStatus::PairedConnected); Ok(()) } else { state.remove_pairing_requested(id); state.notify(PeerEvent::PairingRejected(*id)); Err(PairingError::Rejected) } } // Common code for accept/reject of a pairing request async fn send_pairing_response( &self, from: &EndpointId, command: PairingCommand, ) -> Result<(), PairingError> { info!("Sending pairing response: {command:?}"); let Some(ref inner) = self.inner else { error!("Not initialized"); return Err(PairingError::InvalidState); }; if inner.state.lock().await.by_id(from).is_none() { error!("Can't send pairing response to unknown remote {from}"); return Err(PairingError::UnknownRemote); }; let (sender, mut ack_receiver) = { let mut state = inner.state.lock().await; let Some((sender, ack_receiver)) = state.take_pairing_responder(from) else { error!("No responder for {from}"); return Err(PairingError::UnknownRemote); }; (sender, ack_receiver) }; sender .send(command) .await .expect("Failed to send to the responder"); // Wait for the Ack to be received. let ack = ack_receiver.recv().await.unwrap_or(false); if ack { Ok(()) } else { Err(PairingError::FailedAck) } } // Accept a pairing request pub async fn accept_pairing(&self, from: &EndpointId) -> Result<(), PairingError> { let Some(ref inner) = self.inner else { error!("Not initialized"); return Err(PairingError::InvalidState); }; let res = self .send_pairing_response(from, PairingCommand::Accept) .await; println!("accept_pairing ok 3"); let mut state = inner.state.lock().await; let res = res.map(|_| { // Add the endpoint to the set of pending acks if successfully sending. state.set_pending_ack(from); }); res } // Reject a pairing request pub async fn reject_pairing(&self, from: &EndpointId) -> Result<(), PairingError> { self.send_pairing_response(from, PairingCommand::Reject) .await } // Get the list of current known peers. pub async fn peers(&self) -> Vec { let Some(ref inner) = self.inner else { error!("Not initialized"); return vec![]; }; inner .state .lock() .await .endpoints() .values() .map(|e| e.into()) .collect() } // Send a message to a paired peer. pub async fn send_message(&self, to: &EndpointId, message: &[u8]) -> Result<(), MessageError> { let Some(ref inner) = self.inner else { error!("Not initialized"); return Err(MessageError::InvalidState); }; let addr = { let state = inner.state.lock().await; let Some(remote) = state.by_id(to) else { error!("Can't send messages to unknown remote {to}"); return Err(MessageError::UnknownRemote); }; if !remote.is_paired() { return Err(MessageError::Unpaired); } remote.addr() }; if inner.state.lock().await.by_id(to).is_none() { error!("Can't send message to unknown remote {to}"); return Err(MessageError::UnknownRemote); }; // If we have an existing sender for that endpoint, use it. if let Some(ref mut sender) = inner.state.lock().await.get_message_sender(to) { BasePacket::send(message, sender).await?; return Ok(()); } // Otherwise, create a new connection let connection = inner.router.endpoint().connect(addr, MESSAGE_ALPN).await?; let (mut sender, mut receiver) = connection.open_bi().await?; BasePacket::send(message, &mut sender).await?; inner.state.lock().await.set_message_sender(to, sender); // create the relaying task // TODO: share with message_protocol let state = Arc::clone(&inner.state); let remote_id = *to; tokio::spawn(async move { loop { match BasePacket::recv(&mut receiver).await { Ok(payload) => { state .lock() .await .notify(PeerEvent::Message(remote_id, payload)); } Err(err) => { error!("Error reading message packet: {err}"); break; } } } }); Ok(()) } pub async fn stop(&mut self) { let Some(ref inner) = self.inner else { error!("Not initialized"); return; }; if !inner.router.is_shutdown() { inner.mdns_handle.abort(); let _ = inner.router.shutdown().await; } // Force dropping the mdns advertising tasks. self.inner = None; } pub async fn set_status(&self, endpoint: &EndpointId, status: EndpointStatus) { let Some(ref inner) = self.inner else { error!("Not initialized"); return; }; inner.state.lock().await.set_status(endpoint, status); } }