A Python port of the Invisible Internet Project (I2P)
at main 270 lines 10 kB view raw
1"""Tests for peer_connector — extract_ntcp2_address and ConnectionPool. 2 3TDD: tests written before implementation. 4""" 5 6import base64 7from unittest.mock import AsyncMock, MagicMock, patch 8 9import pytest 10 11from i2p_router.peer_connector import extract_ntcp2_address, ConnectionPool, PeerConnector 12from i2p_data.router import RouterInfo, RouterAddress 13 14 15def _make_router_info(addresses: list[RouterAddress]) -> MagicMock: 16 """Create a mock RouterInfo with the given addresses.""" 17 ri = MagicMock(spec=RouterInfo) 18 ri.addresses = addresses 19 return ri 20 21 22def _make_ntcp2_address(host: str, port: int, static_key_b64: str, iv_b64: str) -> RouterAddress: 23 """Create a RouterAddress with NTCP2 transport and s/i options.""" 24 return RouterAddress( 25 cost=10, 26 expiration=0, 27 transport="NTCP2", 28 options={ 29 "host": host, 30 "port": str(port), 31 "s": static_key_b64, 32 "i": iv_b64, 33 }, 34 ) 35 36 37# --------------------------------------------------------------------------- 38# extract_ntcp2_address tests 39# --------------------------------------------------------------------------- 40 41class TestExtractNtcp2Address: 42 """Tests for extract_ntcp2_address().""" 43 44 def test_extracts_ntcp2_address(self): 45 """extract_ntcp2_address returns (host, port, static_key, iv) for valid NTCP2.""" 46 static_key = b"\x01" * 32 47 iv = b"\x02" * 16 48 s_b64 = base64.b64encode(static_key).decode() 49 i_b64 = base64.b64encode(iv).decode() 50 addr = _make_ntcp2_address("192.168.1.1", 15555, s_b64, i_b64) 51 ri = _make_router_info([addr]) 52 53 result = extract_ntcp2_address(ri) 54 55 assert result is not None 56 host, port, sk, iv_out = result 57 assert host == "192.168.1.1" 58 assert port == 15555 59 assert sk == static_key 60 assert iv_out == iv 61 62 def test_returns_none_for_no_ntcp2(self): 63 """extract_ntcp2_address returns None when no NTCP2 address exists.""" 64 addr = RouterAddress(cost=5, expiration=0, transport="SSU2", options={"host": "1.2.3.4", "port": "1234"}) 65 ri = _make_router_info([addr]) 66 67 result = extract_ntcp2_address(ri) 68 assert result is None 69 70 def test_returns_none_for_empty_addresses(self): 71 """extract_ntcp2_address returns None for RouterInfo with no addresses.""" 72 ri = _make_router_info([]) 73 assert extract_ntcp2_address(ri) is None 74 75 def test_decodes_base64_s_and_i(self): 76 """The 's' and 'i' options are base64-decoded to raw bytes of correct length.""" 77 static_key = bytes(range(32)) 78 iv = bytes(range(16)) 79 s_b64 = base64.b64encode(static_key).decode() 80 i_b64 = base64.b64encode(iv).decode() 81 addr = _make_ntcp2_address("10.0.0.1", 9999, s_b64, i_b64) 82 ri = _make_router_info([addr]) 83 84 result = extract_ntcp2_address(ri) 85 assert result is not None 86 _, _, sk, iv_out = result 87 assert len(sk) == 32 88 assert sk == static_key 89 assert len(iv_out) == 16 90 assert iv_out == iv 91 92 def test_returns_none_when_s_option_missing(self): 93 """Returns None if NTCP2 address lacks the 's' option.""" 94 iv = b"\x02" * 16 95 addr = RouterAddress( 96 cost=10, expiration=0, transport="NTCP2", 97 options={"host": "1.2.3.4", "port": "1234", "i": base64.b64encode(iv).decode()}, 98 ) 99 ri = _make_router_info([addr]) 100 assert extract_ntcp2_address(ri) is None 101 102 def test_returns_none_when_i_option_missing(self): 103 """Returns None if NTCP2 address lacks the 'i' option.""" 104 sk = b"\x01" * 32 105 addr = RouterAddress( 106 cost=10, expiration=0, transport="NTCP2", 107 options={"host": "1.2.3.4", "port": "1234", "s": base64.b64encode(sk).decode()}, 108 ) 109 ri = _make_router_info([addr]) 110 assert extract_ntcp2_address(ri) is None 111 112 def test_skips_non_ntcp2_picks_ntcp2(self): 113 """When multiple addresses exist, the NTCP2 one is found.""" 114 ssu_addr = RouterAddress(cost=5, expiration=0, transport="SSU2", options={}) 115 static_key = b"\xaa" * 32 116 iv = b"\xbb" * 16 117 ntcp2_addr = _make_ntcp2_address("5.6.7.8", 7777, base64.b64encode(static_key).decode(), base64.b64encode(iv).decode()) 118 ri = _make_router_info([ssu_addr, ntcp2_addr]) 119 120 result = extract_ntcp2_address(ri) 121 assert result is not None 122 assert result[0] == "5.6.7.8" 123 assert result[1] == 7777 124 125 126# --------------------------------------------------------------------------- 127# ConnectionPool tests 128# --------------------------------------------------------------------------- 129 130class TestConnectionPool: 131 """Tests for ConnectionPool.""" 132 133 def test_add_and_get(self): 134 """add() stores a connection retrievable by get().""" 135 pool = ConnectionPool(max_connections=10) 136 peer = b"\x01" * 32 137 conn = MagicMock() 138 assert pool.add(peer, conn) is True 139 assert pool.get(peer) is conn 140 141 def test_get_returns_none_for_unknown(self): 142 """get() returns None for a peer not in the pool.""" 143 pool = ConnectionPool() 144 assert pool.get(b"\xff" * 32) is None 145 146 def test_remove_decreases_count(self): 147 """remove() removes the connection and decreases active_count.""" 148 pool = ConnectionPool() 149 peer = b"\x02" * 32 150 conn = MagicMock() 151 pool.add(peer, conn) 152 assert pool.active_count == 1 153 pool.remove(peer) 154 assert pool.active_count == 0 155 assert pool.get(peer) is None 156 157 def test_active_count_tracks_connections(self): 158 """active_count reflects the number of stored connections.""" 159 pool = ConnectionPool() 160 for i in range(5): 161 pool.add(bytes([i]) * 32, MagicMock()) 162 assert pool.active_count == 5 163 164 def test_is_connected_true_false(self): 165 """is_connected returns True for connected peers, False otherwise.""" 166 pool = ConnectionPool() 167 peer = b"\x03" * 32 168 assert pool.is_connected(peer) is False 169 pool.add(peer, MagicMock()) 170 assert pool.is_connected(peer) is True 171 172 def test_respects_max_connections(self): 173 """add() returns False and rejects when at capacity.""" 174 pool = ConnectionPool(max_connections=2) 175 pool.add(b"\x01" * 32, MagicMock()) 176 pool.add(b"\x02" * 32, MagicMock()) 177 result = pool.add(b"\x03" * 32, MagicMock()) 178 assert result is False 179 assert pool.active_count == 2 180 assert pool.is_connected(b"\x03" * 32) is False 181 182 def test_get_all_peer_hashes(self): 183 """get_all_peer_hashes() returns all connected peer hashes.""" 184 pool = ConnectionPool() 185 peers = [bytes([i]) * 32 for i in range(3)] 186 for p in peers: 187 pool.add(p, MagicMock()) 188 result = pool.get_all_peer_hashes() 189 assert sorted(result) == sorted(peers) 190 191 def test_remove_nonexistent_is_noop(self): 192 """remove() on an unknown peer does not raise.""" 193 pool = ConnectionPool() 194 pool.remove(b"\xff" * 32) # should not raise 195 assert pool.active_count == 0 196 197 198# --------------------------------------------------------------------------- 199# PeerConnector tests (mocked network) 200# --------------------------------------------------------------------------- 201 202class TestPeerConnector: 203 """Tests for PeerConnector.connect() with mocked NTCP2RealConnector.""" 204 205 @pytest.fixture 206 def connector(self): 207 static_priv = b"\x10" * 32 208 static_pub = b"\x20" * 32 209 iv = b"\x30" * 16 210 return PeerConnector( 211 our_static_keypair=(static_priv, static_pub), 212 our_iv=iv, 213 ) 214 215 @pytest.mark.asyncio 216 async def test_connect_returns_none_when_no_ntcp2_address(self, connector): 217 """connect() returns None if router_info has no NTCP2 address.""" 218 ri = _make_router_info([]) 219 result = await connector.connect(ri) 220 assert result is None 221 222 @pytest.mark.asyncio 223 async def test_connect_returns_connection_on_success(self, connector): 224 """connect() returns a connection when handshake succeeds.""" 225 static_key = b"\xaa" * 32 226 iv = b"\xbb" * 16 227 addr = _make_ntcp2_address( 228 "10.0.0.1", 15555, 229 base64.b64encode(static_key).decode(), 230 base64.b64encode(iv).decode(), 231 ) 232 ri = _make_router_info([addr]) 233 ri.to_bytes = MagicMock(return_value=b"\x00" * 100) 234 mock_identity = MagicMock() 235 mock_identity.to_bytes = MagicMock(return_value=b"\x00" * 387) 236 ri.identity = mock_identity 237 238 mock_conn = MagicMock() 239 mock_ntcp2_connector = AsyncMock() 240 mock_ntcp2_connector.connect = AsyncMock(return_value=mock_conn) 241 242 with patch("i2p_router.peer_connector.NTCP2RealConnector", return_value=mock_ntcp2_connector): 243 result = await connector.connect(ri) 244 245 assert result is mock_conn 246 mock_ntcp2_connector.connect.assert_called_once() 247 248 @pytest.mark.asyncio 249 async def test_connect_returns_none_on_error(self, connector): 250 """connect() returns None and logs when connection fails.""" 251 static_key = b"\xaa" * 32 252 iv = b"\xbb" * 16 253 addr = _make_ntcp2_address( 254 "10.0.0.1", 15555, 255 base64.b64encode(static_key).decode(), 256 base64.b64encode(iv).decode(), 257 ) 258 ri = _make_router_info([addr]) 259 ri.to_bytes = MagicMock(return_value=b"\x00" * 100) 260 mock_identity = MagicMock() 261 mock_identity.to_bytes = MagicMock(return_value=b"\x00" * 387) 262 ri.identity = mock_identity 263 264 mock_ntcp2_connector = AsyncMock() 265 mock_ntcp2_connector.connect = AsyncMock(side_effect=ConnectionRefusedError("refused")) 266 267 with patch("i2p_router.peer_connector.NTCP2RealConnector", return_value=mock_ntcp2_connector): 268 result = await connector.connect(ri) 269 270 assert result is None