"""Tests for ML-KEM (FIPS 203) wrapper around pqcrypto. TDD: tests written before implementation. """ from __future__ import annotations import pytest from i2p_crypto.mlkem import ( MLKEMVariant, MLKEMKeyPair, generate_keys, encapsulate, decapsulate, is_available, ) # Skip all tests if pqcrypto is not installed pytestmark = pytest.mark.skipif( not is_available(), reason="pqcrypto not installed" ) # ---------- key generation ---------- class TestGenerateKeys: def test_generate_keys_512(self) -> None: kp = generate_keys(MLKEMVariant.ML_KEM_512) assert isinstance(kp, MLKEMKeyPair) assert kp.variant is MLKEMVariant.ML_KEM_512 assert len(kp.public_key) == MLKEMVariant.ML_KEM_512.public_key_len assert len(kp.private_key) == MLKEMVariant.ML_KEM_512.private_key_len def test_generate_keys_768(self) -> None: kp = generate_keys(MLKEMVariant.ML_KEM_768) assert isinstance(kp, MLKEMKeyPair) assert kp.variant is MLKEMVariant.ML_KEM_768 assert len(kp.public_key) == MLKEMVariant.ML_KEM_768.public_key_len assert len(kp.private_key) == MLKEMVariant.ML_KEM_768.private_key_len def test_generate_keys_1024(self) -> None: kp = generate_keys(MLKEMVariant.ML_KEM_1024) assert isinstance(kp, MLKEMKeyPair) assert kp.variant is MLKEMVariant.ML_KEM_1024 assert len(kp.public_key) == MLKEMVariant.ML_KEM_1024.public_key_len assert len(kp.private_key) == MLKEMVariant.ML_KEM_1024.private_key_len # ---------- encaps / decaps roundtrip ---------- class TestEncapsDecaps: @pytest.mark.parametrize("variant", list(MLKEMVariant)) def test_roundtrip(self, variant: MLKEMVariant) -> None: kp = generate_keys(variant) ciphertext, shared_secret_enc = encapsulate(variant, kp.public_key) shared_secret_dec = decapsulate(variant, ciphertext, kp.private_key) assert len(ciphertext) == variant.ciphertext_len assert len(shared_secret_enc) == variant.shared_secret_len assert len(shared_secret_dec) == variant.shared_secret_len assert shared_secret_enc == shared_secret_dec def test_encaps_decaps_roundtrip_512(self) -> None: kp = generate_keys(MLKEMVariant.ML_KEM_512) ct, ss_enc = encapsulate(MLKEMVariant.ML_KEM_512, kp.public_key) ss_dec = decapsulate(MLKEMVariant.ML_KEM_512, ct, kp.private_key) assert ss_enc == ss_dec def test_encaps_decaps_roundtrip_768(self) -> None: kp = generate_keys(MLKEMVariant.ML_KEM_768) ct, ss_enc = encapsulate(MLKEMVariant.ML_KEM_768, kp.public_key) ss_dec = decapsulate(MLKEMVariant.ML_KEM_768, ct, kp.private_key) assert ss_enc == ss_dec def test_encaps_decaps_roundtrip_1024(self) -> None: kp = generate_keys(MLKEMVariant.ML_KEM_1024) ct, ss_enc = encapsulate(MLKEMVariant.ML_KEM_1024, kp.public_key) ss_dec = decapsulate(MLKEMVariant.ML_KEM_1024, ct, kp.private_key) assert ss_enc == ss_dec # ---------- security properties ---------- class TestSecurityProperties: def test_different_keypairs_different_secrets(self) -> None: """Two encapsulations with different keys must produce different shared secrets.""" kp1 = generate_keys(MLKEMVariant.ML_KEM_768) kp2 = generate_keys(MLKEMVariant.ML_KEM_768) _, ss1 = encapsulate(MLKEMVariant.ML_KEM_768, kp1.public_key) _, ss2 = encapsulate(MLKEMVariant.ML_KEM_768, kp2.public_key) # With overwhelming probability these differ assert ss1 != ss2 def test_wrong_private_key_fails(self) -> None: """ML-KEM uses implicit rejection: wrong key yields a different (random) shared secret.""" kp1 = generate_keys(MLKEMVariant.ML_KEM_768) kp2 = generate_keys(MLKEMVariant.ML_KEM_768) ct, ss_enc = encapsulate(MLKEMVariant.ML_KEM_768, kp1.public_key) ss_wrong = decapsulate(MLKEMVariant.ML_KEM_768, ct, kp2.private_key) # Implicit rejection: decapsulate succeeds but returns a different secret assert len(ss_wrong) == MLKEMVariant.ML_KEM_768.shared_secret_len assert ss_wrong != ss_enc # ---------- variant properties ---------- class TestVariantProperties: def test_key_sizes_512(self) -> None: assert MLKEMVariant.ML_KEM_512.public_key_len == 800 assert MLKEMVariant.ML_KEM_512.private_key_len == 1632 assert MLKEMVariant.ML_KEM_512.ciphertext_len == 768 assert MLKEMVariant.ML_KEM_512.shared_secret_len == 32 def test_key_sizes_768(self) -> None: assert MLKEMVariant.ML_KEM_768.public_key_len == 1184 assert MLKEMVariant.ML_KEM_768.private_key_len == 2400 assert MLKEMVariant.ML_KEM_768.ciphertext_len == 1088 assert MLKEMVariant.ML_KEM_768.shared_secret_len == 32 def test_key_sizes_1024(self) -> None: assert MLKEMVariant.ML_KEM_1024.public_key_len == 1568 assert MLKEMVariant.ML_KEM_1024.private_key_len == 3168 assert MLKEMVariant.ML_KEM_1024.ciphertext_len == 1568 assert MLKEMVariant.ML_KEM_1024.shared_secret_len == 32 @pytest.mark.parametrize("variant", list(MLKEMVariant)) def test_generated_key_sizes_match_properties(self, variant: MLKEMVariant) -> None: """Verify that actual generated key sizes match the enum properties.""" kp = generate_keys(variant) assert len(kp.public_key) == variant.public_key_len assert len(kp.private_key) == variant.private_key_len def test_module_names(self) -> None: assert MLKEMVariant.ML_KEM_512.module_name == "ml_kem_512" assert MLKEMVariant.ML_KEM_768.module_name == "ml_kem_768" assert MLKEMVariant.ML_KEM_1024.module_name == "ml_kem_1024" # ---------- availability ---------- class TestAvailability: @pytest.mark.skipif(False, reason="always runs") def test_is_available(self) -> None: """is_available() returns True when pqcrypto is installed (this test only runs when it is).""" assert is_available() is True