A Python port of the Invisible Internet Project (I2P)
1"""Tests for peer_connector — extract_ntcp2_address and ConnectionPool.
2
3TDD: tests written before implementation.
4"""
5
6import base64
7from unittest.mock import AsyncMock, MagicMock, patch
8
9import pytest
10
11from i2p_router.peer_connector import extract_ntcp2_address, ConnectionPool, PeerConnector
12from i2p_data.router import RouterInfo, RouterAddress
13
14
15def _make_router_info(addresses: list[RouterAddress]) -> MagicMock:
16 """Create a mock RouterInfo with the given addresses."""
17 ri = MagicMock(spec=RouterInfo)
18 ri.addresses = addresses
19 return ri
20
21
22def _make_ntcp2_address(host: str, port: int, static_key_b64: str, iv_b64: str) -> RouterAddress:
23 """Create a RouterAddress with NTCP2 transport and s/i options."""
24 return RouterAddress(
25 cost=10,
26 expiration=0,
27 transport="NTCP2",
28 options={
29 "host": host,
30 "port": str(port),
31 "s": static_key_b64,
32 "i": iv_b64,
33 },
34 )
35
36
37# ---------------------------------------------------------------------------
38# extract_ntcp2_address tests
39# ---------------------------------------------------------------------------
40
41class TestExtractNtcp2Address:
42 """Tests for extract_ntcp2_address()."""
43
44 def test_extracts_ntcp2_address(self):
45 """extract_ntcp2_address returns (host, port, static_key, iv) for valid NTCP2."""
46 static_key = b"\x01" * 32
47 iv = b"\x02" * 16
48 s_b64 = base64.b64encode(static_key).decode()
49 i_b64 = base64.b64encode(iv).decode()
50 addr = _make_ntcp2_address("192.168.1.1", 15555, s_b64, i_b64)
51 ri = _make_router_info([addr])
52
53 result = extract_ntcp2_address(ri)
54
55 assert result is not None
56 host, port, sk, iv_out = result
57 assert host == "192.168.1.1"
58 assert port == 15555
59 assert sk == static_key
60 assert iv_out == iv
61
62 def test_returns_none_for_no_ntcp2(self):
63 """extract_ntcp2_address returns None when no NTCP2 address exists."""
64 addr = RouterAddress(cost=5, expiration=0, transport="SSU2", options={"host": "1.2.3.4", "port": "1234"})
65 ri = _make_router_info([addr])
66
67 result = extract_ntcp2_address(ri)
68 assert result is None
69
70 def test_returns_none_for_empty_addresses(self):
71 """extract_ntcp2_address returns None for RouterInfo with no addresses."""
72 ri = _make_router_info([])
73 assert extract_ntcp2_address(ri) is None
74
75 def test_decodes_base64_s_and_i(self):
76 """The 's' and 'i' options are base64-decoded to raw bytes of correct length."""
77 static_key = bytes(range(32))
78 iv = bytes(range(16))
79 s_b64 = base64.b64encode(static_key).decode()
80 i_b64 = base64.b64encode(iv).decode()
81 addr = _make_ntcp2_address("10.0.0.1", 9999, s_b64, i_b64)
82 ri = _make_router_info([addr])
83
84 result = extract_ntcp2_address(ri)
85 assert result is not None
86 _, _, sk, iv_out = result
87 assert len(sk) == 32
88 assert sk == static_key
89 assert len(iv_out) == 16
90 assert iv_out == iv
91
92 def test_returns_none_when_s_option_missing(self):
93 """Returns None if NTCP2 address lacks the 's' option."""
94 iv = b"\x02" * 16
95 addr = RouterAddress(
96 cost=10, expiration=0, transport="NTCP2",
97 options={"host": "1.2.3.4", "port": "1234", "i": base64.b64encode(iv).decode()},
98 )
99 ri = _make_router_info([addr])
100 assert extract_ntcp2_address(ri) is None
101
102 def test_returns_none_when_i_option_missing(self):
103 """Returns None if NTCP2 address lacks the 'i' option."""
104 sk = b"\x01" * 32
105 addr = RouterAddress(
106 cost=10, expiration=0, transport="NTCP2",
107 options={"host": "1.2.3.4", "port": "1234", "s": base64.b64encode(sk).decode()},
108 )
109 ri = _make_router_info([addr])
110 assert extract_ntcp2_address(ri) is None
111
112 def test_skips_non_ntcp2_picks_ntcp2(self):
113 """When multiple addresses exist, the NTCP2 one is found."""
114 ssu_addr = RouterAddress(cost=5, expiration=0, transport="SSU2", options={})
115 static_key = b"\xaa" * 32
116 iv = b"\xbb" * 16
117 ntcp2_addr = _make_ntcp2_address("5.6.7.8", 7777, base64.b64encode(static_key).decode(), base64.b64encode(iv).decode())
118 ri = _make_router_info([ssu_addr, ntcp2_addr])
119
120 result = extract_ntcp2_address(ri)
121 assert result is not None
122 assert result[0] == "5.6.7.8"
123 assert result[1] == 7777
124
125
126# ---------------------------------------------------------------------------
127# ConnectionPool tests
128# ---------------------------------------------------------------------------
129
130class TestConnectionPool:
131 """Tests for ConnectionPool."""
132
133 def test_add_and_get(self):
134 """add() stores a connection retrievable by get()."""
135 pool = ConnectionPool(max_connections=10)
136 peer = b"\x01" * 32
137 conn = MagicMock()
138 assert pool.add(peer, conn) is True
139 assert pool.get(peer) is conn
140
141 def test_get_returns_none_for_unknown(self):
142 """get() returns None for a peer not in the pool."""
143 pool = ConnectionPool()
144 assert pool.get(b"\xff" * 32) is None
145
146 def test_remove_decreases_count(self):
147 """remove() removes the connection and decreases active_count."""
148 pool = ConnectionPool()
149 peer = b"\x02" * 32
150 conn = MagicMock()
151 pool.add(peer, conn)
152 assert pool.active_count == 1
153 pool.remove(peer)
154 assert pool.active_count == 0
155 assert pool.get(peer) is None
156
157 def test_active_count_tracks_connections(self):
158 """active_count reflects the number of stored connections."""
159 pool = ConnectionPool()
160 for i in range(5):
161 pool.add(bytes([i]) * 32, MagicMock())
162 assert pool.active_count == 5
163
164 def test_is_connected_true_false(self):
165 """is_connected returns True for connected peers, False otherwise."""
166 pool = ConnectionPool()
167 peer = b"\x03" * 32
168 assert pool.is_connected(peer) is False
169 pool.add(peer, MagicMock())
170 assert pool.is_connected(peer) is True
171
172 def test_respects_max_connections(self):
173 """add() returns False and rejects when at capacity."""
174 pool = ConnectionPool(max_connections=2)
175 pool.add(b"\x01" * 32, MagicMock())
176 pool.add(b"\x02" * 32, MagicMock())
177 result = pool.add(b"\x03" * 32, MagicMock())
178 assert result is False
179 assert pool.active_count == 2
180 assert pool.is_connected(b"\x03" * 32) is False
181
182 def test_get_all_peer_hashes(self):
183 """get_all_peer_hashes() returns all connected peer hashes."""
184 pool = ConnectionPool()
185 peers = [bytes([i]) * 32 for i in range(3)]
186 for p in peers:
187 pool.add(p, MagicMock())
188 result = pool.get_all_peer_hashes()
189 assert sorted(result) == sorted(peers)
190
191 def test_remove_nonexistent_is_noop(self):
192 """remove() on an unknown peer does not raise."""
193 pool = ConnectionPool()
194 pool.remove(b"\xff" * 32) # should not raise
195 assert pool.active_count == 0
196
197
198# ---------------------------------------------------------------------------
199# PeerConnector tests (mocked network)
200# ---------------------------------------------------------------------------
201
202class TestPeerConnector:
203 """Tests for PeerConnector.connect() with mocked NTCP2RealConnector."""
204
205 @pytest.fixture
206 def connector(self):
207 static_priv = b"\x10" * 32
208 static_pub = b"\x20" * 32
209 iv = b"\x30" * 16
210 return PeerConnector(
211 our_static_keypair=(static_priv, static_pub),
212 our_iv=iv,
213 )
214
215 @pytest.mark.asyncio
216 async def test_connect_returns_none_when_no_ntcp2_address(self, connector):
217 """connect() returns None if router_info has no NTCP2 address."""
218 ri = _make_router_info([])
219 result = await connector.connect(ri)
220 assert result is None
221
222 @pytest.mark.asyncio
223 async def test_connect_returns_connection_on_success(self, connector):
224 """connect() returns a connection when handshake succeeds."""
225 static_key = b"\xaa" * 32
226 iv = b"\xbb" * 16
227 addr = _make_ntcp2_address(
228 "10.0.0.1", 15555,
229 base64.b64encode(static_key).decode(),
230 base64.b64encode(iv).decode(),
231 )
232 ri = _make_router_info([addr])
233 ri.to_bytes = MagicMock(return_value=b"\x00" * 100)
234 mock_identity = MagicMock()
235 mock_identity.to_bytes = MagicMock(return_value=b"\x00" * 387)
236 ri.identity = mock_identity
237
238 mock_conn = MagicMock()
239 mock_ntcp2_connector = AsyncMock()
240 mock_ntcp2_connector.connect = AsyncMock(return_value=mock_conn)
241
242 with patch("i2p_router.peer_connector.NTCP2RealConnector", return_value=mock_ntcp2_connector):
243 result = await connector.connect(ri)
244
245 assert result is mock_conn
246 mock_ntcp2_connector.connect.assert_called_once()
247
248 @pytest.mark.asyncio
249 async def test_connect_returns_none_on_error(self, connector):
250 """connect() returns None and logs when connection fails."""
251 static_key = b"\xaa" * 32
252 iv = b"\xbb" * 16
253 addr = _make_ntcp2_address(
254 "10.0.0.1", 15555,
255 base64.b64encode(static_key).decode(),
256 base64.b64encode(iv).decode(),
257 )
258 ri = _make_router_info([addr])
259 ri.to_bytes = MagicMock(return_value=b"\x00" * 100)
260 mock_identity = MagicMock()
261 mock_identity.to_bytes = MagicMock(return_value=b"\x00" * 387)
262 ri.identity = mock_identity
263
264 mock_ntcp2_connector = AsyncMock()
265 mock_ntcp2_connector.connect = AsyncMock(side_effect=ConnectionRefusedError("refused"))
266
267 with patch("i2p_router.peer_connector.NTCP2RealConnector", return_value=mock_ntcp2_connector):
268 result = await connector.connect(ri)
269
270 assert result is None