Linux kernel mirror (for testing) git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel os linux

crypto: arm64/aes-xctr - Add accelerated implementation of XCTR

Add hardware accelerated version of XCTR for ARM64 CPUs with ARMv8
Crypto Extension support. This XCTR implementation is based on the CTR
implementation in aes-modes.S.

More information on XCTR can be found in
the HCTR2 paper: "Length-preserving encryption with HCTR2":
https://eprint.iacr.org/2021/1441.pdf

Signed-off-by: Nathan Huckleberry <nhuck@google.com>
Reviewed-by: Ard Biesheuvel <ardb@kernel.org>
Reviewed-by: Eric Biggers <ebiggers@google.com>
Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>

authored by

Nathan Huckleberry and committed by
Herbert Xu
23a251cc fd94fcf0

+166 -64
+2 -2
arch/arm64/crypto/Kconfig
··· 96 96 select CRYPTO_LIB_AES 97 97 98 98 config CRYPTO_AES_ARM64_CE_BLK 99 - tristate "AES in ECB/CBC/CTR/XTS modes using ARMv8 Crypto Extensions" 99 + tristate "AES in ECB/CBC/CTR/XTS/XCTR modes using ARMv8 Crypto Extensions" 100 100 depends on KERNEL_MODE_NEON 101 101 select CRYPTO_SKCIPHER 102 102 select CRYPTO_AES_ARM64_CE 103 103 104 104 config CRYPTO_AES_ARM64_NEON_BLK 105 - tristate "AES in ECB/CBC/CTR/XTS modes using NEON instructions" 105 + tristate "AES in ECB/CBC/CTR/XTS/XCTR modes using NEON instructions" 106 106 depends on KERNEL_MODE_NEON 107 107 select CRYPTO_SKCIPHER 108 108 select CRYPTO_LIB_AES
+62 -2
arch/arm64/crypto/aes-glue.c
··· 34 34 #define aes_essiv_cbc_encrypt ce_aes_essiv_cbc_encrypt 35 35 #define aes_essiv_cbc_decrypt ce_aes_essiv_cbc_decrypt 36 36 #define aes_ctr_encrypt ce_aes_ctr_encrypt 37 + #define aes_xctr_encrypt ce_aes_xctr_encrypt 37 38 #define aes_xts_encrypt ce_aes_xts_encrypt 38 39 #define aes_xts_decrypt ce_aes_xts_decrypt 39 40 #define aes_mac_update ce_aes_mac_update 40 - MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions"); 41 + MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS/XCTR using ARMv8 Crypto Extensions"); 41 42 #else 42 43 #define MODE "neon" 43 44 #define PRIO 200 ··· 51 50 #define aes_essiv_cbc_encrypt neon_aes_essiv_cbc_encrypt 52 51 #define aes_essiv_cbc_decrypt neon_aes_essiv_cbc_decrypt 53 52 #define aes_ctr_encrypt neon_aes_ctr_encrypt 53 + #define aes_xctr_encrypt neon_aes_xctr_encrypt 54 54 #define aes_xts_encrypt neon_aes_xts_encrypt 55 55 #define aes_xts_decrypt neon_aes_xts_decrypt 56 56 #define aes_mac_update neon_aes_mac_update 57 - MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 NEON"); 57 + MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS/XCTR using ARMv8 NEON"); 58 58 #endif 59 59 #if defined(USE_V8_CRYPTO_EXTENSIONS) || !IS_ENABLED(CONFIG_CRYPTO_AES_ARM64_BS) 60 60 MODULE_ALIAS_CRYPTO("ecb(aes)"); 61 61 MODULE_ALIAS_CRYPTO("cbc(aes)"); 62 62 MODULE_ALIAS_CRYPTO("ctr(aes)"); 63 63 MODULE_ALIAS_CRYPTO("xts(aes)"); 64 + MODULE_ALIAS_CRYPTO("xctr(aes)"); 64 65 #endif 65 66 MODULE_ALIAS_CRYPTO("cts(cbc(aes))"); 66 67 MODULE_ALIAS_CRYPTO("essiv(cbc(aes),sha256)"); ··· 91 88 92 89 asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[], 93 90 int rounds, int bytes, u8 ctr[]); 91 + 92 + asmlinkage void aes_xctr_encrypt(u8 out[], u8 const in[], u32 const rk[], 93 + int rounds, int bytes, u8 ctr[], int byte_ctr); 94 94 95 95 asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[], 96 96 int rounds, int bytes, u32 const rk2[], u8 iv[], ··· 448 442 return err ?: cbc_decrypt_walk(req, &walk); 449 443 } 450 444 445 + static int __maybe_unused xctr_encrypt(struct skcipher_request *req) 446 + { 447 + struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 448 + struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm); 449 + int err, rounds = 6 + ctx->key_length / 4; 450 + struct skcipher_walk walk; 451 + unsigned int byte_ctr = 0; 452 + 453 + err = skcipher_walk_virt(&walk, req, false); 454 + 455 + while (walk.nbytes > 0) { 456 + const u8 *src = walk.src.virt.addr; 457 + unsigned int nbytes = walk.nbytes; 458 + u8 *dst = walk.dst.virt.addr; 459 + u8 buf[AES_BLOCK_SIZE]; 460 + 461 + if (unlikely(nbytes < AES_BLOCK_SIZE)) 462 + src = dst = memcpy(buf + sizeof(buf) - nbytes, 463 + src, nbytes); 464 + else if (nbytes < walk.total) 465 + nbytes &= ~(AES_BLOCK_SIZE - 1); 466 + 467 + kernel_neon_begin(); 468 + aes_xctr_encrypt(dst, src, ctx->key_enc, rounds, nbytes, 469 + walk.iv, byte_ctr); 470 + kernel_neon_end(); 471 + 472 + if (unlikely(nbytes < AES_BLOCK_SIZE)) 473 + memcpy(walk.dst.virt.addr, 474 + buf + sizeof(buf) - nbytes, nbytes); 475 + byte_ctr += nbytes; 476 + 477 + err = skcipher_walk_done(&walk, walk.nbytes - nbytes); 478 + } 479 + 480 + return err; 481 + } 482 + 451 483 static int __maybe_unused ctr_encrypt(struct skcipher_request *req) 452 484 { 453 485 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); ··· 713 669 .setkey = skcipher_aes_setkey, 714 670 .encrypt = ctr_encrypt, 715 671 .decrypt = ctr_encrypt, 672 + }, { 673 + .base = { 674 + .cra_name = "xctr(aes)", 675 + .cra_driver_name = "xctr-aes-" MODE, 676 + .cra_priority = PRIO, 677 + .cra_blocksize = 1, 678 + .cra_ctxsize = sizeof(struct crypto_aes_ctx), 679 + .cra_module = THIS_MODULE, 680 + }, 681 + .min_keysize = AES_MIN_KEY_SIZE, 682 + .max_keysize = AES_MAX_KEY_SIZE, 683 + .ivsize = AES_BLOCK_SIZE, 684 + .chunksize = AES_BLOCK_SIZE, 685 + .setkey = skcipher_aes_setkey, 686 + .encrypt = xctr_encrypt, 687 + .decrypt = xctr_encrypt, 716 688 }, { 717 689 .base = { 718 690 .cra_name = "xts(aes)",
+102 -60
arch/arm64/crypto/aes-modes.S
··· 318 318 .byte 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff 319 319 .previous 320 320 321 - 322 321 /* 323 - * aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds, 324 - * int bytes, u8 ctr[]) 322 + * This macro generates the code for CTR and XCTR mode. 325 323 */ 326 - 327 - AES_FUNC_START(aes_ctr_encrypt) 324 + .macro ctr_encrypt xctr 328 325 stp x29, x30, [sp, #-16]! 329 326 mov x29, sp 330 327 331 328 enc_prepare w3, x2, x12 332 329 ld1 {vctr.16b}, [x5] 333 330 334 - umov x12, vctr.d[1] /* keep swabbed ctr in reg */ 335 - rev x12, x12 331 + .if \xctr 332 + umov x12, vctr.d[0] 333 + lsr w11, w6, #4 334 + .else 335 + umov x12, vctr.d[1] /* keep swabbed ctr in reg */ 336 + rev x12, x12 337 + .endif 336 338 337 - .LctrloopNx: 339 + .LctrloopNx\xctr: 338 340 add w7, w4, #15 339 341 sub w4, w4, #MAX_STRIDE << 4 340 342 lsr w7, w7, #4 341 343 mov w8, #MAX_STRIDE 342 344 cmp w7, w8 343 345 csel w7, w7, w8, lt 344 - adds x12, x12, x7 345 346 347 + .if \xctr 348 + add x11, x11, x7 349 + .else 350 + adds x12, x12, x7 351 + .endif 346 352 mov v0.16b, vctr.16b 347 353 mov v1.16b, vctr.16b 348 354 mov v2.16b, vctr.16b 349 355 mov v3.16b, vctr.16b 350 356 ST5( mov v4.16b, vctr.16b ) 351 - bcs 0f 357 + .if \xctr 358 + sub x6, x11, #MAX_STRIDE - 1 359 + sub x7, x11, #MAX_STRIDE - 2 360 + sub x8, x11, #MAX_STRIDE - 3 361 + sub x9, x11, #MAX_STRIDE - 4 362 + ST5( sub x10, x11, #MAX_STRIDE - 5 ) 363 + eor x6, x6, x12 364 + eor x7, x7, x12 365 + eor x8, x8, x12 366 + eor x9, x9, x12 367 + ST5( eor x10, x10, x12 ) 368 + mov v0.d[0], x6 369 + mov v1.d[0], x7 370 + mov v2.d[0], x8 371 + mov v3.d[0], x9 372 + ST5( mov v4.d[0], x10 ) 373 + .else 374 + bcs 0f 375 + .subsection 1 376 + /* apply carry to outgoing counter */ 377 + 0: umov x8, vctr.d[0] 378 + rev x8, x8 379 + add x8, x8, #1 380 + rev x8, x8 381 + ins vctr.d[0], x8 352 382 353 - .subsection 1 354 - /* apply carry to outgoing counter */ 355 - 0: umov x8, vctr.d[0] 356 - rev x8, x8 357 - add x8, x8, #1 358 - rev x8, x8 359 - ins vctr.d[0], x8 383 + /* apply carry to N counter blocks for N := x12 */ 384 + cbz x12, 2f 385 + adr x16, 1f 386 + sub x16, x16, x12, lsl #3 387 + br x16 388 + bti c 389 + mov v0.d[0], vctr.d[0] 390 + bti c 391 + mov v1.d[0], vctr.d[0] 392 + bti c 393 + mov v2.d[0], vctr.d[0] 394 + bti c 395 + mov v3.d[0], vctr.d[0] 396 + ST5( bti c ) 397 + ST5( mov v4.d[0], vctr.d[0] ) 398 + 1: b 2f 399 + .previous 360 400 361 - /* apply carry to N counter blocks for N := x12 */ 362 - cbz x12, 2f 363 - adr x16, 1f 364 - sub x16, x16, x12, lsl #3 365 - br x16 366 - bti c 367 - mov v0.d[0], vctr.d[0] 368 - bti c 369 - mov v1.d[0], vctr.d[0] 370 - bti c 371 - mov v2.d[0], vctr.d[0] 372 - bti c 373 - mov v3.d[0], vctr.d[0] 374 - ST5( bti c ) 375 - ST5( mov v4.d[0], vctr.d[0] ) 376 - 1: b 2f 377 - .previous 378 - 379 - 2: rev x7, x12 380 - ins vctr.d[1], x7 381 - sub x7, x12, #MAX_STRIDE - 1 382 - sub x8, x12, #MAX_STRIDE - 2 383 - sub x9, x12, #MAX_STRIDE - 3 384 - rev x7, x7 385 - rev x8, x8 386 - mov v1.d[1], x7 387 - rev x9, x9 388 - ST5( sub x10, x12, #MAX_STRIDE - 4 ) 389 - mov v2.d[1], x8 390 - ST5( rev x10, x10 ) 391 - mov v3.d[1], x9 392 - ST5( mov v4.d[1], x10 ) 393 - tbnz w4, #31, .Lctrtail 401 + 2: rev x7, x12 402 + ins vctr.d[1], x7 403 + sub x7, x12, #MAX_STRIDE - 1 404 + sub x8, x12, #MAX_STRIDE - 2 405 + sub x9, x12, #MAX_STRIDE - 3 406 + rev x7, x7 407 + rev x8, x8 408 + mov v1.d[1], x7 409 + rev x9, x9 410 + ST5( sub x10, x12, #MAX_STRIDE - 4 ) 411 + mov v2.d[1], x8 412 + ST5( rev x10, x10 ) 413 + mov v3.d[1], x9 414 + ST5( mov v4.d[1], x10 ) 415 + .endif 416 + tbnz w4, #31, .Lctrtail\xctr 394 417 ld1 {v5.16b-v7.16b}, [x1], #48 395 418 ST4( bl aes_encrypt_block4x ) 396 419 ST5( bl aes_encrypt_block5x ) ··· 426 403 ST5( eor v4.16b, v6.16b, v4.16b ) 427 404 st1 {v0.16b-v3.16b}, [x0], #64 428 405 ST5( st1 {v4.16b}, [x0], #16 ) 429 - cbz w4, .Lctrout 430 - b .LctrloopNx 406 + cbz w4, .Lctrout\xctr 407 + b .LctrloopNx\xctr 431 408 432 - .Lctrout: 433 - st1 {vctr.16b}, [x5] /* return next CTR value */ 409 + .Lctrout\xctr: 410 + .if !\xctr 411 + st1 {vctr.16b}, [x5] /* return next CTR value */ 412 + .endif 434 413 ldp x29, x30, [sp], #16 435 414 ret 436 415 437 - .Lctrtail: 438 - /* XOR up to MAX_STRIDE * 16 - 1 bytes of in/output with v0 ... v3/v4 */ 416 + .Lctrtail\xctr: 439 417 mov x16, #16 440 418 ands x6, x4, #0xf 441 419 csel x13, x6, x16, ne ··· 451 427 452 428 adr_l x12, .Lcts_permute_table 453 429 add x12, x12, x13 454 - ble .Lctrtail1x 430 + ble .Lctrtail1x\xctr 455 431 456 432 ST5( ld1 {v5.16b}, [x1], x14 ) 457 433 ld1 {v6.16b}, [x1], x15 ··· 483 459 add x13, x13, x0 484 460 st1 {v9.16b}, [x13] // overlapping stores 485 461 st1 {v8.16b}, [x0] 486 - b .Lctrout 462 + b .Lctrout\xctr 487 463 488 - .Lctrtail1x: 464 + .Lctrtail1x\xctr: 489 465 sub x7, x6, #16 490 466 csel x6, x6, x7, eq 491 467 add x1, x1, x6 ··· 500 476 eor v5.16b, v5.16b, v3.16b 501 477 bif v5.16b, v6.16b, v11.16b 502 478 st1 {v5.16b}, [x0] 503 - b .Lctrout 479 + b .Lctrout\xctr 480 + .endm 481 + 482 + /* 483 + * aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds, 484 + * int bytes, u8 ctr[]) 485 + */ 486 + 487 + AES_FUNC_START(aes_ctr_encrypt) 488 + ctr_encrypt 0 504 489 AES_FUNC_END(aes_ctr_encrypt) 490 + 491 + /* 492 + * aes_xctr_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds, 493 + * int bytes, u8 const iv[], int byte_ctr) 494 + */ 495 + 496 + AES_FUNC_START(aes_xctr_encrypt) 497 + ctr_encrypt 1 498 + AES_FUNC_END(aes_xctr_encrypt) 505 499 506 500 507 501 /*