A Python port of the Invisible Internet Project (I2P)
1"""Noise protocol framework for I2P.
2
3Implements CipherState, SymmetricState, and HandshakeState
4for the Noise_IK and Noise_XK patterns used by NTCP2 and SSU2.
5
6Uses:
7- ChaCha20-Poly1305 for AEAD (via cryptography library)
8- X25519 for DH
9- HKDF-SHA256 for key derivation
10"""
11
12import hashlib
13import os
14import struct
15
16from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305
17
18
19from i2p_crypto.x25519 import X25519DH
20from i2p_crypto.hkdf import HKDF
21from i2p_crypto.mlkem import MLKEMVariant, MLKEMKeyPair
22
23_hkdf = HKDF()
24
25
26class CipherState:
27 """Wraps ChaCha20-Poly1305 for the Noise protocol."""
28
29 MAX_NONCE = 2**64 - 1
30
31 def __init__(self, key: bytes | None = None):
32 self._key = key
33 self._n = 0
34
35 def has_key(self) -> bool:
36 return self._key is not None
37
38 def set_nonce(self, n: int):
39 self._n = n
40
41 def _nonce_bytes(self) -> bytes:
42 """4 zero bytes + 8-byte little-endian counter = 12-byte nonce."""
43 return b"\x00\x00\x00\x00" + struct.pack("<Q", self._n)
44
45 def encrypt_with_ad(self, ad: bytes, plaintext: bytes) -> bytes:
46 if not self.has_key():
47 return plaintext
48 if self._n >= self.MAX_NONCE:
49 raise RuntimeError("Nonce exhausted — rekey required")
50 assert self._key is not None
51 aead = ChaCha20Poly1305(self._key)
52 ct = aead.encrypt(self._nonce_bytes(), plaintext, ad)
53 self._n += 1
54 return ct
55
56 def decrypt_with_ad(self, ad: bytes, ciphertext: bytes) -> bytes:
57 if not self.has_key():
58 return ciphertext
59 if self._n >= self.MAX_NONCE:
60 raise RuntimeError("Nonce exhausted — rekey required")
61 assert self._key is not None
62 aead = ChaCha20Poly1305(self._key)
63 pt = aead.decrypt(self._nonce_bytes(), ciphertext, ad)
64 self._n += 1
65 return pt
66
67 def rekey(self):
68 aead = ChaCha20Poly1305(self._key)
69 nonce = b"\xff" * 12 # max nonce
70 # Noise spec: REKEY(k) = ENCRYPT(k, maxnonce, "", zeros[32])
71 # Output is 32 bytes ciphertext + 16 byte tag; take first 32
72 self._key = aead.encrypt(nonce, b"\x00" * 32, b"")[:32]
73
74
75class SymmetricState:
76 """Noise SymmetricState — manages chaining key and handshake hash."""
77
78 def __init__(self, protocol_name: bytes):
79 if len(protocol_name) <= 32:
80 self.h = protocol_name + b"\x00" * (32 - len(protocol_name))
81 else:
82 self.h = hashlib.sha256(protocol_name).digest()
83 self.ck = self.h
84 self._cipher = CipherState()
85
86 def mix_key(self, input_key_material: bytes):
87 output = _hkdf.extract_and_expand(self.ck, input_key_material, b"", 64)
88 self.ck = output[:32]
89 self._cipher = CipherState(output[32:64])
90
91 def mix_hash(self, data: bytes):
92 self.h = hashlib.sha256(self.h + data).digest()
93
94 def encrypt_and_hash(self, plaintext: bytes) -> bytes:
95 ct = self._cipher.encrypt_with_ad(self.h, plaintext)
96 self.mix_hash(ct)
97 return ct
98
99 def decrypt_and_hash(self, ciphertext: bytes) -> bytes:
100 pt = self._cipher.decrypt_with_ad(self.h, ciphertext)
101 self.mix_hash(ciphertext)
102 return pt
103
104 def split(self) -> tuple[CipherState, CipherState]:
105 output = _hkdf.extract_and_expand(self.ck, b"", b"", 64)
106 c1 = CipherState(output[:32])
107 c2 = CipherState(output[32:64])
108 return c1, c2
109
110
111# Handshake pattern definitions
112# Each pattern is a list of message patterns.
113# Each message pattern is a list of tokens.
114# Tokens: 'e', 's', 'ee', 'es', 'se', 'ss'
115
116from typing import Any as _Any
117_PATTERNS: dict[str, dict[str, _Any]] = {
118 "Noise_IK": {
119 "pre_i": [], # initiator has no pre-message
120 "pre_r": ["s"], # responder's static is known
121 "messages": [
122 ["e", "es", "s", "ss"], # -> e, es, s, ss
123 ["e", "ee", "se"], # <- e, ee, se
124 ],
125 },
126 "Noise_XK": {
127 "pre_i": [],
128 "pre_r": ["s"],
129 "messages": [
130 ["e", "es"], # -> e, es
131 ["e", "ee"], # <- e, ee
132 ["s", "se"], # -> s, se
133 ],
134 },
135}
136
137
138class HandshakeState:
139 """Noise handshake state machine for IK and XK patterns."""
140
141 def __init__(self, pattern: str, initiator: bool,
142 s: tuple | None = None, e: tuple | None = None,
143 rs: bytes | None = None, re: bytes | None = None,
144 prologue: bytes = b"",
145 protocol_name: bytes | None = None):
146 """
147 Args:
148 pattern: "Noise_IK" or "Noise_XK"
149 initiator: True if we are the initiator
150 s: our static keypair (private, public)
151 e: our ephemeral keypair (private, public) — usually generated
152 rs: remote static public key (32 bytes)
153 re: remote ephemeral public key (32 bytes)
154 prologue: Prologue data mixed into hash before pre-messages.
155 Noise spec: MixHash(prologue) after SymmetricState init.
156 protocol_name: Override the full protocol name bytes.
157 If None, constructed from pattern name.
158 """
159 if pattern not in _PATTERNS:
160 raise ValueError(f"Unknown pattern: {pattern}")
161
162 self._pattern_name = pattern
163 self._pattern = _PATTERNS[pattern]
164 self._initiator = initiator
165 self._s = s # (priv, pub)
166 self._e = e # (priv, pub)
167 self._rs = rs # remote static pub
168 self._re = re # remote ephemeral pub
169 self._msg_index = 0
170 self._complete = False
171
172 # Initialize SymmetricState
173 if protocol_name is not None:
174 proto_name = protocol_name
175 else:
176 proto_name = f"{pattern}_25519_ChaChaPoly_SHA256".encode()
177 self._ss = SymmetricState(proto_name)
178
179 # MixHash(prologue) — required by the Noise spec between
180 # SymmetricState init and pre-message processing.
181 # For I2P NTCP2 this is an empty prologue: h = SHA256(h || "")
182 self._ss.mix_hash(prologue)
183
184 # Process pre-messages
185 if self._initiator:
186 # pre_r: responder's pre-message keys get mixed into hash
187 for token in self._pattern["pre_r"]:
188 if token == "s" and rs is not None:
189 self._ss.mix_hash(rs)
190 else:
191 # pre_r: our own static (we are responder)
192 for token in self._pattern["pre_r"]:
193 if token == "s" and s is not None:
194 self._ss.mix_hash(s[1])
195
196 def write_message(self, payload: bytes = b"") -> bytes:
197 """Process the next outgoing message pattern."""
198 if self._complete:
199 raise RuntimeError("Handshake already complete")
200
201 messages = self._pattern["messages"]
202 if self._msg_index >= len(messages):
203 raise RuntimeError("No more handshake messages")
204
205 # Determine if this message index is ours to write
206 is_initiator_turn = (self._msg_index % 2 == 0)
207 if is_initiator_turn != self._initiator:
208 raise RuntimeError("Not our turn to write")
209
210 tokens = messages[self._msg_index]
211 buf = b""
212
213 for token in tokens:
214 buf += self._process_write_token(token)
215
216 buf += self._ss.encrypt_and_hash(payload)
217
218 self._msg_index += 1
219 if self._msg_index >= len(messages):
220 self._complete = True
221
222 return buf
223
224 def read_message(self, message: bytes) -> bytes:
225 """Process the next incoming message pattern."""
226 if self._complete:
227 raise RuntimeError("Handshake already complete")
228
229 messages = self._pattern["messages"]
230 if self._msg_index >= len(messages):
231 raise RuntimeError("No more handshake messages")
232
233 is_initiator_turn = (self._msg_index % 2 == 0)
234 if is_initiator_turn == self._initiator:
235 raise RuntimeError("Not our turn to read")
236
237 tokens = messages[self._msg_index]
238 offset = 0
239
240 for token in tokens:
241 consumed = self._process_read_token(token, message, offset)
242 offset += consumed
243
244 # Remaining is encrypted payload
245 payload = self._ss.decrypt_and_hash(message[offset:])
246
247 self._msg_index += 1
248 if self._msg_index >= len(messages):
249 self._complete = True
250
251 return payload
252
253 def split(self) -> tuple[CipherState, CipherState]:
254 """Split after handshake completion."""
255 if not self._complete:
256 raise RuntimeError("Handshake not complete")
257 return self._ss.split()
258
259 @property
260 def complete(self) -> bool:
261 return self._complete
262
263 @property
264 def remote_static(self) -> bytes | None:
265 return self._rs
266
267 def _process_write_token(self, token: str) -> bytes:
268 if token == "e":
269 if self._e is None:
270 self._e = X25519DH.generate_keypair()
271 self._ss.mix_hash(self._e[1])
272 return self._e[1]
273 elif token == "s":
274 assert self._s is not None
275 return self._ss.encrypt_and_hash(self._s[1])
276 elif token == "ee":
277 assert self._e is not None and self._re is not None
278 self._ss.mix_key(X25519DH.dh(self._e[0], self._re))
279 return b""
280 elif token == "es":
281 if self._initiator:
282 assert self._e is not None and self._rs is not None
283 self._ss.mix_key(X25519DH.dh(self._e[0], self._rs))
284 else:
285 assert self._s is not None and self._re is not None
286 self._ss.mix_key(X25519DH.dh(self._s[0], self._re))
287 return b""
288 elif token == "se":
289 if self._initiator:
290 assert self._s is not None and self._re is not None
291 self._ss.mix_key(X25519DH.dh(self._s[0], self._re))
292 else:
293 assert self._e is not None and self._rs is not None
294 self._ss.mix_key(X25519DH.dh(self._e[0], self._rs))
295 return b""
296 elif token == "ss":
297 assert self._s is not None and self._rs is not None
298 self._ss.mix_key(X25519DH.dh(self._s[0], self._rs))
299 return b""
300 else:
301 raise ValueError(f"Unknown token: {token}")
302
303 def _process_read_token(self, token: str, message: bytes, offset: int) -> int:
304 if token == "e":
305 self._re = message[offset:offset + 32]
306 self._ss.mix_hash(self._re)
307 return 32
308 elif token == "s":
309 # Encrypted static key: 32 bytes + 16 byte tag
310 if self._ss._cipher.has_key():
311 enc_s = message[offset:offset + 48]
312 self._rs = self._ss.decrypt_and_hash(enc_s)
313 return 48
314 else:
315 self._rs = message[offset:offset + 32]
316 self._ss.mix_hash(self._rs)
317 return 32
318 elif token == "ee":
319 assert self._e is not None and self._re is not None
320 self._ss.mix_key(X25519DH.dh(self._e[0], self._re))
321 return 0
322 elif token == "es":
323 if self._initiator:
324 assert self._e is not None and self._rs is not None
325 self._ss.mix_key(X25519DH.dh(self._e[0], self._rs))
326 else:
327 assert self._s is not None and self._re is not None
328 self._ss.mix_key(X25519DH.dh(self._s[0], self._re))
329 return 0
330 elif token == "se":
331 if self._initiator:
332 assert self._s is not None and self._re is not None
333 self._ss.mix_key(X25519DH.dh(self._s[0], self._re))
334 else:
335 assert self._e is not None and self._rs is not None
336 self._ss.mix_key(X25519DH.dh(self._e[0], self._rs))
337 return 0
338 elif token == "ss":
339 assert self._s is not None and self._rs is not None
340 self._ss.mix_key(X25519DH.dh(self._s[0], self._rs))
341 return 0
342 else:
343 raise ValueError(f"Unknown token: {token}")
344
345
346class HybridDHState:
347 """Hybrid DH state combining X25519 + ML-KEM for post-quantum forward secrecy.
348
349 Ported from com.southernstorm.noise.protocol.MLKEMDHState.
350
351 In the Noise handshake with the ``hfs`` modifier:
352 - Initiator (Alice) generates X25519 + ML-KEM keypairs
353 - Responder (Bob) generates X25519 keypair + encapsulates with Alice's ML-KEM pubkey
354 - Both derive: SHA-256(x25519_ss || mlkem_ss)
355
356 The hfs modifier adds an extra DH-like exchange using KEM:
357 - Alice sends: x25519_pub || mlkem_pub
358 - Bob sends: x25519_pub || mlkem_ciphertext
359 - Both compute hybrid shared secret
360 """
361
362 def __init__(self, variant: MLKEMVariant = MLKEMVariant.ML_KEM_768):
363 self._variant = variant
364 self._x25519_private: bytes | None = None
365 self._x25519_public: bytes | None = None
366 self._mlkem_keypair: MLKEMKeyPair | None = None
367 self._mlkem_public: bytes | None = None # remote ML-KEM public key (for Bob)
368 self._remote_x25519_public: bytes | None = None
369 self._has_keypair = False
370
371 @property
372 def public_key_len(self) -> int:
373 """Total public key length: 32 (X25519) + ML-KEM pubkey."""
374 return 32 + self._variant.public_key_len
375
376 @property
377 def ciphertext_len(self) -> int:
378 """Total response length: 32 (X25519) + ML-KEM ciphertext."""
379 return 32 + self._variant.ciphertext_len
380
381 def generate_keypair(self) -> None:
382 """Generate both X25519 and ML-KEM keypairs (Alice side)."""
383 from i2p_crypto import mlkem as mlkem_mod
384
385 self._x25519_private, self._x25519_public = X25519DH.generate_keypair()
386 self._mlkem_keypair = mlkem_mod.generate_keys(self._variant)
387 self._has_keypair = True
388
389 def get_public_key(self) -> bytes:
390 """Return x25519_pub || mlkem_pub."""
391 if not self._has_keypair or self._mlkem_keypair is None or self._x25519_public is None:
392 raise RuntimeError("No keypair generated; call generate_keypair() first")
393 return self._x25519_public + self._mlkem_keypair.public_key
394
395 def set_remote_public_key(self, data: bytes) -> None:
396 """Parse remote x25519_pub || mlkem_pub (Bob receives Alice's public keys)."""
397 expected = 32 + self._variant.public_key_len
398 if len(data) != expected:
399 raise ValueError(
400 f"Remote public key must be {expected} bytes, got {len(data)}"
401 )
402 self._remote_x25519_public = data[:32]
403 self._mlkem_public = data[32:]
404
405 def encapsulate(self) -> tuple[bytes, bytes]:
406 """Bob side: generate X25519 keypair, encapsulate with remote ML-KEM pubkey.
407
408 Returns:
409 (x25519_pub || mlkem_ciphertext, hybrid_shared_secret)
410 """
411 from i2p_crypto import mlkem as mlkem_mod
412
413 if self._mlkem_public is None or self._remote_x25519_public is None:
414 raise RuntimeError(
415 "Remote public key not set; call set_remote_public_key() first"
416 )
417
418 # Generate Bob's X25519 ephemeral keypair
419 bob_x25519_priv, bob_x25519_pub = X25519DH.generate_keypair()
420
421 # X25519 DH with Alice's X25519 public key
422 x25519_ss = X25519DH.dh(bob_x25519_priv, self._remote_x25519_public)
423
424 # ML-KEM encapsulation with Alice's ML-KEM public key
425 mlkem_ct, mlkem_ss = mlkem_mod.encapsulate(self._variant, self._mlkem_public)
426
427 # Hybrid shared secret
428 hybrid_ss = self._compute_hybrid_secret(x25519_ss, mlkem_ss)
429
430 # Response: Bob's X25519 pub || ML-KEM ciphertext
431 response = bob_x25519_pub + mlkem_ct
432 return response, hybrid_ss
433
434 def decapsulate(self, response: bytes) -> bytes:
435 """Alice side: extract X25519 pub and ML-KEM ciphertext from Bob's response.
436
437 Returns:
438 hybrid_shared_secret
439 """
440 from i2p_crypto import mlkem as mlkem_mod
441
442 if not self._has_keypair or self._mlkem_keypair is None or self._x25519_private is None:
443 raise RuntimeError("No keypair generated; call generate_keypair() first")
444
445 expected = 32 + self._variant.ciphertext_len
446 if len(response) != expected:
447 raise ValueError(
448 f"Response must be {expected} bytes, got {len(response)}"
449 )
450
451 # Parse Bob's response
452 bob_x25519_pub = response[:32]
453 mlkem_ct = response[32:]
454
455 # X25519 DH with Bob's X25519 public key
456 x25519_ss = X25519DH.dh(self._x25519_private, bob_x25519_pub)
457
458 # ML-KEM decapsulation
459 mlkem_ss = mlkem_mod.decapsulate(
460 self._variant, mlkem_ct, self._mlkem_keypair.private_key
461 )
462
463 # Hybrid shared secret
464 return self._compute_hybrid_secret(x25519_ss, mlkem_ss)
465
466 def _compute_hybrid_secret(self, x25519_ss: bytes, mlkem_ss: bytes) -> bytes:
467 """SHA-256(x25519_ss || mlkem_ss)."""
468 return hashlib.sha256(x25519_ss + mlkem_ss).digest()