From a5804dbc3d4384c3219fd825aa01c799567b7ac8 Mon Sep 17 00:00:00 2001 From: DaniPopes <57450786+DaniPopes@users.noreply.github.com> Date: Sun, 1 Oct 2023 19:59:51 +0200 Subject: [PATCH] chore: loops clean up --- src/algorithms/add.rs | 34 +++++++--------------------------- src/algorithms/div/knuth.rs | 5 +---- src/algorithms/div/small.rs | 1 - src/algorithms/mod.rs | 33 ++++++++++++++++++++++++++++++++- src/bits.rs | 9 +++++++-- src/bytes.rs | 9 ++++++--- src/utils.rs | 4 ++-- 7 files changed, 55 insertions(+), 40 deletions(-) diff --git a/src/algorithms/add.rs b/src/algorithms/add.rs index 74cbb57..92db341 100644 --- a/src/algorithms/add.rs +++ b/src/algorithms/add.rs @@ -1,39 +1,19 @@ -#![allow(dead_code)] // TODO - use super::ops::{adc, sbb}; -use core::cmp::Ordering; - -#[inline(always)] -#[must_use] -pub fn cmp(lhs: &[u64], rhs: &[u64]) -> Ordering { - debug_assert_eq!(lhs.len(), rhs.len()); - for (l, r) in lhs.iter().rev().zip(rhs.iter().rev()) { - match l.cmp(r) { - Ordering::Equal => continue, - other => return other, - } - } - Ordering::Equal -} /// `lhs += rhs + carry` #[inline(always)] pub fn adc_n(lhs: &mut [u64], rhs: &[u64], mut carry: u64) -> u64 { - for (l, r) in lhs.iter_mut().zip(rhs.iter()) { - let (result, new_carry) = adc(*l, *r, carry); - *l = result; - carry = new_carry; + for i in 0..lhs.len() { + (lhs[i], carry) = adc(lhs[i], rhs[i], carry); } carry } -/// `lhs -= rhs + carry` +/// `lhs -= rhs - borrow` #[inline(always)] -pub fn sbb_n(lhs: &mut [u64], rhs: &[u64], mut carry: u64) -> u64 { - for (l, r) in lhs.iter_mut().zip(rhs.iter()) { - let (result, new_carry) = sbb(*l, *r, carry); - *l = result; - carry = new_carry; +pub fn sbb_n(lhs: &mut [u64], rhs: &[u64], mut borrow: u64) -> u64 { + for i in 0..lhs.len() { + (lhs[i], borrow) = sbb(lhs[i], rhs[i], borrow); } - carry + borrow } diff --git a/src/algorithms/div/knuth.rs b/src/algorithms/div/knuth.rs index 57750ad..dbaca63 100644 --- a/src/algorithms/div/knuth.rs +++ b/src/algorithms/div/knuth.rs @@ -192,10 +192,7 @@ pub fn div_nxm(numerator: &mut [u64], divisor: &mut [u64]) { #[cfg(test)] mod tests { use super::*; - use crate::algorithms::{ - add::{cmp, sbb_n}, - addmul, - }; + use crate::algorithms::{addmul, cmp, sbb_n}; use alloc::vec::Vec; use core::cmp::Ordering; use proptest::{ diff --git a/src/algorithms/div/small.rs b/src/algorithms/div/small.rs index 719eecb..98ddcfa 100644 --- a/src/algorithms/div/small.rs +++ b/src/algorithms/div/small.rs @@ -6,7 +6,6 @@ #![allow(clippy::many_single_char_names, clippy::similar_names)] // Truncation is intentional #![allow(clippy::cast_possible_truncation)] -#![allow(dead_code)] // TODO use super::reciprocal::{reciprocal, reciprocal_2}; use crate::{algorithms::DoubleWord, utils::unlikely}; diff --git a/src/algorithms/mod.rs b/src/algorithms/mod.rs index fcce70c..69e09e2 100644 --- a/src/algorithms/mod.rs +++ b/src/algorithms/mod.rs @@ -5,6 +5,8 @@ #![allow(missing_docs)] // TODO: document algorithms +use core::cmp::Ordering; + mod add; pub mod div; mod gcd; @@ -15,7 +17,7 @@ mod ops; mod shift; pub use self::{ - add::{adc_n, cmp, sbb_n}, + add::{adc_n, sbb_n}, div::div, gcd::{gcd, gcd_extended, inv_mod, LehmerMatrix}, mul::{add_nx1, addmul, addmul_n, addmul_nx1, addmul_ref, mul_nx1, submul_nx1}, @@ -85,3 +87,32 @@ impl DoubleWord for u128 { (self.low(), self.high()) } } + +/// Compare two `u64` slices in reverse order. +#[inline(always)] +#[must_use] +pub fn cmp(left: &[u64], right: &[u64]) -> Ordering { + let l = core::cmp::min(left.len(), right.len()); + + // Slice to the loop iteration range to enable bound check + // elimination in the compiler + let lhs = &left[..l]; + let rhs = &right[..l]; + + for i in (0..l).rev() { + match i8::from(lhs[i] > rhs[i]) - i8::from(lhs[i] < rhs[i]) { + -1 => return Ordering::Less, + 0 => {} + 1 => return Ordering::Greater, + _ => unsafe { core::hint::unreachable_unchecked() }, + } + + // Equivalent to: + // match lhs[i].cmp(&rhs[i]) { + // Ordering::Equal => {} + // non_eq => return non_eq, + // } + } + + left.len().cmp(&right.len()) +} diff --git a/src/bits.rs b/src/bits.rs index fd54c89..aaa25a0 100644 --- a/src/bits.rs +++ b/src/bits.rs @@ -483,16 +483,18 @@ macro_rules! impl_bit_op { self.$fn_assign(&rhs); } } + impl $trait_assign<&Uint> for Uint { #[inline] fn $fn_assign(&mut self, rhs: &Uint) { - for (limb, &rhs) in self.limbs.iter_mut().zip(rhs.as_limbs()) { - u64::$fn_assign(limb, rhs); + for i in 0..LIMBS { + u64::$fn_assign(&mut self.limbs[i], rhs.limbs[i]); } } } + impl $trait> for Uint { @@ -504,6 +506,7 @@ macro_rules! impl_bit_op { self } } + impl $trait<&Uint> for Uint { @@ -515,6 +518,7 @@ macro_rules! impl_bit_op { self } } + impl $trait> for &Uint { @@ -526,6 +530,7 @@ macro_rules! impl_bit_op { rhs } } + impl $trait<&Uint> for &Uint { diff --git a/src/bytes.rs b/src/bytes.rs index e873e70..0faa5cb 100644 --- a/src/bytes.rs +++ b/src/bytes.rs @@ -48,7 +48,8 @@ impl Uint { #[cfg(feature = "alloc")] #[must_use] #[inline] - pub const fn as_le_bytes(&self) -> Cow<'_, [u8]> { + #[allow(clippy::missing_const_for_fn)] + pub fn as_le_bytes(&self) -> Cow<'_, [u8]> { // On little endian platforms this is a no-op. #[cfg(target_endian = "little")] return Cow::Borrowed(self.as_le_slice()); @@ -57,8 +58,10 @@ impl Uint { #[cfg(target_endian = "big")] return Cow::Owned({ let mut cpy = *self; - cpy.limbs.iter_mut().for_each(|limb| limb.reverse_bits()); - slice::from_raw_parts(cpy.limbs.as_ptr().cast(), Self::BYTES).to_vec() + for limb in &mut cpy.limbs { + *limb = limb.reverse_bits(); + } + unsafe { slice::from_raw_parts(cpy.limbs.as_ptr().cast(), Self::BYTES).to_vec() } }); } diff --git a/src/utils.rs b/src/utils.rs index 75bdce8..60744f8 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -12,13 +12,13 @@ pub(crate) const fn rem_up(a: usize, b: usize) -> usize { } } -#[allow(dead_code)] +#[allow(dead_code)] // This is used by some support features. #[inline] fn last_idx(x: &[T], value: &T) -> usize { x.iter().rposition(|b| b != value).map_or(0, |idx| idx + 1) } -#[allow(dead_code)] +#[allow(dead_code)] // This is used by some support features. #[inline] #[must_use] pub(crate) fn trim_end_slice<'a, T: PartialEq>(slice: &'a [T], value: &T) -> &'a [T] {