Skip to content

Commit

Permalink
Merge pull request #293 from recmo/remco/base-to
Browse files Browse the repository at this point in the history
Constant space from base conversions
  • Loading branch information
prestwich authored Aug 13, 2023
2 parents 1b14a40 + bbe5bef commit 55c5411
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 58 deletions.
2 changes: 1 addition & 1 deletion src/algorithms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub use self::{
add::{adc_n, cmp, sbb_n},
div::div,
gcd::{gcd, gcd_extended, inv_mod, LehmerMatrix},
mul::{add_nx1, addmul, addmul_n, addmul_nx1, addmul_ref, submul_nx1},
mul::{add_nx1, addmul, addmul_n, addmul_nx1, addmul_ref, mul_nx1, submul_nx1},
ops::{adc, sbb},
shift::{shift_left_small, shift_right_small},
};
Expand Down
13 changes: 12 additions & 1 deletion src/algorithms/mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,18 @@ fn mac(lhs: &mut u64, a: u64, b: u64, c: u64) -> u64 {
prod.high()
}

/// Computes `lhs += a * b` and returns the borrow.
/// Computes `lhs *= a` and returns the carry.
pub fn mul_nx1(lhs: &mut [u64], a: u64) -> u64 {
let mut carry = 0;
for lhs in &mut *lhs {
let product = u128::muladd(*lhs, a, carry);
*lhs = product.low();
carry = product.high();
}
carry
}

/// Computes `lhs += a * b` and returns the carry.
///
/// Requires `lhs.len() == a.len()`.
///
Expand Down
2 changes: 1 addition & 1 deletion src/aliases.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ pub mod tests {
use super::*;

#[test]
fn instantiate_consts() {
const fn instantiate_consts() {
let _ = (U0::ZERO, U0::MAX, B0::ZERO);
let _ = (U1::ZERO, U1::MAX, B1::ZERO);
let _ = (U8::ZERO, U8::MAX, B8::ZERO);
Expand Down
116 changes: 61 additions & 55 deletions src/base_convert.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::Uint;
use crate::{
algorithms::{addmul_nx1, mul_nx1},
Uint,
};
use core::fmt;

/// Error for [`from_base_le`][Uint::from_base_le] and
Expand Down Expand Up @@ -85,29 +88,6 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
}
}

/// Adds a digit in base `base` to the number. This is used internally by
/// [`Uint::from_base_le`] and [`Uint::from_base_be`].
#[inline]
fn add_digit(&mut self, digit: u64, base: u64) -> Result<(), BaseConvertError> {
if digit >= base {
return Err(BaseConvertError::InvalidDigit(digit, base));
}
// Multiply by base.
// OPT: keep track of non-zero limbs and mul the minimum.
let mut carry: u128 = u128::from(digit);
#[allow(clippy::cast_possible_truncation)]
for limb in &mut self.limbs {
carry += u128::from(*limb) * u128::from(base);
*limb = carry as u64;
carry >>= 64;
}
if carry > 0 || (LIMBS != 0 && self.limbs[LIMBS - 1] > Self::MASK) {
return Err(BaseConvertError::Overflow);
}

Ok(())
}

/// Constructs the [`Uint`] from digits in the base `base` in little-endian.
///
/// # Errors
Expand All @@ -124,36 +104,48 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
if base < 2 {
return Err(BaseConvertError::InvalidBase(base));
}

let mut tail = digits.into_iter();
match tail.next() {
Some(digit) => Self::from_base_le_recurse(digit, base, &mut tail),
None => Ok(Self::ZERO),
if BITS == 0 {
for digit in digits {
if digit >= base {
return Err(BaseConvertError::InvalidDigit(digit, base));
}
if digit != 0 {
return Err(BaseConvertError::Overflow);
}
}
return Ok(Self::ZERO);
}
}

/// This is the recursive part of [`Uint::from_base_le`].
///
/// We drain the iterator via the recursive calls, and then perform the
/// same construction loop as [`Uint::from_base_be`] while exiting the
/// recursive callstack.
#[inline]
fn from_base_le_recurse<I: Iterator<Item = u64>>(
digit: u64,
base: u64,
tail: &mut I,
) -> Result<Self, BaseConvertError> {
if digit > base {
return Err(BaseConvertError::InvalidDigit(digit, base));
}
let mut iter = digits.into_iter();
let mut result = Self::ZERO;
let mut power = Self::from(1);
for digit in iter.by_ref() {
if digit >= base {
return Err(BaseConvertError::InvalidDigit(digit, base));
}

let mut acc = match tail.next() {
Some(digit) => Self::from_base_le_recurse::<I>(digit, base, tail)?,
None => Self::ZERO,
};
// Add digit to result
let overflow = addmul_nx1(&mut result.limbs, &power.limbs, digit);
if overflow != 0 || result.limbs[LIMBS - 1] > Self::MASK {
return Err(BaseConvertError::Overflow);
}

acc.add_digit(digit, base)?;
Ok(acc)
// Update power
let overflow = mul_nx1(&mut power.limbs, base);
if overflow != 0 || power.limbs[LIMBS - 1] > Self::MASK {
// Following digits must be zero
break;
}
}
for digit in iter {
if digit >= base {
return Err(BaseConvertError::InvalidDigit(digit, base));
}
if digit != 0 {
return Err(BaseConvertError::Overflow);
}
}
Ok(result)
}

/// Constructs the [`Uint`] from digits in the base `base` in big-endian.
Expand All @@ -178,7 +170,21 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {

let mut result = Self::ZERO;
for digit in digits {
result.add_digit(digit, base)?;
if digit >= base {
return Err(BaseConvertError::InvalidDigit(digit, base));
}
// Multiply by base.
// OPT: keep track of non-zero limbs and mul the minimum.
let mut carry: u128 = u128::from(digit);
#[allow(clippy::cast_possible_truncation)]
for limb in &mut result.limbs {
carry += u128::from(*limb) * u128::from(base);
*limb = carry as u64;
carry >>= 64;
}
if carry > 0 || (LIMBS != 0 && result.limbs[LIMBS - 1] > Self::MASK) {
return Err(BaseConvertError::Overflow);
}
}

Ok(result)
Expand Down Expand Up @@ -309,20 +315,20 @@ mod tests {
#[test]
fn test_from_base_be_overflow() {
assert_eq!(
Uint::<0, 0>::from_base_be(10, [].into_iter()),
Uint::<0, 0>::from_base_be(10, std::iter::empty()),
Ok(Uint::<0, 0>::ZERO)
);
assert_eq!(
Uint::<0, 0>::from_base_be(10, [0].into_iter()),
Uint::<0, 0>::from_base_be(10, std::iter::once(0)),
Ok(Uint::<0, 0>::ZERO)
);
assert_eq!(
Uint::<0, 0>::from_base_be(10, [1].into_iter()),
Uint::<0, 0>::from_base_be(10, std::iter::once(1)),
Err(BaseConvertError::Overflow)
);
assert_eq!(
Uint::<1, 1>::from_base_be(10, [1, 0, 0].into_iter()),
Err(BaseConvertError::Overflow)
)
);
}
}

0 comments on commit 55c5411

Please sign in to comment.