A Python port of the Invisible Internet Project (I2P)
at main 468 lines 17 kB view raw
1"""Noise protocol framework for I2P. 2 3Implements CipherState, SymmetricState, and HandshakeState 4for the Noise_IK and Noise_XK patterns used by NTCP2 and SSU2. 5 6Uses: 7- ChaCha20-Poly1305 for AEAD (via cryptography library) 8- X25519 for DH 9- HKDF-SHA256 for key derivation 10""" 11 12import hashlib 13import os 14import struct 15 16from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 17 18 19from i2p_crypto.x25519 import X25519DH 20from i2p_crypto.hkdf import HKDF 21from i2p_crypto.mlkem import MLKEMVariant, MLKEMKeyPair 22 23_hkdf = HKDF() 24 25 26class CipherState: 27 """Wraps ChaCha20-Poly1305 for the Noise protocol.""" 28 29 MAX_NONCE = 2**64 - 1 30 31 def __init__(self, key: bytes | None = None): 32 self._key = key 33 self._n = 0 34 35 def has_key(self) -> bool: 36 return self._key is not None 37 38 def set_nonce(self, n: int): 39 self._n = n 40 41 def _nonce_bytes(self) -> bytes: 42 """4 zero bytes + 8-byte little-endian counter = 12-byte nonce.""" 43 return b"\x00\x00\x00\x00" + struct.pack("<Q", self._n) 44 45 def encrypt_with_ad(self, ad: bytes, plaintext: bytes) -> bytes: 46 if not self.has_key(): 47 return plaintext 48 if self._n >= self.MAX_NONCE: 49 raise RuntimeError("Nonce exhausted — rekey required") 50 assert self._key is not None 51 aead = ChaCha20Poly1305(self._key) 52 ct = aead.encrypt(self._nonce_bytes(), plaintext, ad) 53 self._n += 1 54 return ct 55 56 def decrypt_with_ad(self, ad: bytes, ciphertext: bytes) -> bytes: 57 if not self.has_key(): 58 return ciphertext 59 if self._n >= self.MAX_NONCE: 60 raise RuntimeError("Nonce exhausted — rekey required") 61 assert self._key is not None 62 aead = ChaCha20Poly1305(self._key) 63 pt = aead.decrypt(self._nonce_bytes(), ciphertext, ad) 64 self._n += 1 65 return pt 66 67 def rekey(self): 68 aead = ChaCha20Poly1305(self._key) 69 nonce = b"\xff" * 12 # max nonce 70 # Noise spec: REKEY(k) = ENCRYPT(k, maxnonce, "", zeros[32]) 71 # Output is 32 bytes ciphertext + 16 byte tag; take first 32 72 self._key = aead.encrypt(nonce, b"\x00" * 32, b"")[:32] 73 74 75class SymmetricState: 76 """Noise SymmetricState — manages chaining key and handshake hash.""" 77 78 def __init__(self, protocol_name: bytes): 79 if len(protocol_name) <= 32: 80 self.h = protocol_name + b"\x00" * (32 - len(protocol_name)) 81 else: 82 self.h = hashlib.sha256(protocol_name).digest() 83 self.ck = self.h 84 self._cipher = CipherState() 85 86 def mix_key(self, input_key_material: bytes): 87 output = _hkdf.extract_and_expand(self.ck, input_key_material, b"", 64) 88 self.ck = output[:32] 89 self._cipher = CipherState(output[32:64]) 90 91 def mix_hash(self, data: bytes): 92 self.h = hashlib.sha256(self.h + data).digest() 93 94 def encrypt_and_hash(self, plaintext: bytes) -> bytes: 95 ct = self._cipher.encrypt_with_ad(self.h, plaintext) 96 self.mix_hash(ct) 97 return ct 98 99 def decrypt_and_hash(self, ciphertext: bytes) -> bytes: 100 pt = self._cipher.decrypt_with_ad(self.h, ciphertext) 101 self.mix_hash(ciphertext) 102 return pt 103 104 def split(self) -> tuple[CipherState, CipherState]: 105 output = _hkdf.extract_and_expand(self.ck, b"", b"", 64) 106 c1 = CipherState(output[:32]) 107 c2 = CipherState(output[32:64]) 108 return c1, c2 109 110 111# Handshake pattern definitions 112# Each pattern is a list of message patterns. 113# Each message pattern is a list of tokens. 114# Tokens: 'e', 's', 'ee', 'es', 'se', 'ss' 115 116from typing import Any as _Any 117_PATTERNS: dict[str, dict[str, _Any]] = { 118 "Noise_IK": { 119 "pre_i": [], # initiator has no pre-message 120 "pre_r": ["s"], # responder's static is known 121 "messages": [ 122 ["e", "es", "s", "ss"], # -> e, es, s, ss 123 ["e", "ee", "se"], # <- e, ee, se 124 ], 125 }, 126 "Noise_XK": { 127 "pre_i": [], 128 "pre_r": ["s"], 129 "messages": [ 130 ["e", "es"], # -> e, es 131 ["e", "ee"], # <- e, ee 132 ["s", "se"], # -> s, se 133 ], 134 }, 135} 136 137 138class HandshakeState: 139 """Noise handshake state machine for IK and XK patterns.""" 140 141 def __init__(self, pattern: str, initiator: bool, 142 s: tuple | None = None, e: tuple | None = None, 143 rs: bytes | None = None, re: bytes | None = None, 144 prologue: bytes = b"", 145 protocol_name: bytes | None = None): 146 """ 147 Args: 148 pattern: "Noise_IK" or "Noise_XK" 149 initiator: True if we are the initiator 150 s: our static keypair (private, public) 151 e: our ephemeral keypair (private, public) — usually generated 152 rs: remote static public key (32 bytes) 153 re: remote ephemeral public key (32 bytes) 154 prologue: Prologue data mixed into hash before pre-messages. 155 Noise spec: MixHash(prologue) after SymmetricState init. 156 protocol_name: Override the full protocol name bytes. 157 If None, constructed from pattern name. 158 """ 159 if pattern not in _PATTERNS: 160 raise ValueError(f"Unknown pattern: {pattern}") 161 162 self._pattern_name = pattern 163 self._pattern = _PATTERNS[pattern] 164 self._initiator = initiator 165 self._s = s # (priv, pub) 166 self._e = e # (priv, pub) 167 self._rs = rs # remote static pub 168 self._re = re # remote ephemeral pub 169 self._msg_index = 0 170 self._complete = False 171 172 # Initialize SymmetricState 173 if protocol_name is not None: 174 proto_name = protocol_name 175 else: 176 proto_name = f"{pattern}_25519_ChaChaPoly_SHA256".encode() 177 self._ss = SymmetricState(proto_name) 178 179 # MixHash(prologue) — required by the Noise spec between 180 # SymmetricState init and pre-message processing. 181 # For I2P NTCP2 this is an empty prologue: h = SHA256(h || "") 182 self._ss.mix_hash(prologue) 183 184 # Process pre-messages 185 if self._initiator: 186 # pre_r: responder's pre-message keys get mixed into hash 187 for token in self._pattern["pre_r"]: 188 if token == "s" and rs is not None: 189 self._ss.mix_hash(rs) 190 else: 191 # pre_r: our own static (we are responder) 192 for token in self._pattern["pre_r"]: 193 if token == "s" and s is not None: 194 self._ss.mix_hash(s[1]) 195 196 def write_message(self, payload: bytes = b"") -> bytes: 197 """Process the next outgoing message pattern.""" 198 if self._complete: 199 raise RuntimeError("Handshake already complete") 200 201 messages = self._pattern["messages"] 202 if self._msg_index >= len(messages): 203 raise RuntimeError("No more handshake messages") 204 205 # Determine if this message index is ours to write 206 is_initiator_turn = (self._msg_index % 2 == 0) 207 if is_initiator_turn != self._initiator: 208 raise RuntimeError("Not our turn to write") 209 210 tokens = messages[self._msg_index] 211 buf = b"" 212 213 for token in tokens: 214 buf += self._process_write_token(token) 215 216 buf += self._ss.encrypt_and_hash(payload) 217 218 self._msg_index += 1 219 if self._msg_index >= len(messages): 220 self._complete = True 221 222 return buf 223 224 def read_message(self, message: bytes) -> bytes: 225 """Process the next incoming message pattern.""" 226 if self._complete: 227 raise RuntimeError("Handshake already complete") 228 229 messages = self._pattern["messages"] 230 if self._msg_index >= len(messages): 231 raise RuntimeError("No more handshake messages") 232 233 is_initiator_turn = (self._msg_index % 2 == 0) 234 if is_initiator_turn == self._initiator: 235 raise RuntimeError("Not our turn to read") 236 237 tokens = messages[self._msg_index] 238 offset = 0 239 240 for token in tokens: 241 consumed = self._process_read_token(token, message, offset) 242 offset += consumed 243 244 # Remaining is encrypted payload 245 payload = self._ss.decrypt_and_hash(message[offset:]) 246 247 self._msg_index += 1 248 if self._msg_index >= len(messages): 249 self._complete = True 250 251 return payload 252 253 def split(self) -> tuple[CipherState, CipherState]: 254 """Split after handshake completion.""" 255 if not self._complete: 256 raise RuntimeError("Handshake not complete") 257 return self._ss.split() 258 259 @property 260 def complete(self) -> bool: 261 return self._complete 262 263 @property 264 def remote_static(self) -> bytes | None: 265 return self._rs 266 267 def _process_write_token(self, token: str) -> bytes: 268 if token == "e": 269 if self._e is None: 270 self._e = X25519DH.generate_keypair() 271 self._ss.mix_hash(self._e[1]) 272 return self._e[1] 273 elif token == "s": 274 assert self._s is not None 275 return self._ss.encrypt_and_hash(self._s[1]) 276 elif token == "ee": 277 assert self._e is not None and self._re is not None 278 self._ss.mix_key(X25519DH.dh(self._e[0], self._re)) 279 return b"" 280 elif token == "es": 281 if self._initiator: 282 assert self._e is not None and self._rs is not None 283 self._ss.mix_key(X25519DH.dh(self._e[0], self._rs)) 284 else: 285 assert self._s is not None and self._re is not None 286 self._ss.mix_key(X25519DH.dh(self._s[0], self._re)) 287 return b"" 288 elif token == "se": 289 if self._initiator: 290 assert self._s is not None and self._re is not None 291 self._ss.mix_key(X25519DH.dh(self._s[0], self._re)) 292 else: 293 assert self._e is not None and self._rs is not None 294 self._ss.mix_key(X25519DH.dh(self._e[0], self._rs)) 295 return b"" 296 elif token == "ss": 297 assert self._s is not None and self._rs is not None 298 self._ss.mix_key(X25519DH.dh(self._s[0], self._rs)) 299 return b"" 300 else: 301 raise ValueError(f"Unknown token: {token}") 302 303 def _process_read_token(self, token: str, message: bytes, offset: int) -> int: 304 if token == "e": 305 self._re = message[offset:offset + 32] 306 self._ss.mix_hash(self._re) 307 return 32 308 elif token == "s": 309 # Encrypted static key: 32 bytes + 16 byte tag 310 if self._ss._cipher.has_key(): 311 enc_s = message[offset:offset + 48] 312 self._rs = self._ss.decrypt_and_hash(enc_s) 313 return 48 314 else: 315 self._rs = message[offset:offset + 32] 316 self._ss.mix_hash(self._rs) 317 return 32 318 elif token == "ee": 319 assert self._e is not None and self._re is not None 320 self._ss.mix_key(X25519DH.dh(self._e[0], self._re)) 321 return 0 322 elif token == "es": 323 if self._initiator: 324 assert self._e is not None and self._rs is not None 325 self._ss.mix_key(X25519DH.dh(self._e[0], self._rs)) 326 else: 327 assert self._s is not None and self._re is not None 328 self._ss.mix_key(X25519DH.dh(self._s[0], self._re)) 329 return 0 330 elif token == "se": 331 if self._initiator: 332 assert self._s is not None and self._re is not None 333 self._ss.mix_key(X25519DH.dh(self._s[0], self._re)) 334 else: 335 assert self._e is not None and self._rs is not None 336 self._ss.mix_key(X25519DH.dh(self._e[0], self._rs)) 337 return 0 338 elif token == "ss": 339 assert self._s is not None and self._rs is not None 340 self._ss.mix_key(X25519DH.dh(self._s[0], self._rs)) 341 return 0 342 else: 343 raise ValueError(f"Unknown token: {token}") 344 345 346class HybridDHState: 347 """Hybrid DH state combining X25519 + ML-KEM for post-quantum forward secrecy. 348 349 Ported from com.southernstorm.noise.protocol.MLKEMDHState. 350 351 In the Noise handshake with the ``hfs`` modifier: 352 - Initiator (Alice) generates X25519 + ML-KEM keypairs 353 - Responder (Bob) generates X25519 keypair + encapsulates with Alice's ML-KEM pubkey 354 - Both derive: SHA-256(x25519_ss || mlkem_ss) 355 356 The hfs modifier adds an extra DH-like exchange using KEM: 357 - Alice sends: x25519_pub || mlkem_pub 358 - Bob sends: x25519_pub || mlkem_ciphertext 359 - Both compute hybrid shared secret 360 """ 361 362 def __init__(self, variant: MLKEMVariant = MLKEMVariant.ML_KEM_768): 363 self._variant = variant 364 self._x25519_private: bytes | None = None 365 self._x25519_public: bytes | None = None 366 self._mlkem_keypair: MLKEMKeyPair | None = None 367 self._mlkem_public: bytes | None = None # remote ML-KEM public key (for Bob) 368 self._remote_x25519_public: bytes | None = None 369 self._has_keypair = False 370 371 @property 372 def public_key_len(self) -> int: 373 """Total public key length: 32 (X25519) + ML-KEM pubkey.""" 374 return 32 + self._variant.public_key_len 375 376 @property 377 def ciphertext_len(self) -> int: 378 """Total response length: 32 (X25519) + ML-KEM ciphertext.""" 379 return 32 + self._variant.ciphertext_len 380 381 def generate_keypair(self) -> None: 382 """Generate both X25519 and ML-KEM keypairs (Alice side).""" 383 from i2p_crypto import mlkem as mlkem_mod 384 385 self._x25519_private, self._x25519_public = X25519DH.generate_keypair() 386 self._mlkem_keypair = mlkem_mod.generate_keys(self._variant) 387 self._has_keypair = True 388 389 def get_public_key(self) -> bytes: 390 """Return x25519_pub || mlkem_pub.""" 391 if not self._has_keypair or self._mlkem_keypair is None or self._x25519_public is None: 392 raise RuntimeError("No keypair generated; call generate_keypair() first") 393 return self._x25519_public + self._mlkem_keypair.public_key 394 395 def set_remote_public_key(self, data: bytes) -> None: 396 """Parse remote x25519_pub || mlkem_pub (Bob receives Alice's public keys).""" 397 expected = 32 + self._variant.public_key_len 398 if len(data) != expected: 399 raise ValueError( 400 f"Remote public key must be {expected} bytes, got {len(data)}" 401 ) 402 self._remote_x25519_public = data[:32] 403 self._mlkem_public = data[32:] 404 405 def encapsulate(self) -> tuple[bytes, bytes]: 406 """Bob side: generate X25519 keypair, encapsulate with remote ML-KEM pubkey. 407 408 Returns: 409 (x25519_pub || mlkem_ciphertext, hybrid_shared_secret) 410 """ 411 from i2p_crypto import mlkem as mlkem_mod 412 413 if self._mlkem_public is None or self._remote_x25519_public is None: 414 raise RuntimeError( 415 "Remote public key not set; call set_remote_public_key() first" 416 ) 417 418 # Generate Bob's X25519 ephemeral keypair 419 bob_x25519_priv, bob_x25519_pub = X25519DH.generate_keypair() 420 421 # X25519 DH with Alice's X25519 public key 422 x25519_ss = X25519DH.dh(bob_x25519_priv, self._remote_x25519_public) 423 424 # ML-KEM encapsulation with Alice's ML-KEM public key 425 mlkem_ct, mlkem_ss = mlkem_mod.encapsulate(self._variant, self._mlkem_public) 426 427 # Hybrid shared secret 428 hybrid_ss = self._compute_hybrid_secret(x25519_ss, mlkem_ss) 429 430 # Response: Bob's X25519 pub || ML-KEM ciphertext 431 response = bob_x25519_pub + mlkem_ct 432 return response, hybrid_ss 433 434 def decapsulate(self, response: bytes) -> bytes: 435 """Alice side: extract X25519 pub and ML-KEM ciphertext from Bob's response. 436 437 Returns: 438 hybrid_shared_secret 439 """ 440 from i2p_crypto import mlkem as mlkem_mod 441 442 if not self._has_keypair or self._mlkem_keypair is None or self._x25519_private is None: 443 raise RuntimeError("No keypair generated; call generate_keypair() first") 444 445 expected = 32 + self._variant.ciphertext_len 446 if len(response) != expected: 447 raise ValueError( 448 f"Response must be {expected} bytes, got {len(response)}" 449 ) 450 451 # Parse Bob's response 452 bob_x25519_pub = response[:32] 453 mlkem_ct = response[32:] 454 455 # X25519 DH with Bob's X25519 public key 456 x25519_ss = X25519DH.dh(self._x25519_private, bob_x25519_pub) 457 458 # ML-KEM decapsulation 459 mlkem_ss = mlkem_mod.decapsulate( 460 self._variant, mlkem_ct, self._mlkem_keypair.private_key 461 ) 462 463 # Hybrid shared secret 464 return self._compute_hybrid_secret(x25519_ss, mlkem_ss) 465 466 def _compute_hybrid_secret(self, x25519_ss: bytes, mlkem_ss: bytes) -> bytes: 467 """SHA-256(x25519_ss || mlkem_ss).""" 468 return hashlib.sha256(x25519_ss + mlkem_ss).digest()