"""Tests for peer_connector — extract_ntcp2_address and ConnectionPool. TDD: tests written before implementation. """ import base64 from unittest.mock import AsyncMock, MagicMock, patch import pytest from i2p_router.peer_connector import extract_ntcp2_address, ConnectionPool, PeerConnector from i2p_data.router import RouterInfo, RouterAddress def _make_router_info(addresses: list[RouterAddress]) -> MagicMock: """Create a mock RouterInfo with the given addresses.""" ri = MagicMock(spec=RouterInfo) ri.addresses = addresses return ri def _make_ntcp2_address(host: str, port: int, static_key_b64: str, iv_b64: str) -> RouterAddress: """Create a RouterAddress with NTCP2 transport and s/i options.""" return RouterAddress( cost=10, expiration=0, transport="NTCP2", options={ "host": host, "port": str(port), "s": static_key_b64, "i": iv_b64, }, ) # --------------------------------------------------------------------------- # extract_ntcp2_address tests # --------------------------------------------------------------------------- class TestExtractNtcp2Address: """Tests for extract_ntcp2_address().""" def test_extracts_ntcp2_address(self): """extract_ntcp2_address returns (host, port, static_key, iv) for valid NTCP2.""" static_key = b"\x01" * 32 iv = b"\x02" * 16 s_b64 = base64.b64encode(static_key).decode() i_b64 = base64.b64encode(iv).decode() addr = _make_ntcp2_address("192.168.1.1", 15555, s_b64, i_b64) ri = _make_router_info([addr]) result = extract_ntcp2_address(ri) assert result is not None host, port, sk, iv_out = result assert host == "192.168.1.1" assert port == 15555 assert sk == static_key assert iv_out == iv def test_returns_none_for_no_ntcp2(self): """extract_ntcp2_address returns None when no NTCP2 address exists.""" addr = RouterAddress(cost=5, expiration=0, transport="SSU2", options={"host": "1.2.3.4", "port": "1234"}) ri = _make_router_info([addr]) result = extract_ntcp2_address(ri) assert result is None def test_returns_none_for_empty_addresses(self): """extract_ntcp2_address returns None for RouterInfo with no addresses.""" ri = _make_router_info([]) assert extract_ntcp2_address(ri) is None def test_decodes_base64_s_and_i(self): """The 's' and 'i' options are base64-decoded to raw bytes of correct length.""" static_key = bytes(range(32)) iv = bytes(range(16)) s_b64 = base64.b64encode(static_key).decode() i_b64 = base64.b64encode(iv).decode() addr = _make_ntcp2_address("10.0.0.1", 9999, s_b64, i_b64) ri = _make_router_info([addr]) result = extract_ntcp2_address(ri) assert result is not None _, _, sk, iv_out = result assert len(sk) == 32 assert sk == static_key assert len(iv_out) == 16 assert iv_out == iv def test_returns_none_when_s_option_missing(self): """Returns None if NTCP2 address lacks the 's' option.""" iv = b"\x02" * 16 addr = RouterAddress( cost=10, expiration=0, transport="NTCP2", options={"host": "1.2.3.4", "port": "1234", "i": base64.b64encode(iv).decode()}, ) ri = _make_router_info([addr]) assert extract_ntcp2_address(ri) is None def test_returns_none_when_i_option_missing(self): """Returns None if NTCP2 address lacks the 'i' option.""" sk = b"\x01" * 32 addr = RouterAddress( cost=10, expiration=0, transport="NTCP2", options={"host": "1.2.3.4", "port": "1234", "s": base64.b64encode(sk).decode()}, ) ri = _make_router_info([addr]) assert extract_ntcp2_address(ri) is None def test_skips_non_ntcp2_picks_ntcp2(self): """When multiple addresses exist, the NTCP2 one is found.""" ssu_addr = RouterAddress(cost=5, expiration=0, transport="SSU2", options={}) static_key = b"\xaa" * 32 iv = b"\xbb" * 16 ntcp2_addr = _make_ntcp2_address("5.6.7.8", 7777, base64.b64encode(static_key).decode(), base64.b64encode(iv).decode()) ri = _make_router_info([ssu_addr, ntcp2_addr]) result = extract_ntcp2_address(ri) assert result is not None assert result[0] == "5.6.7.8" assert result[1] == 7777 # --------------------------------------------------------------------------- # ConnectionPool tests # --------------------------------------------------------------------------- class TestConnectionPool: """Tests for ConnectionPool.""" def test_add_and_get(self): """add() stores a connection retrievable by get().""" pool = ConnectionPool(max_connections=10) peer = b"\x01" * 32 conn = MagicMock() assert pool.add(peer, conn) is True assert pool.get(peer) is conn def test_get_returns_none_for_unknown(self): """get() returns None for a peer not in the pool.""" pool = ConnectionPool() assert pool.get(b"\xff" * 32) is None def test_remove_decreases_count(self): """remove() removes the connection and decreases active_count.""" pool = ConnectionPool() peer = b"\x02" * 32 conn = MagicMock() pool.add(peer, conn) assert pool.active_count == 1 pool.remove(peer) assert pool.active_count == 0 assert pool.get(peer) is None def test_active_count_tracks_connections(self): """active_count reflects the number of stored connections.""" pool = ConnectionPool() for i in range(5): pool.add(bytes([i]) * 32, MagicMock()) assert pool.active_count == 5 def test_is_connected_true_false(self): """is_connected returns True for connected peers, False otherwise.""" pool = ConnectionPool() peer = b"\x03" * 32 assert pool.is_connected(peer) is False pool.add(peer, MagicMock()) assert pool.is_connected(peer) is True def test_respects_max_connections(self): """add() returns False and rejects when at capacity.""" pool = ConnectionPool(max_connections=2) pool.add(b"\x01" * 32, MagicMock()) pool.add(b"\x02" * 32, MagicMock()) result = pool.add(b"\x03" * 32, MagicMock()) assert result is False assert pool.active_count == 2 assert pool.is_connected(b"\x03" * 32) is False def test_get_all_peer_hashes(self): """get_all_peer_hashes() returns all connected peer hashes.""" pool = ConnectionPool() peers = [bytes([i]) * 32 for i in range(3)] for p in peers: pool.add(p, MagicMock()) result = pool.get_all_peer_hashes() assert sorted(result) == sorted(peers) def test_remove_nonexistent_is_noop(self): """remove() on an unknown peer does not raise.""" pool = ConnectionPool() pool.remove(b"\xff" * 32) # should not raise assert pool.active_count == 0 # --------------------------------------------------------------------------- # PeerConnector tests (mocked network) # --------------------------------------------------------------------------- class TestPeerConnector: """Tests for PeerConnector.connect() with mocked NTCP2RealConnector.""" @pytest.fixture def connector(self): static_priv = b"\x10" * 32 static_pub = b"\x20" * 32 iv = b"\x30" * 16 return PeerConnector( our_static_keypair=(static_priv, static_pub), our_iv=iv, ) @pytest.mark.asyncio async def test_connect_returns_none_when_no_ntcp2_address(self, connector): """connect() returns None if router_info has no NTCP2 address.""" ri = _make_router_info([]) result = await connector.connect(ri) assert result is None @pytest.mark.asyncio async def test_connect_returns_connection_on_success(self, connector): """connect() returns a connection when handshake succeeds.""" static_key = b"\xaa" * 32 iv = b"\xbb" * 16 addr = _make_ntcp2_address( "10.0.0.1", 15555, base64.b64encode(static_key).decode(), base64.b64encode(iv).decode(), ) ri = _make_router_info([addr]) ri.to_bytes = MagicMock(return_value=b"\x00" * 100) mock_identity = MagicMock() mock_identity.to_bytes = MagicMock(return_value=b"\x00" * 387) ri.identity = mock_identity mock_conn = MagicMock() mock_ntcp2_connector = AsyncMock() mock_ntcp2_connector.connect = AsyncMock(return_value=mock_conn) with patch("i2p_router.peer_connector.NTCP2RealConnector", return_value=mock_ntcp2_connector): result = await connector.connect(ri) assert result is mock_conn mock_ntcp2_connector.connect.assert_called_once() @pytest.mark.asyncio async def test_connect_returns_none_on_error(self, connector): """connect() returns None and logs when connection fails.""" static_key = b"\xaa" * 32 iv = b"\xbb" * 16 addr = _make_ntcp2_address( "10.0.0.1", 15555, base64.b64encode(static_key).decode(), base64.b64encode(iv).decode(), ) ri = _make_router_info([addr]) ri.to_bytes = MagicMock(return_value=b"\x00" * 100) mock_identity = MagicMock() mock_identity.to_bytes = MagicMock(return_value=b"\x00" * 387) ri.identity = mock_identity mock_ntcp2_connector = AsyncMock() mock_ntcp2_connector.connect = AsyncMock(side_effect=ConnectionRefusedError("refused")) with patch("i2p_router.peer_connector.NTCP2RealConnector", return_value=mock_ntcp2_connector): result = await connector.connect(ri) assert result is None