"""Tests for NTCP2 Noise handshake integration. TDD: these tests are written before the implementation. """ import pytest from i2p_crypto.x25519 import X25519DH from i2p_crypto.noise import CipherState from i2p_transport.ntcp2 import NTCP2Frame, FrameType from i2p_transport.ntcp2_handshake import ( NTCP2Handshake, NTCP2FrameCodec, NTCP2MessageFragmenter, ) # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_keypair(): return X25519DH.generate_keypair() def _do_full_handshake(): """Run a complete 3-message Noise_XK handshake, return both sides.""" alice_static = _make_keypair() bob_static = _make_keypair() # Alice (initiator) knows Bob's static public key alice = NTCP2Handshake( our_static=alice_static, peer_static_pub=bob_static[1], initiator=True, ) # Bob (responder) does NOT know Alice's static yet bob = NTCP2Handshake( our_static=bob_static, peer_static_pub=None, initiator=False, ) # msg1: Alice -> Bob msg1 = alice.create_message_1(options=b"hello") # msg2: Bob processes msg1, produces msg2 msg2 = bob.process_message_1(msg1) # msg3: Alice processes msg2, produces msg3 msg3 = alice.process_message_2(msg2) # Bob processes msg3 bob.process_message_3(msg3) return alice, bob, alice_static, bob_static # --------------------------------------------------------------------------- # Handshake tests # --------------------------------------------------------------------------- class TestNTCP2Handshake: def test_full_handshake_completes(self): alice, bob, _, _ = _do_full_handshake() assert alice.is_complete() assert bob.is_complete() def test_split_produces_cipher_pair(self): alice, bob, _, _ = _do_full_handshake() a_send, a_recv = alice.split() b_send, b_recv = bob.split() assert isinstance(a_send, CipherState) assert isinstance(a_recv, CipherState) assert isinstance(b_send, CipherState) assert isinstance(b_recv, CipherState) def test_transport_encryption_after_handshake(self): """Initiator encrypts, responder decrypts.""" alice, bob, _, _ = _do_full_handshake() a_send, a_recv = alice.split() b_send, b_recv = bob.split() plaintext = b"I2P rocks" ct = a_send.encrypt_with_ad(b"", plaintext) pt = b_recv.decrypt_with_ad(b"", ct) assert pt == plaintext def test_bidirectional_encryption(self): """Both directions work: alice->bob and bob->alice.""" alice, bob, _, _ = _do_full_handshake() a_send, a_recv = alice.split() b_send, b_recv = bob.split() # Alice -> Bob ct1 = a_send.encrypt_with_ad(b"", b"from alice") assert b_recv.decrypt_with_ad(b"", ct1) == b"from alice" # Bob -> Alice ct2 = b_send.encrypt_with_ad(b"", b"from bob") assert a_recv.decrypt_with_ad(b"", ct2) == b"from bob" def test_remote_static_key_recovery(self): """Responder learns initiator's static key through XK pattern.""" alice, bob, alice_static, bob_static = _do_full_handshake() # Responder should now know Alice's static public key assert bob.remote_static_key() == alice_static[1] # Initiator already knew Bob's static assert alice.remote_static_key() == bob_static[1] def test_handshake_out_of_order_error(self): """Calling handshake methods out of order should raise.""" alice_static = _make_keypair() bob_static = _make_keypair() alice = NTCP2Handshake( our_static=alice_static, peer_static_pub=bob_static[1], initiator=True, ) # Cannot process_message_2 before create_message_1 with pytest.raises(RuntimeError): alice.process_message_2(b"fake_msg") def test_decrypt_with_wrong_cipher_fails(self): """Decrypting with mismatched cipher must fail.""" alice, bob, _, _ = _do_full_handshake() a_send, a_recv = alice.split() b_send, b_recv = bob.split() ct = a_send.encrypt_with_ad(b"", b"secret") # Try to decrypt with the wrong cipher (b_send instead of b_recv) with pytest.raises(Exception): b_send.decrypt_with_ad(b"", ct) def test_is_complete_false_before_handshake(self): alice_static = _make_keypair() bob_static = _make_keypair() alice = NTCP2Handshake( our_static=alice_static, peer_static_pub=bob_static[1], initiator=True, ) assert not alice.is_complete() # --------------------------------------------------------------------------- # FrameCodec tests # --------------------------------------------------------------------------- class TestNTCP2FrameCodec: def _make_cipher_pair(self): alice, bob, _, _ = _do_full_handshake() a_send, a_recv = alice.split() b_send, b_recv = bob.split() return a_send, b_recv def test_encrypt_decrypt_data_frame(self): codec = NTCP2FrameCodec() send_c, recv_c = self._make_cipher_pair() frame = NTCP2Frame(FrameType.DATA, b"some data payload") encrypted = codec.encrypt_frame(send_c, frame) decrypted = codec.decrypt_frame(recv_c, encrypted) assert decrypted.frame_type == FrameType.DATA assert decrypted.payload == b"some data payload" def test_encrypt_decrypt_i2np_frame(self): codec = NTCP2FrameCodec() send_c, recv_c = self._make_cipher_pair() frame = NTCP2Frame(FrameType.I2NP, b"\x01\x02\x03\x04") encrypted = codec.encrypt_frame(send_c, frame) decrypted = codec.decrypt_frame(recv_c, encrypted) assert decrypted.frame_type == FrameType.I2NP assert decrypted.payload == b"\x01\x02\x03\x04" def test_encrypt_decrypt_padding_frame(self): codec = NTCP2FrameCodec() send_c, recv_c = self._make_cipher_pair() frame = NTCP2Frame(FrameType.PADDING, b"\x00" * 16) encrypted = codec.encrypt_frame(send_c, frame) decrypted = codec.decrypt_frame(recv_c, encrypted) assert decrypted.frame_type == FrameType.PADDING assert decrypted.payload == b"\x00" * 16 # --------------------------------------------------------------------------- # MessageFragmenter tests # --------------------------------------------------------------------------- class TestNTCP2MessageFragmenter: def test_single_chunk_fits_one_frame(self): frag = NTCP2MessageFragmenter() data = b"short message" frames = frag.fragment(data, max_payload=65535) assert len(frames) == 1 assert frames[0].frame_type == FrameType.I2NP assert frames[0].payload == data def test_large_message_split_across_frames(self): frag = NTCP2MessageFragmenter() data = b"A" * 1000 frames = frag.fragment(data, max_payload=300) # 1000 / 300 = 4 frames (300+300+300+100) assert len(frames) == 4 for f in frames: assert f.frame_type == FrameType.I2NP def test_reassemble_single_frame(self): frag = NTCP2MessageFragmenter() data = b"hello world" frames = frag.fragment(data) result = frag.reassemble(frames) assert result == data def test_reassemble_multiple_frames(self): frag = NTCP2MessageFragmenter() data = b"B" * 1000 frames = frag.fragment(data, max_payload=256) result = frag.reassemble(frames) assert result == data def test_fragment_exact_multiple(self): frag = NTCP2MessageFragmenter() data = b"C" * 600 frames = frag.fragment(data, max_payload=200) assert len(frames) == 3 assert frag.reassemble(frames) == data def test_empty_message(self): frag = NTCP2MessageFragmenter() frames = frag.fragment(b"") assert len(frames) == 1 assert frames[0].payload == b"" assert frag.reassemble(frames) == b""