Linux kernel mirror (for testing) git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel os linux
at master 438 lines 12 kB view raw
1// SPDX-License-Identifier: GPL-2.0-or-later 2/* 3 * KUnit tests and benchmark for ML-DSA 4 * 5 * Copyright 2025 Google LLC 6 */ 7#include <crypto/mldsa.h> 8#include <kunit/test.h> 9#include <linux/random.h> 10#include <linux/unaligned.h> 11 12#define Q 8380417 /* The prime q = 2^23 - 2^13 + 1 */ 13 14/* ML-DSA parameters that the tests use */ 15static const struct { 16 int sig_len; 17 int pk_len; 18 int k; 19 int lambda; 20 int gamma1; 21 int beta; 22 int omega; 23} params[] = { 24 [MLDSA44] = { 25 .sig_len = MLDSA44_SIGNATURE_SIZE, 26 .pk_len = MLDSA44_PUBLIC_KEY_SIZE, 27 .k = 4, 28 .lambda = 128, 29 .gamma1 = 1 << 17, 30 .beta = 78, 31 .omega = 80, 32 }, 33 [MLDSA65] = { 34 .sig_len = MLDSA65_SIGNATURE_SIZE, 35 .pk_len = MLDSA65_PUBLIC_KEY_SIZE, 36 .k = 6, 37 .lambda = 192, 38 .gamma1 = 1 << 19, 39 .beta = 196, 40 .omega = 55, 41 }, 42 [MLDSA87] = { 43 .sig_len = MLDSA87_SIGNATURE_SIZE, 44 .pk_len = MLDSA87_PUBLIC_KEY_SIZE, 45 .k = 8, 46 .lambda = 256, 47 .gamma1 = 1 << 19, 48 .beta = 120, 49 .omega = 75, 50 }, 51}; 52 53#include "mldsa-testvecs.h" 54 55static void do_mldsa_and_assert_success(struct kunit *test, 56 const struct mldsa_testvector *tv) 57{ 58 int err = mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg, 59 tv->msg_len, tv->pk, tv->pk_len); 60 KUNIT_ASSERT_EQ(test, err, 0); 61} 62 63static u8 *kunit_kmemdup_or_fail(struct kunit *test, const u8 *src, size_t len) 64{ 65 u8 *dst = kunit_kmalloc(test, len, GFP_KERNEL); 66 67 KUNIT_ASSERT_NOT_NULL(test, dst); 68 return memcpy(dst, src, len); 69} 70 71/* 72 * Test that changing coefficients in a valid signature's z vector results in 73 * the following behavior from mldsa_verify(): 74 * 75 * * -EBADMSG if a coefficient is changed to have an out-of-range value, i.e. 76 * absolute value >= gamma1 - beta, corresponding to the verifier detecting 77 * the out-of-range coefficient and rejecting the signature as malformed 78 * 79 * * -EKEYREJECTED if a coefficient is changed to a different in-range value, 80 * i.e. absolute value < gamma1 - beta, corresponding to the verifier 81 * continuing to the "real" signature check and that check failing 82 */ 83static void test_mldsa_z_range(struct kunit *test, 84 const struct mldsa_testvector *tv) 85{ 86 u8 *sig = kunit_kmemdup_or_fail(test, tv->sig, tv->sig_len); 87 const int lambda = params[tv->alg].lambda; 88 const s32 gamma1 = params[tv->alg].gamma1; 89 const int beta = params[tv->alg].beta; 90 /* 91 * We just modify the first coefficient. The coefficient is gamma1 92 * minus either the first 18 or 20 bits of the u32, depending on gamma1. 93 * 94 * The layout of ML-DSA signatures is ctilde || z || h. ctilde is 95 * lambda / 4 bytes, so z starts at &sig[lambda / 4]. 96 */ 97 u8 *z_ptr = &sig[lambda / 4]; 98 const u32 z_data = get_unaligned_le32(z_ptr); 99 const u32 mask = (gamma1 << 1) - 1; 100 /* These are the four boundaries of the out-of-range values. */ 101 const s32 out_of_range_coeffs[] = { 102 -gamma1 + 1, 103 -(gamma1 - beta), 104 gamma1, 105 gamma1 - beta, 106 }; 107 /* 108 * These are the two boundaries of the valid range, along with 0. We 109 * assume that none of these matches the original coefficient. 110 */ 111 const s32 in_range_coeffs[] = { 112 -(gamma1 - beta - 1), 113 0, 114 gamma1 - beta - 1, 115 }; 116 117 /* Initially the signature is valid. */ 118 do_mldsa_and_assert_success(test, tv); 119 120 /* Test some out-of-range coefficients. */ 121 for (int i = 0; i < ARRAY_SIZE(out_of_range_coeffs); i++) { 122 const s32 c = out_of_range_coeffs[i]; 123 124 put_unaligned_le32((z_data & ~mask) | (mask & (gamma1 - c)), 125 z_ptr); 126 KUNIT_ASSERT_EQ(test, -EBADMSG, 127 mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg, 128 tv->msg_len, tv->pk, tv->pk_len)); 129 } 130 131 /* Test some in-range coefficients. */ 132 for (int i = 0; i < ARRAY_SIZE(in_range_coeffs); i++) { 133 const s32 c = in_range_coeffs[i]; 134 135 put_unaligned_le32((z_data & ~mask) | (mask & (gamma1 - c)), 136 z_ptr); 137 KUNIT_ASSERT_EQ(test, -EKEYREJECTED, 138 mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg, 139 tv->msg_len, tv->pk, tv->pk_len)); 140 } 141} 142 143/* Test that mldsa_verify() rejects malformed hint vectors with -EBADMSG. */ 144static void test_mldsa_bad_hints(struct kunit *test, 145 const struct mldsa_testvector *tv) 146{ 147 const int omega = params[tv->alg].omega; 148 const int k = params[tv->alg].k; 149 u8 *sig = kunit_kmemdup_or_fail(test, tv->sig, tv->sig_len); 150 /* Pointer to the encoded hint vector in the signature */ 151 u8 *hintvec = &sig[tv->sig_len - omega - k]; 152 u8 h; 153 154 /* Initially the signature is valid. */ 155 do_mldsa_and_assert_success(test, tv); 156 157 /* Cumulative hint count exceeds omega */ 158 memcpy(sig, tv->sig, tv->sig_len); 159 hintvec[omega + k - 1] = omega + 1; 160 KUNIT_ASSERT_EQ(test, -EBADMSG, 161 mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg, 162 tv->msg_len, tv->pk, tv->pk_len)); 163 164 /* Cumulative hint count decreases */ 165 memcpy(sig, tv->sig, tv->sig_len); 166 KUNIT_ASSERT_GE(test, hintvec[omega + k - 2], 1); 167 hintvec[omega + k - 1] = hintvec[omega + k - 2] - 1; 168 KUNIT_ASSERT_EQ(test, -EBADMSG, 169 mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg, 170 tv->msg_len, tv->pk, tv->pk_len)); 171 172 /* 173 * Hint indices out of order. To test this, swap hintvec[0] and 174 * hintvec[1]. This assumes that the original valid signature had at 175 * least two nonzero hints in the first element (asserted below). 176 */ 177 memcpy(sig, tv->sig, tv->sig_len); 178 KUNIT_ASSERT_GE(test, hintvec[omega], 2); 179 h = hintvec[0]; 180 hintvec[0] = hintvec[1]; 181 hintvec[1] = h; 182 KUNIT_ASSERT_EQ(test, -EBADMSG, 183 mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg, 184 tv->msg_len, tv->pk, tv->pk_len)); 185 186 /* 187 * Extra hint indices given. For this test to work, the original valid 188 * signature must have fewer than omega nonzero hints (asserted below). 189 */ 190 memcpy(sig, tv->sig, tv->sig_len); 191 KUNIT_ASSERT_LT(test, hintvec[omega + k - 1], omega); 192 hintvec[omega - 1] = 0xff; 193 KUNIT_ASSERT_EQ(test, -EBADMSG, 194 mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg, 195 tv->msg_len, tv->pk, tv->pk_len)); 196} 197 198static void test_mldsa_mutation(struct kunit *test, 199 const struct mldsa_testvector *tv) 200{ 201 const int sig_len = tv->sig_len; 202 const int msg_len = tv->msg_len; 203 const int pk_len = tv->pk_len; 204 const int num_iter = 200; 205 u8 *sig = kunit_kmemdup_or_fail(test, tv->sig, sig_len); 206 u8 *msg = kunit_kmemdup_or_fail(test, tv->msg, msg_len); 207 u8 *pk = kunit_kmemdup_or_fail(test, tv->pk, pk_len); 208 209 /* Initially the signature is valid. */ 210 do_mldsa_and_assert_success(test, tv); 211 212 /* Changing any bit in the signature should invalidate the signature */ 213 for (int i = 0; i < num_iter; i++) { 214 size_t pos = get_random_u32_below(sig_len); 215 u8 b = 1 << get_random_u32_below(8); 216 217 sig[pos] ^= b; 218 KUNIT_ASSERT_NE(test, 0, 219 mldsa_verify(tv->alg, sig, sig_len, msg, 220 msg_len, pk, pk_len)); 221 sig[pos] ^= b; 222 } 223 224 /* Changing any bit in the message should invalidate the signature */ 225 for (int i = 0; i < num_iter; i++) { 226 size_t pos = get_random_u32_below(msg_len); 227 u8 b = 1 << get_random_u32_below(8); 228 229 msg[pos] ^= b; 230 KUNIT_ASSERT_NE(test, 0, 231 mldsa_verify(tv->alg, sig, sig_len, msg, 232 msg_len, pk, pk_len)); 233 msg[pos] ^= b; 234 } 235 236 /* Changing any bit in the public key should invalidate the signature */ 237 for (int i = 0; i < num_iter; i++) { 238 size_t pos = get_random_u32_below(pk_len); 239 u8 b = 1 << get_random_u32_below(8); 240 241 pk[pos] ^= b; 242 KUNIT_ASSERT_NE(test, 0, 243 mldsa_verify(tv->alg, sig, sig_len, msg, 244 msg_len, pk, pk_len)); 245 pk[pos] ^= b; 246 } 247 248 /* All changes should have been undone. */ 249 KUNIT_ASSERT_EQ(test, 0, 250 mldsa_verify(tv->alg, sig, sig_len, msg, msg_len, pk, 251 pk_len)); 252} 253 254static void test_mldsa(struct kunit *test, const struct mldsa_testvector *tv) 255{ 256 /* Valid signature */ 257 KUNIT_ASSERT_EQ(test, tv->sig_len, params[tv->alg].sig_len); 258 KUNIT_ASSERT_EQ(test, tv->pk_len, params[tv->alg].pk_len); 259 do_mldsa_and_assert_success(test, tv); 260 261 /* Signature too short */ 262 KUNIT_ASSERT_EQ(test, -EBADMSG, 263 mldsa_verify(tv->alg, tv->sig, tv->sig_len - 1, tv->msg, 264 tv->msg_len, tv->pk, tv->pk_len)); 265 266 /* Signature too long */ 267 KUNIT_ASSERT_EQ(test, -EBADMSG, 268 mldsa_verify(tv->alg, tv->sig, tv->sig_len + 1, tv->msg, 269 tv->msg_len, tv->pk, tv->pk_len)); 270 271 /* Public key too short */ 272 KUNIT_ASSERT_EQ(test, -EBADMSG, 273 mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg, 274 tv->msg_len, tv->pk, tv->pk_len - 1)); 275 276 /* Public key too long */ 277 KUNIT_ASSERT_EQ(test, -EBADMSG, 278 mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg, 279 tv->msg_len, tv->pk, tv->pk_len + 1)); 280 281 /* 282 * Message too short. Error is EKEYREJECTED because it gets rejected by 283 * the "real" signature check rather than the well-formedness checks. 284 */ 285 KUNIT_ASSERT_EQ(test, -EKEYREJECTED, 286 mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg, 287 tv->msg_len - 1, tv->pk, tv->pk_len)); 288 /* 289 * Can't simply try (tv->msg, tv->msg_len + 1) too, as tv->msg would be 290 * accessed out of bounds. However, ML-DSA just hashes the message and 291 * doesn't handle different message lengths differently anyway. 292 */ 293 294 /* Test the validity checks on the z vector. */ 295 test_mldsa_z_range(test, tv); 296 297 /* Test the validity checks on the hint vector. */ 298 test_mldsa_bad_hints(test, tv); 299 300 /* Test randomly mutating the inputs. */ 301 test_mldsa_mutation(test, tv); 302} 303 304static void test_mldsa44(struct kunit *test) 305{ 306 test_mldsa(test, &mldsa44_testvector); 307} 308 309static void test_mldsa65(struct kunit *test) 310{ 311 test_mldsa(test, &mldsa65_testvector); 312} 313 314static void test_mldsa87(struct kunit *test) 315{ 316 test_mldsa(test, &mldsa87_testvector); 317} 318 319static s32 mod(s32 a, s32 m) 320{ 321 a %= m; 322 if (a < 0) 323 a += m; 324 return a; 325} 326 327static s32 symmetric_mod(s32 a, s32 m) 328{ 329 a = mod(a, m); 330 if (a > m / 2) 331 a -= m; 332 return a; 333} 334 335/* Mechanical, inefficient translation of FIPS 204 Algorithm 36, Decompose */ 336static void decompose_ref(s32 r, s32 gamma2, s32 *r0, s32 *r1) 337{ 338 s32 rplus = mod(r, Q); 339 340 *r0 = symmetric_mod(rplus, 2 * gamma2); 341 if (rplus - *r0 == Q - 1) { 342 *r1 = 0; 343 *r0 = *r0 - 1; 344 } else { 345 *r1 = (rplus - *r0) / (2 * gamma2); 346 } 347} 348 349/* Mechanical, inefficient translation of FIPS 204 Algorithm 40, UseHint */ 350static s32 use_hint_ref(u8 h, s32 r, s32 gamma2) 351{ 352 s32 m = (Q - 1) / (2 * gamma2); 353 s32 r0, r1; 354 355 decompose_ref(r, gamma2, &r0, &r1); 356 if (h == 1 && r0 > 0) 357 return mod(r1 + 1, m); 358 if (h == 1 && r0 <= 0) 359 return mod(r1 - 1, m); 360 return r1; 361} 362 363/* 364 * Test that for all possible inputs, mldsa_use_hint() gives the same output as 365 * a mechanical translation of the pseudocode from FIPS 204. 366 */ 367static void test_mldsa_use_hint(struct kunit *test) 368{ 369 for (int i = 0; i < 2; i++) { 370 const s32 gamma2 = (Q - 1) / (i == 0 ? 88 : 32); 371 372 for (u8 h = 0; h < 2; h++) { 373 for (s32 r = 0; r < Q; r++) { 374 KUNIT_ASSERT_EQ(test, 375 mldsa_use_hint(h, r, gamma2), 376 use_hint_ref(h, r, gamma2)); 377 } 378 } 379 } 380} 381 382static void benchmark_mldsa(struct kunit *test, 383 const struct mldsa_testvector *tv) 384{ 385 const int warmup_niter = 200; 386 const int benchmark_niter = 200; 387 u64 t0, t1; 388 389 if (!IS_ENABLED(CONFIG_CRYPTO_LIB_BENCHMARK)) 390 kunit_skip(test, "not enabled"); 391 392 for (int i = 0; i < warmup_niter; i++) 393 do_mldsa_and_assert_success(test, tv); 394 395 t0 = ktime_get_ns(); 396 for (int i = 0; i < benchmark_niter; i++) 397 do_mldsa_and_assert_success(test, tv); 398 t1 = ktime_get_ns(); 399 kunit_info(test, "%llu ops/s", 400 div64_u64((u64)benchmark_niter * NSEC_PER_SEC, 401 t1 - t0 ?: 1)); 402} 403 404static void benchmark_mldsa44(struct kunit *test) 405{ 406 benchmark_mldsa(test, &mldsa44_testvector); 407} 408 409static void benchmark_mldsa65(struct kunit *test) 410{ 411 benchmark_mldsa(test, &mldsa65_testvector); 412} 413 414static void benchmark_mldsa87(struct kunit *test) 415{ 416 benchmark_mldsa(test, &mldsa87_testvector); 417} 418 419static struct kunit_case mldsa_kunit_cases[] = { 420 KUNIT_CASE(test_mldsa44), 421 KUNIT_CASE(test_mldsa65), 422 KUNIT_CASE(test_mldsa87), 423 KUNIT_CASE(test_mldsa_use_hint), 424 KUNIT_CASE(benchmark_mldsa44), 425 KUNIT_CASE(benchmark_mldsa65), 426 KUNIT_CASE(benchmark_mldsa87), 427 {}, 428}; 429 430static struct kunit_suite mldsa_kunit_suite = { 431 .name = "mldsa", 432 .test_cases = mldsa_kunit_cases, 433}; 434kunit_test_suite(mldsa_kunit_suite); 435 436MODULE_DESCRIPTION("KUnit tests and benchmark for ML-DSA"); 437MODULE_IMPORT_NS("EXPORTED_FOR_KUNIT_TESTING"); 438MODULE_LICENSE("GPL");