"""NTCP2 listener, connector, and connection manager. Provides the server-side listener (NTCP2Listener) that accepts incoming NTCP2 connections, the client-side connector (NTCP2Connector) that initiates outbound connections, and a ConnectionManager for tracking active peer connections. All handshake messages are length-framed on the wire using a 2-byte big-endian length prefix during the Noise_XK handshake phase. """ import asyncio import struct from typing import Callable, Awaitable from i2p_transport.ntcp2_connection import NTCP2Connection from i2p_transport.ntcp2_handshake import NTCP2Handshake # --------------------------------------------------------------------------- # Handshake wire helpers — 2-byte big-endian length prefix # --------------------------------------------------------------------------- async def _send_handshake_msg(writer: asyncio.StreamWriter, msg: bytes) -> None: """Send a handshake message with a 2-byte big-endian length prefix.""" writer.write(struct.pack("!H", len(msg)) + msg) await writer.drain() async def _recv_handshake_msg(reader: asyncio.StreamReader) -> bytes: """Receive a handshake message (2-byte length prefix, then payload).""" length_bytes = await reader.readexactly(2) length = struct.unpack("!H", length_bytes)[0] return await reader.readexactly(length) # --------------------------------------------------------------------------- # ConnectionManager # --------------------------------------------------------------------------- class ConnectionManager: """Tracks active NTCP2 connections keyed by peer hash.""" def __init__(self) -> None: self._connections: dict[bytes, NTCP2Connection] = {} def add(self, peer_hash: bytes, connection: NTCP2Connection) -> None: """Store a connection for the given peer hash.""" self._connections[peer_hash] = connection def get(self, peer_hash: bytes) -> NTCP2Connection | None: """Retrieve a connection by peer hash, or None if not found.""" return self._connections.get(peer_hash) def remove(self, peer_hash: bytes) -> None: """Remove a connection by peer hash (no-op if not present).""" self._connections.pop(peer_hash, None) def active_count(self) -> int: """Return the number of tracked connections.""" return len(self._connections) def all_peer_hashes(self) -> list[bytes]: """Return a list of all tracked peer hashes.""" return list(self._connections.keys()) def close_all(self) -> None: """Close all connections synchronously (calls writer.close()).""" for conn in self._connections.values(): conn._writer.close() self._connections.clear() # --------------------------------------------------------------------------- # NTCP2Listener — accepts incoming connections # --------------------------------------------------------------------------- class NTCP2Listener: """Listens for incoming NTCP2 connections and performs the responder handshake. Args: host: Bind address (e.g. "0.0.0.0" or "127.0.0.1"). port: Bind port (use 0 to let the OS pick a free port). our_static: (private_key, public_key) X25519 keypair. on_connection: Optional async callback invoked with each new NTCP2Connection after a successful handshake. """ def __init__( self, host: str, port: int, our_static: tuple[bytes, bytes], on_connection: Callable[[NTCP2Connection], Awaitable[None]] | None = None, ) -> None: self._host = host self._port = port self._our_static = our_static self._on_connection = on_connection async def start(self) -> asyncio.Server: """Start listening and return the asyncio.Server object.""" server = await asyncio.start_server( self._handle_client, self._host, self._port ) return server async def _handle_client( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> None: """Run the Noise_XK handshake as responder, then notify callback.""" try: hs = NTCP2Handshake( our_static=self._our_static, peer_static_pub=None, initiator=False, ) # Responder: read msg1, send msg2, read msg3 msg1 = await _recv_handshake_msg(reader) msg2 = hs.process_message_1(msg1) await _send_handshake_msg(writer, msg2) msg3 = await _recv_handshake_msg(reader) hs.process_message_3(msg3) # Derive transport cipher states send_cipher, recv_cipher = hs.split() conn = NTCP2Connection( reader=reader, writer=writer, cipher_send=send_cipher, cipher_recv=recv_cipher, remote_hash=hs.remote_static_key() or b"", ) if self._on_connection is not None: await self._on_connection(conn) except Exception: writer.close() # --------------------------------------------------------------------------- # NTCP2Connector — initiates outbound connections # --------------------------------------------------------------------------- class NTCP2Connector: """Connects to a remote NTCP2 peer and performs the initiator handshake.""" async def connect( self, host: str, port: int, our_static: tuple[bytes, bytes], peer_static_pub: bytes, ) -> NTCP2Connection: """Open a TCP connection and perform the Noise_XK handshake as initiator. Args: host: Remote host address. port: Remote port. our_static: (private_key, public_key) X25519 keypair. peer_static_pub: Remote peer's static X25519 public key. Returns: An established NTCP2Connection ready for frame exchange. Raises: ConnectionRefusedError: If the TCP connection cannot be established. """ reader, writer = await asyncio.open_connection(host, port) try: hs = NTCP2Handshake( our_static=our_static, peer_static_pub=peer_static_pub, initiator=True, ) # Initiator: send msg1, read msg2, send msg3 msg1 = hs.create_message_1() await _send_handshake_msg(writer, msg1) msg2 = await _recv_handshake_msg(reader) msg3 = hs.process_message_2(msg2) await _send_handshake_msg(writer, msg3) # Derive transport cipher states send_cipher, recv_cipher = hs.split() return NTCP2Connection( reader=reader, writer=writer, cipher_send=send_cipher, cipher_recv=recv_cipher, remote_hash=peer_static_pub, ) except Exception: writer.close() raise