Linux kernel mirror (for testing)
git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel
os
linux
1/* SPDX-License-Identifier: GPL-2.0-or-later */
2/*
3 * SM4 Cipher Algorithm, using ARMv8 NEON
4 * as specified in
5 * https://tools.ietf.org/id/draft-ribose-cfrg-sm4-10.html
6 *
7 * Copyright (C) 2022, Alibaba Group.
8 * Copyright (C) 2022 Tianjia Zhang <tianjia.zhang@linux.alibaba.com>
9 */
10
11#include <linux/module.h>
12#include <linux/crypto.h>
13#include <linux/kernel.h>
14#include <linux/cpufeature.h>
15#include <asm/neon.h>
16#include <asm/simd.h>
17#include <crypto/internal/simd.h>
18#include <crypto/internal/skcipher.h>
19#include <crypto/sm4.h>
20
21#define BYTES2BLKS(nbytes) ((nbytes) >> 4)
22#define BYTES2BLK8(nbytes) (((nbytes) >> 4) & ~(8 - 1))
23
24asmlinkage void sm4_neon_crypt_blk1_8(const u32 *rkey, u8 *dst, const u8 *src,
25 unsigned int nblks);
26asmlinkage void sm4_neon_crypt_blk8(const u32 *rkey, u8 *dst, const u8 *src,
27 unsigned int nblks);
28asmlinkage void sm4_neon_cbc_dec_blk8(const u32 *rkey, u8 *dst, const u8 *src,
29 u8 *iv, unsigned int nblks);
30asmlinkage void sm4_neon_cfb_dec_blk8(const u32 *rkey, u8 *dst, const u8 *src,
31 u8 *iv, unsigned int nblks);
32asmlinkage void sm4_neon_ctr_enc_blk8(const u32 *rkey, u8 *dst, const u8 *src,
33 u8 *iv, unsigned int nblks);
34
35static int sm4_setkey(struct crypto_skcipher *tfm, const u8 *key,
36 unsigned int key_len)
37{
38 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
39
40 return sm4_expandkey(ctx, key, key_len);
41}
42
43static int sm4_ecb_do_crypt(struct skcipher_request *req, const u32 *rkey)
44{
45 struct skcipher_walk walk;
46 unsigned int nbytes;
47 int err;
48
49 err = skcipher_walk_virt(&walk, req, false);
50
51 while ((nbytes = walk.nbytes) > 0) {
52 const u8 *src = walk.src.virt.addr;
53 u8 *dst = walk.dst.virt.addr;
54 unsigned int nblks;
55
56 kernel_neon_begin();
57
58 nblks = BYTES2BLK8(nbytes);
59 if (nblks) {
60 sm4_neon_crypt_blk8(rkey, dst, src, nblks);
61 dst += nblks * SM4_BLOCK_SIZE;
62 src += nblks * SM4_BLOCK_SIZE;
63 nbytes -= nblks * SM4_BLOCK_SIZE;
64 }
65
66 nblks = BYTES2BLKS(nbytes);
67 if (nblks) {
68 sm4_neon_crypt_blk1_8(rkey, dst, src, nblks);
69 nbytes -= nblks * SM4_BLOCK_SIZE;
70 }
71
72 kernel_neon_end();
73
74 err = skcipher_walk_done(&walk, nbytes);
75 }
76
77 return err;
78}
79
80static int sm4_ecb_encrypt(struct skcipher_request *req)
81{
82 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
83 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
84
85 return sm4_ecb_do_crypt(req, ctx->rkey_enc);
86}
87
88static int sm4_ecb_decrypt(struct skcipher_request *req)
89{
90 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
91 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
92
93 return sm4_ecb_do_crypt(req, ctx->rkey_dec);
94}
95
96static int sm4_cbc_encrypt(struct skcipher_request *req)
97{
98 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
99 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
100 struct skcipher_walk walk;
101 unsigned int nbytes;
102 int err;
103
104 err = skcipher_walk_virt(&walk, req, false);
105
106 while ((nbytes = walk.nbytes) > 0) {
107 const u8 *iv = walk.iv;
108 const u8 *src = walk.src.virt.addr;
109 u8 *dst = walk.dst.virt.addr;
110
111 while (nbytes >= SM4_BLOCK_SIZE) {
112 crypto_xor_cpy(dst, src, iv, SM4_BLOCK_SIZE);
113 sm4_crypt_block(ctx->rkey_enc, dst, dst);
114 iv = dst;
115 src += SM4_BLOCK_SIZE;
116 dst += SM4_BLOCK_SIZE;
117 nbytes -= SM4_BLOCK_SIZE;
118 }
119 if (iv != walk.iv)
120 memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
121
122 err = skcipher_walk_done(&walk, nbytes);
123 }
124
125 return err;
126}
127
128static int sm4_cbc_decrypt(struct skcipher_request *req)
129{
130 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
131 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
132 struct skcipher_walk walk;
133 unsigned int nbytes;
134 int err;
135
136 err = skcipher_walk_virt(&walk, req, false);
137
138 while ((nbytes = walk.nbytes) > 0) {
139 const u8 *src = walk.src.virt.addr;
140 u8 *dst = walk.dst.virt.addr;
141 unsigned int nblks;
142
143 kernel_neon_begin();
144
145 nblks = BYTES2BLK8(nbytes);
146 if (nblks) {
147 sm4_neon_cbc_dec_blk8(ctx->rkey_dec, dst, src,
148 walk.iv, nblks);
149 dst += nblks * SM4_BLOCK_SIZE;
150 src += nblks * SM4_BLOCK_SIZE;
151 nbytes -= nblks * SM4_BLOCK_SIZE;
152 }
153
154 nblks = BYTES2BLKS(nbytes);
155 if (nblks) {
156 u8 keystream[SM4_BLOCK_SIZE * 8];
157 u8 iv[SM4_BLOCK_SIZE];
158 int i;
159
160 sm4_neon_crypt_blk1_8(ctx->rkey_dec, keystream,
161 src, nblks);
162
163 src += ((int)nblks - 2) * SM4_BLOCK_SIZE;
164 dst += (nblks - 1) * SM4_BLOCK_SIZE;
165 memcpy(iv, src + SM4_BLOCK_SIZE, SM4_BLOCK_SIZE);
166
167 for (i = nblks - 1; i > 0; i--) {
168 crypto_xor_cpy(dst, src,
169 &keystream[i * SM4_BLOCK_SIZE],
170 SM4_BLOCK_SIZE);
171 src -= SM4_BLOCK_SIZE;
172 dst -= SM4_BLOCK_SIZE;
173 }
174 crypto_xor_cpy(dst, walk.iv,
175 keystream, SM4_BLOCK_SIZE);
176 memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
177 nbytes -= nblks * SM4_BLOCK_SIZE;
178 }
179
180 kernel_neon_end();
181
182 err = skcipher_walk_done(&walk, nbytes);
183 }
184
185 return err;
186}
187
188static int sm4_cfb_encrypt(struct skcipher_request *req)
189{
190 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
191 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
192 struct skcipher_walk walk;
193 unsigned int nbytes;
194 int err;
195
196 err = skcipher_walk_virt(&walk, req, false);
197
198 while ((nbytes = walk.nbytes) > 0) {
199 u8 keystream[SM4_BLOCK_SIZE];
200 const u8 *iv = walk.iv;
201 const u8 *src = walk.src.virt.addr;
202 u8 *dst = walk.dst.virt.addr;
203
204 while (nbytes >= SM4_BLOCK_SIZE) {
205 sm4_crypt_block(ctx->rkey_enc, keystream, iv);
206 crypto_xor_cpy(dst, src, keystream, SM4_BLOCK_SIZE);
207 iv = dst;
208 src += SM4_BLOCK_SIZE;
209 dst += SM4_BLOCK_SIZE;
210 nbytes -= SM4_BLOCK_SIZE;
211 }
212 if (iv != walk.iv)
213 memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
214
215 /* tail */
216 if (walk.nbytes == walk.total && nbytes > 0) {
217 sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv);
218 crypto_xor_cpy(dst, src, keystream, nbytes);
219 nbytes = 0;
220 }
221
222 err = skcipher_walk_done(&walk, nbytes);
223 }
224
225 return err;
226}
227
228static int sm4_cfb_decrypt(struct skcipher_request *req)
229{
230 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
231 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
232 struct skcipher_walk walk;
233 unsigned int nbytes;
234 int err;
235
236 err = skcipher_walk_virt(&walk, req, false);
237
238 while ((nbytes = walk.nbytes) > 0) {
239 const u8 *src = walk.src.virt.addr;
240 u8 *dst = walk.dst.virt.addr;
241 unsigned int nblks;
242
243 kernel_neon_begin();
244
245 nblks = BYTES2BLK8(nbytes);
246 if (nblks) {
247 sm4_neon_cfb_dec_blk8(ctx->rkey_enc, dst, src,
248 walk.iv, nblks);
249 dst += nblks * SM4_BLOCK_SIZE;
250 src += nblks * SM4_BLOCK_SIZE;
251 nbytes -= nblks * SM4_BLOCK_SIZE;
252 }
253
254 nblks = BYTES2BLKS(nbytes);
255 if (nblks) {
256 u8 keystream[SM4_BLOCK_SIZE * 8];
257
258 memcpy(keystream, walk.iv, SM4_BLOCK_SIZE);
259 if (nblks > 1)
260 memcpy(&keystream[SM4_BLOCK_SIZE], src,
261 (nblks - 1) * SM4_BLOCK_SIZE);
262 memcpy(walk.iv, src + (nblks - 1) * SM4_BLOCK_SIZE,
263 SM4_BLOCK_SIZE);
264
265 sm4_neon_crypt_blk1_8(ctx->rkey_enc, keystream,
266 keystream, nblks);
267
268 crypto_xor_cpy(dst, src, keystream,
269 nblks * SM4_BLOCK_SIZE);
270 dst += nblks * SM4_BLOCK_SIZE;
271 src += nblks * SM4_BLOCK_SIZE;
272 nbytes -= nblks * SM4_BLOCK_SIZE;
273 }
274
275 kernel_neon_end();
276
277 /* tail */
278 if (walk.nbytes == walk.total && nbytes > 0) {
279 u8 keystream[SM4_BLOCK_SIZE];
280
281 sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv);
282 crypto_xor_cpy(dst, src, keystream, nbytes);
283 nbytes = 0;
284 }
285
286 err = skcipher_walk_done(&walk, nbytes);
287 }
288
289 return err;
290}
291
292static int sm4_ctr_crypt(struct skcipher_request *req)
293{
294 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
295 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
296 struct skcipher_walk walk;
297 unsigned int nbytes;
298 int err;
299
300 err = skcipher_walk_virt(&walk, req, false);
301
302 while ((nbytes = walk.nbytes) > 0) {
303 const u8 *src = walk.src.virt.addr;
304 u8 *dst = walk.dst.virt.addr;
305 unsigned int nblks;
306
307 kernel_neon_begin();
308
309 nblks = BYTES2BLK8(nbytes);
310 if (nblks) {
311 sm4_neon_ctr_enc_blk8(ctx->rkey_enc, dst, src,
312 walk.iv, nblks);
313 dst += nblks * SM4_BLOCK_SIZE;
314 src += nblks * SM4_BLOCK_SIZE;
315 nbytes -= nblks * SM4_BLOCK_SIZE;
316 }
317
318 nblks = BYTES2BLKS(nbytes);
319 if (nblks) {
320 u8 keystream[SM4_BLOCK_SIZE * 8];
321 int i;
322
323 for (i = 0; i < nblks; i++) {
324 memcpy(&keystream[i * SM4_BLOCK_SIZE],
325 walk.iv, SM4_BLOCK_SIZE);
326 crypto_inc(walk.iv, SM4_BLOCK_SIZE);
327 }
328 sm4_neon_crypt_blk1_8(ctx->rkey_enc, keystream,
329 keystream, nblks);
330
331 crypto_xor_cpy(dst, src, keystream,
332 nblks * SM4_BLOCK_SIZE);
333 dst += nblks * SM4_BLOCK_SIZE;
334 src += nblks * SM4_BLOCK_SIZE;
335 nbytes -= nblks * SM4_BLOCK_SIZE;
336 }
337
338 kernel_neon_end();
339
340 /* tail */
341 if (walk.nbytes == walk.total && nbytes > 0) {
342 u8 keystream[SM4_BLOCK_SIZE];
343
344 sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv);
345 crypto_inc(walk.iv, SM4_BLOCK_SIZE);
346 crypto_xor_cpy(dst, src, keystream, nbytes);
347 nbytes = 0;
348 }
349
350 err = skcipher_walk_done(&walk, nbytes);
351 }
352
353 return err;
354}
355
356static struct skcipher_alg sm4_algs[] = {
357 {
358 .base = {
359 .cra_name = "ecb(sm4)",
360 .cra_driver_name = "ecb-sm4-neon",
361 .cra_priority = 200,
362 .cra_blocksize = SM4_BLOCK_SIZE,
363 .cra_ctxsize = sizeof(struct sm4_ctx),
364 .cra_module = THIS_MODULE,
365 },
366 .min_keysize = SM4_KEY_SIZE,
367 .max_keysize = SM4_KEY_SIZE,
368 .setkey = sm4_setkey,
369 .encrypt = sm4_ecb_encrypt,
370 .decrypt = sm4_ecb_decrypt,
371 }, {
372 .base = {
373 .cra_name = "cbc(sm4)",
374 .cra_driver_name = "cbc-sm4-neon",
375 .cra_priority = 200,
376 .cra_blocksize = SM4_BLOCK_SIZE,
377 .cra_ctxsize = sizeof(struct sm4_ctx),
378 .cra_module = THIS_MODULE,
379 },
380 .min_keysize = SM4_KEY_SIZE,
381 .max_keysize = SM4_KEY_SIZE,
382 .ivsize = SM4_BLOCK_SIZE,
383 .setkey = sm4_setkey,
384 .encrypt = sm4_cbc_encrypt,
385 .decrypt = sm4_cbc_decrypt,
386 }, {
387 .base = {
388 .cra_name = "cfb(sm4)",
389 .cra_driver_name = "cfb-sm4-neon",
390 .cra_priority = 200,
391 .cra_blocksize = 1,
392 .cra_ctxsize = sizeof(struct sm4_ctx),
393 .cra_module = THIS_MODULE,
394 },
395 .min_keysize = SM4_KEY_SIZE,
396 .max_keysize = SM4_KEY_SIZE,
397 .ivsize = SM4_BLOCK_SIZE,
398 .chunksize = SM4_BLOCK_SIZE,
399 .setkey = sm4_setkey,
400 .encrypt = sm4_cfb_encrypt,
401 .decrypt = sm4_cfb_decrypt,
402 }, {
403 .base = {
404 .cra_name = "ctr(sm4)",
405 .cra_driver_name = "ctr-sm4-neon",
406 .cra_priority = 200,
407 .cra_blocksize = 1,
408 .cra_ctxsize = sizeof(struct sm4_ctx),
409 .cra_module = THIS_MODULE,
410 },
411 .min_keysize = SM4_KEY_SIZE,
412 .max_keysize = SM4_KEY_SIZE,
413 .ivsize = SM4_BLOCK_SIZE,
414 .chunksize = SM4_BLOCK_SIZE,
415 .setkey = sm4_setkey,
416 .encrypt = sm4_ctr_crypt,
417 .decrypt = sm4_ctr_crypt,
418 }
419};
420
421static int __init sm4_init(void)
422{
423 return crypto_register_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs));
424}
425
426static void __exit sm4_exit(void)
427{
428 crypto_unregister_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs));
429}
430
431module_init(sm4_init);
432module_exit(sm4_exit);
433
434MODULE_DESCRIPTION("SM4 ECB/CBC/CFB/CTR using ARMv8 NEON");
435MODULE_ALIAS_CRYPTO("sm4-neon");
436MODULE_ALIAS_CRYPTO("sm4");
437MODULE_ALIAS_CRYPTO("ecb(sm4)");
438MODULE_ALIAS_CRYPTO("cbc(sm4)");
439MODULE_ALIAS_CRYPTO("cfb(sm4)");
440MODULE_ALIAS_CRYPTO("ctr(sm4)");
441MODULE_AUTHOR("Tianjia Zhang <tianjia.zhang@linux.alibaba.com>");
442MODULE_LICENSE("GPL v2");