"""Tests for NTCP2 wire-format frame encryption. TDD: these tests are written before the implementation. NTCP2 wire format encrypts each frame in two operations: 1. Encrypt the 2-byte frame length -> 18 bytes (2 + 16 AEAD tag) 2. Encrypt the frame payload (type + length + data) -> N + 16 bytes Each encrypt/decrypt consumes one nonce from the CipherState. """ import os import struct import pytest from i2p_crypto.x25519 import X25519DH from i2p_crypto.noise import CipherState from i2p_transport.ntcp2 import NTCP2Frame, FrameType from i2p_transport.ntcp2_handshake import NTCP2Handshake # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_keypair(): return X25519DH.generate_keypair() def _make_cipher_pair(): """Complete a Noise_XK handshake and return (send_cipher, recv_cipher). send_cipher encrypts; recv_cipher decrypts with matching keys/nonces. """ init_static = _make_keypair() resp_static = _make_keypair() init_hs = NTCP2Handshake(init_static, resp_static[1], True) resp_hs = NTCP2Handshake(resp_static, initiator=False) msg1 = init_hs.create_message_1() msg2 = resp_hs.process_message_1(msg1) msg3 = init_hs.process_message_2(msg2) resp_hs.process_message_3(msg3) send, _ = init_hs.split() _, resp_recv = resp_hs.split() return send, resp_recv # --------------------------------------------------------------------------- # Tests # --------------------------------------------------------------------------- class TestEncryptFrameByteCount: """encrypt_frame produces correct byte count: 18 + len(frame_bytes) + 16.""" def test_data_frame_size(self): from i2p_transport.ntcp2_wire import NTCP2WireCodec send, _ = _make_cipher_pair() codec = NTCP2WireCodec() payload = b"hello world" frame = NTCP2Frame(FrameType.DATA, payload) frame_bytes = frame.to_bytes() encrypted = codec.encrypt_frame(send, frame) # 18 bytes for encrypted length + (len(frame_bytes) + 16) for encrypted payload expected_len = 18 + len(frame_bytes) + 16 assert len(encrypted) == expected_len def test_empty_payload_frame_size(self): from i2p_transport.ntcp2_wire import NTCP2WireCodec send, _ = _make_cipher_pair() codec = NTCP2WireCodec() frame = NTCP2Frame(FrameType.PADDING, b"") frame_bytes = frame.to_bytes() encrypted = codec.encrypt_frame(send, frame) expected_len = 18 + len(frame_bytes) + 16 assert len(encrypted) == expected_len def test_large_payload_frame_size(self): from i2p_transport.ntcp2_wire import NTCP2WireCodec send, _ = _make_cipher_pair() codec = NTCP2WireCodec() payload = os.urandom(1024) frame = NTCP2Frame(FrameType.I2NP, payload) frame_bytes = frame.to_bytes() encrypted = codec.encrypt_frame(send, frame) expected_len = 18 + len(frame_bytes) + 16 assert len(encrypted) == expected_len class TestDecryptFrameLength: """decrypt_frame_length recovers original length.""" def test_recover_length(self): from i2p_transport.ntcp2_wire import NTCP2WireCodec send, recv = _make_cipher_pair() codec = NTCP2WireCodec() payload = b"test payload data" frame = NTCP2Frame(FrameType.DATA, payload) frame_bytes = frame.to_bytes() encrypted = codec.encrypt_frame(send, frame) # First 18 bytes are the encrypted length encrypted_length = encrypted[:18] recovered_length = codec.decrypt_frame_length(recv, encrypted_length) assert recovered_length == len(frame_bytes) def test_recover_zero_payload_length(self): from i2p_transport.ntcp2_wire import NTCP2WireCodec send, recv = _make_cipher_pair() codec = NTCP2WireCodec() frame = NTCP2Frame(FrameType.PADDING, b"") frame_bytes = frame.to_bytes() encrypted = codec.encrypt_frame(send, frame) encrypted_length = encrypted[:18] recovered_length = codec.decrypt_frame_length(recv, encrypted_length) assert recovered_length == len(frame_bytes) class TestDecryptFramePayload: """decrypt_frame_payload recovers original frame.""" def test_recover_frame(self): from i2p_transport.ntcp2_wire import NTCP2WireCodec send, recv = _make_cipher_pair() codec = NTCP2WireCodec() payload = b"i2np message content" frame = NTCP2Frame(FrameType.I2NP, payload) encrypted = codec.encrypt_frame(send, frame) # Skip encrypted length (18 bytes), rest is encrypted payload encrypted_payload = encrypted[18:] # Consume the length nonce first codec.decrypt_frame_length(recv, encrypted[:18]) recovered = codec.decrypt_frame_payload(recv, encrypted_payload) assert recovered.frame_type == FrameType.I2NP assert recovered.payload == payload class TestFullRoundtrip: """Full roundtrip: encrypt -> decrypt_length -> decrypt_payload -> original frame.""" def test_roundtrip_data_frame(self): from i2p_transport.ntcp2_wire import NTCP2WireCodec send, recv = _make_cipher_pair() codec = NTCP2WireCodec() original = NTCP2Frame(FrameType.DATA, b"roundtrip test data") wire_bytes = codec.encrypt_frame(send, original) # Decrypt: first the length, then the payload frame_len = codec.decrypt_frame_length(recv, wire_bytes[:18]) recovered = codec.decrypt_frame_payload(recv, wire_bytes[18:]) assert recovered.frame_type == original.frame_type assert recovered.payload == original.payload assert frame_len == len(original.to_bytes()) def test_roundtrip_via_convenience(self): from i2p_transport.ntcp2_wire import NTCP2WireCodec send, recv = _make_cipher_pair() codec = NTCP2WireCodec() original = NTCP2Frame(FrameType.ROUTER_INFO, os.urandom(200)) wire_bytes = codec.encrypt_and_get_wire_bytes(send, original) frame_len = codec.decrypt_frame_length(recv, wire_bytes[:18]) recovered = codec.decrypt_frame_payload(recv, wire_bytes[18:]) assert recovered.frame_type == original.frame_type assert recovered.payload == original.payload class TestMultipleFrames: """Multiple frames in sequence (nonces advance correctly).""" def test_three_sequential_frames(self): from i2p_transport.ntcp2_wire import NTCP2WireCodec send, recv = _make_cipher_pair() codec = NTCP2WireCodec() frames = [ NTCP2Frame(FrameType.DATETIME, struct.pack("!I", 1710000000)), NTCP2Frame(FrameType.I2NP, os.urandom(128)), NTCP2Frame(FrameType.PADDING, os.urandom(32)), ] encrypted_list = [] for f in frames: encrypted_list.append(codec.encrypt_frame(send, f)) for i, enc in enumerate(encrypted_list): frame_len = codec.decrypt_frame_length(recv, enc[:18]) recovered = codec.decrypt_frame_payload(recv, enc[18:]) assert recovered.frame_type == frames[i].frame_type assert recovered.payload == frames[i].payload def test_nonces_advance_two_per_frame(self): """Each frame uses 2 nonces (one for length, one for payload). Encrypting 3 frames should use nonces 0-5 on the send side.""" from i2p_transport.ntcp2_wire import NTCP2WireCodec send, recv = _make_cipher_pair() codec = NTCP2WireCodec() for _ in range(3): frame = NTCP2Frame(FrameType.DATA, b"x") codec.encrypt_frame(send, frame) # After 3 frames, send cipher should have used 6 nonces (0..5) # Internal nonce counter should be 6 assert send._n == 6 def test_out_of_order_decrypt_fails(self): """Decrypting frames out of order should fail (nonce mismatch).""" from i2p_transport.ntcp2_wire import NTCP2WireCodec send, recv = _make_cipher_pair() codec = NTCP2WireCodec() f1 = NTCP2Frame(FrameType.DATA, b"first") f2 = NTCP2Frame(FrameType.DATA, b"second") enc1 = codec.encrypt_frame(send, f1) enc2 = codec.encrypt_frame(send, f2) # Try to decrypt enc2 first — should fail because recv expects nonce 0 with pytest.raises(Exception): codec.decrypt_frame_length(recv, enc2[:18]) class TestDifferentFrameTypes: """Different frame types (I2NP, PADDING, TERMINATION).""" def test_i2np_frame(self): from i2p_transport.ntcp2_wire import NTCP2WireCodec send, recv = _make_cipher_pair() codec = NTCP2WireCodec() frame = NTCP2Frame(FrameType.I2NP, os.urandom(256)) wire = codec.encrypt_frame(send, frame) codec.decrypt_frame_length(recv, wire[:18]) recovered = codec.decrypt_frame_payload(recv, wire[18:]) assert recovered.frame_type == FrameType.I2NP assert recovered.payload == frame.payload def test_padding_frame(self): from i2p_transport.ntcp2_wire import NTCP2WireCodec send, recv = _make_cipher_pair() codec = NTCP2WireCodec() frame = NTCP2Frame(FrameType.PADDING, os.urandom(64)) wire = codec.encrypt_frame(send, frame) codec.decrypt_frame_length(recv, wire[:18]) recovered = codec.decrypt_frame_payload(recv, wire[18:]) assert recovered.frame_type == FrameType.PADDING assert recovered.payload == frame.payload def test_termination_frame(self): from i2p_transport.ntcp2_wire import NTCP2WireCodec send, recv = _make_cipher_pair() codec = NTCP2WireCodec() # reason(1) + valid_received(8) payload = struct.pack("!BQ", 0, 42) frame = NTCP2Frame(FrameType.TERMINATION, payload) wire = codec.encrypt_frame(send, frame) codec.decrypt_frame_length(recv, wire[:18]) recovered = codec.decrypt_frame_payload(recv, wire[18:]) assert recovered.frame_type == FrameType.TERMINATION assert recovered.payload == payload def test_datetime_frame(self): from i2p_transport.ntcp2_wire import NTCP2WireCodec send, recv = _make_cipher_pair() codec = NTCP2WireCodec() ts = struct.pack("!I", 1710000000) frame = NTCP2Frame(FrameType.DATETIME, ts) wire = codec.encrypt_frame(send, frame) codec.decrypt_frame_length(recv, wire[:18]) recovered = codec.decrypt_frame_payload(recv, wire[18:]) assert recovered.frame_type == FrameType.DATETIME assert struct.unpack("!I", recovered.payload)[0] == 1710000000 def test_options_frame(self): from i2p_transport.ntcp2_wire import NTCP2WireCodec send, recv = _make_cipher_pair() codec = NTCP2WireCodec() frame = NTCP2Frame(FrameType.OPTIONS, b"\x01\x02\x03\x04") wire = codec.encrypt_frame(send, frame) codec.decrypt_frame_length(recv, wire[:18]) recovered = codec.decrypt_frame_payload(recv, wire[18:]) assert recovered.frame_type == FrameType.OPTIONS assert recovered.payload == b"\x01\x02\x03\x04" class TestEmptyPayloadFrame: """Empty payload frame.""" def test_empty_data_frame_roundtrip(self): from i2p_transport.ntcp2_wire import NTCP2WireCodec send, recv = _make_cipher_pair() codec = NTCP2WireCodec() frame = NTCP2Frame(FrameType.DATA, b"") wire = codec.encrypt_frame(send, frame) # frame_bytes = type(1) + length(2) + payload(0) = 3 bytes assert len(wire) == 18 + 3 + 16 frame_len = codec.decrypt_frame_length(recv, wire[:18]) assert frame_len == 3 # type(1) + length(2) recovered = codec.decrypt_frame_payload(recv, wire[18:]) assert recovered.frame_type == FrameType.DATA assert recovered.payload == b"" def test_empty_padding_frame_roundtrip(self): from i2p_transport.ntcp2_wire import NTCP2WireCodec send, recv = _make_cipher_pair() codec = NTCP2WireCodec() frame = NTCP2Frame(FrameType.PADDING, b"") wire = codec.encrypt_frame(send, frame) frame_len = codec.decrypt_frame_length(recv, wire[:18]) recovered = codec.decrypt_frame_payload(recv, wire[18:]) assert recovered.frame_type == FrameType.PADDING assert recovered.payload == b""