Linux kernel mirror (for testing)
git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel
os
linux
1// SPDX-License-Identifier: GPL-2.0-or-later
2/*
3 * KUnit tests and benchmark for ML-DSA
4 *
5 * Copyright 2025 Google LLC
6 */
7#include <crypto/mldsa.h>
8#include <kunit/test.h>
9#include <linux/random.h>
10#include <linux/unaligned.h>
11
12#define Q 8380417 /* The prime q = 2^23 - 2^13 + 1 */
13
14/* ML-DSA parameters that the tests use */
15static const struct {
16 int sig_len;
17 int pk_len;
18 int k;
19 int lambda;
20 int gamma1;
21 int beta;
22 int omega;
23} params[] = {
24 [MLDSA44] = {
25 .sig_len = MLDSA44_SIGNATURE_SIZE,
26 .pk_len = MLDSA44_PUBLIC_KEY_SIZE,
27 .k = 4,
28 .lambda = 128,
29 .gamma1 = 1 << 17,
30 .beta = 78,
31 .omega = 80,
32 },
33 [MLDSA65] = {
34 .sig_len = MLDSA65_SIGNATURE_SIZE,
35 .pk_len = MLDSA65_PUBLIC_KEY_SIZE,
36 .k = 6,
37 .lambda = 192,
38 .gamma1 = 1 << 19,
39 .beta = 196,
40 .omega = 55,
41 },
42 [MLDSA87] = {
43 .sig_len = MLDSA87_SIGNATURE_SIZE,
44 .pk_len = MLDSA87_PUBLIC_KEY_SIZE,
45 .k = 8,
46 .lambda = 256,
47 .gamma1 = 1 << 19,
48 .beta = 120,
49 .omega = 75,
50 },
51};
52
53#include "mldsa-testvecs.h"
54
55static void do_mldsa_and_assert_success(struct kunit *test,
56 const struct mldsa_testvector *tv)
57{
58 int err = mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg,
59 tv->msg_len, tv->pk, tv->pk_len);
60 KUNIT_ASSERT_EQ(test, err, 0);
61}
62
63static u8 *kunit_kmemdup_or_fail(struct kunit *test, const u8 *src, size_t len)
64{
65 u8 *dst = kunit_kmalloc(test, len, GFP_KERNEL);
66
67 KUNIT_ASSERT_NOT_NULL(test, dst);
68 return memcpy(dst, src, len);
69}
70
71/*
72 * Test that changing coefficients in a valid signature's z vector results in
73 * the following behavior from mldsa_verify():
74 *
75 * * -EBADMSG if a coefficient is changed to have an out-of-range value, i.e.
76 * absolute value >= gamma1 - beta, corresponding to the verifier detecting
77 * the out-of-range coefficient and rejecting the signature as malformed
78 *
79 * * -EKEYREJECTED if a coefficient is changed to a different in-range value,
80 * i.e. absolute value < gamma1 - beta, corresponding to the verifier
81 * continuing to the "real" signature check and that check failing
82 */
83static void test_mldsa_z_range(struct kunit *test,
84 const struct mldsa_testvector *tv)
85{
86 u8 *sig = kunit_kmemdup_or_fail(test, tv->sig, tv->sig_len);
87 const int lambda = params[tv->alg].lambda;
88 const s32 gamma1 = params[tv->alg].gamma1;
89 const int beta = params[tv->alg].beta;
90 /*
91 * We just modify the first coefficient. The coefficient is gamma1
92 * minus either the first 18 or 20 bits of the u32, depending on gamma1.
93 *
94 * The layout of ML-DSA signatures is ctilde || z || h. ctilde is
95 * lambda / 4 bytes, so z starts at &sig[lambda / 4].
96 */
97 u8 *z_ptr = &sig[lambda / 4];
98 const u32 z_data = get_unaligned_le32(z_ptr);
99 const u32 mask = (gamma1 << 1) - 1;
100 /* These are the four boundaries of the out-of-range values. */
101 const s32 out_of_range_coeffs[] = {
102 -gamma1 + 1,
103 -(gamma1 - beta),
104 gamma1,
105 gamma1 - beta,
106 };
107 /*
108 * These are the two boundaries of the valid range, along with 0. We
109 * assume that none of these matches the original coefficient.
110 */
111 const s32 in_range_coeffs[] = {
112 -(gamma1 - beta - 1),
113 0,
114 gamma1 - beta - 1,
115 };
116
117 /* Initially the signature is valid. */
118 do_mldsa_and_assert_success(test, tv);
119
120 /* Test some out-of-range coefficients. */
121 for (int i = 0; i < ARRAY_SIZE(out_of_range_coeffs); i++) {
122 const s32 c = out_of_range_coeffs[i];
123
124 put_unaligned_le32((z_data & ~mask) | (mask & (gamma1 - c)),
125 z_ptr);
126 KUNIT_ASSERT_EQ(test, -EBADMSG,
127 mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
128 tv->msg_len, tv->pk, tv->pk_len));
129 }
130
131 /* Test some in-range coefficients. */
132 for (int i = 0; i < ARRAY_SIZE(in_range_coeffs); i++) {
133 const s32 c = in_range_coeffs[i];
134
135 put_unaligned_le32((z_data & ~mask) | (mask & (gamma1 - c)),
136 z_ptr);
137 KUNIT_ASSERT_EQ(test, -EKEYREJECTED,
138 mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
139 tv->msg_len, tv->pk, tv->pk_len));
140 }
141}
142
143/* Test that mldsa_verify() rejects malformed hint vectors with -EBADMSG. */
144static void test_mldsa_bad_hints(struct kunit *test,
145 const struct mldsa_testvector *tv)
146{
147 const int omega = params[tv->alg].omega;
148 const int k = params[tv->alg].k;
149 u8 *sig = kunit_kmemdup_or_fail(test, tv->sig, tv->sig_len);
150 /* Pointer to the encoded hint vector in the signature */
151 u8 *hintvec = &sig[tv->sig_len - omega - k];
152 u8 h;
153
154 /* Initially the signature is valid. */
155 do_mldsa_and_assert_success(test, tv);
156
157 /* Cumulative hint count exceeds omega */
158 memcpy(sig, tv->sig, tv->sig_len);
159 hintvec[omega + k - 1] = omega + 1;
160 KUNIT_ASSERT_EQ(test, -EBADMSG,
161 mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
162 tv->msg_len, tv->pk, tv->pk_len));
163
164 /* Cumulative hint count decreases */
165 memcpy(sig, tv->sig, tv->sig_len);
166 KUNIT_ASSERT_GE(test, hintvec[omega + k - 2], 1);
167 hintvec[omega + k - 1] = hintvec[omega + k - 2] - 1;
168 KUNIT_ASSERT_EQ(test, -EBADMSG,
169 mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
170 tv->msg_len, tv->pk, tv->pk_len));
171
172 /*
173 * Hint indices out of order. To test this, swap hintvec[0] and
174 * hintvec[1]. This assumes that the original valid signature had at
175 * least two nonzero hints in the first element (asserted below).
176 */
177 memcpy(sig, tv->sig, tv->sig_len);
178 KUNIT_ASSERT_GE(test, hintvec[omega], 2);
179 h = hintvec[0];
180 hintvec[0] = hintvec[1];
181 hintvec[1] = h;
182 KUNIT_ASSERT_EQ(test, -EBADMSG,
183 mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
184 tv->msg_len, tv->pk, tv->pk_len));
185
186 /*
187 * Extra hint indices given. For this test to work, the original valid
188 * signature must have fewer than omega nonzero hints (asserted below).
189 */
190 memcpy(sig, tv->sig, tv->sig_len);
191 KUNIT_ASSERT_LT(test, hintvec[omega + k - 1], omega);
192 hintvec[omega - 1] = 0xff;
193 KUNIT_ASSERT_EQ(test, -EBADMSG,
194 mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
195 tv->msg_len, tv->pk, tv->pk_len));
196}
197
198static void test_mldsa_mutation(struct kunit *test,
199 const struct mldsa_testvector *tv)
200{
201 const int sig_len = tv->sig_len;
202 const int msg_len = tv->msg_len;
203 const int pk_len = tv->pk_len;
204 const int num_iter = 200;
205 u8 *sig = kunit_kmemdup_or_fail(test, tv->sig, sig_len);
206 u8 *msg = kunit_kmemdup_or_fail(test, tv->msg, msg_len);
207 u8 *pk = kunit_kmemdup_or_fail(test, tv->pk, pk_len);
208
209 /* Initially the signature is valid. */
210 do_mldsa_and_assert_success(test, tv);
211
212 /* Changing any bit in the signature should invalidate the signature */
213 for (int i = 0; i < num_iter; i++) {
214 size_t pos = get_random_u32_below(sig_len);
215 u8 b = 1 << get_random_u32_below(8);
216
217 sig[pos] ^= b;
218 KUNIT_ASSERT_NE(test, 0,
219 mldsa_verify(tv->alg, sig, sig_len, msg,
220 msg_len, pk, pk_len));
221 sig[pos] ^= b;
222 }
223
224 /* Changing any bit in the message should invalidate the signature */
225 for (int i = 0; i < num_iter; i++) {
226 size_t pos = get_random_u32_below(msg_len);
227 u8 b = 1 << get_random_u32_below(8);
228
229 msg[pos] ^= b;
230 KUNIT_ASSERT_NE(test, 0,
231 mldsa_verify(tv->alg, sig, sig_len, msg,
232 msg_len, pk, pk_len));
233 msg[pos] ^= b;
234 }
235
236 /* Changing any bit in the public key should invalidate the signature */
237 for (int i = 0; i < num_iter; i++) {
238 size_t pos = get_random_u32_below(pk_len);
239 u8 b = 1 << get_random_u32_below(8);
240
241 pk[pos] ^= b;
242 KUNIT_ASSERT_NE(test, 0,
243 mldsa_verify(tv->alg, sig, sig_len, msg,
244 msg_len, pk, pk_len));
245 pk[pos] ^= b;
246 }
247
248 /* All changes should have been undone. */
249 KUNIT_ASSERT_EQ(test, 0,
250 mldsa_verify(tv->alg, sig, sig_len, msg, msg_len, pk,
251 pk_len));
252}
253
254static void test_mldsa(struct kunit *test, const struct mldsa_testvector *tv)
255{
256 /* Valid signature */
257 KUNIT_ASSERT_EQ(test, tv->sig_len, params[tv->alg].sig_len);
258 KUNIT_ASSERT_EQ(test, tv->pk_len, params[tv->alg].pk_len);
259 do_mldsa_and_assert_success(test, tv);
260
261 /* Signature too short */
262 KUNIT_ASSERT_EQ(test, -EBADMSG,
263 mldsa_verify(tv->alg, tv->sig, tv->sig_len - 1, tv->msg,
264 tv->msg_len, tv->pk, tv->pk_len));
265
266 /* Signature too long */
267 KUNIT_ASSERT_EQ(test, -EBADMSG,
268 mldsa_verify(tv->alg, tv->sig, tv->sig_len + 1, tv->msg,
269 tv->msg_len, tv->pk, tv->pk_len));
270
271 /* Public key too short */
272 KUNIT_ASSERT_EQ(test, -EBADMSG,
273 mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg,
274 tv->msg_len, tv->pk, tv->pk_len - 1));
275
276 /* Public key too long */
277 KUNIT_ASSERT_EQ(test, -EBADMSG,
278 mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg,
279 tv->msg_len, tv->pk, tv->pk_len + 1));
280
281 /*
282 * Message too short. Error is EKEYREJECTED because it gets rejected by
283 * the "real" signature check rather than the well-formedness checks.
284 */
285 KUNIT_ASSERT_EQ(test, -EKEYREJECTED,
286 mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg,
287 tv->msg_len - 1, tv->pk, tv->pk_len));
288 /*
289 * Can't simply try (tv->msg, tv->msg_len + 1) too, as tv->msg would be
290 * accessed out of bounds. However, ML-DSA just hashes the message and
291 * doesn't handle different message lengths differently anyway.
292 */
293
294 /* Test the validity checks on the z vector. */
295 test_mldsa_z_range(test, tv);
296
297 /* Test the validity checks on the hint vector. */
298 test_mldsa_bad_hints(test, tv);
299
300 /* Test randomly mutating the inputs. */
301 test_mldsa_mutation(test, tv);
302}
303
304static void test_mldsa44(struct kunit *test)
305{
306 test_mldsa(test, &mldsa44_testvector);
307}
308
309static void test_mldsa65(struct kunit *test)
310{
311 test_mldsa(test, &mldsa65_testvector);
312}
313
314static void test_mldsa87(struct kunit *test)
315{
316 test_mldsa(test, &mldsa87_testvector);
317}
318
319static s32 mod(s32 a, s32 m)
320{
321 a %= m;
322 if (a < 0)
323 a += m;
324 return a;
325}
326
327static s32 symmetric_mod(s32 a, s32 m)
328{
329 a = mod(a, m);
330 if (a > m / 2)
331 a -= m;
332 return a;
333}
334
335/* Mechanical, inefficient translation of FIPS 204 Algorithm 36, Decompose */
336static void decompose_ref(s32 r, s32 gamma2, s32 *r0, s32 *r1)
337{
338 s32 rplus = mod(r, Q);
339
340 *r0 = symmetric_mod(rplus, 2 * gamma2);
341 if (rplus - *r0 == Q - 1) {
342 *r1 = 0;
343 *r0 = *r0 - 1;
344 } else {
345 *r1 = (rplus - *r0) / (2 * gamma2);
346 }
347}
348
349/* Mechanical, inefficient translation of FIPS 204 Algorithm 40, UseHint */
350static s32 use_hint_ref(u8 h, s32 r, s32 gamma2)
351{
352 s32 m = (Q - 1) / (2 * gamma2);
353 s32 r0, r1;
354
355 decompose_ref(r, gamma2, &r0, &r1);
356 if (h == 1 && r0 > 0)
357 return mod(r1 + 1, m);
358 if (h == 1 && r0 <= 0)
359 return mod(r1 - 1, m);
360 return r1;
361}
362
363/*
364 * Test that for all possible inputs, mldsa_use_hint() gives the same output as
365 * a mechanical translation of the pseudocode from FIPS 204.
366 */
367static void test_mldsa_use_hint(struct kunit *test)
368{
369 for (int i = 0; i < 2; i++) {
370 const s32 gamma2 = (Q - 1) / (i == 0 ? 88 : 32);
371
372 for (u8 h = 0; h < 2; h++) {
373 for (s32 r = 0; r < Q; r++) {
374 KUNIT_ASSERT_EQ(test,
375 mldsa_use_hint(h, r, gamma2),
376 use_hint_ref(h, r, gamma2));
377 }
378 }
379 }
380}
381
382static void benchmark_mldsa(struct kunit *test,
383 const struct mldsa_testvector *tv)
384{
385 const int warmup_niter = 200;
386 const int benchmark_niter = 200;
387 u64 t0, t1;
388
389 if (!IS_ENABLED(CONFIG_CRYPTO_LIB_BENCHMARK))
390 kunit_skip(test, "not enabled");
391
392 for (int i = 0; i < warmup_niter; i++)
393 do_mldsa_and_assert_success(test, tv);
394
395 t0 = ktime_get_ns();
396 for (int i = 0; i < benchmark_niter; i++)
397 do_mldsa_and_assert_success(test, tv);
398 t1 = ktime_get_ns();
399 kunit_info(test, "%llu ops/s",
400 div64_u64((u64)benchmark_niter * NSEC_PER_SEC,
401 t1 - t0 ?: 1));
402}
403
404static void benchmark_mldsa44(struct kunit *test)
405{
406 benchmark_mldsa(test, &mldsa44_testvector);
407}
408
409static void benchmark_mldsa65(struct kunit *test)
410{
411 benchmark_mldsa(test, &mldsa65_testvector);
412}
413
414static void benchmark_mldsa87(struct kunit *test)
415{
416 benchmark_mldsa(test, &mldsa87_testvector);
417}
418
419static struct kunit_case mldsa_kunit_cases[] = {
420 KUNIT_CASE(test_mldsa44),
421 KUNIT_CASE(test_mldsa65),
422 KUNIT_CASE(test_mldsa87),
423 KUNIT_CASE(test_mldsa_use_hint),
424 KUNIT_CASE(benchmark_mldsa44),
425 KUNIT_CASE(benchmark_mldsa65),
426 KUNIT_CASE(benchmark_mldsa87),
427 {},
428};
429
430static struct kunit_suite mldsa_kunit_suite = {
431 .name = "mldsa",
432 .test_cases = mldsa_kunit_cases,
433};
434kunit_test_suite(mldsa_kunit_suite);
435
436MODULE_DESCRIPTION("KUnit tests and benchmark for ML-DSA");
437MODULE_IMPORT_NS("EXPORTED_FOR_KUNIT_TESTING");
438MODULE_LICENSE("GPL");