"""Noise protocol compliance verification tests. Verifies that the Noise_XK and Noise_IK implementations conform to the Noise Protocol Framework specification (noiseprotocol.org/noise.html). Covers: message sequence correctness, MixHash/MixKey calls, Split correctness, prologue binding, nonce overflow detection, and static key validation. """ from __future__ import annotations import hashlib import os import pytest from i2p_crypto.noise import CipherState, SymmetricState, HandshakeState from i2p_crypto.x25519 import X25519DH def _make_keypair(): return X25519DH.generate_keypair() # ---------- Prologue binding ---------- class TestPrologueBinding: """Noise spec: prologue is MixHash'd into h before any DH. Mismatched prologues must cause handshake failure. """ def test_prologue_mismatch_causes_failure(self) -> None: """XK handshake with different prologues must fail. In XK, the 'es' token in message 1 runs MixKey, giving the CipherState a key. The payload in message 1 is then encrypted with AD = h (which includes the prologue). If prologues differ, h values diverge, causing AEAD decryption to fail at message 1. """ i_s = _make_keypair() r_s = _make_keypair() initiator = HandshakeState( "Noise_XK", initiator=True, s=i_s, rs=r_s[1], prologue=b"prologue-A", ) responder = HandshakeState( "Noise_XK", initiator=False, s=r_s, prologue=b"prologue-B", ) msg1 = initiator.write_message(b"test") # Responder fails to read message 1 — hash chains diverged with pytest.raises(Exception): responder.read_message(msg1) def test_matching_prologue_succeeds(self) -> None: """Same prologue on both sides must produce successful handshake.""" i_s = _make_keypair() r_s = _make_keypair() initiator = HandshakeState( "Noise_XK", initiator=True, s=i_s, rs=r_s[1], prologue=b"same-prologue", ) responder = HandshakeState( "Noise_XK", initiator=False, s=r_s, prologue=b"same-prologue", ) msg1 = initiator.write_message(b"hello") p1 = responder.read_message(msg1) assert p1 == b"hello" msg2 = responder.write_message(b"world") p2 = initiator.read_message(msg2) assert p2 == b"world" msg3 = initiator.write_message(b"final") p3 = responder.read_message(msg3) assert p3 == b"final" assert initiator.complete and responder.complete def test_empty_prologue_default(self) -> None: """Default empty prologue must match another empty prologue.""" i_s = _make_keypair() r_s = _make_keypair() initiator = HandshakeState( "Noise_XK", initiator=True, s=i_s, rs=r_s[1], ) responder = HandshakeState( "Noise_XK", initiator=False, s=r_s, ) msg1 = initiator.write_message() responder.read_message(msg1) msg2 = responder.write_message() initiator.read_message(msg2) msg3 = initiator.write_message() responder.read_message(msg3) assert initiator.complete and responder.complete # ---------- Nonce overflow ---------- class TestNonceOverflow: def test_nonce_overflow_encrypt_raises(self) -> None: """CipherState must reject encryption at MAX_NONCE.""" cs = CipherState(os.urandom(32)) cs.set_nonce(CipherState.MAX_NONCE) with pytest.raises(RuntimeError, match="Nonce exhausted"): cs.encrypt_with_ad(b"", b"data") def test_nonce_overflow_decrypt_raises(self) -> None: """CipherState must reject decryption at MAX_NONCE.""" key = os.urandom(32) cs_enc = CipherState(key) ct = cs_enc.encrypt_with_ad(b"", b"data") cs_dec = CipherState(key) cs_dec.set_nonce(CipherState.MAX_NONCE) with pytest.raises(RuntimeError, match="Nonce exhausted"): cs_dec.decrypt_with_ad(b"", ct) def test_nonce_just_below_max_works(self) -> None: """Nonce at MAX_NONCE - 1 should still work.""" key = os.urandom(32) cs_enc = CipherState(key) cs_enc.set_nonce(CipherState.MAX_NONCE - 1) ct = cs_enc.encrypt_with_ad(b"", b"last-message") cs_dec = CipherState(key) cs_dec.set_nonce(CipherState.MAX_NONCE - 1) pt = cs_dec.decrypt_with_ad(b"", ct) assert pt == b"last-message" # ---------- Split correctness ---------- class TestSplitCorrectness: def test_split_produces_distinct_keys(self) -> None: """Split must produce two CipherStates with different keys.""" ss = SymmetricState(b"test-protocol") ss.mix_key(os.urandom(32)) c1, c2 = ss.split() assert c1.has_key() and c2.has_key() assert c1._key != c2._key def test_split_ciphers_start_at_nonce_zero(self) -> None: """Both CipherStates from Split must start with nonce = 0.""" ss = SymmetricState(b"test-protocol") ss.mix_key(os.urandom(32)) c1, c2 = ss.split() assert c1._n == 0 assert c2._n == 0 def test_split_is_deterministic(self) -> None: """Two SymmetricStates with same history produce same Split keys.""" ikm = os.urandom(32) ss1 = SymmetricState(b"test") ss1.mix_key(ikm) c1a, c1b = ss1.split() ss2 = SymmetricState(b"test") ss2.mix_key(ikm) c2a, c2b = ss2.split() assert c1a._key == c2a._key assert c1b._key == c2b._key # ---------- XK message sequence compliance ---------- class TestXKMessageSequence: """Verify XK pattern: -> e, es / <- e, ee / -> s, se""" def test_message_1_contains_ephemeral(self) -> None: """Message 1 must contain initiator's ephemeral public key (32 bytes).""" i_s = _make_keypair() r_s = _make_keypair() initiator = HandshakeState( "Noise_XK", initiator=True, s=i_s, rs=r_s[1], ) msg1 = initiator.write_message(b"") # Message 1 for XK with empty payload: # e(32) + encrypted_payload(0 bytes + 16 byte tag from es having set cipher key) # After MixKey from es, the cipher has a key, so payload gets encrypted assert len(msg1) >= 32 def test_message_3_contains_encrypted_static(self) -> None: """Message 3 must contain initiator's encrypted static key.""" i_s = _make_keypair() r_s = _make_keypair() initiator = HandshakeState( "Noise_XK", initiator=True, s=i_s, rs=r_s[1], ) responder = HandshakeState( "Noise_XK", initiator=False, s=r_s, ) msg1 = initiator.write_message() responder.read_message(msg1) msg2 = responder.write_message() initiator.read_message(msg2) msg3 = initiator.write_message() # Message 3: encrypted static (32 + 16 tag) + encrypted payload (0 + 16 tag) = 64 assert len(msg3) == 64 # Responder recovers initiator's static key responder.read_message(msg3) assert responder.remote_static == i_s[1] # ---------- MixHash chain integrity ---------- class TestMixHashChainIntegrity: def test_handshake_hash_chains_match(self) -> None: """After handshake, both sides must have identical h (handshake hash).""" i_s = _make_keypair() r_s = _make_keypair() initiator = HandshakeState( "Noise_XK", initiator=True, s=i_s, rs=r_s[1], ) responder = HandshakeState( "Noise_XK", initiator=False, s=r_s, ) msg1 = initiator.write_message() responder.read_message(msg1) msg2 = responder.write_message() initiator.read_message(msg2) msg3 = initiator.write_message() responder.read_message(msg3) # Both sides must have identical handshake hash assert initiator._ss.h == responder._ss.h # ---------- Rekey correctness ---------- class TestRekey: def test_rekey_changes_key(self) -> None: key = os.urandom(32) cs = CipherState(key) old_key = cs._key cs.rekey() assert cs._key != old_key assert len(cs._key) == 32 def test_rekey_preserves_nonce(self) -> None: cs = CipherState(os.urandom(32)) cs.set_nonce(42) cs.rekey() assert cs._n == 42 def test_rekey_is_deterministic(self) -> None: key = os.urandom(32) cs1 = CipherState(key) cs2 = CipherState(key) cs1.rekey() cs2.rekey() assert cs1._key == cs2._key