"""Tests for SSU2 data-phase connection state (PeerState2). TDD: tests written before implementation. """ import os import time import pytest from i2p_transport.ssu2_handshake import HandshakeKeys from i2p_transport.ssu2_connection import SSU2Connection, SentPacket from i2p_transport.ssu2_payload import ( AckBlock, PaddingBlock, TerminationBlock, PathChallengeBlock, PathResponseBlock, parse_payload, ) def _make_keys() -> HandshakeKeys: """Create random handshake keys.""" return HandshakeKeys( send_cipher_key=os.urandom(32), recv_cipher_key=os.urandom(32), send_header_key=os.urandom(32), recv_header_key=os.urandom(32), ) def _make_connection_pair() -> tuple[SSU2Connection, SSU2Connection]: """Create a matched pair of SSU2Connections (Alice + Bob). Alice's send keys = Bob's recv keys and vice versa. """ keys_alice = _make_keys() keys_bob = HandshakeKeys( send_cipher_key=keys_alice.recv_cipher_key, recv_cipher_key=keys_alice.send_cipher_key, send_header_key=keys_alice.recv_header_key, recv_header_key=keys_alice.send_header_key, ) src_id = 0x1234567890ABCDEF dst_id = 0xFEDCBA0987654321 alice = SSU2Connection( keys=keys_alice, src_conn_id=src_id, dest_conn_id=dst_id, remote_address=("192.168.1.2", 5000), is_initiator=True, ) bob = SSU2Connection( keys=keys_bob, src_conn_id=dst_id, dest_conn_id=src_id, remote_address=("192.168.1.1", 5001), is_initiator=False, ) return alice, bob class TestEncryptDecryptRoundtrip: """Encrypt a data packet on one side, decrypt on the other.""" def test_alice_to_bob(self): alice, bob = _make_connection_pair() blocks = [PaddingBlock(padding=os.urandom(64))] packet = alice.encrypt_data_packet(blocks) assert isinstance(packet, bytes) assert len(packet) > 0 pkt_num, decoded_blocks = bob.decrypt_data_packet(packet) assert pkt_num == 0 assert len(decoded_blocks) >= 1 def test_bob_to_alice(self): alice, bob = _make_connection_pair() blocks = [PaddingBlock(padding=os.urandom(32))] packet = bob.encrypt_data_packet(blocks) pkt_num, decoded_blocks = alice.decrypt_data_packet(packet) assert pkt_num == 0 assert len(decoded_blocks) >= 1 def test_multiple_packets(self): alice, bob = _make_connection_pair() for i in range(5): blocks = [PaddingBlock(padding=os.urandom(16))] packet = alice.encrypt_data_packet(blocks) pkt_num, _ = bob.decrypt_data_packet(packet) assert pkt_num == i def test_corrupted_packet_fails(self): alice, bob = _make_connection_pair() blocks = [PaddingBlock(padding=os.urandom(32))] packet = alice.encrypt_data_packet(blocks) # Corrupt the encrypted payload corrupted = bytearray(packet) corrupted[-5] ^= 0xFF with pytest.raises(Exception): bob.decrypt_data_packet(bytes(corrupted)) class TestPacketNumbering: """Packet numbers are sequential.""" def test_sequential_send_numbers(self): alice, bob = _make_connection_pair() for expected_num in range(10): blocks = [PaddingBlock(padding=b"\x00" * 8)] packet = alice.encrypt_data_packet(blocks) pkt_num, _ = bob.decrypt_data_packet(packet) assert pkt_num == expected_num def test_send_and_recv_independent(self): alice, bob = _make_connection_pair() # Alice sends 3 packets for i in range(3): pkt = alice.encrypt_data_packet([PaddingBlock(padding=b"\x00")]) pkt_num, _ = bob.decrypt_data_packet(pkt) assert pkt_num == i # Bob sends 2 packets -- independent numbering for i in range(2): pkt = bob.encrypt_data_packet([PaddingBlock(padding=b"\x00")]) pkt_num, _ = alice.decrypt_data_packet(pkt) assert pkt_num == i class TestAckTracking: """Recv bitfield tracks received packets.""" def test_recv_bitfield_updated(self): alice, bob = _make_connection_pair() # Alice sends 3 packets for _ in range(3): pkt = alice.encrypt_data_packet([PaddingBlock(padding=b"\x00")]) bob.decrypt_data_packet(pkt) # Bob's recv bitfield should have packets 0, 1, 2 ack = bob.build_ack_block() assert ack.ack_through == 2 def test_build_ack_block(self): alice, bob = _make_connection_pair() pkt = alice.encrypt_data_packet([PaddingBlock(padding=b"\x00")]) bob.decrypt_data_packet(pkt) ack = bob.build_ack_block() assert isinstance(ack, AckBlock) assert ack.ack_through == 0 class TestProcessAck: """process_ack returns newly acked packet numbers.""" def test_ack_returns_newly_acked(self): alice, bob = _make_connection_pair() # Alice sends 3 packets for _ in range(3): alice.encrypt_data_packet([PaddingBlock(padding=b"\x00")]) # Build an ACK block as if Bob is acking packets 0, 1, 2 ack = AckBlock(ack_through=2, ack_count=1, ranges=[(3, 0)]) newly_acked = alice.process_ack(ack) assert set(newly_acked) == {0, 1, 2} def test_ack_idempotent(self): alice, bob = _make_connection_pair() alice.encrypt_data_packet([PaddingBlock(padding=b"\x00")]) ack = AckBlock(ack_through=0, ack_count=1, ranges=[(1, 0)]) first = alice.process_ack(ack) assert len(first) == 1 second = alice.process_ack(ack) assert len(second) == 0 # Already acked class TestUnackedPackets: """get_unacked_packets returns packets needing retransmit.""" def test_returns_unacked(self): alice, bob = _make_connection_pair() alice.encrypt_data_packet([PaddingBlock(padding=b"\x00")]) alice.encrypt_data_packet([PaddingBlock(padding=b"\x00")]) unacked = alice.get_unacked_packets(max_age_seconds=0.0) assert len(unacked) == 2 assert all(isinstance(sp, SentPacket) for sp in unacked) def test_acked_packets_excluded(self): alice, bob = _make_connection_pair() alice.encrypt_data_packet([PaddingBlock(padding=b"\x00")]) alice.encrypt_data_packet([PaddingBlock(padding=b"\x00")]) # ACK packet 0 only ack = AckBlock(ack_through=0, ack_count=1, ranges=[(1, 0)]) alice.process_ack(ack) unacked = alice.get_unacked_packets(max_age_seconds=0.0) assert len(unacked) == 1 assert unacked[0].packet_num == 1 class TestPathChallengeResponse: """Path challenge/response roundtrip.""" def test_challenge_response(self): alice, bob = _make_connection_pair() challenge_packet = alice.send_path_challenge() assert isinstance(challenge_packet, bytes) # Bob decrypts the challenge _, blocks = bob.decrypt_data_packet(challenge_packet) challenge_blocks = [b for b in blocks if isinstance(b, PathChallengeBlock)] assert len(challenge_blocks) == 1 # Bob builds response response_packet = bob.process_path_challenge(challenge_blocks[0].challenge_data) assert isinstance(response_packet, bytes) # Alice decrypts response _, resp_blocks = alice.decrypt_data_packet(response_packet) response_blocks = [b for b in resp_blocks if isinstance(b, PathResponseBlock)] assert len(response_blocks) == 1 # Challenge data should match assert response_blocks[0].response_data == challenge_blocks[0].challenge_data class TestCloseTermination: """close() produces a termination packet.""" def test_close_produces_packet(self): alice, bob = _make_connection_pair() term_packet = alice.close(reason=42) assert isinstance(term_packet, bytes) _, blocks = bob.decrypt_data_packet(term_packet) term_blocks = [b for b in blocks if isinstance(b, TerminationBlock)] assert len(term_blocks) == 1 assert term_blocks[0].reason == 42 def test_close_marks_not_established(self): alice, _ = _make_connection_pair() assert alice.is_established alice.close() assert not alice.is_established class TestIdleTime: """Tracks last activity.""" def test_idle_time_increases(self): alice, bob = _make_connection_pair() pkt = alice.encrypt_data_packet([PaddingBlock(padding=b"\x00")]) t1 = alice.idle_time # idle_time should be very small right after sending assert t1 < 1.0 def test_recv_updates_idle(self): alice, bob = _make_connection_pair() pkt = alice.encrypt_data_packet([PaddingBlock(padding=b"\x00")]) bob.decrypt_data_packet(pkt) assert bob.idle_time < 1.0 class TestDifferentConnectionsDifferentKeys: """Independent connections have independent cipher states.""" def test_independent_cipher_states(self): alice1, bob1 = _make_connection_pair() alice2, bob2 = _make_connection_pair() blocks = [PaddingBlock(padding=b"\x42" * 16)] pkt1 = alice1.encrypt_data_packet(blocks) pkt2 = alice2.encrypt_data_packet(blocks) # Different keys produce different ciphertext assert pkt1 != pkt2 # Each can only be decrypted by the matching peer _, dec1 = bob1.decrypt_data_packet(pkt1) _, dec2 = bob2.decrypt_data_packet(pkt2) assert len(dec1) >= 1 assert len(dec2) >= 1 # Cross-decryption should fail with pytest.raises(Exception): bob2.decrypt_data_packet(pkt1)