"""Tests for SSU2 Noise_XK handshake state machines. TDD: tests written before implementation. """ import time import pytest from i2p_crypto.x25519 import X25519DH from i2p_transport.ssu2_handshake import ( HandshakeRole, HandshakePhase, SSU2Token, HandshakeKeys, TokenManager, OutboundHandshake, InboundHandshake, derive_header_keys, derive_data_keys, ) # --------------------------------------------------------------------------- # TokenManager tests # --------------------------------------------------------------------------- class TestTokenManager: def test_token_manager_create_validate(self): """Create a token for an IP and validate it.""" mgr = TokenManager(token_lifetime_seconds=3600) remote_ip = b"\x7f\x00\x00\x01" # 127.0.0.1 token = mgr.create_token(remote_ip) assert isinstance(token, SSU2Token) assert token.remote_ip == remote_ip assert token.token != 0 assert mgr.validate_token(token.token, remote_ip) is True def test_token_manager_expired(self): """Expired token fails validation.""" mgr = TokenManager(token_lifetime_seconds=0) # immediate expiry remote_ip = b"\x7f\x00\x00\x01" token = mgr.create_token(remote_ip) # Token expires immediately (lifetime=0), so by the time we check it's expired assert mgr.validate_token(token.token, remote_ip) is False def test_token_manager_wrong_ip(self): """Token for one IP is not valid for another.""" mgr = TokenManager(token_lifetime_seconds=3600) ip1 = b"\x7f\x00\x00\x01" ip2 = b"\x0a\x00\x00\x01" token = mgr.create_token(ip1) assert mgr.validate_token(token.token, ip1) is True assert mgr.validate_token(token.token, ip2) is False def test_token_manager_cleanup_expired(self): """cleanup_expired removes old tokens.""" mgr = TokenManager(token_lifetime_seconds=0) remote_ip = b"\x7f\x00\x00\x01" mgr.create_token(remote_ip) mgr.cleanup_expired() # After cleanup, token should be gone assert len(mgr._tokens) == 0 # --------------------------------------------------------------------------- # Outbound handshake initial state # --------------------------------------------------------------------------- class TestOutboundHandshakeState: def test_outbound_initial_phase(self): """Without a token, starts at TOKEN_REQUEST.""" alice_priv, alice_pub = X25519DH.generate_keypair() bob_priv, bob_pub = X25519DH.generate_keypair() intro_key = b"\x00" * 32 hs = OutboundHandshake( local_static_key=alice_priv, remote_static_key=bob_pub, remote_intro_key=intro_key, token=None, ) assert hs.phase == HandshakePhase.TOKEN_REQUEST def test_outbound_with_token_skips_token_request(self): """With a cached token, starts at SESSION_REQUEST.""" alice_priv, alice_pub = X25519DH.generate_keypair() bob_priv, bob_pub = X25519DH.generate_keypair() intro_key = b"\x00" * 32 hs = OutboundHandshake( local_static_key=alice_priv, remote_static_key=bob_pub, remote_intro_key=intro_key, token=0xDEADBEEF, ) assert hs.phase == HandshakePhase.SESSION_REQUEST # --------------------------------------------------------------------------- # HandshakeKeys structure # --------------------------------------------------------------------------- class TestHandshakeKeys: def test_handshake_keys_structure(self): """HandshakeKeys has correct field sizes.""" keys = HandshakeKeys( send_cipher_key=b"\x01" * 32, recv_cipher_key=b"\x02" * 32, send_header_key=b"\x03" * 32, recv_header_key=b"\x04" * 32, ) assert len(keys.send_cipher_key) == 32 assert len(keys.recv_cipher_key) == 32 assert len(keys.send_header_key) == 32 assert len(keys.recv_header_key) == 32 # --------------------------------------------------------------------------- # Key derivation # --------------------------------------------------------------------------- class TestKeyDerivation: def test_derive_header_keys(self): """HKDF produces two 32-byte keys.""" ck = b"\xab" * 32 k1, k2 = derive_header_keys(ck) assert len(k1) == 32 assert len(k2) == 32 assert k1 != k2 def test_derive_data_keys(self): """Produces HandshakeKeys with all fields populated.""" ck = b"\xcd" * 32 hh = b"\xef" * 32 keys = derive_data_keys(ck, hh) assert isinstance(keys, HandshakeKeys) assert len(keys.send_cipher_key) == 32 assert len(keys.recv_cipher_key) == 32 assert len(keys.send_header_key) == 32 assert len(keys.recv_header_key) == 32 def test_derive_header_keys_deterministic(self): """Same input produces same output.""" ck = b"\x11" * 32 k1a, k2a = derive_header_keys(ck) k1b, k2b = derive_header_keys(ck) assert k1a == k1b assert k2a == k2b # --------------------------------------------------------------------------- # get_keys before ESTABLISHED # --------------------------------------------------------------------------- class TestGetKeysBeforeEstablished: def test_get_keys_before_established_raises(self): """get_keys raises ValueError if not ESTABLISHED.""" alice_priv, alice_pub = X25519DH.generate_keypair() bob_priv, bob_pub = X25519DH.generate_keypair() intro_key = b"\x00" * 32 hs = OutboundHandshake( local_static_key=alice_priv, remote_static_key=bob_pub, remote_intro_key=intro_key, token=0x1234, ) with pytest.raises(ValueError, match="not.*ESTABLISHED"): hs.get_keys() # --------------------------------------------------------------------------- # Full handshake roundtrip tests # --------------------------------------------------------------------------- class TestFullHandshake: def _make_keypairs(self): """Generate Alice and Bob keypairs.""" alice_priv, alice_pub = X25519DH.generate_keypair() bob_priv, bob_pub = X25519DH.generate_keypair() return alice_priv, alice_pub, bob_priv, bob_pub def test_full_handshake_roundtrip(self): """Alice and Bob complete full handshake, both derive same data keys.""" alice_priv, alice_pub, bob_priv, bob_pub = self._make_keypairs() bob_intro_key = b"\x42" * 32 token_mgr = TokenManager() # Alice starts with a token (skip token request for simplicity) alice = OutboundHandshake( local_static_key=alice_priv, remote_static_key=bob_pub, remote_intro_key=bob_intro_key, token=0xCAFEBABE, ) bob = InboundHandshake( local_static_key=bob_priv, local_intro_key=bob_intro_key, token_manager=token_mgr, ) # Pre-register the token so Bob accepts it token_mgr._tokens[b"\x7f\x00\x00\x01"] = SSU2Token( token=0xCAFEBABE, remote_ip=b"\x7f\x00\x00\x01", expires=time.time() + 3600, ) # 1. Alice builds Session Request session_request = alice.build_session_request() assert len(session_request) > 0 # 2. Bob processes Session Request bob.process_session_request(session_request, remote_ip=b"\x7f\x00\x00\x01") # 3. Bob builds Session Created session_created = bob.build_session_created() assert len(session_created) > 0 # 4. Alice processes Session Created alice.process_session_created(session_created) # 5. Alice builds Session Confirmed session_confirmed = alice.build_session_confirmed() assert len(session_confirmed) > 0 # 6. Bob processes Session Confirmed bob.process_session_confirmed(session_confirmed) # Both should be ESTABLISHED assert alice.phase == HandshakePhase.ESTABLISHED assert bob.phase == HandshakePhase.ESTABLISHED # Both derive the same data-phase keys (swapped for send/recv) alice_keys = alice.get_keys() bob_keys = bob.get_keys() assert alice_keys.send_cipher_key == bob_keys.recv_cipher_key assert alice_keys.recv_cipher_key == bob_keys.send_cipher_key assert alice_keys.send_header_key == bob_keys.recv_header_key assert alice_keys.recv_header_key == bob_keys.send_header_key def test_handshake_with_token_request(self): """Full flow including token request/retry.""" alice_priv, alice_pub, bob_priv, bob_pub = self._make_keypairs() bob_intro_key = b"\x42" * 32 token_mgr = TokenManager() remote_ip = b"\x7f\x00\x00\x01" # Alice has no token alice = OutboundHandshake( local_static_key=alice_priv, remote_static_key=bob_pub, remote_intro_key=bob_intro_key, token=None, ) bob = InboundHandshake( local_static_key=bob_priv, local_intro_key=bob_intro_key, token_manager=token_mgr, ) assert alice.phase == HandshakePhase.TOKEN_REQUEST # 1. Alice builds Token Request token_request = alice.build_token_request() assert len(token_request) > 0 # 2. Bob processes Token Request and returns Retry retry = bob.process_token_request(token_request, remote_ip) assert len(retry) > 0 # 3. Alice processes Retry (gets token) alice.process_retry(retry) assert alice.phase == HandshakePhase.SESSION_REQUEST # 4. Continue with normal handshake session_request = alice.build_session_request() bob.process_session_request(session_request, remote_ip) session_created = bob.build_session_created() alice.process_session_created(session_created) session_confirmed = alice.build_session_confirmed() bob.process_session_confirmed(session_confirmed) assert alice.phase == HandshakePhase.ESTABLISHED assert bob.phase == HandshakePhase.ESTABLISHED alice_keys = alice.get_keys() bob_keys = bob.get_keys() assert alice_keys.send_cipher_key == bob_keys.recv_cipher_key assert alice_keys.recv_cipher_key == bob_keys.send_cipher_key def test_handshake_without_token(self): """Full flow with pre-cached token (skips token request).""" alice_priv, alice_pub, bob_priv, bob_pub = self._make_keypairs() bob_intro_key = b"\x42" * 32 token_mgr = TokenManager() remote_ip = b"\x0a\x00\x00\x01" # Pre-register token cached_token = 0xFEEDFACE token_mgr._tokens[remote_ip] = SSU2Token( token=cached_token, remote_ip=remote_ip, expires=time.time() + 3600, ) alice = OutboundHandshake( local_static_key=alice_priv, remote_static_key=bob_pub, remote_intro_key=bob_intro_key, token=cached_token, ) bob = InboundHandshake( local_static_key=bob_priv, local_intro_key=bob_intro_key, token_manager=token_mgr, ) assert alice.phase == HandshakePhase.SESSION_REQUEST session_request = alice.build_session_request() bob.process_session_request(session_request, remote_ip) session_created = bob.build_session_created() alice.process_session_created(session_created) session_confirmed = alice.build_session_confirmed() bob.process_session_confirmed(session_confirmed) assert alice.phase == HandshakePhase.ESTABLISHED assert bob.phase == HandshakePhase.ESTABLISHED alice_keys = alice.get_keys() bob_keys = bob.get_keys() assert alice_keys.send_cipher_key == bob_keys.recv_cipher_key def test_session_request_contains_ephemeral(self): """Session Request contains a 32-byte ephemeral key.""" alice_priv, _ = X25519DH.generate_keypair() _, bob_pub = X25519DH.generate_keypair() intro_key = b"\x00" * 32 alice = OutboundHandshake( local_static_key=alice_priv, remote_static_key=bob_pub, remote_intro_key=intro_key, token=0x1234, ) pkt = alice.build_session_request() # The session request long header is 32 bytes, followed by the # 32-byte ephemeral key (start of Noise message 1) assert len(pkt) >= 64 # at least header + ephemeral key # After building, the ephemeral public should be set assert alice._ephemeral_public is not None assert len(alice._ephemeral_public) == 32 def test_session_created_contains_ephemeral(self): """Session Created contains Bob's ephemeral key.""" alice_priv, alice_pub, bob_priv, bob_pub = ( X25519DH.generate_keypair() + X25519DH.generate_keypair() ) bob_intro_key = b"\x42" * 32 token_mgr = TokenManager() remote_ip = b"\x7f\x00\x00\x01" token_mgr._tokens[remote_ip] = SSU2Token( token=0xBEEF, remote_ip=remote_ip, expires=time.time() + 3600 ) alice = OutboundHandshake( local_static_key=alice_priv, remote_static_key=bob_pub, remote_intro_key=bob_intro_key, token=0xBEEF, ) bob = InboundHandshake( local_static_key=bob_priv, local_intro_key=bob_intro_key, token_manager=token_mgr, ) session_request = alice.build_session_request() bob.process_session_request(session_request, remote_ip) session_created = bob.build_session_created() # Bob should have generated an ephemeral key assert bob._ephemeral_public is not None assert len(bob._ephemeral_public) == 32 # Session created packet should be non-trivial assert len(session_created) >= 64