Linux kernel mirror (for testing) git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel os linux

crypto: arm/aes-neonbs - implement ciphertext stealing for XTS

Update the AES-XTS implementation based on NEON instructions so that it
can deal with inputs whose size is not a multiple of the cipher block
size. This is part of the original XTS specification, but was never
implemented before in the Linux kernel.

Signed-off-by: Ard Biesheuvel <ard.biesheuvel@linaro.org>
Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>

authored by

Ard Biesheuvel and committed by
Herbert Xu
2ed8b790 c61b1607

+72 -13
+12 -4
arch/arm/crypto/aes-neonbs-core.S
··· 889 889 890 890 /* 891 891 * aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds, 892 - * int blocks, u8 iv[]) 892 + * int blocks, u8 iv[], int reorder_last_tweak) 893 893 * aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[], int rounds, 894 - * int blocks, u8 iv[]) 894 + * int blocks, u8 iv[], int reorder_last_tweak) 895 895 */ 896 896 __xts_prepare8: 897 897 vld1.8 {q14}, [r7] // load iv ··· 944 944 945 945 vld1.8 {q7}, [r1]! 946 946 next_tweak q14, q12, q15, q13 947 - veor q7, q7, q12 947 + THUMB( itt le ) 948 + W(cmple) r8, #0 949 + ble 1f 950 + 0: veor q7, q7, q12 948 951 vst1.8 {q12}, [r4, :128] 949 952 950 - 0: vst1.8 {q14}, [r7] // store next iv 953 + vst1.8 {q14}, [r7] // store next iv 951 954 bx lr 955 + 956 + 1: vswp q12, q14 957 + b 0b 952 958 ENDPROC(__xts_prepare8) 953 959 954 960 .macro __xts_crypt, do8, o0, o1, o2, o3, o4, o5, o6, o7 955 961 push {r4-r8, lr} 956 962 mov r5, sp // preserve sp 957 963 ldrd r6, r7, [sp, #24] // get blocks and iv args 964 + ldr r8, [sp, #32] // reorder final tweak? 965 + rsb r8, r8, #1 958 966 sub ip, sp, #128 // make room for 8x tweak 959 967 bic ip, ip, #0xf // align sp to 16 bytes 960 968 mov sp, ip
+60 -9
arch/arm/crypto/aes-neonbs-glue.c
··· 12 12 #include <crypto/ctr.h> 13 13 #include <crypto/internal/simd.h> 14 14 #include <crypto/internal/skcipher.h> 15 + #include <crypto/scatterwalk.h> 15 16 #include <crypto/xts.h> 16 17 #include <linux/module.h> 17 18 ··· 38 37 int rounds, int blocks, u8 ctr[], u8 final[]); 39 38 40 39 asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[], 41 - int rounds, int blocks, u8 iv[]); 40 + int rounds, int blocks, u8 iv[], int); 42 41 asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[], 43 - int rounds, int blocks, u8 iv[]); 42 + int rounds, int blocks, u8 iv[], int); 44 43 45 44 struct aesbs_ctx { 46 45 int rounds; ··· 54 53 55 54 struct aesbs_xts_ctx { 56 55 struct aesbs_ctx key; 56 + struct crypto_cipher *cts_tfm; 57 57 struct crypto_cipher *tweak_tfm; 58 58 }; 59 59 ··· 293 291 return err; 294 292 295 293 key_len /= 2; 294 + err = crypto_cipher_setkey(ctx->cts_tfm, in_key, key_len); 295 + if (err) 296 + return err; 296 297 err = crypto_cipher_setkey(ctx->tweak_tfm, in_key + key_len, key_len); 297 298 if (err) 298 299 return err; ··· 307 302 { 308 303 struct aesbs_xts_ctx *ctx = crypto_tfm_ctx(tfm); 309 304 305 + ctx->cts_tfm = crypto_alloc_cipher("aes", 0, 0); 306 + if (IS_ERR(ctx->cts_tfm)) 307 + return PTR_ERR(ctx->cts_tfm); 308 + 310 309 ctx->tweak_tfm = crypto_alloc_cipher("aes", 0, 0); 310 + if (IS_ERR(ctx->tweak_tfm)) 311 + crypto_free_cipher(ctx->cts_tfm); 311 312 312 313 return PTR_ERR_OR_ZERO(ctx->tweak_tfm); 313 314 } ··· 323 312 struct aesbs_xts_ctx *ctx = crypto_tfm_ctx(tfm); 324 313 325 314 crypto_free_cipher(ctx->tweak_tfm); 315 + crypto_free_cipher(ctx->cts_tfm); 326 316 } 327 317 328 - static int __xts_crypt(struct skcipher_request *req, 318 + static int __xts_crypt(struct skcipher_request *req, bool encrypt, 329 319 void (*fn)(u8 out[], u8 const in[], u8 const rk[], 330 - int rounds, int blocks, u8 iv[])) 320 + int rounds, int blocks, u8 iv[], int)) 331 321 { 332 322 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 333 323 struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm); 324 + int tail = req->cryptlen % AES_BLOCK_SIZE; 325 + struct skcipher_request subreq; 326 + u8 buf[2 * AES_BLOCK_SIZE]; 334 327 struct skcipher_walk walk; 335 328 int err; 329 + 330 + if (req->cryptlen < AES_BLOCK_SIZE) 331 + return -EINVAL; 332 + 333 + if (unlikely(tail)) { 334 + skcipher_request_set_tfm(&subreq, tfm); 335 + skcipher_request_set_callback(&subreq, 336 + skcipher_request_flags(req), 337 + NULL, NULL); 338 + skcipher_request_set_crypt(&subreq, req->src, req->dst, 339 + req->cryptlen - tail, req->iv); 340 + req = &subreq; 341 + } 336 342 337 343 err = skcipher_walk_virt(&walk, req, true); 338 344 if (err) ··· 359 331 360 332 while (walk.nbytes >= AES_BLOCK_SIZE) { 361 333 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 334 + int reorder_last_tweak = !encrypt && tail > 0; 362 335 363 - if (walk.nbytes < walk.total) 336 + if (walk.nbytes < walk.total) { 364 337 blocks = round_down(blocks, 365 338 walk.stride / AES_BLOCK_SIZE); 339 + reorder_last_tweak = 0; 340 + } 366 341 367 342 kernel_neon_begin(); 368 343 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk, 369 - ctx->key.rounds, blocks, walk.iv); 344 + ctx->key.rounds, blocks, walk.iv, reorder_last_tweak); 370 345 kernel_neon_end(); 371 346 err = skcipher_walk_done(&walk, 372 347 walk.nbytes - blocks * AES_BLOCK_SIZE); 373 348 } 374 349 375 - return err; 350 + if (err || likely(!tail)) 351 + return err; 352 + 353 + /* handle ciphertext stealing */ 354 + scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE, 355 + AES_BLOCK_SIZE, 0); 356 + memcpy(buf + AES_BLOCK_SIZE, buf, tail); 357 + scatterwalk_map_and_copy(buf, req->src, req->cryptlen, tail, 0); 358 + 359 + crypto_xor(buf, req->iv, AES_BLOCK_SIZE); 360 + 361 + if (encrypt) 362 + crypto_cipher_encrypt_one(ctx->cts_tfm, buf, buf); 363 + else 364 + crypto_cipher_decrypt_one(ctx->cts_tfm, buf, buf); 365 + 366 + crypto_xor(buf, req->iv, AES_BLOCK_SIZE); 367 + 368 + scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE, 369 + AES_BLOCK_SIZE + tail, 1); 370 + return 0; 376 371 } 377 372 378 373 static int xts_encrypt(struct skcipher_request *req) 379 374 { 380 - return __xts_crypt(req, aesbs_xts_encrypt); 375 + return __xts_crypt(req, true, aesbs_xts_encrypt); 381 376 } 382 377 383 378 static int xts_decrypt(struct skcipher_request *req) 384 379 { 385 - return __xts_crypt(req, aesbs_xts_decrypt); 380 + return __xts_crypt(req, false, aesbs_xts_decrypt); 386 381 } 387 382 388 383 static struct skcipher_alg aes_algs[] = { {