"""Multi-transport coordinator. Ported from net.i2p.router.transport.TransportManager. Manages NTCP2 and SSU2 transport instances, selects the best transport for each outbound message via competitive bidding, and publishes addresses in RouterInfo. """ import asyncio import logging from i2p_transport.transport_base import ( Transport, TransportBid, TransportStyle, ReachabilityStatus, _reachability_rank, ) logger = logging.getLogger(__name__) class TransportManager: """Coordinates multiple transports for message delivery. Each registered transport bids on outbound messages. The manager picks the lowest bid and delegates the send to that transport. """ def __init__(self, banlist=None, router_hash: bytes | None = None) -> None: self._transports: dict[TransportStyle, Transport] = {} self._running = False self._banlist = banlist self._router_hash = router_hash self._failed_counts: dict[bytes, int] = {} # ------------------------------------------------------------------ # Registration # ------------------------------------------------------------------ def register(self, transport: Transport) -> None: """Register a transport instance (keyed by style).""" self._transports[transport.style] = transport def get_transport(self, style: TransportStyle) -> Transport | None: """Get a specific transport by style.""" return self._transports.get(style) # ------------------------------------------------------------------ # Lifecycle # ------------------------------------------------------------------ async def start_all(self) -> None: """Start all registered transports concurrently.""" if not self._transports: return await asyncio.gather(*(t.start() for t in self._transports.values())) self._running = True logger.info("TransportManager started %d transports", len(self._transports)) async def stop_all(self) -> None: """Stop all registered transports concurrently.""" await asyncio.gather(*(t.stop() for t in self._transports.values())) self._running = False logger.info("TransportManager stopped") # ------------------------------------------------------------------ # Sending # ------------------------------------------------------------------ async def send(self, peer_hash: bytes, data: bytes) -> bool: """Send data to peer using the best available transport. Collects bids from all transports, selects the lowest bid, and sends via that transport. Returns True on success. """ # Reject self-send if self._router_hash and peer_hash == self._router_hash: logger.debug("Rejecting self-send to %s", peer_hash.hex()[:16]) return False # Reject banlisted peers if self._banlist and self._banlist.is_banlisted(peer_hash): logger.debug("Rejecting send to banlisted peer %s", peer_hash.hex()[:16]) return False # Reject peers with repeated failures if self._failed_counts.get(peer_hash, 0) > 1: logger.debug("Rejecting send to repeatedly-failed peer %s", peer_hash.hex()[:16]) return False best = await self.get_best_transport(peer_hash) if best is None: logger.debug("No transport available for peer %s", peer_hash.hex()[:16]) return False result = await best.send(peer_hash, data) if result: self._failed_counts.pop(peer_hash, None) else: count = self._failed_counts.get(peer_hash, 0) + 1 self._failed_counts[peer_hash] = count return result def clear_failed_count(self, peer_hash: bytes) -> None: """Clear failure tracking for a peer.""" self._failed_counts.pop(peer_hash, None) async def get_best_transport(self, peer_hash: bytes) -> Transport | None: """Select the best transport for a peer via competitive bidding. Returns None if no transport is willing/able to send. """ if not self._transports: return None bids: list[TransportBid] = [] for transport in self._transports.values(): try: bid = await transport.bid(peer_hash) if bid.latency_ms != TransportBid.WILL_NOT_SEND: bids.append(bid) except Exception: logger.debug( "Transport %s failed to bid for peer %s", transport.style.value, peer_hash.hex()[:16], exc_info=True, ) if not bids: return None best_bid = min(bids) return best_bid.transport # ------------------------------------------------------------------ # Address publication # ------------------------------------------------------------------ def get_addresses(self) -> list[dict]: """Get published addresses from all transports.""" addresses: list[dict] = [] for transport in self._transports.values(): addr = transport.current_address if addr is not None: addresses.append(addr) return addresses # ------------------------------------------------------------------ # Properties # ------------------------------------------------------------------ @property def reachability(self) -> ReachabilityStatus: """Overall reachability — best (lowest rank) of all transports.""" if not self._transports: return ReachabilityStatus.UNKNOWN best = min( (t.reachability for t in self._transports.values()), key=_reachability_rank, ) return best @property def is_running(self) -> bool: return self._running @property def transport_count(self) -> int: return len(self._transports)