"""Noise protocol framework for I2P. Implements CipherState, SymmetricState, and HandshakeState for the Noise_IK and Noise_XK patterns used by NTCP2 and SSU2. Uses: - ChaCha20-Poly1305 for AEAD (via cryptography library) - X25519 for DH - HKDF-SHA256 for key derivation """ import hashlib import os import struct from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 from i2p_crypto.x25519 import X25519DH from i2p_crypto.hkdf import HKDF from i2p_crypto.mlkem import MLKEMVariant, MLKEMKeyPair _hkdf = HKDF() class CipherState: """Wraps ChaCha20-Poly1305 for the Noise protocol.""" MAX_NONCE = 2**64 - 1 def __init__(self, key: bytes | None = None): self._key = key self._n = 0 def has_key(self) -> bool: return self._key is not None def set_nonce(self, n: int): self._n = n def _nonce_bytes(self) -> bytes: """4 zero bytes + 8-byte little-endian counter = 12-byte nonce.""" return b"\x00\x00\x00\x00" + struct.pack(" bytes: if not self.has_key(): return plaintext if self._n >= self.MAX_NONCE: raise RuntimeError("Nonce exhausted — rekey required") assert self._key is not None aead = ChaCha20Poly1305(self._key) ct = aead.encrypt(self._nonce_bytes(), plaintext, ad) self._n += 1 return ct def decrypt_with_ad(self, ad: bytes, ciphertext: bytes) -> bytes: if not self.has_key(): return ciphertext if self._n >= self.MAX_NONCE: raise RuntimeError("Nonce exhausted — rekey required") assert self._key is not None aead = ChaCha20Poly1305(self._key) pt = aead.decrypt(self._nonce_bytes(), ciphertext, ad) self._n += 1 return pt def rekey(self): aead = ChaCha20Poly1305(self._key) nonce = b"\xff" * 12 # max nonce # Noise spec: REKEY(k) = ENCRYPT(k, maxnonce, "", zeros[32]) # Output is 32 bytes ciphertext + 16 byte tag; take first 32 self._key = aead.encrypt(nonce, b"\x00" * 32, b"")[:32] class SymmetricState: """Noise SymmetricState — manages chaining key and handshake hash.""" def __init__(self, protocol_name: bytes): if len(protocol_name) <= 32: self.h = protocol_name + b"\x00" * (32 - len(protocol_name)) else: self.h = hashlib.sha256(protocol_name).digest() self.ck = self.h self._cipher = CipherState() def mix_key(self, input_key_material: bytes): output = _hkdf.extract_and_expand(self.ck, input_key_material, b"", 64) self.ck = output[:32] self._cipher = CipherState(output[32:64]) def mix_hash(self, data: bytes): self.h = hashlib.sha256(self.h + data).digest() def encrypt_and_hash(self, plaintext: bytes) -> bytes: ct = self._cipher.encrypt_with_ad(self.h, plaintext) self.mix_hash(ct) return ct def decrypt_and_hash(self, ciphertext: bytes) -> bytes: pt = self._cipher.decrypt_with_ad(self.h, ciphertext) self.mix_hash(ciphertext) return pt def split(self) -> tuple[CipherState, CipherState]: output = _hkdf.extract_and_expand(self.ck, b"", b"", 64) c1 = CipherState(output[:32]) c2 = CipherState(output[32:64]) return c1, c2 # Handshake pattern definitions # Each pattern is a list of message patterns. # Each message pattern is a list of tokens. # Tokens: 'e', 's', 'ee', 'es', 'se', 'ss' from typing import Any as _Any _PATTERNS: dict[str, dict[str, _Any]] = { "Noise_IK": { "pre_i": [], # initiator has no pre-message "pre_r": ["s"], # responder's static is known "messages": [ ["e", "es", "s", "ss"], # -> e, es, s, ss ["e", "ee", "se"], # <- e, ee, se ], }, "Noise_XK": { "pre_i": [], "pre_r": ["s"], "messages": [ ["e", "es"], # -> e, es ["e", "ee"], # <- e, ee ["s", "se"], # -> s, se ], }, } class HandshakeState: """Noise handshake state machine for IK and XK patterns.""" def __init__(self, pattern: str, initiator: bool, s: tuple | None = None, e: tuple | None = None, rs: bytes | None = None, re: bytes | None = None, prologue: bytes = b"", protocol_name: bytes | None = None): """ Args: pattern: "Noise_IK" or "Noise_XK" initiator: True if we are the initiator s: our static keypair (private, public) e: our ephemeral keypair (private, public) — usually generated rs: remote static public key (32 bytes) re: remote ephemeral public key (32 bytes) prologue: Prologue data mixed into hash before pre-messages. Noise spec: MixHash(prologue) after SymmetricState init. protocol_name: Override the full protocol name bytes. If None, constructed from pattern name. """ if pattern not in _PATTERNS: raise ValueError(f"Unknown pattern: {pattern}") self._pattern_name = pattern self._pattern = _PATTERNS[pattern] self._initiator = initiator self._s = s # (priv, pub) self._e = e # (priv, pub) self._rs = rs # remote static pub self._re = re # remote ephemeral pub self._msg_index = 0 self._complete = False # Initialize SymmetricState if protocol_name is not None: proto_name = protocol_name else: proto_name = f"{pattern}_25519_ChaChaPoly_SHA256".encode() self._ss = SymmetricState(proto_name) # MixHash(prologue) — required by the Noise spec between # SymmetricState init and pre-message processing. # For I2P NTCP2 this is an empty prologue: h = SHA256(h || "") self._ss.mix_hash(prologue) # Process pre-messages if self._initiator: # pre_r: responder's pre-message keys get mixed into hash for token in self._pattern["pre_r"]: if token == "s" and rs is not None: self._ss.mix_hash(rs) else: # pre_r: our own static (we are responder) for token in self._pattern["pre_r"]: if token == "s" and s is not None: self._ss.mix_hash(s[1]) def write_message(self, payload: bytes = b"") -> bytes: """Process the next outgoing message pattern.""" if self._complete: raise RuntimeError("Handshake already complete") messages = self._pattern["messages"] if self._msg_index >= len(messages): raise RuntimeError("No more handshake messages") # Determine if this message index is ours to write is_initiator_turn = (self._msg_index % 2 == 0) if is_initiator_turn != self._initiator: raise RuntimeError("Not our turn to write") tokens = messages[self._msg_index] buf = b"" for token in tokens: buf += self._process_write_token(token) buf += self._ss.encrypt_and_hash(payload) self._msg_index += 1 if self._msg_index >= len(messages): self._complete = True return buf def read_message(self, message: bytes) -> bytes: """Process the next incoming message pattern.""" if self._complete: raise RuntimeError("Handshake already complete") messages = self._pattern["messages"] if self._msg_index >= len(messages): raise RuntimeError("No more handshake messages") is_initiator_turn = (self._msg_index % 2 == 0) if is_initiator_turn == self._initiator: raise RuntimeError("Not our turn to read") tokens = messages[self._msg_index] offset = 0 for token in tokens: consumed = self._process_read_token(token, message, offset) offset += consumed # Remaining is encrypted payload payload = self._ss.decrypt_and_hash(message[offset:]) self._msg_index += 1 if self._msg_index >= len(messages): self._complete = True return payload def split(self) -> tuple[CipherState, CipherState]: """Split after handshake completion.""" if not self._complete: raise RuntimeError("Handshake not complete") return self._ss.split() @property def complete(self) -> bool: return self._complete @property def remote_static(self) -> bytes | None: return self._rs def _process_write_token(self, token: str) -> bytes: if token == "e": if self._e is None: self._e = X25519DH.generate_keypair() self._ss.mix_hash(self._e[1]) return self._e[1] elif token == "s": assert self._s is not None return self._ss.encrypt_and_hash(self._s[1]) elif token == "ee": assert self._e is not None and self._re is not None self._ss.mix_key(X25519DH.dh(self._e[0], self._re)) return b"" elif token == "es": if self._initiator: assert self._e is not None and self._rs is not None self._ss.mix_key(X25519DH.dh(self._e[0], self._rs)) else: assert self._s is not None and self._re is not None self._ss.mix_key(X25519DH.dh(self._s[0], self._re)) return b"" elif token == "se": if self._initiator: assert self._s is not None and self._re is not None self._ss.mix_key(X25519DH.dh(self._s[0], self._re)) else: assert self._e is not None and self._rs is not None self._ss.mix_key(X25519DH.dh(self._e[0], self._rs)) return b"" elif token == "ss": assert self._s is not None and self._rs is not None self._ss.mix_key(X25519DH.dh(self._s[0], self._rs)) return b"" else: raise ValueError(f"Unknown token: {token}") def _process_read_token(self, token: str, message: bytes, offset: int) -> int: if token == "e": self._re = message[offset:offset + 32] self._ss.mix_hash(self._re) return 32 elif token == "s": # Encrypted static key: 32 bytes + 16 byte tag if self._ss._cipher.has_key(): enc_s = message[offset:offset + 48] self._rs = self._ss.decrypt_and_hash(enc_s) return 48 else: self._rs = message[offset:offset + 32] self._ss.mix_hash(self._rs) return 32 elif token == "ee": assert self._e is not None and self._re is not None self._ss.mix_key(X25519DH.dh(self._e[0], self._re)) return 0 elif token == "es": if self._initiator: assert self._e is not None and self._rs is not None self._ss.mix_key(X25519DH.dh(self._e[0], self._rs)) else: assert self._s is not None and self._re is not None self._ss.mix_key(X25519DH.dh(self._s[0], self._re)) return 0 elif token == "se": if self._initiator: assert self._s is not None and self._re is not None self._ss.mix_key(X25519DH.dh(self._s[0], self._re)) else: assert self._e is not None and self._rs is not None self._ss.mix_key(X25519DH.dh(self._e[0], self._rs)) return 0 elif token == "ss": assert self._s is not None and self._rs is not None self._ss.mix_key(X25519DH.dh(self._s[0], self._rs)) return 0 else: raise ValueError(f"Unknown token: {token}") class HybridDHState: """Hybrid DH state combining X25519 + ML-KEM for post-quantum forward secrecy. Ported from com.southernstorm.noise.protocol.MLKEMDHState. In the Noise handshake with the ``hfs`` modifier: - Initiator (Alice) generates X25519 + ML-KEM keypairs - Responder (Bob) generates X25519 keypair + encapsulates with Alice's ML-KEM pubkey - Both derive: SHA-256(x25519_ss || mlkem_ss) The hfs modifier adds an extra DH-like exchange using KEM: - Alice sends: x25519_pub || mlkem_pub - Bob sends: x25519_pub || mlkem_ciphertext - Both compute hybrid shared secret """ def __init__(self, variant: MLKEMVariant = MLKEMVariant.ML_KEM_768): self._variant = variant self._x25519_private: bytes | None = None self._x25519_public: bytes | None = None self._mlkem_keypair: MLKEMKeyPair | None = None self._mlkem_public: bytes | None = None # remote ML-KEM public key (for Bob) self._remote_x25519_public: bytes | None = None self._has_keypair = False @property def public_key_len(self) -> int: """Total public key length: 32 (X25519) + ML-KEM pubkey.""" return 32 + self._variant.public_key_len @property def ciphertext_len(self) -> int: """Total response length: 32 (X25519) + ML-KEM ciphertext.""" return 32 + self._variant.ciphertext_len def generate_keypair(self) -> None: """Generate both X25519 and ML-KEM keypairs (Alice side).""" from i2p_crypto import mlkem as mlkem_mod self._x25519_private, self._x25519_public = X25519DH.generate_keypair() self._mlkem_keypair = mlkem_mod.generate_keys(self._variant) self._has_keypair = True def get_public_key(self) -> bytes: """Return x25519_pub || mlkem_pub.""" if not self._has_keypair or self._mlkem_keypair is None or self._x25519_public is None: raise RuntimeError("No keypair generated; call generate_keypair() first") return self._x25519_public + self._mlkem_keypair.public_key def set_remote_public_key(self, data: bytes) -> None: """Parse remote x25519_pub || mlkem_pub (Bob receives Alice's public keys).""" expected = 32 + self._variant.public_key_len if len(data) != expected: raise ValueError( f"Remote public key must be {expected} bytes, got {len(data)}" ) self._remote_x25519_public = data[:32] self._mlkem_public = data[32:] def encapsulate(self) -> tuple[bytes, bytes]: """Bob side: generate X25519 keypair, encapsulate with remote ML-KEM pubkey. Returns: (x25519_pub || mlkem_ciphertext, hybrid_shared_secret) """ from i2p_crypto import mlkem as mlkem_mod if self._mlkem_public is None or self._remote_x25519_public is None: raise RuntimeError( "Remote public key not set; call set_remote_public_key() first" ) # Generate Bob's X25519 ephemeral keypair bob_x25519_priv, bob_x25519_pub = X25519DH.generate_keypair() # X25519 DH with Alice's X25519 public key x25519_ss = X25519DH.dh(bob_x25519_priv, self._remote_x25519_public) # ML-KEM encapsulation with Alice's ML-KEM public key mlkem_ct, mlkem_ss = mlkem_mod.encapsulate(self._variant, self._mlkem_public) # Hybrid shared secret hybrid_ss = self._compute_hybrid_secret(x25519_ss, mlkem_ss) # Response: Bob's X25519 pub || ML-KEM ciphertext response = bob_x25519_pub + mlkem_ct return response, hybrid_ss def decapsulate(self, response: bytes) -> bytes: """Alice side: extract X25519 pub and ML-KEM ciphertext from Bob's response. Returns: hybrid_shared_secret """ from i2p_crypto import mlkem as mlkem_mod if not self._has_keypair or self._mlkem_keypair is None or self._x25519_private is None: raise RuntimeError("No keypair generated; call generate_keypair() first") expected = 32 + self._variant.ciphertext_len if len(response) != expected: raise ValueError( f"Response must be {expected} bytes, got {len(response)}" ) # Parse Bob's response bob_x25519_pub = response[:32] mlkem_ct = response[32:] # X25519 DH with Bob's X25519 public key x25519_ss = X25519DH.dh(self._x25519_private, bob_x25519_pub) # ML-KEM decapsulation mlkem_ss = mlkem_mod.decapsulate( self._variant, mlkem_ct, self._mlkem_keypair.private_key ) # Hybrid shared secret return self._compute_hybrid_secret(x25519_ss, mlkem_ss) def _compute_hybrid_secret(self, x25519_ss: bytes, mlkem_ss: bytes) -> bytes: """SHA-256(x25519_ss || mlkem_ss).""" return hashlib.sha256(x25519_ss + mlkem_ss).digest()