From 189f6eead8c1fed229504dfdc90623a6092cbf16 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Mon, 7 Aug 2023 15:10:28 +0100 Subject: [PATCH] Faster i256 Division (2-100x) (#4663) --- arrow-buffer/benches/i256.rs | 53 ++-- arrow-buffer/src/bigint/div.rs | 241 ++++++++++++++++++ arrow-buffer/src/{bigint.rs => bigint/mod.rs} | 80 +++--- 3 files changed, 304 insertions(+), 70 deletions(-) create mode 100644 arrow-buffer/src/bigint/div.rs rename arrow-buffer/src/{bigint.rs => bigint/mod.rs} (94%) diff --git a/arrow-buffer/benches/i256.rs b/arrow-buffer/benches/i256.rs index 2c43e0e91070..ebb45e793bd0 100644 --- a/arrow-buffer/benches/i256.rs +++ b/arrow-buffer/benches/i256.rs @@ -21,18 +21,7 @@ use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use std::str::FromStr; -/// Returns fixed seedable RNG -fn seedable_rng() -> StdRng { - StdRng::seed_from_u64(42) -} - -fn create_i256_vec(size: usize) -> Vec { - let mut rng = seedable_rng(); - - (0..size) - .map(|_| i256::from_i128(rng.gen::())) - .collect() -} +const SIZE: usize = 1024; fn criterion_benchmark(c: &mut Criterion) { let numbers = vec![ @@ -54,24 +43,40 @@ fn criterion_benchmark(c: &mut Criterion) { }); } - c.bench_function("i256_div", |b| { + let mut rng = StdRng::seed_from_u64(42); + + let numerators: Vec<_> = (0..SIZE) + .map(|_| { + let high = rng.gen_range(1000..i128::MAX); + let low = rng.gen(); + i256::from_parts(low, high) + }) + .collect(); + + let divisors: Vec<_> = numerators + .iter() + .map(|n| { + let quotient = rng.gen_range(1..100_i32); + n.wrapping_div(i256::from(quotient)) + }) + .collect(); + + c.bench_function("i256_div_rem small quotient", |b| { b.iter(|| { - for number_a in create_i256_vec(10) { - for number_b in create_i256_vec(5) { - number_a.checked_div(number_b); - number_a.wrapping_div(number_b); - } + for (n, d) in numerators.iter().zip(&divisors) { + black_box(n.wrapping_div(*d)); } }); }); - c.bench_function("i256_rem", |b| { + let divisors: Vec<_> = (0..SIZE) + .map(|_| i256::from(rng.gen_range(1..100_i32))) + .collect(); + + c.bench_function("i256_div_rem small divisor", |b| { b.iter(|| { - for number_a in create_i256_vec(10) { - for number_b in create_i256_vec(5) { - number_a.checked_rem(number_b); - number_a.wrapping_rem(number_b); - } + for (n, d) in numerators.iter().zip(&divisors) { + black_box(n.wrapping_div(*d)); } }); }); diff --git a/arrow-buffer/src/bigint/div.rs b/arrow-buffer/src/bigint/div.rs new file mode 100644 index 000000000000..57a36e5eb12c --- /dev/null +++ b/arrow-buffer/src/bigint/div.rs @@ -0,0 +1,241 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! N-digit division +//! +//! Implementation heavily inspired by [uint] +//! +//! [uint]: https://github.com/paritytech/parity-common/blob/d3a9327124a66e52ca1114bb8640c02c18c134b8/uint/src/uint.rs#L844 + +/// Unsigned, little-endian, n-digit division with remainder +/// +/// # Panics +/// +/// Panics if divisor is zero +pub fn div_rem( + numerator: &[u64; N], + divisor: &[u64; N], +) -> ([u64; N], [u64; N]) { + let numerator_bits = bits(numerator); + let divisor_bits = bits(divisor); + assert_ne!(divisor_bits, 0, "division by zero"); + + if numerator_bits < divisor_bits { + return ([0; N], *numerator); + } + + if divisor_bits <= 64 { + return div_rem_small(numerator, divisor[0]); + } + + let numerator_words = (numerator_bits + 63) / 64; + let divisor_words = (divisor_bits + 63) / 64; + let n = divisor_words; + let m = numerator_words - divisor_words; + + div_rem_knuth(numerator, divisor, n, m) +} + +/// Return the least number of bits needed to represent the number +fn bits(arr: &[u64]) -> usize { + for (idx, v) in arr.iter().enumerate().rev() { + if *v > 0 { + return 64 - v.leading_zeros() as usize + 64 * idx; + } + } + 0 +} + +/// Division of numerator by a u64 divisor +fn div_rem_small( + numerator: &[u64; N], + divisor: u64, +) -> ([u64; N], [u64; N]) { + let mut rem = 0u64; + let mut numerator = *numerator; + numerator.iter_mut().rev().for_each(|d| { + let (q, r) = div_rem_word(rem, *d, divisor); + *d = q; + rem = r; + }); + + let mut rem_padded = [0; N]; + rem_padded[0] = rem; + (numerator, rem_padded) +} + +fn div_rem_knuth( + numerator: &[u64; N], + divisor: &[u64; N], + n: usize, + m: usize, +) -> ([u64; N], [u64; N]) { + assert!(n + m <= N); + + let shift = divisor[n - 1].leading_zeros(); + let divisor = shl_word(divisor, shift); + let mut u = full_shl(numerator, shift); + + let mut q = [0; N]; + let v_n_1 = divisor[n - 1]; + let v_n_2 = divisor[n - 2]; + + for j in (0..=m).rev() { + let u_jn = u[j + n]; + + let mut q_hat = if u_jn < v_n_1 { + let (mut q_hat, mut r_hat) = div_rem_word(u_jn, u[j + n - 1], v_n_1); + + loop { + let r = u128::from(q_hat) * u128::from(v_n_2); + let (lo, hi) = (r as u64, (r >> 64) as u64); + if (hi, lo) <= (r_hat, u[j + n - 2]) { + break; + } + + q_hat -= 1; + let (new_r_hat, overflow) = r_hat.overflowing_add(v_n_1); + r_hat = new_r_hat; + + if overflow { + break; + } + } + q_hat + } else { + u64::MAX + }; + + let q_hat_v = full_mul_u64(&divisor, q_hat); + + let c = sub_assign(&mut u[j..], &q_hat_v[..n + 1]); + + if c { + q_hat -= 1; + + let c = add_assign(&mut u[j..], &divisor[..n]); + u[j + n] = u[j + n].wrapping_add(u64::from(c)); + } + + q[j] = q_hat; + } + + let remainder = full_shr(&u, shift); + (q, remainder) +} + +/// Divide a u128 by a u64 divisor, returning the quotient and remainder +fn div_rem_word(hi: u64, lo: u64, y: u64) -> (u64, u64) { + debug_assert!(hi < y); + let x = (u128::from(hi) << 64) + u128::from(lo); + let y = u128::from(y); + ((x / y) as u64, (x % y) as u64) +} + +/// Perform `a += b` +fn add_assign(a: &mut [u64], b: &[u64]) -> bool { + binop_slice(a, b, u64::overflowing_add) +} + +/// Perform `a -= b` +fn sub_assign(a: &mut [u64], b: &[u64]) -> bool { + binop_slice(a, b, u64::overflowing_sub) +} + +/// Converts an overflowing binary operation on scalars to one on slices +fn binop_slice( + a: &mut [u64], + b: &[u64], + binop: impl Fn(u64, u64) -> (u64, bool) + Copy, +) -> bool { + let mut c = false; + a.iter_mut().zip(b.iter()).for_each(|(x, y)| { + let (res1, overflow1) = y.overflowing_add(u64::from(c)); + let (res2, overflow2) = binop(*x, res1); + *x = res2; + c = overflow1 || overflow2; + }); + c +} + +/// Widening multiplication of an N-digit array with a u64 +fn full_mul_u64(a: &[u64; N], b: u64) -> ArrayPlusOne { + let mut carry = 0; + let mut out = [0; N]; + out.iter_mut().zip(a).for_each(|(o, v)| { + let r = *v as u128 * b as u128 + carry as u128; + *o = r as u64; + carry = (r >> 64) as u64; + }); + ArrayPlusOne(out, carry) +} + +/// Left shift of an N-digit array by at most 63 bits +fn shl_word(v: &[u64; N], shift: u32) -> [u64; N] { + full_shl(v, shift).0 +} + +/// Widening left shift of an N-digit array by at most 63 bits +fn full_shl(v: &[u64; N], shift: u32) -> ArrayPlusOne { + debug_assert!(shift < 64); + if shift == 0 { + return ArrayPlusOne(*v, 0); + } + let mut out = [0u64; N]; + out[0] = v[0] << shift; + for i in 1..N { + out[i] = v[i - 1] >> (64 - shift) | v[i] << shift + } + let carry = v[N - 1] >> (64 - shift); + return ArrayPlusOne(out, carry); +} + +/// Narrowing right shift of an (N+1)-digit array by at most 63 bits +fn full_shr(a: &ArrayPlusOne, shift: u32) -> [u64; N] { + debug_assert!(shift < 64); + if shift == 0 { + return a.0; + } + let mut out = [0; N]; + for i in 0..N - 1 { + out[i] = a[i] >> shift | a[i + 1] << (64 - shift) + } + out[N - 1] = a[N - 1] >> shift; + out +} + +/// An array of N + 1 elements +/// +/// This is a hack around lack of support for const arithmetic +struct ArrayPlusOne([T; N], T); + +impl std::ops::Deref for ArrayPlusOne { + type Target = [T]; + + #[inline] + fn deref(&self) -> &Self::Target { + let x = self as *const Self; + unsafe { std::slice::from_raw_parts(x as *const T, N + 1) } + } +} + +impl std::ops::DerefMut for ArrayPlusOne { + fn deref_mut(&mut self) -> &mut Self::Target { + let x = self as *mut Self; + unsafe { std::slice::from_raw_parts_mut(x as *mut T, N + 1) } + } +} diff --git a/arrow-buffer/src/bigint.rs b/arrow-buffer/src/bigint/mod.rs similarity index 94% rename from arrow-buffer/src/bigint.rs rename to arrow-buffer/src/bigint/mod.rs index 86150e67fd91..fe0774539989 100644 --- a/arrow-buffer/src/bigint.rs +++ b/arrow-buffer/src/bigint/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::bigint::div::div_rem; use num::cast::AsPrimitive; use num::{BigInt, FromPrimitive, ToPrimitive}; use std::cmp::Ordering; @@ -22,6 +23,8 @@ use std::num::ParseIntError; use std::ops::{BitAnd, BitOr, BitXor, Neg, Shl, Shr}; use std::str::FromStr; +mod div; + /// An opaque error similar to [`std::num::ParseIntError`] #[derive(Debug)] pub struct ParseI256Error {} @@ -428,25 +431,6 @@ impl i256 { .then_some(Self { low, high }) } - /// Return the least number of bits needed to represent the number - #[inline] - fn bits_required(&self) -> usize { - let le_bytes = self.to_le_bytes(); - let arr: [u128; 2] = [ - u128::from_le_bytes(le_bytes[0..16].try_into().unwrap()), - u128::from_le_bytes(le_bytes[16..32].try_into().unwrap()), - ]; - - let iter = arr.iter().rev().take(2 - 1); - if self.is_negative() { - let ctr = iter.take_while(|&&b| b == ::core::u128::MAX).count(); - (128 * (2 - ctr)) + 1 - (!arr[2 - ctr - 1]).leading_zeros() as usize - } else { - let ctr = iter.take_while(|&&b| b == ::core::u128::MIN).count(); - (128 * (2 - ctr)) + 1 - arr[2 - ctr - 1].leading_zeros() as usize - } - } - /// Division operation, returns (quotient, remainder). /// This basically implements [Long division]: `` #[inline] @@ -458,41 +442,45 @@ impl i256 { return Err(DivRemError::DivideOverflow); } - if self == Self::MIN || other == Self::MIN { - let l = BigInt::from_signed_bytes_le(&self.to_le_bytes()); - let r = BigInt::from_signed_bytes_le(&other.to_le_bytes()); - let d = i256::from_bigint_with_overflow(&l / &r).0; - let r = i256::from_bigint_with_overflow(&l % &r).0; - return Ok((d, r)); - } - - let mut me = self.checked_abs().unwrap(); - let mut you = other.checked_abs().unwrap(); - let mut ret = [0u128; 2]; - if me < you { - return Ok((Self::from_parts(ret[0], ret[1] as i128), self)); - } + let a = self.wrapping_abs(); + let b = other.wrapping_abs(); - let shift = me.bits_required() - you.bits_required(); - you = you.shl(shift as u8); - for i in (0..=shift).rev() { - if me >= you { - ret[i / 128] |= 1 << (i % 128); - me = me.checked_sub(you).unwrap(); - } - you = you.shr(1); - } + let (div, rem) = div_rem(&a.as_digits(), &b.as_digits()); + let div = Self::from_digits(div); + let rem = Self::from_digits(rem); Ok(( if self.is_negative() == other.is_negative() { - Self::from_parts(ret[0], ret[1] as i128) + div + } else { + div.wrapping_neg() + }, + if self.is_negative() { + rem.wrapping_neg() } else { - -Self::from_parts(ret[0], ret[1] as i128) + rem }, - if self.is_negative() { -me } else { me }, )) } + /// Interpret this [`i256`] as 4 `u64` digits, least significant first + fn as_digits(self) -> [u64; 4] { + [ + self.low as u64, + (self.low >> 64) as u64, + self.high as u64, + (self.high as u128 >> 64) as u64, + ] + } + + /// Interpret 4 `u64` digits, least significant first, as a [`i256`] + fn from_digits(digits: [u64; 4]) -> Self { + Self::from_parts( + digits[0] as u128 | (digits[1] as u128) << 64, + digits[2] as i128 | (digits[3] as i128) << 64, + ) + } + /// Performs wrapping division #[inline] pub fn wrapping_div(self, other: Self) -> Self { @@ -969,7 +957,7 @@ mod tests { let expected = bl.clone() % br.clone(); let checked = il.checked_rem(ir); - assert_eq!(actual.to_string(), expected.to_string()); + assert_eq!(actual.to_string(), expected.to_string(), "{il} % {ir}"); if ir == i256::MINUS_ONE && il == i256::MIN { assert!(checked.is_none());