A Python port of the Invisible Internet Project (I2P)
1"""Tests for SSU2 data-phase connection state (PeerState2).
2
3TDD: tests written before implementation.
4"""
5
6import os
7import time
8
9import pytest
10
11from i2p_transport.ssu2_handshake import HandshakeKeys
12from i2p_transport.ssu2_connection import SSU2Connection, SentPacket
13from i2p_transport.ssu2_payload import (
14 AckBlock, PaddingBlock, TerminationBlock,
15 PathChallengeBlock, PathResponseBlock, parse_payload,
16)
17
18
19def _make_keys() -> HandshakeKeys:
20 """Create random handshake keys."""
21 return HandshakeKeys(
22 send_cipher_key=os.urandom(32),
23 recv_cipher_key=os.urandom(32),
24 send_header_key=os.urandom(32),
25 recv_header_key=os.urandom(32),
26 )
27
28
29def _make_connection_pair() -> tuple[SSU2Connection, SSU2Connection]:
30 """Create a matched pair of SSU2Connections (Alice + Bob).
31
32 Alice's send keys = Bob's recv keys and vice versa.
33 """
34 keys_alice = _make_keys()
35 keys_bob = HandshakeKeys(
36 send_cipher_key=keys_alice.recv_cipher_key,
37 recv_cipher_key=keys_alice.send_cipher_key,
38 send_header_key=keys_alice.recv_header_key,
39 recv_header_key=keys_alice.send_header_key,
40 )
41 src_id = 0x1234567890ABCDEF
42 dst_id = 0xFEDCBA0987654321
43
44 alice = SSU2Connection(
45 keys=keys_alice,
46 src_conn_id=src_id,
47 dest_conn_id=dst_id,
48 remote_address=("192.168.1.2", 5000),
49 is_initiator=True,
50 )
51 bob = SSU2Connection(
52 keys=keys_bob,
53 src_conn_id=dst_id,
54 dest_conn_id=src_id,
55 remote_address=("192.168.1.1", 5001),
56 is_initiator=False,
57 )
58 return alice, bob
59
60
61class TestEncryptDecryptRoundtrip:
62 """Encrypt a data packet on one side, decrypt on the other."""
63
64 def test_alice_to_bob(self):
65 alice, bob = _make_connection_pair()
66 blocks = [PaddingBlock(padding=os.urandom(64))]
67 packet = alice.encrypt_data_packet(blocks)
68 assert isinstance(packet, bytes)
69 assert len(packet) > 0
70
71 pkt_num, decoded_blocks = bob.decrypt_data_packet(packet)
72 assert pkt_num == 0
73 assert len(decoded_blocks) >= 1
74
75 def test_bob_to_alice(self):
76 alice, bob = _make_connection_pair()
77 blocks = [PaddingBlock(padding=os.urandom(32))]
78 packet = bob.encrypt_data_packet(blocks)
79 pkt_num, decoded_blocks = alice.decrypt_data_packet(packet)
80 assert pkt_num == 0
81 assert len(decoded_blocks) >= 1
82
83 def test_multiple_packets(self):
84 alice, bob = _make_connection_pair()
85 for i in range(5):
86 blocks = [PaddingBlock(padding=os.urandom(16))]
87 packet = alice.encrypt_data_packet(blocks)
88 pkt_num, _ = bob.decrypt_data_packet(packet)
89 assert pkt_num == i
90
91 def test_corrupted_packet_fails(self):
92 alice, bob = _make_connection_pair()
93 blocks = [PaddingBlock(padding=os.urandom(32))]
94 packet = alice.encrypt_data_packet(blocks)
95 # Corrupt the encrypted payload
96 corrupted = bytearray(packet)
97 corrupted[-5] ^= 0xFF
98 with pytest.raises(Exception):
99 bob.decrypt_data_packet(bytes(corrupted))
100
101
102class TestPacketNumbering:
103 """Packet numbers are sequential."""
104
105 def test_sequential_send_numbers(self):
106 alice, bob = _make_connection_pair()
107 for expected_num in range(10):
108 blocks = [PaddingBlock(padding=b"\x00" * 8)]
109 packet = alice.encrypt_data_packet(blocks)
110 pkt_num, _ = bob.decrypt_data_packet(packet)
111 assert pkt_num == expected_num
112
113 def test_send_and_recv_independent(self):
114 alice, bob = _make_connection_pair()
115 # Alice sends 3 packets
116 for i in range(3):
117 pkt = alice.encrypt_data_packet([PaddingBlock(padding=b"\x00")])
118 pkt_num, _ = bob.decrypt_data_packet(pkt)
119 assert pkt_num == i
120 # Bob sends 2 packets -- independent numbering
121 for i in range(2):
122 pkt = bob.encrypt_data_packet([PaddingBlock(padding=b"\x00")])
123 pkt_num, _ = alice.decrypt_data_packet(pkt)
124 assert pkt_num == i
125
126
127class TestAckTracking:
128 """Recv bitfield tracks received packets."""
129
130 def test_recv_bitfield_updated(self):
131 alice, bob = _make_connection_pair()
132 # Alice sends 3 packets
133 for _ in range(3):
134 pkt = alice.encrypt_data_packet([PaddingBlock(padding=b"\x00")])
135 bob.decrypt_data_packet(pkt)
136
137 # Bob's recv bitfield should have packets 0, 1, 2
138 ack = bob.build_ack_block()
139 assert ack.ack_through == 2
140
141 def test_build_ack_block(self):
142 alice, bob = _make_connection_pair()
143 pkt = alice.encrypt_data_packet([PaddingBlock(padding=b"\x00")])
144 bob.decrypt_data_packet(pkt)
145 ack = bob.build_ack_block()
146 assert isinstance(ack, AckBlock)
147 assert ack.ack_through == 0
148
149
150class TestProcessAck:
151 """process_ack returns newly acked packet numbers."""
152
153 def test_ack_returns_newly_acked(self):
154 alice, bob = _make_connection_pair()
155 # Alice sends 3 packets
156 for _ in range(3):
157 alice.encrypt_data_packet([PaddingBlock(padding=b"\x00")])
158
159 # Build an ACK block as if Bob is acking packets 0, 1, 2
160 ack = AckBlock(ack_through=2, ack_count=1, ranges=[(3, 0)])
161 newly_acked = alice.process_ack(ack)
162 assert set(newly_acked) == {0, 1, 2}
163
164 def test_ack_idempotent(self):
165 alice, bob = _make_connection_pair()
166 alice.encrypt_data_packet([PaddingBlock(padding=b"\x00")])
167 ack = AckBlock(ack_through=0, ack_count=1, ranges=[(1, 0)])
168 first = alice.process_ack(ack)
169 assert len(first) == 1
170 second = alice.process_ack(ack)
171 assert len(second) == 0 # Already acked
172
173
174class TestUnackedPackets:
175 """get_unacked_packets returns packets needing retransmit."""
176
177 def test_returns_unacked(self):
178 alice, bob = _make_connection_pair()
179 alice.encrypt_data_packet([PaddingBlock(padding=b"\x00")])
180 alice.encrypt_data_packet([PaddingBlock(padding=b"\x00")])
181 unacked = alice.get_unacked_packets(max_age_seconds=0.0)
182 assert len(unacked) == 2
183 assert all(isinstance(sp, SentPacket) for sp in unacked)
184
185 def test_acked_packets_excluded(self):
186 alice, bob = _make_connection_pair()
187 alice.encrypt_data_packet([PaddingBlock(padding=b"\x00")])
188 alice.encrypt_data_packet([PaddingBlock(padding=b"\x00")])
189 # ACK packet 0 only
190 ack = AckBlock(ack_through=0, ack_count=1, ranges=[(1, 0)])
191 alice.process_ack(ack)
192 unacked = alice.get_unacked_packets(max_age_seconds=0.0)
193 assert len(unacked) == 1
194 assert unacked[0].packet_num == 1
195
196
197class TestPathChallengeResponse:
198 """Path challenge/response roundtrip."""
199
200 def test_challenge_response(self):
201 alice, bob = _make_connection_pair()
202 challenge_packet = alice.send_path_challenge()
203 assert isinstance(challenge_packet, bytes)
204
205 # Bob decrypts the challenge
206 _, blocks = bob.decrypt_data_packet(challenge_packet)
207 challenge_blocks = [b for b in blocks if isinstance(b, PathChallengeBlock)]
208 assert len(challenge_blocks) == 1
209
210 # Bob builds response
211 response_packet = bob.process_path_challenge(challenge_blocks[0].challenge_data)
212 assert isinstance(response_packet, bytes)
213
214 # Alice decrypts response
215 _, resp_blocks = alice.decrypt_data_packet(response_packet)
216 response_blocks = [b for b in resp_blocks if isinstance(b, PathResponseBlock)]
217 assert len(response_blocks) == 1
218
219 # Challenge data should match
220 assert response_blocks[0].response_data == challenge_blocks[0].challenge_data
221
222
223class TestCloseTermination:
224 """close() produces a termination packet."""
225
226 def test_close_produces_packet(self):
227 alice, bob = _make_connection_pair()
228 term_packet = alice.close(reason=42)
229 assert isinstance(term_packet, bytes)
230
231 _, blocks = bob.decrypt_data_packet(term_packet)
232 term_blocks = [b for b in blocks if isinstance(b, TerminationBlock)]
233 assert len(term_blocks) == 1
234 assert term_blocks[0].reason == 42
235
236 def test_close_marks_not_established(self):
237 alice, _ = _make_connection_pair()
238 assert alice.is_established
239 alice.close()
240 assert not alice.is_established
241
242
243class TestIdleTime:
244 """Tracks last activity."""
245
246 def test_idle_time_increases(self):
247 alice, bob = _make_connection_pair()
248 pkt = alice.encrypt_data_packet([PaddingBlock(padding=b"\x00")])
249 t1 = alice.idle_time
250 # idle_time should be very small right after sending
251 assert t1 < 1.0
252
253 def test_recv_updates_idle(self):
254 alice, bob = _make_connection_pair()
255 pkt = alice.encrypt_data_packet([PaddingBlock(padding=b"\x00")])
256 bob.decrypt_data_packet(pkt)
257 assert bob.idle_time < 1.0
258
259
260class TestDifferentConnectionsDifferentKeys:
261 """Independent connections have independent cipher states."""
262
263 def test_independent_cipher_states(self):
264 alice1, bob1 = _make_connection_pair()
265 alice2, bob2 = _make_connection_pair()
266
267 blocks = [PaddingBlock(padding=b"\x42" * 16)]
268 pkt1 = alice1.encrypt_data_packet(blocks)
269 pkt2 = alice2.encrypt_data_packet(blocks)
270
271 # Different keys produce different ciphertext
272 assert pkt1 != pkt2
273
274 # Each can only be decrypted by the matching peer
275 _, dec1 = bob1.decrypt_data_packet(pkt1)
276 _, dec2 = bob2.decrypt_data_packet(pkt2)
277 assert len(dec1) >= 1
278 assert len(dec2) >= 1
279
280 # Cross-decryption should fail
281 with pytest.raises(Exception):
282 bob2.decrypt_data_packet(pkt1)