Skip to content

Commit 3faaa7e

Browse files
committed
Add accumulate_axis_inplace method
1 parent 26f7762 commit 3faaa7e

File tree

2 files changed

+97
-0
lines changed

2 files changed

+97
-0
lines changed

src/impl_methods.rs

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2138,4 +2138,60 @@ where
21382138
}
21392139
})
21402140
}
2141+
2142+
/// Iterates over pairs of consecutive elements along the axis.
2143+
///
2144+
/// The first argument to the closure is an element, and the second
2145+
/// argument is the next element along the axis. Iteration is guaranteed to
2146+
/// proceed in order along the specified axis, but in all other respects
2147+
/// the iteration order is unspecified.
2148+
///
2149+
/// # Example
2150+
///
2151+
/// For example, this can be used to compute the cumulative sum along an
2152+
/// axis:
2153+
///
2154+
/// ```
2155+
/// use ndarray::{array, Axis};
2156+
///
2157+
/// let mut arr = array![
2158+
/// [[1, 2], [3, 4], [5, 6]],
2159+
/// [[7, 8], [9, 10], [11, 12]],
2160+
/// ];
2161+
/// arr.accumulate_axis_inplace(Axis(1), |&prev, curr| *curr += prev);
2162+
/// assert_eq!(
2163+
/// arr,
2164+
/// array![
2165+
/// [[1, 2], [4, 6], [9, 12]],
2166+
/// [[7, 8], [16, 18], [27, 30]],
2167+
/// ],
2168+
/// );
2169+
/// ```
2170+
pub fn accumulate_axis_inplace<F>(&mut self, axis: Axis, mut f: F)
2171+
where
2172+
F: FnMut(&A, &mut A),
2173+
S: DataMut,
2174+
{
2175+
if self.len_of(axis) <= 1 {
2176+
return;
2177+
}
2178+
let mut prev = self.raw_view();
2179+
prev.slice_axis_inplace(axis, Slice::from(..-1));
2180+
let mut curr = self.raw_view_mut();
2181+
curr.slice_axis_inplace(axis, Slice::from(1..));
2182+
// This implementation relies on `Zip` iterating along `axis` in order.
2183+
Zip::from(prev).and(curr).apply(|prev, curr| unsafe {
2184+
// These pointer dereferences and borrows are safe because:
2185+
//
2186+
// 1. They're pointers to elements in the array.
2187+
//
2188+
// 2. `S: DataMut` guarantees that elements are safe to borrow
2189+
// mutably and that they don't alias.
2190+
//
2191+
// 3. The lifetimes of the borrows last only for the duration
2192+
// of the call to `f`, so aliasing across calls to `f`
2193+
// cannot occur.
2194+
f(&*prev, &mut *curr)
2195+
});
2196+
}
21412197
}

tests/array.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2004,6 +2004,47 @@ fn test_map_axis() {
20042004
assert_eq!(c, answer2);
20052005
}
20062006

2007+
#[test]
2008+
fn test_accumulate_axis_inplace_noop() {
2009+
let mut a = Array2::<u8>::zeros((0, 3));
2010+
a.accumulate_axis_inplace(Axis(0), |&prev, curr| *curr += prev);
2011+
assert_eq!(a, Array2::zeros((0, 3)));
2012+
2013+
let mut a = Array2::<u8>::zeros((3, 1));
2014+
a.accumulate_axis_inplace(Axis(1), |&prev, curr| *curr += prev);
2015+
assert_eq!(a, Array2::zeros((3, 1)));
2016+
}
2017+
2018+
#[test]
2019+
fn test_accumulate_axis_inplace_nonstandard_layout() {
2020+
let a = arr2(&[[1, 2, 3],
2021+
[4, 5, 6],
2022+
[7, 8, 9],
2023+
[10,11,12]]);
2024+
2025+
let mut a_t = a.clone().reversed_axes();
2026+
a_t.accumulate_axis_inplace(Axis(0), |&prev, curr| *curr += prev);
2027+
assert_eq!(a_t, aview2(&[[1, 4, 7, 10],
2028+
[3, 9, 15, 21],
2029+
[6, 15, 24, 33]]));
2030+
2031+
let mut a0 = a.clone();
2032+
a0.invert_axis(Axis(0));
2033+
a0.accumulate_axis_inplace(Axis(0), |&prev, curr| *curr += prev);
2034+
assert_eq!(a0, aview2(&[[10, 11, 12],
2035+
[17, 19, 21],
2036+
[21, 24, 27],
2037+
[22, 26, 30]]));
2038+
2039+
let mut a1 = a.clone();
2040+
a1.invert_axis(Axis(1));
2041+
a1.accumulate_axis_inplace(Axis(1), |&prev, curr| *curr += prev);
2042+
assert_eq!(a1, aview2(&[[3, 5, 6],
2043+
[6, 11, 15],
2044+
[9, 17, 24],
2045+
[12, 23, 33]]));
2046+
}
2047+
20072048
#[test]
20082049
fn test_to_vec() {
20092050
let mut a = arr2(&[[1, 2, 3],

0 commit comments

Comments
 (0)