//! Arbitrary-precision unsigned integer arithmetic for RSA. //! //! Provides a `BigUint` type backed by little-endian `u64` limbs, with //! modular exponentiation (square-and-multiply) for up to 4096-bit keys. use core::cmp::Ordering; /// Arbitrary-precision unsigned integer stored as little-endian `u64` limbs. #[derive(Clone, Debug)] pub struct BigUint { /// Limbs in little-endian order (limbs[0] is the least significant). limbs: Vec, } impl BigUint { /// The number zero. pub fn zero() -> Self { Self { limbs: vec![0] } } /// The number one. pub fn one() -> Self { Self { limbs: vec![1] } } /// Create from a single u64. pub fn from_u64(v: u64) -> Self { Self { limbs: vec![v] } } /// Create from big-endian bytes (as used in DER INTEGER encoding). pub fn from_be_bytes(bytes: &[u8]) -> Self { if bytes.is_empty() { return Self::zero(); } // Convert big-endian bytes to little-endian u64 limbs. let mut limbs = Vec::new(); let mut i = bytes.len(); while i > 0 { let start = i.saturating_sub(8); let chunk = &bytes[start..i]; let mut val = 0u64; for &b in chunk { val = (val << 8) | b as u64; } limbs.push(val); i = start; } let mut result = Self { limbs }; result.normalize(); result } /// Export to big-endian bytes. pub fn to_be_bytes(&self) -> Vec { if self.is_zero() { return vec![0]; } let mut bytes = Vec::new(); // Start from the most significant limb. let mut started = false; for &limb in self.limbs.iter().rev() { for shift in (0..8).rev() { let byte = (limb >> (shift * 8)) as u8; if !started && byte == 0 { continue; } started = true; bytes.push(byte); } } if bytes.is_empty() { bytes.push(0); } bytes } /// Export to big-endian bytes, zero-padded to exactly `len` bytes. pub fn to_be_bytes_padded(&self, len: usize) -> Vec { let raw = self.to_be_bytes(); if raw.len() >= len { // Take the last `len` bytes (truncate leading bytes). return raw[raw.len() - len..].to_vec(); } let mut out = vec![0u8; len - raw.len()]; out.extend_from_slice(&raw); out } /// True if the value is zero. pub fn is_zero(&self) -> bool { self.limbs.iter().all(|&l| l == 0) } /// Number of significant bits. pub fn bit_len(&self) -> usize { if self.is_zero() { return 0; } let top = self.limbs.len() - 1; let top_bits = 64 - self.limbs[top].leading_zeros() as usize; top * 64 + top_bits } /// Get bit at position `i` (0-indexed from LSB). pub fn bit(&self, i: usize) -> bool { let limb_idx = i / 64; let bit_idx = i % 64; if limb_idx >= self.limbs.len() { return false; } (self.limbs[limb_idx] >> bit_idx) & 1 == 1 } /// Remove trailing zero limbs (keep at least one limb). fn normalize(&mut self) { while self.limbs.len() > 1 && *self.limbs.last().unwrap() == 0 { self.limbs.pop(); } } /// Addition: self + other. pub fn add(&self, other: &Self) -> Self { let max_len = self.limbs.len().max(other.limbs.len()); let mut result = Vec::with_capacity(max_len + 1); let mut carry = 0u64; for i in 0..max_len { let a = if i < self.limbs.len() { self.limbs[i] } else { 0 }; let b = if i < other.limbs.len() { other.limbs[i] } else { 0 }; let (sum1, c1) = a.overflowing_add(b); let (sum2, c2) = sum1.overflowing_add(carry); result.push(sum2); carry = (c1 as u64) + (c2 as u64); } if carry > 0 { result.push(carry); } let mut r = Self { limbs: result }; r.normalize(); r } /// Subtraction: self - other. Panics if other > self. pub fn sub(&self, other: &Self) -> Self { debug_assert!(self.cmp(other) != Ordering::Less); let mut result = Vec::with_capacity(self.limbs.len()); let mut borrow = 0u64; for i in 0..self.limbs.len() { let a = self.limbs[i]; let b = if i < other.limbs.len() { other.limbs[i] } else { 0 }; let (diff1, b1) = a.overflowing_sub(b); let (diff2, b2) = diff1.overflowing_sub(borrow); result.push(diff2); borrow = (b1 as u64) + (b2 as u64); } let mut r = Self { limbs: result }; r.normalize(); r } /// Multiplication: self * other. pub fn mul(&self, other: &Self) -> Self { let n = self.limbs.len(); let m = other.limbs.len(); let mut result = vec![0u64; n + m]; for i in 0..n { let mut carry = 0u128; for j in 0..m { let prod = (self.limbs[i] as u128) * (other.limbs[j] as u128) + result[i + j] as u128 + carry; result[i + j] = prod as u64; carry = prod >> 64; } result[i + m] = carry as u64; } let mut r = Self { limbs: result }; r.normalize(); r } /// Division with remainder: returns (quotient, remainder). pub fn div_rem(&self, divisor: &Self) -> (Self, Self) { assert!(!divisor.is_zero(), "division by zero"); if self.cmp(divisor) == Ordering::Less { return (Self::zero(), self.clone()); } if divisor.limbs.len() == 1 { return self.div_rem_single(divisor.limbs[0]); } self.div_rem_knuth(divisor) } /// Single-limb division. fn div_rem_single(&self, d: u64) -> (Self, Self) { let mut quotient = vec![0u64; self.limbs.len()]; let mut rem = 0u128; for i in (0..self.limbs.len()).rev() { rem = (rem << 64) | self.limbs[i] as u128; quotient[i] = (rem / d as u128) as u64; rem %= d as u128; } let mut q = Self { limbs: quotient }; q.normalize(); (q, Self::from_u64(rem as u64)) } /// Multi-limb division using Knuth's Algorithm D. fn div_rem_knuth(&self, divisor: &Self) -> (Self, Self) { let n = divisor.limbs.len(); let m = self.limbs.len() - n; // Normalize: shift so that the top bit of divisor is set. let shift = divisor.limbs[n - 1].leading_zeros(); let u = self.shl_bits(shift); let v = divisor.shl_bits(shift); let mut u_limbs = u.limbs.clone(); // Ensure u has m + n + 1 limbs. while u_limbs.len() <= m + n { u_limbs.push(0); } let mut q = vec![0u64; m + 1]; for j in (0..=m).rev() { // Estimate q_hat. let u_hi = ((u_limbs[j + n] as u128) << 64) | u_limbs[j + n - 1] as u128; let mut q_hat = u_hi / v.limbs[n - 1] as u128; let mut r_hat = u_hi % v.limbs[n - 1] as u128; // Refine estimate. loop { if q_hat >= (1u128 << 64) || (q_hat * v.limbs[n - 2] as u128 > ((r_hat << 64) | u_limbs[j + n - 2] as u128)) { q_hat -= 1; r_hat += v.limbs[n - 1] as u128; if r_hat < (1u128 << 64) { continue; } } break; } // Multiply and subtract. let mut borrow: i128 = 0; for i in 0..n { let prod = q_hat * v.limbs[i] as u128; let diff = u_limbs[j + i] as i128 - borrow - (prod as u64) as i128; u_limbs[j + i] = diff as u64; borrow = (prod >> 64) as i128 - (diff >> 64); } let diff = u_limbs[j + n] as i128 - borrow; u_limbs[j + n] = diff as u64; q[j] = q_hat as u64; // If we subtracted too much, add back. if diff < 0 { q[j] -= 1; let mut carry = 0u64; for i in 0..n { let sum = u_limbs[j + i] as u128 + v.limbs[i] as u128 + carry as u128; u_limbs[j + i] = sum as u64; carry = (sum >> 64) as u64; } u_limbs[j + n] = u_limbs[j + n].wrapping_add(carry); } } // Remainder: unnormalize. u_limbs.truncate(n); let rem = Self { limbs: u_limbs }.shr_bits(shift); let mut quotient = Self { limbs: q }; quotient.normalize(); (quotient, rem) } /// Left shift by `bits` positions (bits < 64). fn shl_bits(&self, bits: u32) -> Self { if bits == 0 { return self.clone(); } let mut result = Vec::with_capacity(self.limbs.len() + 1); let mut carry = 0u64; for &limb in &self.limbs { result.push((limb << bits) | carry); carry = limb >> (64 - bits); } if carry > 0 { result.push(carry); } let mut r = Self { limbs: result }; r.normalize(); r } /// Right shift by `bits` positions (bits < 64). fn shr_bits(&self, bits: u32) -> Self { if bits == 0 { return self.clone(); } let mut result = Vec::with_capacity(self.limbs.len()); let mut carry = 0u64; for &limb in self.limbs.iter().rev() { result.push((limb >> bits) | carry); carry = limb << (64 - bits); } result.reverse(); let mut r = Self { limbs: result }; r.normalize(); r } /// Modular exponentiation: self^exp mod modulus. /// Uses left-to-right binary method (square-and-multiply). pub fn modpow(&self, exp: &Self, modulus: &Self) -> Self { assert!(!modulus.is_zero(), "modulus must be non-zero"); if modulus.limbs == [1] { return Self::zero(); } // Montgomery multiplication requires an odd modulus. // RSA moduli are always odd, but handle even case with simple modpow. if modulus.limbs[0] & 1 == 0 { return simple_modpow(self, exp, modulus); } montgomery_modpow(self, exp, modulus) } /// Simple modular reduction: self mod modulus. pub fn modulo(&self, modulus: &Self) -> Self { self.div_rem(modulus).1 } /// Number of limbs. pub fn num_limbs(&self) -> usize { self.limbs.len() } } impl Ord for BigUint { fn cmp(&self, other: &Self) -> Ordering { let a_len = self.limbs.len(); let b_len = other.limbs.len(); let max_len = a_len.max(b_len); for i in (0..max_len).rev() { let a = if i < a_len { self.limbs[i] } else { 0 }; let b = if i < b_len { other.limbs[i] } else { 0 }; match a.cmp(&b) { Ordering::Equal => continue, ord => return ord, } } Ordering::Equal } } impl PartialOrd for BigUint { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } impl PartialEq for BigUint { fn eq(&self, other: &Self) -> bool { self.cmp(other) == Ordering::Equal } } impl Eq for BigUint {} // --------------------------------------------------------------------------- // Montgomery multiplication for efficient modular exponentiation // --------------------------------------------------------------------------- /// Montgomery context for a given modulus. struct Montgomery { /// The modulus N. n: BigUint, /// Number of limbs in N. num_limbs: usize, /// R = 2^(64*num_limbs) (not stored, implicit). /// R^2 mod N — used to convert into Montgomery form. r_squared: BigUint, /// n0_inv = -N^(-1) mod 2^64. n0_inv: u64, } impl Montgomery { fn new(n: &BigUint) -> Self { let num_limbs = n.limbs.len(); // Compute n0_inv = -N^(-1) mod 2^64 using Newton's method. // Start with x = 1 (since n is odd for RSA moduli). let n0 = n.limbs[0]; let mut inv = 1u64; for _ in 0..63 { inv = inv.wrapping_mul(2u64.wrapping_sub(n0.wrapping_mul(inv))); } let n0_inv = inv.wrapping_neg(); // -N^(-1) mod 2^64 // Compute R^2 mod N where R = 2^(64*num_limbs). // Start with R mod N, then square it. let r_mod_n = { // R = 2^(64*num_limbs). We compute R mod N by creating R and reducing. let mut r_limbs = vec![0u64; num_limbs + 1]; r_limbs[num_limbs] = 1; let r = BigUint { limbs: r_limbs }; r.div_rem(n).1 }; let r_squared = r_mod_n.mul(&r_mod_n).div_rem(n).1; Self { n: n.clone(), num_limbs, r_squared, n0_inv, } } /// Convert a value into Montgomery form: aR mod N. fn to_montgomery(&self, a: &BigUint) -> BigUint { self.montgomery_mul(a, &self.r_squared) } /// Convert from Montgomery form back to normal: a * R^(-1) mod N. fn reduce(&self, a: &BigUint) -> BigUint { self.montgomery_mul(a, &BigUint::one()) } /// Montgomery multiplication: computes a * b * R^(-1) mod N. fn montgomery_mul(&self, a: &BigUint, b: &BigUint) -> BigUint { let n = self.num_limbs; let mut t = vec![0u64; 2 * n + 2]; for i in 0..n { // t = t + a[i] * b let ai = if i < a.limbs.len() { a.limbs[i] } else { 0 }; let mut carry = 0u128; for j in 0..n { let bj = if j < b.limbs.len() { b.limbs[j] } else { 0 }; let sum = t[i + j] as u128 + (ai as u128) * (bj as u128) + carry; t[i + j] = sum as u64; carry = sum >> 64; } // Propagate carry. let mut k = i + n; while carry > 0 { let sum = t[k] as u128 + carry; t[k] = sum as u64; carry = sum >> 64; k += 1; } // Montgomery reduction step. let m = t[i].wrapping_mul(self.n0_inv); carry = 0u128; for j in 0..n { let sum = t[i + j] as u128 + (m as u128) * (self.n.limbs[j] as u128) + carry; t[i + j] = sum as u64; carry = sum >> 64; } let mut k = i + n; while carry > 0 { let sum = t[k] as u128 + carry; t[k] = sum as u64; carry = sum >> 64; k += 1; } } // Result is t[n..2n]. let result_limbs: Vec = t[n..2 * n + 1].to_vec(); let mut result = BigUint { limbs: result_limbs, }; result.normalize(); // Final subtraction if result >= N. if result.cmp(&self.n) != Ordering::Less { result = result.sub(&self.n); } result } } /// Simple modular exponentiation (for even moduli where Montgomery doesn't work). fn simple_modpow(base: &BigUint, exp: &BigUint, modulus: &BigUint) -> BigUint { let base_reduced = base.modulo(modulus); let mut result = BigUint::one(); let bits = exp.bit_len(); for i in (0..bits).rev() { result = result.mul(&result).modulo(modulus); if exp.bit(i) { result = result.mul(&base_reduced).modulo(modulus); } } result } /// Modular exponentiation using Montgomery multiplication. fn montgomery_modpow(base: &BigUint, exp: &BigUint, modulus: &BigUint) -> BigUint { let mont = Montgomery::new(modulus); // Convert base to Montgomery form. let base_reduced = base.modulo(modulus); let mut result = mont.to_montgomery(&BigUint::one()); // 1 in Montgomery form = R mod N let base_mont = mont.to_montgomery(&base_reduced); // Left-to-right binary method. let bits = exp.bit_len(); for i in (0..bits).rev() { result = mont.montgomery_mul(&result, &result); // square if exp.bit(i) { result = mont.montgomery_mul(&result, &base_mont); // multiply } } mont.reduce(&result) } // --------------------------------------------------------------------------- // Tests // --------------------------------------------------------------------------- #[cfg(test)] mod tests { use super::*; fn hex(bytes: &[u8]) -> String { bytes.iter().map(|b| format!("{b:02x}")).collect() } fn from_hex(s: &str) -> Vec { let s = s.replace(' ', ""); (0..s.len()) .step_by(2) .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap()) .collect() } #[test] fn zero_and_one() { assert!(BigUint::zero().is_zero()); assert!(!BigUint::one().is_zero()); assert_eq!(BigUint::zero().bit_len(), 0); assert_eq!(BigUint::one().bit_len(), 1); } #[test] fn from_be_bytes_roundtrip() { let bytes = from_hex("0123456789abcdef"); let n = BigUint::from_be_bytes(&bytes); assert_eq!(hex(&n.to_be_bytes()), "0123456789abcdef"); } #[test] fn from_be_bytes_large() { // 128-bit number. let bytes = from_hex("ffffffffffffffffffffffffffffffff"); let n = BigUint::from_be_bytes(&bytes); assert_eq!(hex(&n.to_be_bytes()), "ffffffffffffffffffffffffffffffff"); } #[test] fn addition() { let a = BigUint::from_be_bytes(&from_hex("ffffffffffffffff")); let b = BigUint::from_u64(1); let c = a.add(&b); assert_eq!(hex(&c.to_be_bytes()), "010000000000000000"); } #[test] fn subtraction() { let a = BigUint::from_be_bytes(&from_hex("010000000000000000")); let b = BigUint::from_u64(1); let c = a.sub(&b); assert_eq!(hex(&c.to_be_bytes()), "ffffffffffffffff"); } #[test] fn multiplication() { let a = BigUint::from_u64(0xFFFFFFFF); let b = BigUint::from_u64(0xFFFFFFFF); let c = a.mul(&b); // 0xFFFFFFFF * 0xFFFFFFFF = 0xFFFFFFFE00000001 assert_eq!(hex(&c.to_be_bytes()), "fffffffe00000001"); } #[test] fn division() { let a = BigUint::from_be_bytes(&from_hex("fffffffe00000001")); let b = BigUint::from_u64(0xFFFFFFFF); let (q, r) = a.div_rem(&b); assert_eq!(hex(&q.to_be_bytes()), "ffffffff"); assert!(r.is_zero()); } #[test] fn division_with_remainder() { let a = BigUint::from_u64(17); let b = BigUint::from_u64(5); let (q, r) = a.div_rem(&b); assert_eq!(q, BigUint::from_u64(3)); assert_eq!(r, BigUint::from_u64(2)); } #[test] fn modpow_small() { // 2^10 mod 1000 = 1024 mod 1000 = 24 let base = BigUint::from_u64(2); let exp = BigUint::from_u64(10); let modulus = BigUint::from_u64(1000); let result = base.modpow(&exp, &modulus); assert_eq!(result, BigUint::from_u64(24)); } #[test] fn modpow_fermat() { // Fermat's little theorem: a^(p-1) ≡ 1 (mod p) for prime p, gcd(a,p)=1. // p = 65537 (prime), a = 12345. let a = BigUint::from_u64(12345); let p = BigUint::from_u64(65537); let exp = BigUint::from_u64(65536); // p-1 let result = a.modpow(&exp, &p); assert_eq!(result, BigUint::one()); } #[test] fn modpow_rsa_sized() { // Test with a larger modulus to verify it works at scale. // 3^17 mod 2^128+1 — verify with known computation. let base = BigUint::from_u64(3); let exp = BigUint::from_u64(17); // 2^128 + 1 let mut mod_bytes = vec![0u8; 17]; mod_bytes[0] = 1; mod_bytes[16] = 1; let modulus = BigUint::from_be_bytes(&mod_bytes); let result = base.modpow(&exp, &modulus); // 3^17 = 129140163. This is < 2^128+1, so result = 129140163. assert_eq!(result, BigUint::from_u64(129140163)); } #[test] fn to_be_bytes_padded() { let n = BigUint::from_u64(0xFF); let padded = n.to_be_bytes_padded(4); assert_eq!(padded, vec![0, 0, 0, 0xFF]); } #[test] fn comparison() { let a = BigUint::from_u64(100); let b = BigUint::from_u64(200); assert_eq!(a.cmp(&b), Ordering::Less); assert_eq!(b.cmp(&a), Ordering::Greater); assert_eq!(a.cmp(&a), Ordering::Equal); } #[test] fn multi_limb_division() { // (2^128 - 1) / (2^64 - 1) = 2^64 + 1 let a = BigUint::from_be_bytes(&from_hex("ffffffffffffffffffffffffffffffff")); let b = BigUint::from_be_bytes(&from_hex("ffffffffffffffff")); let (q, r) = a.div_rem(&b); assert_eq!(hex(&q.to_be_bytes()), "010000000000000001"); assert!(r.is_zero()); } #[test] fn modpow_256bit() { // Verify modpow with 256-bit numbers. // base = 2, exp = 255, mod = 2^256 - 189 (a prime near 2^256) // 2^255 mod (2^256 - 189) let base = BigUint::from_u64(2); let exp = BigUint::from_u64(255); // mod = 2^256 - 189 let mut mod_bytes = vec![0xFF; 32]; // 2^256 - 189 = 0xFFFFFF...FF43 mod_bytes[31] = 0x43; // 0xFF - 188 = 0x43 let modulus = BigUint::from_be_bytes(&mod_bytes); let result = base.modpow(&exp, &modulus); // Verify: result^2 * 2 should equal 2^256 mod (2^256 - 189) // 2^256 mod (2^256 - 189) = 189 // So result * 2 mod m should equal... let's just verify result < modulus. assert_eq!(result.cmp(&modulus), Ordering::Less); // Double-check: result should be 2^255. // 2^255 < 2^256 - 189, so result = 2^255 exactly. let expected = { let mut b = vec![0u8; 32]; b[0] = 0x80; BigUint::from_be_bytes(&b) }; assert_eq!(result, expected); } }