diff --git a/src/maybe_nan/mod.rs b/src/maybe_nan/mod.rs index 3905fa9b..15adbafd 100644 --- a/src/maybe_nan/mod.rs +++ b/src/maybe_nan/mod.rs @@ -241,6 +241,15 @@ where A: 'a, F: FnMut(B, &'a A::NotNan) -> B; + /// Traverse the non-NaN elements and their indices and apply a fold, + /// returning the resulting value. + /// + /// Elements are visited in arbitrary order. + fn indexed_fold_skipnan<'a, F, B>(&'a self, init: B, f: F) -> B + where + A: 'a, + F: FnMut(B, (D::Pattern, &'a A::NotNan)) -> B; + /// Visit each non-NaN element in the array by calling `f` on each element. /// /// Elements are visited in arbitrary order. @@ -302,6 +311,20 @@ where }) } + fn indexed_fold_skipnan<'a, F, B>(&'a self, init: B, mut f: F) -> B + where + A: 'a, + F: FnMut(B, (D::Pattern, &'a A::NotNan)) -> B, + { + self.indexed_iter().fold(init, |acc, (idx, elem)| { + if let Some(not_nan) = elem.try_as_not_nan() { + f(acc, (idx, not_nan)) + } else { + acc + } + }) + } + fn visit_skipnan<'a, F>(&'a self, mut f: F) where A: 'a, diff --git a/src/quantile.rs b/src/quantile.rs index 76c9c457..1b4b4fd6 100644 --- a/src/quantile.rs +++ b/src/quantile.rs @@ -211,6 +211,33 @@ where where A: PartialOrd; + /// Finds the index of the minimum value of the array skipping NaN values. + /// + /// Returns `None` if the array is empty or none of the values in the array + /// are non-NaN values. + /// + /// Even if there are multiple (equal) elements that are minima, only one + /// index is returned. (Which one is returned is unspecified and may depend + /// on the memory layout of the array.) + /// + /// # Example + /// + /// ``` + /// extern crate ndarray; + /// extern crate ndarray_stats; + /// + /// use ndarray::array; + /// use ndarray_stats::QuantileExt; + /// + /// let a = array![[::std::f64::NAN, 3., 5.], + /// [2., 0., 6.]]; + /// assert_eq!(a.argmin_skipnan(), Some((1, 1))); + /// ``` + fn argmin_skipnan(&self) -> Option + where + A: MaybeNan, + A::NotNan: Ord; + /// Finds the elementwise minimum of the array. /// /// Returns `None` if any of the pairwise orderings tested by the function @@ -269,6 +296,33 @@ where where A: PartialOrd; + /// Finds the index of the maximum value of the array skipping NaN values. + /// + /// Returns `None` if the array is empty or none of the values in the array + /// are non-NaN values. + /// + /// Even if there are multiple (equal) elements that are maxima, only one + /// index is returned. (Which one is returned is unspecified and may depend + /// on the memory layout of the array.) + /// + /// # Example + /// + /// ``` + /// extern crate ndarray; + /// extern crate ndarray_stats; + /// + /// use ndarray::array; + /// use ndarray_stats::QuantileExt; + /// + /// let a = array![[::std::f64::NAN, 3., 5.], + /// [2., 0., 6.]]; + /// assert_eq!(a.argmax_skipnan(), Some((1, 2))); + /// ``` + fn argmax_skipnan(&self) -> Option + where + A: MaybeNan, + A::NotNan: Ord; + /// Finds the elementwise maximum of the array. /// /// Returns `None` if any of the pairwise orderings tested by the function @@ -369,6 +423,28 @@ where Some(current_pattern_min) } + fn argmin_skipnan(&self) -> Option + where + A: MaybeNan, + A::NotNan: Ord, + { + let mut pattern_min = D::zeros(self.ndim()).into_pattern(); + let min = self.indexed_fold_skipnan(None, |current_min, (pattern, elem)| { + Some(match current_min { + Some(m) if (m <= elem) => m, + _ => { + pattern_min = pattern; + elem + } + }) + }); + if min.is_some() { + Some(pattern_min) + } else { + None + } + } + fn min(&self) -> Option<&A> where A: PartialOrd, @@ -411,6 +487,28 @@ where Some(current_pattern_max) } + fn argmax_skipnan(&self) -> Option + where + A: MaybeNan, + A::NotNan: Ord, + { + let mut pattern_max = D::zeros(self.ndim()).into_pattern(); + let max = self.indexed_fold_skipnan(None, |current_max, (pattern, elem)| { + Some(match current_max { + Some(m) if m >= elem => m, + _ => { + pattern_max = pattern; + elem + } + }) + }); + if max.is_some() { + Some(pattern_max) + } else { + None + } + } + fn max(&self) -> Option<&A> where A: PartialOrd, diff --git a/tests/quantile.rs b/tests/quantile.rs index 31b51907..3e5ba53b 100644 --- a/tests/quantile.rs +++ b/tests/quantile.rs @@ -32,6 +32,37 @@ quickcheck! { } } +#[test] +fn test_argmin_skipnan() { + let a = array![[1., 5., 3.], [2., 0., 6.]]; + assert_eq!(a.argmin_skipnan(), Some((1, 1))); + + let a = array![[1., 5., 3.], [2., ::std::f64::NAN, 6.]]; + assert_eq!(a.argmin_skipnan(), Some((0, 0))); + + let a = array![[::std::f64::NAN, 5., 3.], [2., ::std::f64::NAN, 6.]]; + assert_eq!(a.argmin_skipnan(), Some((1, 0))); + + let a: Array2 = array![[], []]; + assert_eq!(a.argmin_skipnan(), None); + + let a = arr2(&[[::std::f64::NAN; 2]; 2]); + assert_eq!(a.argmin_skipnan(), None); +} + +quickcheck! { + fn argmin_skipnan_matches_min_skipnan(data: Vec>) -> bool { + let a = Array1::from(data); + let min = a.min_skipnan(); + let argmin = a.argmin_skipnan(); + if min.is_none() { + argmin == None + } else { + a[argmin.unwrap()] == *min + } + } +} + #[test] fn test_min() { let a = array![[1, 5, 3], [2, 0, 6]]; @@ -81,6 +112,40 @@ quickcheck! { } } +#[test] +fn test_argmax_skipnan() { + let a = array![[1., 5., 3.], [2., 0., 6.]]; + assert_eq!(a.argmax_skipnan(), Some((1, 2))); + + let a = array![[1., 5., 3.], [2., ::std::f64::NAN, ::std::f64::NAN]]; + assert_eq!(a.argmax_skipnan(), Some((0, 1))); + + let a = array![ + [::std::f64::NAN, ::std::f64::NAN, 3.], + [2., ::std::f64::NAN, 6.] + ]; + assert_eq!(a.argmax_skipnan(), Some((1, 2))); + + let a: Array2 = array![[], []]; + assert_eq!(a.argmax_skipnan(), None); + + let a = arr2(&[[::std::f64::NAN; 2]; 2]); + assert_eq!(a.argmax_skipnan(), None); +} + +quickcheck! { + fn argmax_skipnan_matches_max_skipnan(data: Vec>) -> bool { + let a = Array1::from(data); + let max = a.max_skipnan(); + let argmax = a.argmax_skipnan(); + if max.is_none() { + argmax == None + } else { + a[argmax.unwrap()] == *max + } + } +} + #[test] fn test_max() { let a = array![[1, 5, 7], [2, 0, 6]];