Skip to content

Commit d137212

Browse files
committed
Use .scalar_sum() in .sum() when possible
Special case for when we are summing a contiguous axis.
1 parent b06c8a1 commit d137212

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

src/numeric/impl_numeric.rs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
use std::ops::Add;
1010
use libnum::{self, Zero, Float};
11+
use itertools::free::enumerate;
1112

1213
use imp_prelude::*;
1314
use numeric_util;
@@ -70,9 +71,18 @@ impl<A, S, D> ArrayBase<S, D>
7071
{
7172
let n = self.shape().axis(axis);
7273
let mut res = self.subview(axis, 0).to_owned();
73-
for i in 1..n {
74-
let view = self.subview(axis, i);
75-
res = res + &view;
74+
let stride = self.strides()[axis.axis()];
75+
if self.ndim() == 2 && stride == 1 {
76+
// contiguous along the axis we are summing
77+
let ax = axis.axis();
78+
for (i, elt) in enumerate(&mut res) {
79+
*elt = self.subview(Axis(1 - ax), i).scalar_sum();
80+
}
81+
} else {
82+
for i in 1..n {
83+
let view = self.subview(axis, i);
84+
res = res + &view;
85+
}
7686
}
7787
res
7888
}

0 commit comments

Comments
 (0)