Skip to content

Commit 189f6ee

Browse files
committed
Faster i256 Division (2-100x) (#4663)
1 parent 08e4692 commit 189f6ee

File tree

3 files changed

+304
-70
lines changed

3 files changed

+304
-70
lines changed

arrow-buffer/benches/i256.rs

+29-24
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,7 @@ use rand::rngs::StdRng;
2121
use rand::{Rng, SeedableRng};
2222
use std::str::FromStr;
2323

24-
/// Returns fixed seedable RNG
25-
fn seedable_rng() -> StdRng {
26-
StdRng::seed_from_u64(42)
27-
}
28-
29-
fn create_i256_vec(size: usize) -> Vec<i256> {
30-
let mut rng = seedable_rng();
31-
32-
(0..size)
33-
.map(|_| i256::from_i128(rng.gen::<i128>()))
34-
.collect()
35-
}
24+
const SIZE: usize = 1024;
3625

3726
fn criterion_benchmark(c: &mut Criterion) {
3827
let numbers = vec![
@@ -54,24 +43,40 @@ fn criterion_benchmark(c: &mut Criterion) {
5443
});
5544
}
5645

57-
c.bench_function("i256_div", |b| {
46+
let mut rng = StdRng::seed_from_u64(42);
47+
48+
let numerators: Vec<_> = (0..SIZE)
49+
.map(|_| {
50+
let high = rng.gen_range(1000..i128::MAX);
51+
let low = rng.gen();
52+
i256::from_parts(low, high)
53+
})
54+
.collect();
55+
56+
let divisors: Vec<_> = numerators
57+
.iter()
58+
.map(|n| {
59+
let quotient = rng.gen_range(1..100_i32);
60+
n.wrapping_div(i256::from(quotient))
61+
})
62+
.collect();
63+
64+
c.bench_function("i256_div_rem small quotient", |b| {
5865
b.iter(|| {
59-
for number_a in create_i256_vec(10) {
60-
for number_b in create_i256_vec(5) {
61-
number_a.checked_div(number_b);
62-
number_a.wrapping_div(number_b);
63-
}
66+
for (n, d) in numerators.iter().zip(&divisors) {
67+
black_box(n.wrapping_div(*d));
6468
}
6569
});
6670
});
6771

68-
c.bench_function("i256_rem", |b| {
72+
let divisors: Vec<_> = (0..SIZE)
73+
.map(|_| i256::from(rng.gen_range(1..100_i32)))
74+
.collect();
75+
76+
c.bench_function("i256_div_rem small divisor", |b| {
6977
b.iter(|| {
70-
for number_a in create_i256_vec(10) {
71-
for number_b in create_i256_vec(5) {
72-
number_a.checked_rem(number_b);
73-
number_a.wrapping_rem(number_b);
74-
}
78+
for (n, d) in numerators.iter().zip(&divisors) {
79+
black_box(n.wrapping_div(*d));
7580
}
7681
});
7782
});

arrow-buffer/src/bigint/div.rs

+241
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! N-digit division
19+
//!
20+
//! Implementation heavily inspired by [uint]
21+
//!
22+
//! [uint]: https://github.com/paritytech/parity-common/blob/d3a9327124a66e52ca1114bb8640c02c18c134b8/uint/src/uint.rs#L844
23+
24+
/// Unsigned, little-endian, n-digit division with remainder
25+
///
26+
/// # Panics
27+
///
28+
/// Panics if divisor is zero
29+
pub fn div_rem<const N: usize>(
30+
numerator: &[u64; N],
31+
divisor: &[u64; N],
32+
) -> ([u64; N], [u64; N]) {
33+
let numerator_bits = bits(numerator);
34+
let divisor_bits = bits(divisor);
35+
assert_ne!(divisor_bits, 0, "division by zero");
36+
37+
if numerator_bits < divisor_bits {
38+
return ([0; N], *numerator);
39+
}
40+
41+
if divisor_bits <= 64 {
42+
return div_rem_small(numerator, divisor[0]);
43+
}
44+
45+
let numerator_words = (numerator_bits + 63) / 64;
46+
let divisor_words = (divisor_bits + 63) / 64;
47+
let n = divisor_words;
48+
let m = numerator_words - divisor_words;
49+
50+
div_rem_knuth(numerator, divisor, n, m)
51+
}
52+
53+
/// Return the least number of bits needed to represent the number
54+
fn bits(arr: &[u64]) -> usize {
55+
for (idx, v) in arr.iter().enumerate().rev() {
56+
if *v > 0 {
57+
return 64 - v.leading_zeros() as usize + 64 * idx;
58+
}
59+
}
60+
0
61+
}
62+
63+
/// Division of numerator by a u64 divisor
64+
fn div_rem_small<const N: usize>(
65+
numerator: &[u64; N],
66+
divisor: u64,
67+
) -> ([u64; N], [u64; N]) {
68+
let mut rem = 0u64;
69+
let mut numerator = *numerator;
70+
numerator.iter_mut().rev().for_each(|d| {
71+
let (q, r) = div_rem_word(rem, *d, divisor);
72+
*d = q;
73+
rem = r;
74+
});
75+
76+
let mut rem_padded = [0; N];
77+
rem_padded[0] = rem;
78+
(numerator, rem_padded)
79+
}
80+
81+
fn div_rem_knuth<const N: usize>(
82+
numerator: &[u64; N],
83+
divisor: &[u64; N],
84+
n: usize,
85+
m: usize,
86+
) -> ([u64; N], [u64; N]) {
87+
assert!(n + m <= N);
88+
89+
let shift = divisor[n - 1].leading_zeros();
90+
let divisor = shl_word(divisor, shift);
91+
let mut u = full_shl(numerator, shift);
92+
93+
let mut q = [0; N];
94+
let v_n_1 = divisor[n - 1];
95+
let v_n_2 = divisor[n - 2];
96+
97+
for j in (0..=m).rev() {
98+
let u_jn = u[j + n];
99+
100+
let mut q_hat = if u_jn < v_n_1 {
101+
let (mut q_hat, mut r_hat) = div_rem_word(u_jn, u[j + n - 1], v_n_1);
102+
103+
loop {
104+
let r = u128::from(q_hat) * u128::from(v_n_2);
105+
let (lo, hi) = (r as u64, (r >> 64) as u64);
106+
if (hi, lo) <= (r_hat, u[j + n - 2]) {
107+
break;
108+
}
109+
110+
q_hat -= 1;
111+
let (new_r_hat, overflow) = r_hat.overflowing_add(v_n_1);
112+
r_hat = new_r_hat;
113+
114+
if overflow {
115+
break;
116+
}
117+
}
118+
q_hat
119+
} else {
120+
u64::MAX
121+
};
122+
123+
let q_hat_v = full_mul_u64(&divisor, q_hat);
124+
125+
let c = sub_assign(&mut u[j..], &q_hat_v[..n + 1]);
126+
127+
if c {
128+
q_hat -= 1;
129+
130+
let c = add_assign(&mut u[j..], &divisor[..n]);
131+
u[j + n] = u[j + n].wrapping_add(u64::from(c));
132+
}
133+
134+
q[j] = q_hat;
135+
}
136+
137+
let remainder = full_shr(&u, shift);
138+
(q, remainder)
139+
}
140+
141+
/// Divide a u128 by a u64 divisor, returning the quotient and remainder
142+
fn div_rem_word(hi: u64, lo: u64, y: u64) -> (u64, u64) {
143+
debug_assert!(hi < y);
144+
let x = (u128::from(hi) << 64) + u128::from(lo);
145+
let y = u128::from(y);
146+
((x / y) as u64, (x % y) as u64)
147+
}
148+
149+
/// Perform `a += b`
150+
fn add_assign(a: &mut [u64], b: &[u64]) -> bool {
151+
binop_slice(a, b, u64::overflowing_add)
152+
}
153+
154+
/// Perform `a -= b`
155+
fn sub_assign(a: &mut [u64], b: &[u64]) -> bool {
156+
binop_slice(a, b, u64::overflowing_sub)
157+
}
158+
159+
/// Converts an overflowing binary operation on scalars to one on slices
160+
fn binop_slice(
161+
a: &mut [u64],
162+
b: &[u64],
163+
binop: impl Fn(u64, u64) -> (u64, bool) + Copy,
164+
) -> bool {
165+
let mut c = false;
166+
a.iter_mut().zip(b.iter()).for_each(|(x, y)| {
167+
let (res1, overflow1) = y.overflowing_add(u64::from(c));
168+
let (res2, overflow2) = binop(*x, res1);
169+
*x = res2;
170+
c = overflow1 || overflow2;
171+
});
172+
c
173+
}
174+
175+
/// Widening multiplication of an N-digit array with a u64
176+
fn full_mul_u64<const N: usize>(a: &[u64; N], b: u64) -> ArrayPlusOne<u64, N> {
177+
let mut carry = 0;
178+
let mut out = [0; N];
179+
out.iter_mut().zip(a).for_each(|(o, v)| {
180+
let r = *v as u128 * b as u128 + carry as u128;
181+
*o = r as u64;
182+
carry = (r >> 64) as u64;
183+
});
184+
ArrayPlusOne(out, carry)
185+
}
186+
187+
/// Left shift of an N-digit array by at most 63 bits
188+
fn shl_word<const N: usize>(v: &[u64; N], shift: u32) -> [u64; N] {
189+
full_shl(v, shift).0
190+
}
191+
192+
/// Widening left shift of an N-digit array by at most 63 bits
193+
fn full_shl<const N: usize>(v: &[u64; N], shift: u32) -> ArrayPlusOne<u64, N> {
194+
debug_assert!(shift < 64);
195+
if shift == 0 {
196+
return ArrayPlusOne(*v, 0);
197+
}
198+
let mut out = [0u64; N];
199+
out[0] = v[0] << shift;
200+
for i in 1..N {
201+
out[i] = v[i - 1] >> (64 - shift) | v[i] << shift
202+
}
203+
let carry = v[N - 1] >> (64 - shift);
204+
return ArrayPlusOne(out, carry);
205+
}
206+
207+
/// Narrowing right shift of an (N+1)-digit array by at most 63 bits
208+
fn full_shr<const N: usize>(a: &ArrayPlusOne<u64, N>, shift: u32) -> [u64; N] {
209+
debug_assert!(shift < 64);
210+
if shift == 0 {
211+
return a.0;
212+
}
213+
let mut out = [0; N];
214+
for i in 0..N - 1 {
215+
out[i] = a[i] >> shift | a[i + 1] << (64 - shift)
216+
}
217+
out[N - 1] = a[N - 1] >> shift;
218+
out
219+
}
220+
221+
/// An array of N + 1 elements
222+
///
223+
/// This is a hack around lack of support for const arithmetic
224+
struct ArrayPlusOne<T, const N: usize>([T; N], T);
225+
226+
impl<T, const N: usize> std::ops::Deref for ArrayPlusOne<T, N> {
227+
type Target = [T];
228+
229+
#[inline]
230+
fn deref(&self) -> &Self::Target {
231+
let x = self as *const Self;
232+
unsafe { std::slice::from_raw_parts(x as *const T, N + 1) }
233+
}
234+
}
235+
236+
impl<T, const N: usize> std::ops::DerefMut for ArrayPlusOne<T, N> {
237+
fn deref_mut(&mut self) -> &mut Self::Target {
238+
let x = self as *mut Self;
239+
unsafe { std::slice::from_raw_parts_mut(x as *mut T, N + 1) }
240+
}
241+
}

0 commit comments

Comments
 (0)