A Python port of the Invisible Internet Project (I2P)
at main 282 lines 9.8 kB view raw
1"""Tests for SSU2 data-phase connection state (PeerState2). 2 3TDD: tests written before implementation. 4""" 5 6import os 7import time 8 9import pytest 10 11from i2p_transport.ssu2_handshake import HandshakeKeys 12from i2p_transport.ssu2_connection import SSU2Connection, SentPacket 13from i2p_transport.ssu2_payload import ( 14 AckBlock, PaddingBlock, TerminationBlock, 15 PathChallengeBlock, PathResponseBlock, parse_payload, 16) 17 18 19def _make_keys() -> HandshakeKeys: 20 """Create random handshake keys.""" 21 return HandshakeKeys( 22 send_cipher_key=os.urandom(32), 23 recv_cipher_key=os.urandom(32), 24 send_header_key=os.urandom(32), 25 recv_header_key=os.urandom(32), 26 ) 27 28 29def _make_connection_pair() -> tuple[SSU2Connection, SSU2Connection]: 30 """Create a matched pair of SSU2Connections (Alice + Bob). 31 32 Alice's send keys = Bob's recv keys and vice versa. 33 """ 34 keys_alice = _make_keys() 35 keys_bob = HandshakeKeys( 36 send_cipher_key=keys_alice.recv_cipher_key, 37 recv_cipher_key=keys_alice.send_cipher_key, 38 send_header_key=keys_alice.recv_header_key, 39 recv_header_key=keys_alice.send_header_key, 40 ) 41 src_id = 0x1234567890ABCDEF 42 dst_id = 0xFEDCBA0987654321 43 44 alice = SSU2Connection( 45 keys=keys_alice, 46 src_conn_id=src_id, 47 dest_conn_id=dst_id, 48 remote_address=("192.168.1.2", 5000), 49 is_initiator=True, 50 ) 51 bob = SSU2Connection( 52 keys=keys_bob, 53 src_conn_id=dst_id, 54 dest_conn_id=src_id, 55 remote_address=("192.168.1.1", 5001), 56 is_initiator=False, 57 ) 58 return alice, bob 59 60 61class TestEncryptDecryptRoundtrip: 62 """Encrypt a data packet on one side, decrypt on the other.""" 63 64 def test_alice_to_bob(self): 65 alice, bob = _make_connection_pair() 66 blocks = [PaddingBlock(padding=os.urandom(64))] 67 packet = alice.encrypt_data_packet(blocks) 68 assert isinstance(packet, bytes) 69 assert len(packet) > 0 70 71 pkt_num, decoded_blocks = bob.decrypt_data_packet(packet) 72 assert pkt_num == 0 73 assert len(decoded_blocks) >= 1 74 75 def test_bob_to_alice(self): 76 alice, bob = _make_connection_pair() 77 blocks = [PaddingBlock(padding=os.urandom(32))] 78 packet = bob.encrypt_data_packet(blocks) 79 pkt_num, decoded_blocks = alice.decrypt_data_packet(packet) 80 assert pkt_num == 0 81 assert len(decoded_blocks) >= 1 82 83 def test_multiple_packets(self): 84 alice, bob = _make_connection_pair() 85 for i in range(5): 86 blocks = [PaddingBlock(padding=os.urandom(16))] 87 packet = alice.encrypt_data_packet(blocks) 88 pkt_num, _ = bob.decrypt_data_packet(packet) 89 assert pkt_num == i 90 91 def test_corrupted_packet_fails(self): 92 alice, bob = _make_connection_pair() 93 blocks = [PaddingBlock(padding=os.urandom(32))] 94 packet = alice.encrypt_data_packet(blocks) 95 # Corrupt the encrypted payload 96 corrupted = bytearray(packet) 97 corrupted[-5] ^= 0xFF 98 with pytest.raises(Exception): 99 bob.decrypt_data_packet(bytes(corrupted)) 100 101 102class TestPacketNumbering: 103 """Packet numbers are sequential.""" 104 105 def test_sequential_send_numbers(self): 106 alice, bob = _make_connection_pair() 107 for expected_num in range(10): 108 blocks = [PaddingBlock(padding=b"\x00" * 8)] 109 packet = alice.encrypt_data_packet(blocks) 110 pkt_num, _ = bob.decrypt_data_packet(packet) 111 assert pkt_num == expected_num 112 113 def test_send_and_recv_independent(self): 114 alice, bob = _make_connection_pair() 115 # Alice sends 3 packets 116 for i in range(3): 117 pkt = alice.encrypt_data_packet([PaddingBlock(padding=b"\x00")]) 118 pkt_num, _ = bob.decrypt_data_packet(pkt) 119 assert pkt_num == i 120 # Bob sends 2 packets -- independent numbering 121 for i in range(2): 122 pkt = bob.encrypt_data_packet([PaddingBlock(padding=b"\x00")]) 123 pkt_num, _ = alice.decrypt_data_packet(pkt) 124 assert pkt_num == i 125 126 127class TestAckTracking: 128 """Recv bitfield tracks received packets.""" 129 130 def test_recv_bitfield_updated(self): 131 alice, bob = _make_connection_pair() 132 # Alice sends 3 packets 133 for _ in range(3): 134 pkt = alice.encrypt_data_packet([PaddingBlock(padding=b"\x00")]) 135 bob.decrypt_data_packet(pkt) 136 137 # Bob's recv bitfield should have packets 0, 1, 2 138 ack = bob.build_ack_block() 139 assert ack.ack_through == 2 140 141 def test_build_ack_block(self): 142 alice, bob = _make_connection_pair() 143 pkt = alice.encrypt_data_packet([PaddingBlock(padding=b"\x00")]) 144 bob.decrypt_data_packet(pkt) 145 ack = bob.build_ack_block() 146 assert isinstance(ack, AckBlock) 147 assert ack.ack_through == 0 148 149 150class TestProcessAck: 151 """process_ack returns newly acked packet numbers.""" 152 153 def test_ack_returns_newly_acked(self): 154 alice, bob = _make_connection_pair() 155 # Alice sends 3 packets 156 for _ in range(3): 157 alice.encrypt_data_packet([PaddingBlock(padding=b"\x00")]) 158 159 # Build an ACK block as if Bob is acking packets 0, 1, 2 160 ack = AckBlock(ack_through=2, ack_count=1, ranges=[(3, 0)]) 161 newly_acked = alice.process_ack(ack) 162 assert set(newly_acked) == {0, 1, 2} 163 164 def test_ack_idempotent(self): 165 alice, bob = _make_connection_pair() 166 alice.encrypt_data_packet([PaddingBlock(padding=b"\x00")]) 167 ack = AckBlock(ack_through=0, ack_count=1, ranges=[(1, 0)]) 168 first = alice.process_ack(ack) 169 assert len(first) == 1 170 second = alice.process_ack(ack) 171 assert len(second) == 0 # Already acked 172 173 174class TestUnackedPackets: 175 """get_unacked_packets returns packets needing retransmit.""" 176 177 def test_returns_unacked(self): 178 alice, bob = _make_connection_pair() 179 alice.encrypt_data_packet([PaddingBlock(padding=b"\x00")]) 180 alice.encrypt_data_packet([PaddingBlock(padding=b"\x00")]) 181 unacked = alice.get_unacked_packets(max_age_seconds=0.0) 182 assert len(unacked) == 2 183 assert all(isinstance(sp, SentPacket) for sp in unacked) 184 185 def test_acked_packets_excluded(self): 186 alice, bob = _make_connection_pair() 187 alice.encrypt_data_packet([PaddingBlock(padding=b"\x00")]) 188 alice.encrypt_data_packet([PaddingBlock(padding=b"\x00")]) 189 # ACK packet 0 only 190 ack = AckBlock(ack_through=0, ack_count=1, ranges=[(1, 0)]) 191 alice.process_ack(ack) 192 unacked = alice.get_unacked_packets(max_age_seconds=0.0) 193 assert len(unacked) == 1 194 assert unacked[0].packet_num == 1 195 196 197class TestPathChallengeResponse: 198 """Path challenge/response roundtrip.""" 199 200 def test_challenge_response(self): 201 alice, bob = _make_connection_pair() 202 challenge_packet = alice.send_path_challenge() 203 assert isinstance(challenge_packet, bytes) 204 205 # Bob decrypts the challenge 206 _, blocks = bob.decrypt_data_packet(challenge_packet) 207 challenge_blocks = [b for b in blocks if isinstance(b, PathChallengeBlock)] 208 assert len(challenge_blocks) == 1 209 210 # Bob builds response 211 response_packet = bob.process_path_challenge(challenge_blocks[0].challenge_data) 212 assert isinstance(response_packet, bytes) 213 214 # Alice decrypts response 215 _, resp_blocks = alice.decrypt_data_packet(response_packet) 216 response_blocks = [b for b in resp_blocks if isinstance(b, PathResponseBlock)] 217 assert len(response_blocks) == 1 218 219 # Challenge data should match 220 assert response_blocks[0].response_data == challenge_blocks[0].challenge_data 221 222 223class TestCloseTermination: 224 """close() produces a termination packet.""" 225 226 def test_close_produces_packet(self): 227 alice, bob = _make_connection_pair() 228 term_packet = alice.close(reason=42) 229 assert isinstance(term_packet, bytes) 230 231 _, blocks = bob.decrypt_data_packet(term_packet) 232 term_blocks = [b for b in blocks if isinstance(b, TerminationBlock)] 233 assert len(term_blocks) == 1 234 assert term_blocks[0].reason == 42 235 236 def test_close_marks_not_established(self): 237 alice, _ = _make_connection_pair() 238 assert alice.is_established 239 alice.close() 240 assert not alice.is_established 241 242 243class TestIdleTime: 244 """Tracks last activity.""" 245 246 def test_idle_time_increases(self): 247 alice, bob = _make_connection_pair() 248 pkt = alice.encrypt_data_packet([PaddingBlock(padding=b"\x00")]) 249 t1 = alice.idle_time 250 # idle_time should be very small right after sending 251 assert t1 < 1.0 252 253 def test_recv_updates_idle(self): 254 alice, bob = _make_connection_pair() 255 pkt = alice.encrypt_data_packet([PaddingBlock(padding=b"\x00")]) 256 bob.decrypt_data_packet(pkt) 257 assert bob.idle_time < 1.0 258 259 260class TestDifferentConnectionsDifferentKeys: 261 """Independent connections have independent cipher states.""" 262 263 def test_independent_cipher_states(self): 264 alice1, bob1 = _make_connection_pair() 265 alice2, bob2 = _make_connection_pair() 266 267 blocks = [PaddingBlock(padding=b"\x42" * 16)] 268 pkt1 = alice1.encrypt_data_packet(blocks) 269 pkt2 = alice2.encrypt_data_packet(blocks) 270 271 # Different keys produce different ciphertext 272 assert pkt1 != pkt2 273 274 # Each can only be decrypted by the matching peer 275 _, dec1 = bob1.decrypt_data_packet(pkt1) 276 _, dec2 = bob2.decrypt_data_packet(pkt2) 277 assert len(dec1) >= 1 278 assert len(dec2) >= 1 279 280 # Cross-decryption should fail 281 with pytest.raises(Exception): 282 bob2.decrypt_data_packet(pkt1)