From 80a7e4455b7e6713ea0f345db01307c0c71ba418 Mon Sep 17 00:00:00 2001 From: Remco Bloemen Date: Thu, 3 Aug 2023 18:58:57 +0200 Subject: [PATCH 1/6] Constant space from base conversions --- src/base_convert.rs | 99 ++++++++++++++++++++++----------------------- 1 file changed, 48 insertions(+), 51 deletions(-) diff --git a/src/base_convert.rs b/src/base_convert.rs index 3c21d16..daa7055 100644 --- a/src/base_convert.rs +++ b/src/base_convert.rs @@ -85,29 +85,6 @@ impl Uint { } } - /// Adds a digit in base `base` to the number. This is used internally by - /// [`Uint::from_base_le`] and [`Uint::from_base_be`]. - #[inline] - fn add_digit(&mut self, digit: u64, base: u64) -> Result<(), BaseConvertError> { - if digit >= base { - return Err(BaseConvertError::InvalidDigit(digit, base)); - } - // Multiply by base. - // OPT: keep track of non-zero limbs and mul the minimum. - let mut carry: u128 = u128::from(digit); - #[allow(clippy::cast_possible_truncation)] - for limb in &mut self.limbs { - carry += u128::from(*limb) * u128::from(base); - *limb = carry as u64; - carry >>= 64; - } - if carry > 0 || (LIMBS != 0 && self.limbs[LIMBS - 1] > Self::MASK) { - return Err(BaseConvertError::Overflow); - } - - Ok(()) - } - /// Constructs the [`Uint`] from digits in the base `base` in little-endian. /// /// # Errors @@ -124,36 +101,42 @@ impl Uint { if base < 2 { return Err(BaseConvertError::InvalidBase(base)); } - - let mut tail = digits.into_iter(); - match tail.next() { - Some(digit) => Self::from_base_le_recurse(digit, base, &mut tail), - None => Ok(Self::ZERO), + if BITS == 0 { + for digit in digits { + if digit >= base { + return Err(BaseConvertError::InvalidDigit(digit, base)); + } + if digit != 0 { + return Err(BaseConvertError::Overflow); + } + } + return Ok(Self::ZERO); } - } - /// This is the recursive part of [`Uint::from_base_le`]. - /// - /// We drain the iterator via the recursive calls, and then perform the - /// same construction loop as [`Uint::from_base_be`] while exiting the - /// recursive callstack. - #[inline] - fn from_base_le_recurse>( - digit: u64, - base: u64, - tail: &mut I, - ) -> Result { - if digit > base { - return Err(BaseConvertError::InvalidDigit(digit, base)); + let mut iter = digits.into_iter(); + let mut result = Self::ZERO; + let mut power = Self::from(1); + while let Some(digit) = iter.next() { + if digit >= base { + return Err(BaseConvertError::InvalidDigit(digit, base)); + } + let term = power.checked_mul(Self::from(digit)).ok_or(BaseConvertError::Overflow)?; + result = result.checked_add(term).ok_or(BaseConvertError::Overflow)?; + if let Some(next_power) = power.checked_mul(Self::from(base)) { + power = next_power; + } else { + break; + } } - - let mut acc = match tail.next() { - Some(digit) => Self::from_base_le_recurse::(digit, base, tail)?, - None => Self::ZERO, - }; - - acc.add_digit(digit, base)?; - Ok(acc) + while let Some(digit) = iter.next() { + if digit >= base { + return Err(BaseConvertError::InvalidDigit(digit, base)); + } + if digit != 0 { + return Err(BaseConvertError::Overflow); + } + } + Ok(result) } /// Constructs the [`Uint`] from digits in the base `base` in big-endian. @@ -178,7 +161,21 @@ impl Uint { let mut result = Self::ZERO; for digit in digits { - result.add_digit(digit, base)?; + if digit >= base { + return Err(BaseConvertError::InvalidDigit(digit, base)); + } + // Multiply by base. + // OPT: keep track of non-zero limbs and mul the minimum. + let mut carry: u128 = u128::from(digit); + #[allow(clippy::cast_possible_truncation)] + for limb in &mut result.limbs { + carry += u128::from(*limb) * u128::from(base); + *limb = carry as u64; + carry >>= 64; + } + if carry > 0 || (LIMBS != 0 && result.limbs[LIMBS - 1] > Self::MASK) { + return Err(BaseConvertError::Overflow); + } } Ok(result) From dff57c6bb6aa1b0b6d6889ce471f7cbb03688aca Mon Sep 17 00:00:00 2001 From: Remco Bloemen Date: Thu, 3 Aug 2023 19:25:16 +0200 Subject: [PATCH 2/6] Use single-limb mul algorithms --- src/algorithms/mod.rs | 2 +- src/algorithms/mul.rs | 13 ++++++++++++- src/base_convert.rs | 18 ++++++++++++------ 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/src/algorithms/mod.rs b/src/algorithms/mod.rs index a96ec9a..4be8fc1 100644 --- a/src/algorithms/mod.rs +++ b/src/algorithms/mod.rs @@ -16,7 +16,7 @@ pub use self::{ add::{adc_n, cmp, sbb_n}, div::div, gcd::{gcd, gcd_extended, inv_mod, LehmerMatrix}, - mul::{add_nx1, addmul, addmul_n, addmul_nx1, addmul_ref, submul_nx1}, + mul::{mul_nx1, add_nx1, addmul, addmul_n, addmul_nx1, addmul_ref, submul_nx1}, ops::{adc, sbb}, shift::{shift_left_small, shift_right_small}, }; diff --git a/src/algorithms/mul.rs b/src/algorithms/mul.rs index 1d6770e..34b2e30 100644 --- a/src/algorithms/mul.rs +++ b/src/algorithms/mul.rs @@ -224,7 +224,18 @@ fn mac(lhs: &mut u64, a: u64, b: u64, c: u64) -> u64 { prod.high() } -/// Computes `lhs += a * b` and returns the borrow. +/// Computes `lhs *= a` and returns the borrow. +pub fn mul_nx1(lhs: &mut [u64], a: u64) -> u64 { + let mut carry = 0; + for lhs in lhs.iter_mut() { + let product = u128::muladd(*lhs, a, carry); + *lhs = product.low(); + carry = product.high(); + } + carry +} + +/// Computes `lhs += a * b` and returns the carry. /// /// Requires `lhs.len() == a.len()`. /// diff --git a/src/base_convert.rs b/src/base_convert.rs index daa7055..89dd375 100644 --- a/src/base_convert.rs +++ b/src/base_convert.rs @@ -1,4 +1,4 @@ -use crate::Uint; +use crate::{Uint, algorithms::{mul_nx1, addmul_nx1}}; use core::fmt; /// Error for [`from_base_le`][Uint::from_base_le] and @@ -120,11 +120,17 @@ impl Uint { if digit >= base { return Err(BaseConvertError::InvalidDigit(digit, base)); } - let term = power.checked_mul(Self::from(digit)).ok_or(BaseConvertError::Overflow)?; - result = result.checked_add(term).ok_or(BaseConvertError::Overflow)?; - if let Some(next_power) = power.checked_mul(Self::from(base)) { - power = next_power; - } else { + + // Add digit to result + let overflow = addmul_nx1(&mut result.limbs, &power.limbs, digit); + if overflow != 0 || result.limbs[LIMBS - 1] > Self::MASK { + return Err(BaseConvertError::Overflow); + } + + // Update power + let overflow = mul_nx1(&mut power.limbs, base); + if overflow != 0 || power.limbs[LIMBS - 1] > Self::MASK { + // Following digits must be zero break; } } From 924ea2e4132db6f7764d1dbd0b0e2741e3771be8 Mon Sep 17 00:00:00 2001 From: Remco Bloemen Date: Thu, 3 Aug 2023 19:28:17 +0200 Subject: [PATCH 3/6] Typo --- src/algorithms/mul.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/algorithms/mul.rs b/src/algorithms/mul.rs index 34b2e30..651ba6c 100644 --- a/src/algorithms/mul.rs +++ b/src/algorithms/mul.rs @@ -224,7 +224,7 @@ fn mac(lhs: &mut u64, a: u64, b: u64, c: u64) -> u64 { prod.high() } -/// Computes `lhs *= a` and returns the borrow. +/// Computes `lhs *= a` and returns the carry. pub fn mul_nx1(lhs: &mut [u64], a: u64) -> u64 { let mut carry = 0; for lhs in lhs.iter_mut() { From bfdb1ec3de0f990ee51b3b67286c4e7c71983243 Mon Sep 17 00:00:00 2001 From: James Date: Sun, 13 Aug 2023 12:53:03 -0700 Subject: [PATCH 4/6] chore: fmt --- src/algorithms/mod.rs | 2 +- src/base_convert.rs | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/algorithms/mod.rs b/src/algorithms/mod.rs index 4be8fc1..08a9603 100644 --- a/src/algorithms/mod.rs +++ b/src/algorithms/mod.rs @@ -16,7 +16,7 @@ pub use self::{ add::{adc_n, cmp, sbb_n}, div::div, gcd::{gcd, gcd_extended, inv_mod, LehmerMatrix}, - mul::{mul_nx1, add_nx1, addmul, addmul_n, addmul_nx1, addmul_ref, submul_nx1}, + mul::{add_nx1, addmul, addmul_n, addmul_nx1, addmul_ref, mul_nx1, submul_nx1}, ops::{adc, sbb}, shift::{shift_left_small, shift_right_small}, }; diff --git a/src/base_convert.rs b/src/base_convert.rs index 89dd375..d34b369 100644 --- a/src/base_convert.rs +++ b/src/base_convert.rs @@ -1,4 +1,7 @@ -use crate::{Uint, algorithms::{mul_nx1, addmul_nx1}}; +use crate::{ + algorithms::{addmul_nx1, mul_nx1}, + Uint, +}; use core::fmt; /// Error for [`from_base_le`][Uint::from_base_le] and From feffcb96330155917e2965b7dc65a9f7dacc7896 Mon Sep 17 00:00:00 2001 From: James Date: Sun, 13 Aug 2023 12:57:30 -0700 Subject: [PATCH 5/6] lint: clippy --- src/aliases.rs | 2 +- src/base_convert.rs | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/aliases.rs b/src/aliases.rs index c58c3ed..93b004a 100644 --- a/src/aliases.rs +++ b/src/aliases.rs @@ -68,7 +68,7 @@ pub mod tests { use super::*; #[test] - fn instantiate_consts() { + const fn instantiate_consts() { let _ = (U0::ZERO, U0::MAX, B0::ZERO); let _ = (U1::ZERO, U1::MAX, B1::ZERO); let _ = (U8::ZERO, U8::MAX, B8::ZERO); diff --git a/src/base_convert.rs b/src/base_convert.rs index d34b369..65bfe9e 100644 --- a/src/base_convert.rs +++ b/src/base_convert.rs @@ -119,7 +119,7 @@ impl Uint { let mut iter = digits.into_iter(); let mut result = Self::ZERO; let mut power = Self::from(1); - while let Some(digit) = iter.next() { + for digit in iter.by_ref() { if digit >= base { return Err(BaseConvertError::InvalidDigit(digit, base)); } @@ -137,7 +137,7 @@ impl Uint { break; } } - while let Some(digit) = iter.next() { + for digit in iter { if digit >= base { return Err(BaseConvertError::InvalidDigit(digit, base)); } @@ -315,20 +315,20 @@ mod tests { #[test] fn test_from_base_be_overflow() { assert_eq!( - Uint::<0, 0>::from_base_be(10, [].into_iter()), + Uint::<0, 0>::from_base_be(10, std::iter::empty()), Ok(Uint::<0, 0>::ZERO) ); assert_eq!( - Uint::<0, 0>::from_base_be(10, [0].into_iter()), + Uint::<0, 0>::from_base_be(10, std::iter::once(0)), Ok(Uint::<0, 0>::ZERO) ); assert_eq!( - Uint::<0, 0>::from_base_be(10, [1].into_iter()), + Uint::<0, 0>::from_base_be(10, std::iter::once(1)), Err(BaseConvertError::Overflow) ); assert_eq!( Uint::<1, 1>::from_base_be(10, [1, 0, 0].into_iter()), Err(BaseConvertError::Overflow) - ) + ); } } From bbe5beff99879ca0e855f2fda56ce7291647ef9f Mon Sep 17 00:00:00 2001 From: James Date: Sun, 13 Aug 2023 13:03:18 -0700 Subject: [PATCH 6/6] lint: clippy in features too --- src/algorithms/mul.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/algorithms/mul.rs b/src/algorithms/mul.rs index 651ba6c..95cb9c0 100644 --- a/src/algorithms/mul.rs +++ b/src/algorithms/mul.rs @@ -227,7 +227,7 @@ fn mac(lhs: &mut u64, a: u64, b: u64, c: u64) -> u64 { /// Computes `lhs *= a` and returns the carry. pub fn mul_nx1(lhs: &mut [u64], a: u64) -> u64 { let mut carry = 0; - for lhs in lhs.iter_mut() { + for lhs in &mut *lhs { let product = u128::muladd(*lhs, a, carry); *lhs = product.low(); carry = product.high();