diff --git a/arrow-buffer/benches/i256.rs b/arrow-buffer/benches/i256.rs index 2c43e0e91070..ebb45e793bd0 100644 --- a/arrow-buffer/benches/i256.rs +++ b/arrow-buffer/benches/i256.rs @@ -21,18 +21,7 @@ use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use std::str::FromStr; -/// Returns fixed seedable RNG -fn seedable_rng() -> StdRng { - StdRng::seed_from_u64(42) -} - -fn create_i256_vec(size: usize) -> Vec { - let mut rng = seedable_rng(); - - (0..size) - .map(|_| i256::from_i128(rng.gen::())) - .collect() -} +const SIZE: usize = 1024; fn criterion_benchmark(c: &mut Criterion) { let numbers = vec![ @@ -54,24 +43,40 @@ fn criterion_benchmark(c: &mut Criterion) { }); } - c.bench_function("i256_div", |b| { + let mut rng = StdRng::seed_from_u64(42); + + let numerators: Vec<_> = (0..SIZE) + .map(|_| { + let high = rng.gen_range(1000..i128::MAX); + let low = rng.gen(); + i256::from_parts(low, high) + }) + .collect(); + + let divisors: Vec<_> = numerators + .iter() + .map(|n| { + let quotient = rng.gen_range(1..100_i32); + n.wrapping_div(i256::from(quotient)) + }) + .collect(); + + c.bench_function("i256_div_rem small quotient", |b| { b.iter(|| { - for number_a in create_i256_vec(10) { - for number_b in create_i256_vec(5) { - number_a.checked_div(number_b); - number_a.wrapping_div(number_b); - } + for (n, d) in numerators.iter().zip(&divisors) { + black_box(n.wrapping_div(*d)); } }); }); - c.bench_function("i256_rem", |b| { + let divisors: Vec<_> = (0..SIZE) + .map(|_| i256::from(rng.gen_range(1..100_i32))) + .collect(); + + c.bench_function("i256_div_rem small divisor", |b| { b.iter(|| { - for number_a in create_i256_vec(10) { - for number_b in create_i256_vec(5) { - number_a.checked_rem(number_b); - number_a.wrapping_rem(number_b); - } + for (n, d) in numerators.iter().zip(&divisors) { + black_box(n.wrapping_div(*d)); } }); }); diff --git a/arrow-buffer/src/bigint/div.rs b/arrow-buffer/src/bigint/div.rs new file mode 100644 index 000000000000..ba530ffcc6c8 --- /dev/null +++ b/arrow-buffer/src/bigint/div.rs @@ -0,0 +1,312 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! N-digit division +//! +//! Implementation heavily inspired by [uint] +//! +//! [uint]: https://github.com/paritytech/parity-common/blob/d3a9327124a66e52ca1114bb8640c02c18c134b8/uint/src/uint.rs#L844 + +/// Unsigned, little-endian, n-digit division with remainder +/// +/// # Panics +/// +/// Panics if divisor is zero +pub fn div_rem( + numerator: &[u64; N], + divisor: &[u64; N], +) -> ([u64; N], [u64; N]) { + let numerator_bits = bits(numerator); + let divisor_bits = bits(divisor); + assert_ne!(divisor_bits, 0, "division by zero"); + + if numerator_bits < divisor_bits { + return ([0; N], *numerator); + } + + if divisor_bits <= 64 { + return div_rem_small(numerator, divisor[0]); + } + + let numerator_words = (numerator_bits + 63) / 64; + let divisor_words = (divisor_bits + 63) / 64; + let n = divisor_words; + let m = numerator_words - divisor_words; + + div_rem_knuth(numerator, divisor, n, m) +} + +/// Return the least number of bits needed to represent the number +fn bits(arr: &[u64]) -> usize { + for (idx, v) in arr.iter().enumerate().rev() { + if *v > 0 { + return 64 - v.leading_zeros() as usize + 64 * idx; + } + } + 0 +} + +/// Division of numerator by a u64 divisor +fn div_rem_small( + numerator: &[u64; N], + divisor: u64, +) -> ([u64; N], [u64; N]) { + let mut rem = 0u64; + let mut numerator = *numerator; + numerator.iter_mut().rev().for_each(|d| { + let (q, r) = div_rem_word(rem, *d, divisor); + *d = q; + rem = r; + }); + + let mut rem_padded = [0; N]; + rem_padded[0] = rem; + (numerator, rem_padded) +} + +/// Use Knuth Algorithm D to compute `numerator / divisor` returning the +/// quotient and remainder +/// +/// `n` is the number of non-zero 64-bit words in `divisor` +/// `m` is the number of non-zero 64-bit words present in `numerator` beyond `divisor`, and +/// therefore the number of words in the quotient +/// +/// A good explanation of the algorithm can be found [here](https://ridiculousfish.com/blog/posts/labor-of-division-episode-iv.html) +fn div_rem_knuth( + numerator: &[u64; N], + divisor: &[u64; N], + n: usize, + m: usize, +) -> ([u64; N], [u64; N]) { + assert!(n + m <= N); + + // The algorithm works by incrementally generating guesses `q_hat`, for the next digit + // of the quotient, starting from the most significant digit. + // + // This relies on the property that for any `q_hat` where + // + // (q_hat << (j * 64)) * divisor <= numerator` + // + // We can set + // + // q += q_hat << (j * 64) + // numerator -= (q_hat << (j * 64)) * divisor + // + // And then iterate until `numerator < divisor` + + // We normalize the divisor so that the highest bit in the highest digit of the + // divisor is set, this ensures our initial guess of `q_hat` is at most 2 off from + // the correct value for q[j] + let shift = divisor[n - 1].leading_zeros(); + // As the shift is computed based on leading zeros, don't need to perform full_shl + let divisor = shl_word(divisor, shift); + // numerator may have fewer leading zeros than divisor, so must add another digit + let mut numerator = full_shl(numerator, shift); + + // The two most significant digits of the divisor + let b0 = divisor[n - 1]; + let b1 = divisor[n - 2]; + + let mut q = [0; N]; + + for j in (0..=m).rev() { + let a0 = numerator[j + n]; + let a1 = numerator[j + n - 1]; + + let mut q_hat = if a0 < b0 { + // The first estimate is [a1, a0] / b0, it may be too large by at most 2 + let (mut q_hat, mut r_hat) = div_rem_word(a0, a1, b0); + + // r_hat = [a1, a0] - q_hat * b0 + // + // Now we want to compute a more precise estimate [a2,a1,a0] / [b1,b0] + // which can only be less or equal to the current q_hat + // + // q_hat is too large if: + // [a2,a1,a0] < q_hat * [b1,b0] + // [a2,r_hat] < q_hat * b1 + let a2 = numerator[j + n - 2]; + loop { + let r = u128::from(q_hat) * u128::from(b1); + let (lo, hi) = (r as u64, (r >> 64) as u64); + if (hi, lo) <= (r_hat, a2) { + break; + } + + q_hat -= 1; + let (new_r_hat, overflow) = r_hat.overflowing_add(b0); + r_hat = new_r_hat; + + if overflow { + break; + } + } + q_hat + } else { + u64::MAX + }; + + // q_hat is now either the correct quotient digit, or in rare cases 1 too large + + // Compute numerator -= (q_hat * divisor) << (j * 64) + let q_hat_v = full_mul_u64(&divisor, q_hat); + let c = sub_assign(&mut numerator[j..], &q_hat_v[..n + 1]); + + // If underflow, q_hat was too large by 1 + if c { + // Reduce q_hat by 1 + q_hat -= 1; + + // Add back one multiple of divisor + let c = add_assign(&mut numerator[j..], &divisor[..n]); + numerator[j + n] = numerator[j + n].wrapping_add(u64::from(c)); + } + + // q_hat is the correct value for q[j] + q[j] = q_hat; + } + + // The remainder is what is left in numerator, with the initial normalization shl reversed + let remainder = full_shr(&numerator, shift); + (q, remainder) +} + +/// Perform narrowing division of a u128 by a u64 divisor, returning the quotient and remainder +/// +/// This method may trap or panic if hi >= divisor, i.e. the quotient would not fit +/// into a 64-bit integer +fn div_rem_word(hi: u64, lo: u64, divisor: u64) -> (u64, u64) { + debug_assert!(hi < divisor); + debug_assert_ne!(divisor, 0); + + // LLVM fails to use the div instruction as it is not able to prove + // that hi < divisor, and therefore the result will fit into 64-bits + #[cfg(target_arch = "x86_64")] + unsafe { + let mut quot = lo; + let mut rem = hi; + std::arch::asm!( + "div {divisor}", + divisor = in(reg) divisor, + inout("rax") quot, + inout("rdx") rem, + options(pure, nomem, nostack) + ); + (quot, rem) + } + #[cfg(not(target_arch = "x86_64"))] + { + let x = (u128::from(hi) << 64) + u128::from(lo); + let y = u128::from(divisor); + ((x / y) as u64, (x % y) as u64) + } +} + +/// Perform `a += b` +fn add_assign(a: &mut [u64], b: &[u64]) -> bool { + binop_slice(a, b, u64::overflowing_add) +} + +/// Perform `a -= b` +fn sub_assign(a: &mut [u64], b: &[u64]) -> bool { + binop_slice(a, b, u64::overflowing_sub) +} + +/// Converts an overflowing binary operation on scalars to one on slices +fn binop_slice( + a: &mut [u64], + b: &[u64], + binop: impl Fn(u64, u64) -> (u64, bool) + Copy, +) -> bool { + let mut c = false; + a.iter_mut().zip(b.iter()).for_each(|(x, y)| { + let (res1, overflow1) = y.overflowing_add(u64::from(c)); + let (res2, overflow2) = binop(*x, res1); + *x = res2; + c = overflow1 || overflow2; + }); + c +} + +/// Widening multiplication of an N-digit array with a u64 +fn full_mul_u64(a: &[u64; N], b: u64) -> ArrayPlusOne { + let mut carry = 0; + let mut out = [0; N]; + out.iter_mut().zip(a).for_each(|(o, v)| { + let r = *v as u128 * b as u128 + carry as u128; + *o = r as u64; + carry = (r >> 64) as u64; + }); + ArrayPlusOne(out, carry) +} + +/// Left shift of an N-digit array by at most 63 bits +fn shl_word(v: &[u64; N], shift: u32) -> [u64; N] { + full_shl(v, shift).0 +} + +/// Widening left shift of an N-digit array by at most 63 bits +fn full_shl(v: &[u64; N], shift: u32) -> ArrayPlusOne { + debug_assert!(shift < 64); + if shift == 0 { + return ArrayPlusOne(*v, 0); + } + let mut out = [0u64; N]; + out[0] = v[0] << shift; + for i in 1..N { + out[i] = v[i - 1] >> (64 - shift) | v[i] << shift + } + let carry = v[N - 1] >> (64 - shift); + ArrayPlusOne(out, carry) +} + +/// Narrowing right shift of an (N+1)-digit array by at most 63 bits +fn full_shr(a: &ArrayPlusOne, shift: u32) -> [u64; N] { + debug_assert!(shift < 64); + if shift == 0 { + return a.0; + } + let mut out = [0; N]; + for i in 0..N - 1 { + out[i] = a[i] >> shift | a[i + 1] << (64 - shift) + } + out[N - 1] = a[N - 1] >> shift; + out +} + +/// An array of N + 1 elements +/// +/// This is a hack around lack of support for const arithmetic +#[repr(C)] +struct ArrayPlusOne([T; N], T); + +impl std::ops::Deref for ArrayPlusOne { + type Target = [T]; + + #[inline] + fn deref(&self) -> &Self::Target { + let x = self as *const Self; + unsafe { std::slice::from_raw_parts(x as *const T, N + 1) } + } +} + +impl std::ops::DerefMut for ArrayPlusOne { + fn deref_mut(&mut self) -> &mut Self::Target { + let x = self as *mut Self; + unsafe { std::slice::from_raw_parts_mut(x as *mut T, N + 1) } + } +} diff --git a/arrow-buffer/src/bigint.rs b/arrow-buffer/src/bigint/mod.rs similarity index 94% rename from arrow-buffer/src/bigint.rs rename to arrow-buffer/src/bigint/mod.rs index 86150e67fd91..fe0774539989 100644 --- a/arrow-buffer/src/bigint.rs +++ b/arrow-buffer/src/bigint/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::bigint::div::div_rem; use num::cast::AsPrimitive; use num::{BigInt, FromPrimitive, ToPrimitive}; use std::cmp::Ordering; @@ -22,6 +23,8 @@ use std::num::ParseIntError; use std::ops::{BitAnd, BitOr, BitXor, Neg, Shl, Shr}; use std::str::FromStr; +mod div; + /// An opaque error similar to [`std::num::ParseIntError`] #[derive(Debug)] pub struct ParseI256Error {} @@ -428,25 +431,6 @@ impl i256 { .then_some(Self { low, high }) } - /// Return the least number of bits needed to represent the number - #[inline] - fn bits_required(&self) -> usize { - let le_bytes = self.to_le_bytes(); - let arr: [u128; 2] = [ - u128::from_le_bytes(le_bytes[0..16].try_into().unwrap()), - u128::from_le_bytes(le_bytes[16..32].try_into().unwrap()), - ]; - - let iter = arr.iter().rev().take(2 - 1); - if self.is_negative() { - let ctr = iter.take_while(|&&b| b == ::core::u128::MAX).count(); - (128 * (2 - ctr)) + 1 - (!arr[2 - ctr - 1]).leading_zeros() as usize - } else { - let ctr = iter.take_while(|&&b| b == ::core::u128::MIN).count(); - (128 * (2 - ctr)) + 1 - arr[2 - ctr - 1].leading_zeros() as usize - } - } - /// Division operation, returns (quotient, remainder). /// This basically implements [Long division]: `` #[inline] @@ -458,41 +442,45 @@ impl i256 { return Err(DivRemError::DivideOverflow); } - if self == Self::MIN || other == Self::MIN { - let l = BigInt::from_signed_bytes_le(&self.to_le_bytes()); - let r = BigInt::from_signed_bytes_le(&other.to_le_bytes()); - let d = i256::from_bigint_with_overflow(&l / &r).0; - let r = i256::from_bigint_with_overflow(&l % &r).0; - return Ok((d, r)); - } - - let mut me = self.checked_abs().unwrap(); - let mut you = other.checked_abs().unwrap(); - let mut ret = [0u128; 2]; - if me < you { - return Ok((Self::from_parts(ret[0], ret[1] as i128), self)); - } + let a = self.wrapping_abs(); + let b = other.wrapping_abs(); - let shift = me.bits_required() - you.bits_required(); - you = you.shl(shift as u8); - for i in (0..=shift).rev() { - if me >= you { - ret[i / 128] |= 1 << (i % 128); - me = me.checked_sub(you).unwrap(); - } - you = you.shr(1); - } + let (div, rem) = div_rem(&a.as_digits(), &b.as_digits()); + let div = Self::from_digits(div); + let rem = Self::from_digits(rem); Ok(( if self.is_negative() == other.is_negative() { - Self::from_parts(ret[0], ret[1] as i128) + div + } else { + div.wrapping_neg() + }, + if self.is_negative() { + rem.wrapping_neg() } else { - -Self::from_parts(ret[0], ret[1] as i128) + rem }, - if self.is_negative() { -me } else { me }, )) } + /// Interpret this [`i256`] as 4 `u64` digits, least significant first + fn as_digits(self) -> [u64; 4] { + [ + self.low as u64, + (self.low >> 64) as u64, + self.high as u64, + (self.high as u128 >> 64) as u64, + ] + } + + /// Interpret 4 `u64` digits, least significant first, as a [`i256`] + fn from_digits(digits: [u64; 4]) -> Self { + Self::from_parts( + digits[0] as u128 | (digits[1] as u128) << 64, + digits[2] as i128 | (digits[3] as i128) << 64, + ) + } + /// Performs wrapping division #[inline] pub fn wrapping_div(self, other: Self) -> Self { @@ -969,7 +957,7 @@ mod tests { let expected = bl.clone() % br.clone(); let checked = il.checked_rem(ir); - assert_eq!(actual.to_string(), expected.to_string()); + assert_eq!(actual.to_string(), expected.to_string(), "{il} % {ir}"); if ir == i256::MINUS_ONE && il == i256::MIN { assert!(checked.is_none());