//! X25519 Diffie-Hellman key exchange (RFC 7748). //! //! Field arithmetic in GF(2^255 - 19) using 5×51-bit limb representation. //! Montgomery ladder scalar multiplication on Curve25519. //! Constant-time: no secret-dependent branches or memory accesses. // --------------------------------------------------------------------------- // GF(2^255 - 19) field element // --------------------------------------------------------------------------- const MASK51: u64 = (1u64 << 51) - 1; /// Basepoint u-coordinate = 9. const BASEPOINT: [u8; 32] = [ 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ]; #[derive(Clone, Copy)] struct Fe([u64; 5]); impl Fe { const ZERO: Fe = Fe([0; 5]); const ONE: Fe = Fe([1, 0, 0, 0, 0]); } // --------------------------------------------------------------------------- // Field encoding / decoding // --------------------------------------------------------------------------- fn fe_frombytes(bytes: &[u8; 32]) -> Fe { let load8 = |src: &[u8]| -> u64 { let mut dst = [0u8; 8]; let n = src.len().min(8); dst[..n].copy_from_slice(&src[..n]); u64::from_le_bytes(dst) }; let mut h = Fe::ZERO; h.0[0] = load8(&bytes[0..]) & MASK51; h.0[1] = (load8(&bytes[6..]) >> 3) & MASK51; h.0[2] = (load8(&bytes[12..]) >> 6) & MASK51; h.0[3] = (load8(&bytes[19..]) >> 1) & MASK51; h.0[4] = (load8(&bytes[24..]) >> 12) & MASK51; h } fn fe_tobytes(h: &Fe) -> [u8; 32] { let h = fe_reduce(h); let mut out = [0u8; 32]; // Pack 5×51-bit limbs into 256 bits little-endian let combine = |lo: u64, hi: u64, shift: u32| -> u64 { lo | (hi << shift) }; let w0 = combine(h.0[0], h.0[1], 51); let w1 = combine(h.0[1] >> 13, h.0[2], 38); let w2 = combine(h.0[2] >> 26, h.0[3], 25); let w3 = combine(h.0[3] >> 39, h.0[4], 12); out[0..8].copy_from_slice(&w0.to_le_bytes()); out[8..16].copy_from_slice(&w1.to_le_bytes()); out[16..24].copy_from_slice(&w2.to_le_bytes()); out[24..32].copy_from_slice(&w3.to_le_bytes()); out } // --------------------------------------------------------------------------- // Field addition and subtraction // --------------------------------------------------------------------------- fn fe_add(a: &Fe, b: &Fe) -> Fe { Fe([ a.0[0] + b.0[0], a.0[1] + b.0[1], a.0[2] + b.0[2], a.0[3] + b.0[3], a.0[4] + b.0[4], ]) } fn fe_sub(a: &Fe, b: &Fe) -> Fe { // Add 2*p before subtracting to avoid underflow. // 2*p = 2*(2^255 - 19) in 5×51-bit limbs: // limb 0: 2*(2^51 - 19) = 0xFFFFFFFFFFFDA // limbs 1-4: 2*(2^51 - 1) = 0xFFFFFFFFFFFFE Fe([ (a.0[0] + 0xFFFFFFFFFFFDA) - b.0[0], (a.0[1] + 0xFFFFFFFFFFFFE) - b.0[1], (a.0[2] + 0xFFFFFFFFFFFFE) - b.0[2], (a.0[3] + 0xFFFFFFFFFFFFE) - b.0[3], (a.0[4] + 0xFFFFFFFFFFFFE) - b.0[4], ]) } // --------------------------------------------------------------------------- // Field multiplication // --------------------------------------------------------------------------- fn fe_mul(a: &Fe, b: &Fe) -> Fe { let a0 = a.0[0] as u128; let a1 = a.0[1] as u128; let a2 = a.0[2] as u128; let a3 = a.0[3] as u128; let a4 = a.0[4] as u128; let b0 = b.0[0] as u128; let b1 = b.0[1] as u128; let b2 = b.0[2] as u128; let b3 = b.0[3] as u128; let b4 = b.0[4] as u128; // Precompute 19*b[j] for reduction (2^255 ≡ 19 mod p) let b1_19 = 19 * b1; let b2_19 = 19 * b2; let b3_19 = 19 * b3; let b4_19 = 19 * b4; // Schoolbook multiply with modular reduction let mut t0 = a0 * b0 + a1 * b4_19 + a2 * b3_19 + a3 * b2_19 + a4 * b1_19; let mut t1 = a0 * b1 + a1 * b0 + a2 * b4_19 + a3 * b3_19 + a4 * b2_19; let mut t2 = a0 * b2 + a1 * b1 + a2 * b0 + a3 * b4_19 + a4 * b3_19; let mut t3 = a0 * b3 + a1 * b2 + a2 * b1 + a3 * b0 + a4 * b4_19; let mut t4 = a0 * b4 + a1 * b3 + a2 * b2 + a3 * b1 + a4 * b0; // Carry propagation let c = t0 >> 51; t0 &= MASK51 as u128; t1 += c; let c = t1 >> 51; t1 &= MASK51 as u128; t2 += c; let c = t2 >> 51; t2 &= MASK51 as u128; t3 += c; let c = t3 >> 51; t3 &= MASK51 as u128; t4 += c; let c = t4 >> 51; t4 &= MASK51 as u128; t0 += c * 19; let c = t0 >> 51; t0 &= MASK51 as u128; t1 += c; Fe([t0 as u64, t1 as u64, t2 as u64, t3 as u64, t4 as u64]) } fn fe_sq(a: &Fe) -> Fe { let a0 = a.0[0] as u128; let a1 = a.0[1] as u128; let a2 = a.0[2] as u128; let a3 = a.0[3] as u128; let a4 = a.0[4] as u128; let d0 = 2 * a0; let d1 = 2 * a1; let d2 = 2 * a2; let d3 = 2 * a3; let a4_19 = 19 * a4; let a3_19 = 19 * a3; let mut t0 = a0 * a0 + d1 * a4_19 + d2 * a3_19; let mut t1 = d0 * a1 + d2 * a4_19 + a3 * a3_19; let mut t2 = d0 * a2 + a1 * a1 + d3 * a4_19; let mut t3 = d0 * a3 + d1 * a2 + a4 * a4_19; let mut t4 = d0 * a4 + d1 * a3 + a2 * a2; // Carry propagation let c = t0 >> 51; t0 &= MASK51 as u128; t1 += c; let c = t1 >> 51; t1 &= MASK51 as u128; t2 += c; let c = t2 >> 51; t2 &= MASK51 as u128; t3 += c; let c = t3 >> 51; t3 &= MASK51 as u128; t4 += c; let c = t4 >> 51; t4 &= MASK51 as u128; t0 += c * 19; let c = t0 >> 51; t0 &= MASK51 as u128; t1 += c; Fe([t0 as u64, t1 as u64, t2 as u64, t3 as u64, t4 as u64]) } /// Multiply by the curve constant a24 = (A - 2) / 4 = 121665 for Curve25519. /// This is the correct constant when the ladder uses AA (not BB) in the z_2 update. fn fe_mul_a24(a: &Fe) -> Fe { let mut t0 = a.0[0] as u128 * 121665; let mut t1 = a.0[1] as u128 * 121665; let mut t2 = a.0[2] as u128 * 121665; let mut t3 = a.0[3] as u128 * 121665; let mut t4 = a.0[4] as u128 * 121665; let c = t0 >> 51; t0 &= MASK51 as u128; t1 += c; let c = t1 >> 51; t1 &= MASK51 as u128; t2 += c; let c = t2 >> 51; t2 &= MASK51 as u128; t3 += c; let c = t3 >> 51; t3 &= MASK51 as u128; t4 += c; let c = t4 >> 51; t4 &= MASK51 as u128; t0 += c * 19; let c = t0 >> 51; t0 &= MASK51 as u128; t1 += c; Fe([t0 as u64, t1 as u64, t2 as u64, t3 as u64, t4 as u64]) } // --------------------------------------------------------------------------- // Field reduction and canonical form // --------------------------------------------------------------------------- /// Fully reduce to canonical form in [0, p). fn fe_reduce(a: &Fe) -> Fe { let mut h = *a; // First, carry-propagate let mut c = h.0[0] >> 51; h.0[0] &= MASK51; h.0[1] += c; c = h.0[1] >> 51; h.0[1] &= MASK51; h.0[2] += c; c = h.0[2] >> 51; h.0[2] &= MASK51; h.0[3] += c; c = h.0[3] >> 51; h.0[3] &= MASK51; h.0[4] += c; c = h.0[4] >> 51; h.0[4] &= MASK51; h.0[0] += c * 19; c = h.0[0] >> 51; h.0[0] &= MASK51; h.0[1] += c; // Now test if h >= p by computing h + 19 and checking overflow let mut q = (h.0[0] + 19) >> 51; q = (h.0[1] + q) >> 51; q = (h.0[2] + q) >> 51; q = (h.0[3] + q) >> 51; q = (h.0[4] + q) >> 51; // q is 1 if h >= p, 0 otherwise h.0[0] += 19 * q; // Carry-propagate again c = h.0[0] >> 51; h.0[0] &= MASK51; h.0[1] += c; c = h.0[1] >> 51; h.0[1] &= MASK51; h.0[2] += c; c = h.0[2] >> 51; h.0[2] &= MASK51; h.0[3] += c; c = h.0[3] >> 51; h.0[3] &= MASK51; h.0[4] += c; h.0[4] &= MASK51; h } // --------------------------------------------------------------------------- // Field inversion via Fermat's little theorem: a^(p-2) // --------------------------------------------------------------------------- /// Compute a^(2^n) by repeated squaring. fn fe_sq_n(a: &Fe, n: usize) -> Fe { let mut r = fe_sq(a); for _ in 1..n { r = fe_sq(&r); } r } /// Compute a^(p-2) = a^(2^255 - 21) using an addition chain. fn fe_invert(a: &Fe) -> Fe { let z1 = *a; // t0 = a^2 let t0 = fe_sq(&z1); // t1 = a^4 let t1 = fe_sq(&t0); // t1 = a^8 let t1 = fe_sq(&t1); // t1 = a^9 = a^(8+1) let t1 = fe_mul(&t1, &z1); // t0 = a^11 = a^(9+2) let t0 = fe_mul(&t0, &t1); // t2 = a^22 let t2 = fe_sq(&t0); // t1 = a^31 = a^(22+9) = a^(2^5 - 1) let t1 = fe_mul(&t1, &t2); // t2 = a^(2^10 - 2^5) let t2 = fe_sq_n(&t1, 5); // t1 = a^(2^10 - 1) let t1 = fe_mul(&t2, &t1); // t2 = a^(2^20 - 2^10) let t2 = fe_sq_n(&t1, 10); // t2 = a^(2^20 - 1) let t2 = fe_mul(&t2, &t1); // t3 = a^(2^40 - 2^20) let t3 = fe_sq_n(&t2, 20); // t2 = a^(2^40 - 1) let t2 = fe_mul(&t3, &t2); // t2 = a^(2^50 - 2^10) let t2 = fe_sq_n(&t2, 10); // t1 = a^(2^50 - 1) let t1 = fe_mul(&t2, &t1); // t2 = a^(2^100 - 2^50) let t2 = fe_sq_n(&t1, 50); // t2 = a^(2^100 - 1) let t2 = fe_mul(&t2, &t1); // t3 = a^(2^200 - 2^100) let t3 = fe_sq_n(&t2, 100); // t2 = a^(2^200 - 1) let t2 = fe_mul(&t3, &t2); // t2 = a^(2^250 - 2^50) let t2 = fe_sq_n(&t2, 50); // t1 = a^(2^250 - 1) let t1 = fe_mul(&t2, &t1); // t1 = a^(2^255 - 2^5) let t1 = fe_sq_n(&t1, 5); // a^(2^255 - 21) = a^(p-2) fe_mul(&t1, &t0) } // --------------------------------------------------------------------------- // Constant-time utilities // --------------------------------------------------------------------------- /// Constant-time conditional swap: swap a and b if swap == 1, no-op if swap == 0. fn fe_cswap(a: &mut Fe, b: &mut Fe, swap: u64) { let mask = 0u64.wrapping_sub(swap); for i in 0..5 { let t = mask & (a.0[i] ^ b.0[i]); a.0[i] ^= t; b.0[i] ^= t; } } // --------------------------------------------------------------------------- // Scalar clamping // --------------------------------------------------------------------------- fn clamp_scalar(s: &[u8; 32]) -> [u8; 32] { let mut k = *s; k[0] &= 248; k[31] &= 127; k[31] |= 64; k } // --------------------------------------------------------------------------- // Montgomery ladder scalar multiplication // --------------------------------------------------------------------------- fn x25519_scalar_mult(scalar: &[u8; 32], u_point: &[u8; 32]) -> [u8; 32] { let mut u_bytes = *u_point; u_bytes[31] &= 127; // Mask bit 255 per RFC 7748 let u = fe_frombytes(&u_bytes); let mut x_2 = Fe::ONE; let mut z_2 = Fe::ZERO; let mut x_3 = u; let mut z_3 = Fe::ONE; let mut swap: u64 = 0; // Montgomery ladder: iterate from bit 254 down to 0 for t in (0..=254).rev() { let k_t = ((scalar[t >> 3] >> (t & 7)) & 1) as u64; swap ^= k_t; fe_cswap(&mut x_2, &mut x_3, swap); fe_cswap(&mut z_2, &mut z_3, swap); swap = k_t; let a = fe_add(&x_2, &z_2); let aa = fe_sq(&a); let b = fe_sub(&x_2, &z_2); let bb = fe_sq(&b); let e = fe_sub(&aa, &bb); let c = fe_add(&x_3, &z_3); let d = fe_sub(&x_3, &z_3); let da = fe_mul(&d, &a); let cb = fe_mul(&c, &b); x_3 = fe_sq(&fe_add(&da, &cb)); z_3 = fe_mul(&u, &fe_sq(&fe_sub(&da, &cb))); x_2 = fe_mul(&aa, &bb); z_2 = fe_mul(&e, &fe_add(&aa, &fe_mul_a24(&e))); } fe_cswap(&mut x_2, &mut x_3, swap); fe_cswap(&mut z_2, &mut z_3, swap); let result = fe_mul(&x_2, &fe_invert(&z_2)); fe_tobytes(&result) } // --------------------------------------------------------------------------- // Public API // --------------------------------------------------------------------------- /// Compute X25519 Diffie-Hellman: scalar multiplication of `u_point` by `scalar`. /// /// The scalar is clamped internally per RFC 7748. pub fn x25519(scalar: &[u8; 32], u_point: &[u8; 32]) -> [u8; 32] { let k = clamp_scalar(scalar); x25519_scalar_mult(&k, u_point) } /// Compute the X25519 public key from a private key (basepoint multiplication). pub fn x25519_base(scalar: &[u8; 32]) -> [u8; 32] { x25519(scalar, &BASEPOINT) } // --------------------------------------------------------------------------- // Tests // --------------------------------------------------------------------------- #[cfg(test)] mod tests { use super::*; fn hex(bytes: &[u8]) -> String { bytes.iter().map(|b| format!("{:02x}", b)).collect() } fn from_hex(s: &str) -> Vec { (0..s.len()) .step_by(2) .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap()) .collect() } fn hex32(s: &str) -> [u8; 32] { let v = from_hex(s); let mut out = [0u8; 32]; out.copy_from_slice(&v); out } // --- Field arithmetic tests --- #[test] fn fe_encode_decode_roundtrip() { let bytes: [u8; 32] = [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, ]; let fe = fe_frombytes(&bytes); let out = fe_tobytes(&fe); assert_eq!(bytes, out); } #[test] fn fe_mul_identity() { let bytes = hex32("0900000000000000000000000000000000000000000000000000000000000000"); let a = fe_frombytes(&bytes); let result = fe_mul(&a, &Fe::ONE); assert_eq!(fe_tobytes(&result), bytes); } #[test] fn fe_mul_commutative() { let a = fe_frombytes(&hex32( "0900000000000000000000000000000000000000000000000000000000000000", )); let b = fe_frombytes(&hex32( "0500000000000000000000000000000000000000000000000000000000000000", )); let ab = fe_tobytes(&fe_mul(&a, &b)); let ba = fe_tobytes(&fe_mul(&b, &a)); assert_eq!(ab, ba); } #[test] fn fe_sq_matches_mul() { let a = fe_frombytes(&hex32( "0900000000000000000000000000000000000000000000000000000000000000", )); let sq = fe_tobytes(&fe_sq(&a)); let mul = fe_tobytes(&fe_mul(&a, &a)); assert_eq!(sq, mul); } #[test] fn fe_invert_roundtrip() { let a = fe_frombytes(&hex32( "0900000000000000000000000000000000000000000000000000000000000000", )); let inv = fe_invert(&a); let product = fe_mul(&a, &inv); assert_eq!( fe_tobytes(&product), hex32("0100000000000000000000000000000000000000000000000000000000000000") ); } #[test] fn fe_add_sub_roundtrip() { let a = fe_frombytes(&hex32( "0900000000000000000000000000000000000000000000000000000000000000", )); let b = fe_frombytes(&hex32( "0500000000000000000000000000000000000000000000000000000000000000", )); let sum = fe_add(&a, &b); let diff = fe_sub(&sum, &b); assert_eq!(fe_tobytes(&diff), fe_tobytes(&a)); } // --- RFC 7748 §6.1 test vectors --- #[test] fn rfc7748_section_6_1_alice_public_key() { let alice_private = hex32("77076d0a7318a57d3c16c17251b26645df4c2f87ebc0992ab177fba51db92c2a"); let expected = hex32("8520f0098930a754748b7ddcb43ef75a0dbf3a0d26381af4eba4a98eaa9b4e6a"); assert_eq!(x25519_base(&alice_private), expected); } #[test] fn rfc7748_section_6_1_bob_public_key() { let bob_private = hex32("5dab087e624a8a4b79e17f8b83800ee66f3bb1292618b6fd1c2f8b27ff88e0eb"); let expected = hex32("de9edb7d7b7dc1b4d35b61c2ece435373f8343c85b78674dadfc7e146f882b4f"); assert_eq!(x25519_base(&bob_private), expected); } #[test] fn rfc7748_section_6_1_shared_secret() { let alice_private = hex32("77076d0a7318a57d3c16c17251b26645df4c2f87ebc0992ab177fba51db92c2a"); let bob_public = hex32("de9edb7d7b7dc1b4d35b61c2ece435373f8343c85b78674dadfc7e146f882b4f"); let expected = hex32("4a5d9d5ba4ce2de1728e3bf480350f25e07e21c947d19e3376f09b3c1e161742"); assert_eq!(x25519(&alice_private, &bob_public), expected); } #[test] fn rfc7748_section_6_1_shared_secret_both_sides() { let alice_private = hex32("77076d0a7318a57d3c16c17251b26645df4c2f87ebc0992ab177fba51db92c2a"); let bob_private = hex32("5dab087e624a8a4b79e17f8b83800ee66f3bb1292618b6fd1c2f8b27ff88e0eb"); let alice_public = x25519_base(&alice_private); let bob_public = x25519_base(&bob_private); let shared_ab = x25519(&alice_private, &bob_public); let shared_ba = x25519(&bob_private, &alice_public); assert_eq!(shared_ab, shared_ba); assert_eq!( hex(&shared_ab), "4a5d9d5ba4ce2de1728e3bf480350f25e07e21c947d19e3376f09b3c1e161742" ); } // --- RFC 7748 §5.2 iterated test vectors --- #[test] fn rfc7748_section_5_2_one_iteration() { let mut k = hex32("0900000000000000000000000000000000000000000000000000000000000000"); let mut u = k; let output = x25519(&k, &u); u = k; k = output; assert_eq!( hex(&k), "422c8e7a6227d7bca1350b3e2bb7279f7897b87bb6854b783c60e80311ae3079" ); let _ = u; // suppress unused warning } #[test] fn rfc7748_section_5_2_1000_iterations() { let mut k = hex32("0900000000000000000000000000000000000000000000000000000000000000"); let mut u = k; for _ in 0..1000 { let output = x25519(&k, &u); u = k; k = output; } assert_eq!( hex(&k), "684cf59ba83309552800ef566f2f4d3c1c3887c49360e3875f2eb94d99532c51" ); } #[test] #[ignore] // Takes too long for regular test runs fn rfc7748_section_5_2_1000000_iterations() { let mut k = hex32("0900000000000000000000000000000000000000000000000000000000000000"); let mut u = k; for _ in 0..1_000_000 { let output = x25519(&k, &u); u = k; k = output; } assert_eq!( hex(&k), "7c3911e0ab2586fd864497297e575e6f3bc601c0883c30df5f4dd2d24f665424" ); } // --- Low-order point test --- #[test] fn x25519_zero_point_gives_zero() { let scalar = hex32("77076d0a7318a57d3c16c17251b26645df4c2f87ebc0992ab177fba51db92c2a"); let zero_point = [0u8; 32]; let result = x25519(&scalar, &zero_point); assert_eq!(result, [0u8; 32]); } }