"""Tests for NTCP2 listener, connector, and connection manager.""" import asyncio import struct import unittest from i2p_crypto.x25519 import X25519DH from i2p_transport.ntcp2 import NTCP2Frame, FrameType from i2p_transport.ntcp2_connection import NTCP2Connection from i2p_transport.ntcp2_server import ( ConnectionManager, NTCP2Listener, NTCP2Connector, ) # --------------------------------------------------------------------------- # Mock connection for ConnectionManager tests (no asyncio needed) # --------------------------------------------------------------------------- class MockWriter: """Minimal mock for writer.close() / is_closing().""" def __init__(self): self.closed = False def close(self): self.closed = True def is_closing(self): return self.closed def _make_mock_connection(remote_hash: bytes = b"") -> NTCP2Connection: """Create an NTCP2Connection with mock streams (no real I/O).""" from i2p_crypto.noise import CipherState cs = CipherState(b"\x00" * 32) conn = NTCP2Connection( reader=None, writer=MockWriter(), cipher_send=cs, cipher_recv=CipherState(b"\x01" * 32), remote_hash=remote_hash, ) return conn # =========================================================================== # ConnectionManager tests — synchronous # =========================================================================== class TestConnectionManager(unittest.TestCase): """Unit tests for ConnectionManager.""" def test_add_and_get(self): mgr = ConnectionManager() peer = b"\xaa" * 32 conn = _make_mock_connection(peer) mgr.add(peer, conn) self.assertIs(mgr.get(peer), conn) def test_get_unknown_returns_none(self): mgr = ConnectionManager() self.assertIsNone(mgr.get(b"\xbb" * 32)) def test_remove(self): mgr = ConnectionManager() peer = b"\xcc" * 32 conn = _make_mock_connection(peer) mgr.add(peer, conn) mgr.remove(peer) self.assertIsNone(mgr.get(peer)) def test_remove_nonexistent_is_noop(self): mgr = ConnectionManager() mgr.remove(b"\xdd" * 32) # should not raise def test_active_count(self): mgr = ConnectionManager() self.assertEqual(mgr.active_count(), 0) mgr.add(b"\x01" * 32, _make_mock_connection()) mgr.add(b"\x02" * 32, _make_mock_connection()) self.assertEqual(mgr.active_count(), 2) mgr.remove(b"\x01" * 32) self.assertEqual(mgr.active_count(), 1) def test_all_peer_hashes(self): mgr = ConnectionManager() h1 = b"\x01" * 32 h2 = b"\x02" * 32 mgr.add(h1, _make_mock_connection()) mgr.add(h2, _make_mock_connection()) hashes = mgr.all_peer_hashes() self.assertEqual(len(hashes), 2) self.assertIn(h1, hashes) self.assertIn(h2, hashes) def test_close_all(self): mgr = ConnectionManager() conns = [] for i in range(3): c = _make_mock_connection() mgr.add(bytes([i]) * 32, c) conns.append(c) mgr.close_all() # All writers should be closed for c in conns: self.assertTrue(c._writer.closed) # Manager should be empty self.assertEqual(mgr.active_count(), 0) # =========================================================================== # NTCP2Listener + NTCP2Connector integration tests — asyncio # =========================================================================== class TestListenerConnectorIntegration(unittest.TestCase): """Integration tests that run a real TCP listener and connector.""" def test_listener_connector_handshake(self): """Start listener, connect with connector, verify handshake completes.""" asyncio.run(self._test_handshake()) async def _test_handshake(self): listener_static = X25519DH.generate_keypair() connector_static = X25519DH.generate_keypair() connections = [] async def on_conn(conn): connections.append(conn) listener = NTCP2Listener( "127.0.0.1", 0, listener_static, on_connection=on_conn ) server = await listener.start() port = server.sockets[0].getsockname()[1] connector = NTCP2Connector() conn = await connector.connect( "127.0.0.1", port, connector_static, listener_static[1] ) # Give the listener a moment to finish handling await asyncio.sleep(0.2) # Handshake completed — connector got a connection self.assertIsInstance(conn, NTCP2Connection) self.assertTrue(conn.is_alive()) # Listener side received the connection via callback self.assertEqual(len(connections), 1) self.assertIsInstance(connections[0], NTCP2Connection) self.assertTrue(connections[0].is_alive()) # Close connections before server (Python 3.12+ wait_closed blocks otherwise) conn._writer.close() connections[0]._writer.close() server.close() await server.wait_closed() def test_frame_exchange_after_handshake(self): """After handshake, send a frame from connector and receive on listener side.""" asyncio.run(self._test_frame_exchange()) async def _test_frame_exchange(self): listener_static = X25519DH.generate_keypair() connector_static = X25519DH.generate_keypair() connections = [] async def on_conn(conn): connections.append(conn) listener = NTCP2Listener( "127.0.0.1", 0, listener_static, on_connection=on_conn ) server = await listener.start() port = server.sockets[0].getsockname()[1] connector = NTCP2Connector() client_conn = await connector.connect( "127.0.0.1", port, connector_static, listener_static[1] ) await asyncio.sleep(0.2) server_conn = connections[0] # Send a frame from client to server test_payload = b"hello from connector" frame = NTCP2Frame(FrameType.I2NP, test_payload) await client_conn.send_frame(frame) # Server receives it received = await server_conn.recv_frame() self.assertEqual(received.frame_type, FrameType.I2NP) self.assertEqual(received.payload, test_payload) # Send a frame from server to client reply_payload = b"hello from listener" reply_frame = NTCP2Frame(FrameType.I2NP, reply_payload) await server_conn.send_frame(reply_frame) received_reply = await client_conn.recv_frame() self.assertEqual(received_reply.frame_type, FrameType.I2NP) self.assertEqual(received_reply.payload, reply_payload) # Close connections before server (Python 3.12+ wait_closed blocks otherwise) client_conn._writer.close() server_conn._writer.close() server.close() await server.wait_closed() def test_multiple_connections(self): """Listener accepts multiple concurrent connections.""" asyncio.run(self._test_multiple_connections()) async def _test_multiple_connections(self): listener_static = X25519DH.generate_keypair() connections = [] async def on_conn(conn): connections.append(conn) listener = NTCP2Listener( "127.0.0.1", 0, listener_static, on_connection=on_conn ) server = await listener.start() port = server.sockets[0].getsockname()[1] # Connect three clients client_conns = [] for _ in range(3): cs = X25519DH.generate_keypair() connector = NTCP2Connector() conn = await connector.connect( "127.0.0.1", port, cs, listener_static[1] ) client_conns.append(conn) await asyncio.sleep(0.3) self.assertEqual(len(connections), 3) for c in client_conns: self.assertTrue(c.is_alive()) # Close connections before server (Python 3.12+ wait_closed blocks otherwise) for c in client_conns: c._writer.close() for c in connections: c._writer.close() server.close() await server.wait_closed() def test_connector_without_listener_raises(self): """Connecting to a port with no listener should raise.""" asyncio.run(self._test_connector_no_listener()) async def _test_connector_no_listener(self): connector = NTCP2Connector() cs = X25519DH.generate_keypair() with self.assertRaises(ConnectionRefusedError): await connector.connect( "127.0.0.1", 19999, cs, b"\x00" * 32 ) if __name__ == "__main__": unittest.main()