"""SSU2 asyncio UDP server. Ported from: - net.i2p.router.transport.udp.UDPTransport - net.i2p.router.transport.udp.EstablishmentManager - net.i2p.router.transport.udp.PeerTestManager - net.i2p.router.transport.udp.IntroductionManager Provides: - SSU2Transport: asyncio DatagramProtocol, implements Transport interface - EstablishmentManager: dispatches handshake packets to correct state machine - PeerStateMap: tracks established SSU2 connections by peer/address/conn_id - PeerTestManager: three-party NAT detection (Alice/Bob/Charlie) - RelayManager: introduction/relay for firewalled peers """ from __future__ import annotations import asyncio import enum import logging import os import struct import time from dataclasses import dataclass, field from i2p_crypto.x25519 import X25519DH from i2p_transport.transport_base import ( Transport, TransportBid, TransportStyle, ReachabilityStatus, ) from i2p_transport.ssu2_handshake import ( HandshakeKeys, TokenManager, OutboundHandshake, InboundHandshake, LONG_HEADER_SIZE, SHORT_HEADER_SIZE, PKT_TOKEN_REQUEST, PKT_SESSION_REQUEST, PKT_DATA, ) from i2p_transport.ssu2_connection import SSU2Connection logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Packet classification # --------------------------------------------------------------------------- class PacketClass(enum.Enum): HANDSHAKE = "handshake" DATA = "data" INVALID = "invalid" def classify_packet(packet: bytes) -> PacketClass: """Classify an incoming SSU2 packet. - Packets shorter than SHORT_HEADER_SIZE are invalid. - Packets >= LONG_HEADER_SIZE with handshake-type bytes are HANDSHAKE. - Otherwise DATA. """ if len(packet) < SHORT_HEADER_SIZE: return PacketClass.INVALID if len(packet) >= LONG_HEADER_SIZE: # Check type byte at offset 12 (after dest_conn_id(8) + pkt_num(4)) pkt_type = packet[12] if pkt_type in (PKT_TOKEN_REQUEST, PKT_SESSION_REQUEST, 1, 9): return PacketClass.HANDSHAKE return PacketClass.DATA # --------------------------------------------------------------------------- # PeerStateMap — tracks established connections # --------------------------------------------------------------------------- class PeerStateMap: """Tracks established SSU2 connections by peer hash, address, and conn ID.""" def __init__(self) -> None: self._by_peer: dict[bytes, SSU2Connection] = {} self._by_address: dict[tuple[str, int], SSU2Connection] = {} self._by_conn_id: dict[int, SSU2Connection] = {} def add(self, peer_hash: bytes, conn: SSU2Connection, address: tuple[str, int]) -> None: self._by_peer[peer_hash] = conn self._by_address[address] = conn self._by_conn_id[conn.src_conn_id] = conn def get_by_peer(self, peer_hash: bytes) -> SSU2Connection | None: return self._by_peer.get(peer_hash) def get_by_address(self, address: tuple[str, int]) -> SSU2Connection | None: return self._by_address.get(address) def get_by_conn_id(self, conn_id: int) -> SSU2Connection | None: return self._by_conn_id.get(conn_id) def remove(self, peer_hash: bytes) -> None: conn = self._by_peer.pop(peer_hash, None) if conn is not None: self._by_address.pop(conn.remote_address, None) self._by_conn_id.pop(conn.src_conn_id, None) @property def active_count(self) -> int: return len(self._by_peer) def all_peers(self) -> list[bytes]: return list(self._by_peer.keys()) # --------------------------------------------------------------------------- # EstablishmentManager — dispatches handshake packets # --------------------------------------------------------------------------- class EstablishmentManager: """Manages pending SSU2 handshakes. Creates InboundHandshake (responder) or OutboundHandshake (initiator) instances and tracks them by connection ID until they complete. """ def __init__(self, local_static_key: bytes, local_intro_key: bytes, token_manager: TokenManager) -> None: self._local_static = local_static_key self._local_intro_key = local_intro_key self._token_manager = token_manager self._pending: dict[int, InboundHandshake | OutboundHandshake] = {} def create_inbound_handshake(self) -> InboundHandshake: return InboundHandshake( local_static_key=self._local_static, local_intro_key=self._local_intro_key, token_manager=self._token_manager, ) def create_outbound_handshake(self, remote_static_key: bytes, remote_intro_key: bytes, token: int | None = None) -> OutboundHandshake: return OutboundHandshake( local_static_key=self._local_static, remote_static_key=remote_static_key, remote_intro_key=remote_intro_key, token=token, ) def add_pending(self, conn_id: int, hs: InboundHandshake | OutboundHandshake) -> None: self._pending[conn_id] = hs def get_pending(self, conn_id: int) -> InboundHandshake | OutboundHandshake | None: return self._pending.get(conn_id) def remove_pending(self, conn_id: int) -> None: self._pending.pop(conn_id, None) @property def pending_count(self) -> int: return len(self._pending) # --------------------------------------------------------------------------- # Peer test protocol — three-party NAT detection # --------------------------------------------------------------------------- class PeerTestRole(enum.Enum): ALICE = "alice" # Initiator: wants to know if reachable BOB = "bob" # Relay: forwards test from Alice to Charlie CHARLIE = "charlie" # Tester: sends probe to Alice class PeerTestManager: """Manages SSU2 peer tests for NAT detection. Protocol: 1. Alice asks Bob to test her reachability 2. Bob relays request to Charlie 3. Charlie sends probe directly to Alice 4. Alice reports result back through Bob """ def __init__(self) -> None: self._pending_tests: dict[int, dict] = {} def create_test_request(self) -> tuple[int, bytes]: """Create a peer test request (Alice role). Returns (nonce, serialized_message). """ nonce = int.from_bytes(os.urandom(4), "big") | 1 # Ensure non-zero msg = struct.pack("!IB", nonce, PeerTestRole.ALICE.value.encode()[0]) self._pending_tests[nonce] = { "role": PeerTestRole.ALICE, "created_at": time.monotonic(), "nonce": nonce, } return nonce, msg def get_pending_test(self, nonce: int) -> dict | None: return self._pending_tests.get(nonce) def process_test_response(self, nonce: int, result_code: int, ip: bytes, port: int) -> dict | None: """Process a peer test response. result_code 0 = reachable, non-zero = unreachable or error. Returns result dict or None if nonce unknown. """ pending = self._pending_tests.pop(nonce, None) if pending is None: return None return { "nonce": nonce, "reachable": result_code == 0, "result_code": result_code, "ip": ip, "port": port, } def create_relay_to_charlie(self, nonce: int, alice_ip: bytes, alice_port: int) -> bytes: """Create a relay message from Bob to Charlie. Bob received a test request from Alice and relays it to Charlie. """ msg = struct.pack("!I", nonce) msg += struct.pack("!H", len(alice_ip)) + alice_ip msg += struct.pack("!H", alice_port) return msg def cleanup_stale(self, max_age_seconds: float = 30.0) -> int: """Remove stale pending tests.""" now = time.monotonic() stale = [n for n, t in self._pending_tests.items() if (now - t["created_at"]) >= max_age_seconds] for n in stale: del self._pending_tests[n] return len(stale) # --------------------------------------------------------------------------- # Introduction/relay protocol # --------------------------------------------------------------------------- class RelayManager: """Manages relay tags and introduction for firewalled peers. Firewalled peers cannot receive direct connections. They register a relay tag with an introducer. When someone wants to connect, they send a RelayRequest to the introducer, who forwards it to the target via the existing session. """ def __init__(self) -> None: self._relay_tags: dict[int, bytes] = {} # tag -> peer_hash self._next_tag = 1 self._pending_relays: dict[int, dict] = {} # nonce -> request info def assign_relay_tag(self, peer_hash: bytes) -> int: """Assign a relay tag to a peer.""" tag = self._next_tag self._next_tag += 1 self._relay_tags[tag] = peer_hash return tag def get_peer_for_tag(self, tag: int) -> bytes | None: return self._relay_tags.get(tag) def remove_relay_tag(self, tag: int) -> None: self._relay_tags.pop(tag, None) def create_relay_request(self, relay_tag: int, target_hash: bytes) -> tuple[int, bytes]: """Create a relay request to send to an introducer. Returns (nonce, serialized_message). """ nonce = int.from_bytes(os.urandom(4), "big") | 1 msg = struct.pack("!II", nonce, relay_tag) + target_hash self._pending_relays[nonce] = { "relay_tag": relay_tag, "target_hash": target_hash, "created_at": time.monotonic(), } return nonce, msg def process_relay_intro(self, nonce: int, requester_ip: bytes, requester_port: int, target_hash: bytes) -> dict: """Process a relay intro as the introducer. The introducer forwards the connection request to the target and returns info for the response. """ return { "nonce": nonce, "requester_ip": requester_ip, "requester_port": requester_port, "target_hash": target_hash, "action": "forward_to_target", } def process_relay_response(self, nonce: int, result_code: int, target_ip: bytes, target_port: int) -> dict | None: """Process a relay response. result_code 0 = success (target accepted the introduction). """ pending = self._pending_relays.pop(nonce, None) if pending is None: return None return { "nonce": nonce, "success": result_code == 0, "result_code": result_code, "target_ip": target_ip, "target_port": target_port, } # --------------------------------------------------------------------------- # SSU2 DatagramProtocol # --------------------------------------------------------------------------- class SSU2DatagramProtocol(asyncio.DatagramProtocol): """asyncio DatagramProtocol for SSU2 UDP packets. Receives datagrams, classifies them (handshake vs data), and dispatches to the appropriate handler. """ def __init__(self, transport_ref: SSU2Transport) -> None: self._transport_ref = transport_ref self._udp_transport: asyncio.DatagramTransport | None = None def connection_made(self, transport: asyncio.BaseTransport) -> None: self._udp_transport = transport # type: ignore[assignment] def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None: pkt_class = classify_packet(data) if pkt_class == PacketClass.INVALID: logger.debug("Dropping invalid packet from %s (too small)", addr) return if pkt_class == PacketClass.HANDSHAKE: self._transport_ref._handle_handshake_packet(data, addr) else: self._transport_ref._handle_data_packet(data, addr) def error_received(self, exc: Exception) -> None: logger.warning("UDP error: %s", exc) def connection_lost(self, exc: Exception | None) -> None: logger.info("UDP transport closed") def send_to(self, data: bytes, addr: tuple[str, int]) -> None: if self._udp_transport is not None: self._udp_transport.sendto(data, addr) # --------------------------------------------------------------------------- # SSU2Transport — main transport implementation # --------------------------------------------------------------------------- class SSU2Transport(Transport): """SSU2 UDP transport implementing the Transport interface. Manages: - UDP socket via asyncio DatagramProtocol - Peer state map for established connections - EstablishmentManager for handshake dispatch - PeerTestManager for NAT detection - RelayManager for firewalled peer introductions """ def __init__(self, host: str, port: int, static_key: bytes, intro_key: bytes) -> None: self._host = host self._port = port self._static_key = static_key self._static_pub = X25519DH.public_from_private(static_key) self._intro_key = intro_key self._token_manager = TokenManager() self._establishment = EstablishmentManager( local_static_key=static_key, local_intro_key=intro_key, token_manager=self._token_manager, ) self._peers = PeerStateMap() self._peer_test = PeerTestManager() self._relay = RelayManager() self._protocol: SSU2DatagramProtocol | None = None self._udp_transport: asyncio.DatagramTransport | None = None self._running = False self._actual_port: int = 0 self._reachability = ReachabilityStatus.UNKNOWN # ------------------------------------------------------------------ # Transport interface # ------------------------------------------------------------------ @property def style(self) -> TransportStyle: return TransportStyle.SSU2 async def start(self) -> None: loop = asyncio.get_running_loop() transport, protocol = await loop.create_datagram_endpoint( lambda: SSU2DatagramProtocol(self), local_addr=(self._host, self._port), ) self._udp_transport = transport self._protocol = protocol # Resolve actual port (if 0 was passed) sock = transport.get_extra_info("socket") if sock is not None: self._actual_port = sock.getsockname()[1] else: self._actual_port = self._port self._running = True logger.info("SSU2 transport started on %s:%d", self._host, self._actual_port) async def stop(self) -> None: if self._udp_transport is not None: self._udp_transport.close() self._udp_transport = None self._protocol = None self._running = False logger.info("SSU2 transport stopped") @property def is_running(self) -> bool: return self._running async def bid(self, peer_hash: bytes) -> TransportBid: conn = self._peers.get_by_peer(peer_hash) if conn is not None and conn.is_established: return TransportBid(latency_ms=50, transport=self, preference=1) return TransportBid( latency_ms=TransportBid.WILL_NOT_SEND, transport=self, preference=100, ) async def send(self, peer_hash: bytes, data: bytes) -> bool: conn = self._peers.get_by_peer(peer_hash) if conn is None or not conn.is_established: return False from i2p_transport.ssu2_payload import I2NPBlock blocks = [I2NPBlock(i2np_data=data)] packet = conn.encrypt_data_packet(blocks) if self._protocol is not None: self._protocol.send_to(packet, conn.remote_address) return True return False @property def reachability(self) -> ReachabilityStatus: return self._reachability @property def current_address(self) -> dict | None: if not self._running: return None return { "style": "SSU2", "host": self._host, "port": self._actual_port, "intro_key": self._intro_key.hex(), } # ------------------------------------------------------------------ # Packet handlers # ------------------------------------------------------------------ def _handle_handshake_packet(self, data: bytes, addr: tuple[str, int]) -> None: """Dispatch a handshake packet to the EstablishmentManager.""" logger.debug("Handshake packet from %s (%d bytes)", addr, len(data)) # Parse dest_conn_id to find pending handshake dest_conn_id = struct.unpack("!Q", data[:8])[0] hs = self._establishment.get_pending(dest_conn_id) if hs is None: # New inbound handshake hs = self._establishment.create_inbound_handshake() self._establishment.add_pending(hs._src_conn_id, hs) # Further processing would continue the Noise_XK state machine def _handle_data_packet(self, data: bytes, addr: tuple[str, int]) -> None: """Dispatch a data packet to the appropriate SSU2Connection.""" conn = self._peers.get_by_address(addr) if conn is None: logger.debug("Data packet from unknown address %s, dropping", addr) return try: pkt_num, blocks = conn.decrypt_data_packet(data) logger.debug("Received data packet #%d with %d blocks from %s", pkt_num, len(blocks), addr) except Exception: logger.debug("Failed to decrypt data packet from %s", addr, exc_info=True)