Skip to content

Commit

Permalink
Add diff method as an equivalent to numpy.diff (#1437)
Browse files Browse the repository at this point in the history
* implement forward finite differneces on arrays

* implement tests for the  method

* remove some heap allocations
  • Loading branch information
johann-cm authored Sep 26, 2024
1 parent c7ebd35 commit fce6034
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 1 deletion.
58 changes: 57 additions & 1 deletion src/numeric/impl_numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
use num_traits::Float;
use num_traits::One;
use num_traits::{FromPrimitive, Zero};
use std::ops::{Add, Div, Mul};
use std::ops::{Add, Div, Mul, Sub};

use crate::imp_prelude::*;
use crate::numeric_util;
use crate::Slice;

/// # Numerical Methods for Arrays
impl<A, S, D> ArrayBase<S, D>
Expand Down Expand Up @@ -437,4 +438,59 @@ where
{
self.var_axis(axis, ddof).mapv_into(|x| x.sqrt())
}

/// Calculates the (forward) finite differences of order `n`, along the `axis`.
/// For the 1D-case, `n==1`, this means: `diff[i] == arr[i+1] - arr[i]`
///
/// For `n>=2`, the process is iterated:
/// ```
/// use ndarray::{array, Axis};
/// let arr = array![1.0, 2.0, 5.0];
/// assert_eq!(arr.diff(2, Axis(0)), arr.diff(1, Axis(0)).diff(1, Axis(0)))
/// ```
/// **Panics** if `axis` is out of bounds
///
/// **Panics** if `n` is too big / the array is to short:
/// ```should_panic
/// use ndarray::{array, Axis};
/// array![1.0, 2.0, 3.0].diff(10, Axis(0));
/// ```
pub fn diff(&self, n: usize, axis: Axis) -> Array<A, D>
where A: Sub<A, Output = A> + Zero + Clone
{
if n == 0 {
return self.to_owned();
}
assert!(axis.0 < self.ndim(), "The array has only ndim {}, but `axis` {:?} is given.", self.ndim(), axis);
assert!(
n < self.shape()[axis.0],
"The array must have length at least `n+1`=={} in the direction of `axis`. It has length {}",
n + 1,
self.shape()[axis.0]
);

let mut inp = self.to_owned();
let mut out = Array::zeros({
let mut inp_dim = self.raw_dim();
// inp_dim[axis.0] >= 1 as per the 2nd assertion.
inp_dim[axis.0] -= 1;
inp_dim
});
for _ in 0..n {
let head = inp.slice_axis(axis, Slice::from(..-1));
let tail = inp.slice_axis(axis, Slice::from(1..));

azip!((o in &mut out, h in head, t in tail) *o = t.clone() - h.clone());

// feed the output as the input to the next iteration
std::mem::swap(&mut inp, &mut out);

// adjust the new output array width along `axis`.
// Current situation: width of `inp`: k, `out`: k+1
// needed width: `inp`: k, `out`: k-1
// slice is possible, since k >= 1.
out.slice_axis_inplace(axis, Slice::from(..-2));
}
inp
}
}
68 changes: 68 additions & 0 deletions tests/numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,3 +336,71 @@ fn std_axis_empty_axis()
assert_eq!(v.shape(), &[2]);
v.mapv(|x| assert!(x.is_nan()));
}

#[test]
fn diff_1d_order1()
{
let data = array![1.0, 2.0, 4.0, 7.0];
let expected = array![1.0, 2.0, 3.0];
assert_eq!(data.diff(1, Axis(0)), expected);
}

#[test]
fn diff_1d_order2()
{
let data = array![1.0, 2.0, 4.0, 7.0];
assert_eq!(
data.diff(2, Axis(0)),
data.diff(1, Axis(0)).diff(1, Axis(0))
);
}

#[test]
fn diff_1d_order3()
{
let data = array![1.0, 2.0, 4.0, 7.0];
assert_eq!(
data.diff(3, Axis(0)),
data.diff(1, Axis(0)).diff(1, Axis(0)).diff(1, Axis(0))
);
}

#[test]
fn diff_2d_order1_ax0()
{
let data = array![
[1.0, 2.0, 4.0, 7.0],
[1.0, 3.0, 6.0, 6.0],
[1.5, 3.5, 5.5, 5.5]
];
let expected = array![[0.0, 1.0, 2.0, -1.0], [0.5, 0.5, -0.5, -0.5]];
assert_eq!(data.diff(1, Axis(0)), expected);
}

#[test]
fn diff_2d_order1_ax1()
{
let data = array![
[1.0, 2.0, 4.0, 7.0],
[1.0, 3.0, 6.0, 6.0],
[1.5, 3.5, 5.5, 5.5]
];
let expected = array![[1.0, 2.0, 3.0], [2.0, 3.0, 0.0], [2.0, 2.0, 0.0]];
assert_eq!(data.diff(1, Axis(1)), expected);
}

#[test]
#[should_panic]
fn diff_panic_n_too_big()
{
let data = array![1.0, 2.0, 4.0, 7.0];
data.diff(10, Axis(0));
}

#[test]
#[should_panic]
fn diff_panic_axis_out_of_bounds()
{
let data = array![1, 2, 4, 7];
data.diff(1, Axis(2));
}

0 comments on commit fce6034

Please sign in to comment.