"""ML-KEM (Module-Lattice Key Encapsulation Mechanism) -- FIPS 203. Wraps pqcrypto for ML-KEM-512/768/1024. Used in hybrid X25519+ML-KEM Noise handshakes for post-quantum forward secrecy. """ from __future__ import annotations import enum from dataclasses import dataclass from typing import Any, Tuple # Lazy-loaded module references per variant _kem_modules: dict[str, Any] = {} def _get_kem_module(module_name: str) -> Any: """Lazy-import a pqcrypto.kem module and cache it.""" if module_name not in _kem_modules: try: import importlib _kem_modules[module_name] = importlib.import_module(f"pqcrypto.kem.{module_name}") except ImportError: raise ImportError( "pqcrypto is required for ML-KEM support. " "Install it with: pip install 'i2p-python[pqc]'" ) return _kem_modules[module_name] def is_available() -> bool: """Return True if pqcrypto is installed and importable.""" try: _get_kem_module("ml_kem_768") return True except ImportError: return False class MLKEMVariant(enum.Enum): """ML-KEM security levels.""" # pub_len priv_len ct_len ss_len module_name ML_KEM_512 = (800, 1632, 768, 32, "ml_kem_512") ML_KEM_768 = (1184, 2400, 1088, 32, "ml_kem_768") ML_KEM_1024 = (1568, 3168, 1568, 32, "ml_kem_1024") def __init__( self, public_key_len: int, private_key_len: int, ciphertext_len: int, shared_secret_len: int, module_name: str, ) -> None: self._public_key_len = public_key_len self._private_key_len = private_key_len self._ciphertext_len = ciphertext_len self._shared_secret_len = shared_secret_len self._module_name = module_name @property def public_key_len(self) -> int: return self._public_key_len @property def private_key_len(self) -> int: return self._private_key_len @property def ciphertext_len(self) -> int: return self._ciphertext_len @property def shared_secret_len(self) -> int: return self._shared_secret_len @property def module_name(self) -> str: return self._module_name @dataclass(frozen=True) class MLKEMKeyPair: """A ML-KEM keypair.""" public_key: bytes private_key: bytes variant: MLKEMVariant def generate_keys(variant: MLKEMVariant) -> MLKEMKeyPair: """Generate a new ML-KEM keypair for the given variant.""" mod = _get_kem_module(variant.module_name) public_key, private_key = mod.generate_keypair() return MLKEMKeyPair( public_key=bytes(public_key), private_key=bytes(private_key), variant=variant, ) def encapsulate( variant: MLKEMVariant, public_key: bytes ) -> Tuple[bytes, bytes]: """Encapsulate against a public key, returning (ciphertext, shared_secret).""" mod = _get_kem_module(variant.module_name) ciphertext, shared_secret = mod.encrypt(public_key) return bytes(ciphertext), bytes(shared_secret) def decapsulate( variant: MLKEMVariant, ciphertext: bytes, private_key: bytes ) -> bytes: """Decapsulate a ciphertext with a private key, returning the shared secret.""" mod = _get_kem_module(variant.module_name) shared_secret = mod.decrypt(private_key, ciphertext) return bytes(shared_secret)