Linux kernel mirror (for testing) git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel os linux

crypto: arm64/aes-xctr - Improve readability of XCTR and CTR modes

Added some clarifying comments, changed the register allocations to make
the code clearer, and added register aliases.

Signed-off-by: Nathan Huckleberry <nhuck@google.com>
Reviewed-by: Eric Biggers <ebiggers@google.com>
Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>

authored by

Nathan Huckleberry and committed by
Herbert Xu
c0eb7591 23a251cc

+185 -68
+16
arch/arm64/crypto/aes-glue.c
··· 464 464 u8 *dst = walk.dst.virt.addr; 465 465 u8 buf[AES_BLOCK_SIZE]; 466 466 467 + /* 468 + * If given less than 16 bytes, we must copy the partial block 469 + * into a temporary buffer of 16 bytes to avoid out of bounds 470 + * reads and writes. Furthermore, this code is somewhat unusual 471 + * in that it expects the end of the data to be at the end of 472 + * the temporary buffer, rather than the start of the data at 473 + * the start of the temporary buffer. 474 + */ 467 475 if (unlikely(nbytes < AES_BLOCK_SIZE)) 468 476 src = dst = memcpy(buf + sizeof(buf) - nbytes, 469 477 src, nbytes); ··· 509 501 u8 *dst = walk.dst.virt.addr; 510 502 u8 buf[AES_BLOCK_SIZE]; 511 503 504 + /* 505 + * If given less than 16 bytes, we must copy the partial block 506 + * into a temporary buffer of 16 bytes to avoid out of bounds 507 + * reads and writes. Furthermore, this code is somewhat unusual 508 + * in that it expects the end of the data to be at the end of 509 + * the temporary buffer, rather than the start of the data at 510 + * the start of the temporary buffer. 511 + */ 512 512 if (unlikely(nbytes < AES_BLOCK_SIZE)) 513 513 src = dst = memcpy(buf + sizeof(buf) - nbytes, 514 514 src, nbytes);
+169 -68
arch/arm64/crypto/aes-modes.S
··· 322 322 * This macro generates the code for CTR and XCTR mode. 323 323 */ 324 324 .macro ctr_encrypt xctr 325 + // Arguments 326 + OUT .req x0 327 + IN .req x1 328 + KEY .req x2 329 + ROUNDS_W .req w3 330 + BYTES_W .req w4 331 + IV .req x5 332 + BYTE_CTR_W .req w6 // XCTR only 333 + // Intermediate values 334 + CTR_W .req w11 // XCTR only 335 + CTR .req x11 // XCTR only 336 + IV_PART .req x12 337 + BLOCKS .req x13 338 + BLOCKS_W .req w13 339 + 325 340 stp x29, x30, [sp, #-16]! 326 341 mov x29, sp 327 342 328 - enc_prepare w3, x2, x12 329 - ld1 {vctr.16b}, [x5] 343 + enc_prepare ROUNDS_W, KEY, IV_PART 344 + ld1 {vctr.16b}, [IV] 330 345 346 + /* 347 + * Keep 64 bits of the IV in a register. For CTR mode this lets us 348 + * easily increment the IV. For XCTR mode this lets us efficiently XOR 349 + * the 64-bit counter with the IV. 350 + */ 331 351 .if \xctr 332 - umov x12, vctr.d[0] 333 - lsr w11, w6, #4 352 + umov IV_PART, vctr.d[0] 353 + lsr CTR_W, BYTE_CTR_W, #4 334 354 .else 335 - umov x12, vctr.d[1] /* keep swabbed ctr in reg */ 336 - rev x12, x12 355 + umov IV_PART, vctr.d[1] 356 + rev IV_PART, IV_PART 337 357 .endif 338 358 339 359 .LctrloopNx\xctr: 340 - add w7, w4, #15 341 - sub w4, w4, #MAX_STRIDE << 4 342 - lsr w7, w7, #4 360 + add BLOCKS_W, BYTES_W, #15 361 + sub BYTES_W, BYTES_W, #MAX_STRIDE << 4 362 + lsr BLOCKS_W, BLOCKS_W, #4 343 363 mov w8, #MAX_STRIDE 344 - cmp w7, w8 345 - csel w7, w7, w8, lt 364 + cmp BLOCKS_W, w8 365 + csel BLOCKS_W, BLOCKS_W, w8, lt 346 366 367 + /* 368 + * Set up the counter values in v0-v{MAX_STRIDE-1}. 369 + * 370 + * If we are encrypting less than MAX_STRIDE blocks, the tail block 371 + * handling code expects the last keystream block to be in 372 + * v{MAX_STRIDE-1}. For example: if encrypting two blocks with 373 + * MAX_STRIDE=5, then v3 and v4 should have the next two counter blocks. 374 + */ 347 375 .if \xctr 348 - add x11, x11, x7 376 + add CTR, CTR, BLOCKS 349 377 .else 350 - adds x12, x12, x7 378 + adds IV_PART, IV_PART, BLOCKS 351 379 .endif 352 380 mov v0.16b, vctr.16b 353 381 mov v1.16b, vctr.16b ··· 383 355 mov v3.16b, vctr.16b 384 356 ST5( mov v4.16b, vctr.16b ) 385 357 .if \xctr 386 - sub x6, x11, #MAX_STRIDE - 1 387 - sub x7, x11, #MAX_STRIDE - 2 388 - sub x8, x11, #MAX_STRIDE - 3 389 - sub x9, x11, #MAX_STRIDE - 4 390 - ST5( sub x10, x11, #MAX_STRIDE - 5 ) 391 - eor x6, x6, x12 392 - eor x7, x7, x12 393 - eor x8, x8, x12 394 - eor x9, x9, x12 395 - ST5( eor x10, x10, x12 ) 358 + sub x6, CTR, #MAX_STRIDE - 1 359 + sub x7, CTR, #MAX_STRIDE - 2 360 + sub x8, CTR, #MAX_STRIDE - 3 361 + sub x9, CTR, #MAX_STRIDE - 4 362 + ST5( sub x10, CTR, #MAX_STRIDE - 5 ) 363 + eor x6, x6, IV_PART 364 + eor x7, x7, IV_PART 365 + eor x8, x8, IV_PART 366 + eor x9, x9, IV_PART 367 + ST5( eor x10, x10, IV_PART ) 396 368 mov v0.d[0], x6 397 369 mov v1.d[0], x7 398 370 mov v2.d[0], x8 ··· 401 373 .else 402 374 bcs 0f 403 375 .subsection 1 404 - /* apply carry to outgoing counter */ 376 + /* 377 + * This subsection handles carries. 378 + * 379 + * Conditional branching here is allowed with respect to time 380 + * invariance since the branches are dependent on the IV instead 381 + * of the plaintext or key. This code is rarely executed in 382 + * practice anyway. 383 + */ 384 + 385 + /* Apply carry to outgoing counter. */ 405 386 0: umov x8, vctr.d[0] 406 387 rev x8, x8 407 388 add x8, x8, #1 408 389 rev x8, x8 409 390 ins vctr.d[0], x8 410 391 411 - /* apply carry to N counter blocks for N := x12 */ 412 - cbz x12, 2f 392 + /* 393 + * Apply carry to counter blocks if needed. 394 + * 395 + * Since the carry flag was set, we know 0 <= IV_PART < 396 + * MAX_STRIDE. Using the value of IV_PART we can determine how 397 + * many counter blocks need to be updated. 398 + */ 399 + cbz IV_PART, 2f 413 400 adr x16, 1f 414 - sub x16, x16, x12, lsl #3 401 + sub x16, x16, IV_PART, lsl #3 415 402 br x16 416 403 bti c 417 404 mov v0.d[0], vctr.d[0] ··· 441 398 1: b 2f 442 399 .previous 443 400 444 - 2: rev x7, x12 401 + 2: rev x7, IV_PART 445 402 ins vctr.d[1], x7 446 - sub x7, x12, #MAX_STRIDE - 1 447 - sub x8, x12, #MAX_STRIDE - 2 448 - sub x9, x12, #MAX_STRIDE - 3 403 + sub x7, IV_PART, #MAX_STRIDE - 1 404 + sub x8, IV_PART, #MAX_STRIDE - 2 405 + sub x9, IV_PART, #MAX_STRIDE - 3 449 406 rev x7, x7 450 407 rev x8, x8 451 408 mov v1.d[1], x7 452 409 rev x9, x9 453 - ST5( sub x10, x12, #MAX_STRIDE - 4 ) 410 + ST5( sub x10, IV_PART, #MAX_STRIDE - 4 ) 454 411 mov v2.d[1], x8 455 412 ST5( rev x10, x10 ) 456 413 mov v3.d[1], x9 457 414 ST5( mov v4.d[1], x10 ) 458 415 .endif 459 - tbnz w4, #31, .Lctrtail\xctr 460 - ld1 {v5.16b-v7.16b}, [x1], #48 416 + 417 + /* 418 + * If there are at least MAX_STRIDE blocks left, XOR the data with 419 + * keystream and store. Otherwise jump to tail handling. 420 + */ 421 + tbnz BYTES_W, #31, .Lctrtail\xctr 422 + ld1 {v5.16b-v7.16b}, [IN], #48 461 423 ST4( bl aes_encrypt_block4x ) 462 424 ST5( bl aes_encrypt_block5x ) 463 425 eor v0.16b, v5.16b, v0.16b 464 - ST4( ld1 {v5.16b}, [x1], #16 ) 426 + ST4( ld1 {v5.16b}, [IN], #16 ) 465 427 eor v1.16b, v6.16b, v1.16b 466 - ST5( ld1 {v5.16b-v6.16b}, [x1], #32 ) 428 + ST5( ld1 {v5.16b-v6.16b}, [IN], #32 ) 467 429 eor v2.16b, v7.16b, v2.16b 468 430 eor v3.16b, v5.16b, v3.16b 469 431 ST5( eor v4.16b, v6.16b, v4.16b ) 470 - st1 {v0.16b-v3.16b}, [x0], #64 471 - ST5( st1 {v4.16b}, [x0], #16 ) 472 - cbz w4, .Lctrout\xctr 432 + st1 {v0.16b-v3.16b}, [OUT], #64 433 + ST5( st1 {v4.16b}, [OUT], #16 ) 434 + cbz BYTES_W, .Lctrout\xctr 473 435 b .LctrloopNx\xctr 474 436 475 437 .Lctrout\xctr: 476 438 .if !\xctr 477 - st1 {vctr.16b}, [x5] /* return next CTR value */ 439 + st1 {vctr.16b}, [IV] /* return next CTR value */ 478 440 .endif 479 441 ldp x29, x30, [sp], #16 480 442 ret 481 443 482 444 .Lctrtail\xctr: 445 + /* 446 + * Handle up to MAX_STRIDE * 16 - 1 bytes of plaintext 447 + * 448 + * This code expects the last keystream block to be in v{MAX_STRIDE-1}. 449 + * For example: if encrypting two blocks with MAX_STRIDE=5, then v3 and 450 + * v4 should have the next two counter blocks. 451 + * 452 + * This allows us to store the ciphertext by writing to overlapping 453 + * regions of memory. Any invalid ciphertext blocks get overwritten by 454 + * correctly computed blocks. This approach greatly simplifies the 455 + * logic for storing the ciphertext. 456 + */ 483 457 mov x16, #16 484 - ands x6, x4, #0xf 485 - csel x13, x6, x16, ne 458 + ands w7, BYTES_W, #0xf 459 + csel x13, x7, x16, ne 486 460 487 - ST5( cmp w4, #64 - (MAX_STRIDE << 4) ) 461 + ST5( cmp BYTES_W, #64 - (MAX_STRIDE << 4)) 488 462 ST5( csel x14, x16, xzr, gt ) 489 - cmp w4, #48 - (MAX_STRIDE << 4) 463 + cmp BYTES_W, #48 - (MAX_STRIDE << 4) 490 464 csel x15, x16, xzr, gt 491 - cmp w4, #32 - (MAX_STRIDE << 4) 465 + cmp BYTES_W, #32 - (MAX_STRIDE << 4) 492 466 csel x16, x16, xzr, gt 493 - cmp w4, #16 - (MAX_STRIDE << 4) 467 + cmp BYTES_W, #16 - (MAX_STRIDE << 4) 494 468 495 - adr_l x12, .Lcts_permute_table 496 - add x12, x12, x13 469 + adr_l x9, .Lcts_permute_table 470 + add x9, x9, x13 497 471 ble .Lctrtail1x\xctr 498 472 499 - ST5( ld1 {v5.16b}, [x1], x14 ) 500 - ld1 {v6.16b}, [x1], x15 501 - ld1 {v7.16b}, [x1], x16 473 + ST5( ld1 {v5.16b}, [IN], x14 ) 474 + ld1 {v6.16b}, [IN], x15 475 + ld1 {v7.16b}, [IN], x16 502 476 503 477 ST4( bl aes_encrypt_block4x ) 504 478 ST5( bl aes_encrypt_block5x ) 505 479 506 - ld1 {v8.16b}, [x1], x13 507 - ld1 {v9.16b}, [x1] 508 - ld1 {v10.16b}, [x12] 480 + ld1 {v8.16b}, [IN], x13 481 + ld1 {v9.16b}, [IN] 482 + ld1 {v10.16b}, [x9] 509 483 510 484 ST4( eor v6.16b, v6.16b, v0.16b ) 511 485 ST4( eor v7.16b, v7.16b, v1.16b ) ··· 537 477 ST5( eor v8.16b, v8.16b, v3.16b ) 538 478 ST5( eor v9.16b, v9.16b, v4.16b ) 539 479 540 - ST5( st1 {v5.16b}, [x0], x14 ) 541 - st1 {v6.16b}, [x0], x15 542 - st1 {v7.16b}, [x0], x16 543 - add x13, x13, x0 480 + ST5( st1 {v5.16b}, [OUT], x14 ) 481 + st1 {v6.16b}, [OUT], x15 482 + st1 {v7.16b}, [OUT], x16 483 + add x13, x13, OUT 544 484 st1 {v9.16b}, [x13] // overlapping stores 545 - st1 {v8.16b}, [x0] 485 + st1 {v8.16b}, [OUT] 546 486 b .Lctrout\xctr 547 487 548 488 .Lctrtail1x\xctr: 549 - sub x7, x6, #16 550 - csel x6, x6, x7, eq 551 - add x1, x1, x6 552 - add x0, x0, x6 553 - ld1 {v5.16b}, [x1] 554 - ld1 {v6.16b}, [x0] 489 + /* 490 + * Handle <= 16 bytes of plaintext 491 + * 492 + * This code always reads and writes 16 bytes. To avoid out of bounds 493 + * accesses, XCTR and CTR modes must use a temporary buffer when 494 + * encrypting/decrypting less than 16 bytes. 495 + * 496 + * This code is unusual in that it loads the input and stores the output 497 + * relative to the end of the buffers rather than relative to the start. 498 + * This causes unusual behaviour when encrypting/decrypting less than 16 499 + * bytes; the end of the data is expected to be at the end of the 500 + * temporary buffer rather than the start of the data being at the start 501 + * of the temporary buffer. 502 + */ 503 + sub x8, x7, #16 504 + csel x7, x7, x8, eq 505 + add IN, IN, x7 506 + add OUT, OUT, x7 507 + ld1 {v5.16b}, [IN] 508 + ld1 {v6.16b}, [OUT] 555 509 ST5( mov v3.16b, v4.16b ) 556 - encrypt_block v3, w3, x2, x8, w7 557 - ld1 {v10.16b-v11.16b}, [x12] 510 + encrypt_block v3, ROUNDS_W, KEY, x8, w7 511 + ld1 {v10.16b-v11.16b}, [x9] 558 512 tbl v3.16b, {v3.16b}, v10.16b 559 513 sshr v11.16b, v11.16b, #7 560 514 eor v5.16b, v5.16b, v3.16b 561 515 bif v5.16b, v6.16b, v11.16b 562 - st1 {v5.16b}, [x0] 516 + st1 {v5.16b}, [OUT] 563 517 b .Lctrout\xctr 518 + 519 + // Arguments 520 + .unreq OUT 521 + .unreq IN 522 + .unreq KEY 523 + .unreq ROUNDS_W 524 + .unreq BYTES_W 525 + .unreq IV 526 + .unreq BYTE_CTR_W // XCTR only 527 + // Intermediate values 528 + .unreq CTR_W // XCTR only 529 + .unreq CTR // XCTR only 530 + .unreq IV_PART 531 + .unreq BLOCKS 532 + .unreq BLOCKS_W 564 533 .endm 565 534 566 535 /* 567 536 * aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds, 568 537 * int bytes, u8 ctr[]) 538 + * 539 + * The input and output buffers must always be at least 16 bytes even if 540 + * encrypting/decrypting less than 16 bytes. Otherwise out of bounds 541 + * accesses will occur. The data to be encrypted/decrypted is expected 542 + * to be at the end of this 16-byte temporary buffer rather than the 543 + * start. 569 544 */ 570 545 571 546 AES_FUNC_START(aes_ctr_encrypt) ··· 610 515 /* 611 516 * aes_xctr_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds, 612 517 * int bytes, u8 const iv[], int byte_ctr) 518 + * 519 + * The input and output buffers must always be at least 16 bytes even if 520 + * encrypting/decrypting less than 16 bytes. Otherwise out of bounds 521 + * accesses will occur. The data to be encrypted/decrypted is expected 522 + * to be at the end of this 16-byte temporary buffer rather than the 523 + * start. 613 524 */ 614 525 615 526 AES_FUNC_START(aes_xctr_encrypt)