Skip to content

Commit

Permalink
Merge pull request rust-ndarray#3 from jturner314/pairwise-summation
Browse files Browse the repository at this point in the history
Improve pairwise summation
  • Loading branch information
LukeMathWalker authored Feb 3, 2019
2 parents bbc4a75 + 82453df commit 1d51f70
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 27 deletions.
46 changes: 45 additions & 1 deletion benches/numeric.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

#![feature(test)]
extern crate test;
use test::Bencher;
use test::{black_box, Bencher};

extern crate ndarray;
use ndarray::prelude::*;
Expand Down Expand Up @@ -65,6 +65,38 @@ fn contiguous_sum_1e2(bench: &mut Bencher)
});
}

#[bench]
fn contiguous_sum_ix3_1e2(bench: &mut Bencher)
{
let n = 1e2 as usize;
let a = Array::linspace(-1e6, 1e6, n * n * n)
.into_shape([n, n, n])
.unwrap();
bench.iter(|| black_box(&a).sum());
}

#[bench]
fn inner_discontiguous_sum_ix3_1e2(bench: &mut Bencher)
{
let n = 1e2 as usize;
let a = Array::linspace(-1e6, 1e6, n * n * 2*n)
.into_shape([n, n, 2*n])
.unwrap();
let v = a.slice(s![.., .., ..;2]);
bench.iter(|| black_box(&v).sum());
}

#[bench]
fn middle_discontiguous_sum_ix3_1e2(bench: &mut Bencher)
{
let n = 1e2 as usize;
let a = Array::linspace(-1e6, 1e6, n * 2*n * n)
.into_shape([n, 2*n, n])
.unwrap();
let v = a.slice(s![.., ..;2, ..]);
bench.iter(|| black_box(&v).sum());
}

#[bench]
fn sum_by_row_1e4(bench: &mut Bencher)
{
Expand All @@ -88,3 +120,15 @@ fn sum_by_col_1e4(bench: &mut Bencher)
a.sum_axis(Axis(1))
});
}

#[bench]
fn sum_by_middle_1e2(bench: &mut Bencher)
{
let n = 1e2 as usize;
let a = Array::linspace(-1e6, 1e6, n * n * n)
.into_shape([n, n, n])
.unwrap();
bench.iter(|| {
a.sum_axis(Axis(1))
});
}
26 changes: 8 additions & 18 deletions src/numeric/impl_numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

use std::ops::{Add, Div, Mul};
use num_traits::{self, Zero, Float, FromPrimitive};
use itertools::free::enumerate;

use crate::imp_prelude::*;
use crate::numeric_util;
Expand All @@ -33,17 +32,10 @@ impl<A, S, D> ArrayBase<S, D>
where A: Clone + Add<Output=A> + num_traits::Zero,
{
if let Some(slc) = self.as_slice_memory_order() {
return numeric_util::pairwise_sum(&slc)
}
let mut sum = A::zero();
for row in self.inner_rows() {
if let Some(slc) = row.as_slice() {
sum = sum + numeric_util::pairwise_sum(&slc);
} else {
sum = sum + numeric_util::iterator_pairwise_sum(row.iter());
}
numeric_util::pairwise_sum(&slc)
} else {
numeric_util::iterator_pairwise_sum(self.iter())
}
sum
}

/// Return the sum of all elements in the array.
Expand Down Expand Up @@ -104,16 +96,14 @@ impl<A, S, D> ArrayBase<S, D>
D: RemoveAxis,
{
let n = self.len_of(axis);
let stride = self.strides()[axis.index()];
if self.ndim() == 2 && stride == 1 {
if self.stride_of(axis) == 1 {
// contiguous along the axis we are summing
let mut res = Array::zeros(self.raw_dim().remove_axis(axis));
let ax = axis.index();
for (i, elt) in enumerate(&mut res) {
*elt = self.index_axis(Axis(1 - ax), i).sum();
}
Zip::from(&mut res)
.and(self.lanes(axis))
.apply(|sum, lane| *sum = lane.sum());
res
} else if self.len_of(axis) <= numeric_util::NAIVE_SUM_THRESHOLD {
} else if n <= numeric_util::NAIVE_SUM_THRESHOLD {
self.fold_axis(axis, A::zero(), |acc, x| acc.clone() + x.clone())
} else {
let (v1, v2) = self.view().split_at(axis, n / 2);
Expand Down
32 changes: 24 additions & 8 deletions src/numeric_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,18 @@ where
I: Iterator<Item=&'a A>,
A: Clone + Add<Output=A> + Zero,
{
let mut partial_sums = vec![];
let mut partial_sum = A::zero();
for (i, x) in iter.enumerate() {
partial_sum = partial_sum + x.clone();
if i % NAIVE_SUM_THRESHOLD == NAIVE_SUM_THRESHOLD - 1 {
let (len, _) = iter.size_hint();
let cap = len.saturating_sub(1) / NAIVE_SUM_THRESHOLD + 1; // ceiling of division
let mut partial_sums = Vec::with_capacity(cap);
let (_, last_sum) = iter.fold((0, A::zero()), |(count, partial_sum), x| {
if count < NAIVE_SUM_THRESHOLD {
(count + 1, partial_sum + x.clone())
} else {
partial_sums.push(partial_sum);
partial_sum = A::zero();
(1, x.clone())
}
}
partial_sums.push(partial_sum);
});
partial_sums.push(last_sum);

pure_pairwise_sum(&partial_sums)
}
Expand Down Expand Up @@ -205,3 +207,17 @@ pub fn unrolled_eq<A>(xs: &[A], ys: &[A]) -> bool

true
}

#[cfg(test)]
mod tests {
use quickcheck::quickcheck;
use std::num::Wrapping;
use super::iterator_pairwise_sum;

quickcheck! {
fn iterator_pairwise_sum_is_correct(xs: Vec<i32>) -> bool {
let xs: Vec<_> = xs.into_iter().map(|x| Wrapping(x)).collect();
iterator_pairwise_sum(xs.iter()) == xs.iter().sum()
}
}
}

0 comments on commit 1d51f70

Please sign in to comment.