diff --git a/src/dimension/mod.rs b/src/dimension/mod.rs index efde123f8..2584152b3 100644 --- a/src/dimension/mod.rs +++ b/src/dimension/mod.rs @@ -728,36 +728,6 @@ where } } -/// Move the axis which has the smallest absolute stride and a length -/// greater than one to be the last axis. -pub fn move_min_stride_axis_to_last(dim: &mut D, strides: &mut D) -where - D: Dimension, -{ - debug_assert_eq!(dim.ndim(), strides.ndim()); - match dim.ndim() { - 0 | 1 => {} - 2 => { - if dim[1] <= 1 - || dim[0] > 1 && (strides[0] as isize).abs() < (strides[1] as isize).abs() - { - dim.slice_mut().swap(0, 1); - strides.slice_mut().swap(0, 1); - } - } - n => { - if let Some(min_stride_axis) = (0..n) - .filter(|&ax| dim[ax] > 1) - .min_by_key(|&ax| (strides[ax] as isize).abs()) - { - let last = n - 1; - dim.slice_mut().swap(last, min_stride_axis); - strides.slice_mut().swap(last, min_stride_axis); - } - } - } -} - /// Remove axes with length one, except never removing the last axis. pub(crate) fn squeeze(dim: &mut D, strides: &mut D) where @@ -801,7 +771,9 @@ pub(crate) fn sort_axes_to_standard(dim: &mut D, strides: &mut D) where D: Dimension, { - debug_assert!(dim.ndim() > 1); + if dim.ndim() <= 1 { + return; + } debug_assert_eq!(dim.ndim(), strides.ndim()); // bubble sort axes let mut changed = true; @@ -809,6 +781,7 @@ where changed = false; for i in 0..dim.ndim() - 1 { // make sure higher stride axes sort before. + debug_assert!(strides.get_stride(Axis(i)) >= 0); if strides.get_stride(Axis(i)).abs() < strides.get_stride(Axis(i + 1)).abs() { changed = true; dim.slice_mut().swap(i, i + 1); diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 03ca09d74..aab3dabff 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -19,7 +19,7 @@ use crate::argument_traits::AssignElem; use crate::dimension; use crate::dimension::IntoDimension; use crate::dimension::{ - abs_index, axes_of, do_slice, merge_axes, move_min_stride_axis_to_last, + abs_index, axes_of, do_slice, merge_axes, offset_from_ptr_to_memory, size_of_shape_checked, stride_offset, Axes, }; use crate::dimension::broadcast::co_broadcast; @@ -316,7 +316,7 @@ where where S: Data, { - IndexedIter::new(self.view().into_elements_base()) + IndexedIter::new(self.view().into_elements_base_keep_dims()) } /// Return an iterator of indexes and mutable references to the elements of the array. @@ -329,7 +329,7 @@ where where S: DataMut, { - IndexedIterMut::new(self.view_mut().into_elements_base()) + IndexedIterMut::new(self.view_mut().into_elements_base_keep_dims()) } /// Return a sliced view of the array. @@ -2175,9 +2175,7 @@ where if let Some(slc) = self.as_slice_memory_order() { slc.iter().fold(init, f) } else { - let mut v = self.view(); - move_min_stride_axis_to_last(&mut v.dim, &mut v.strides); - v.into_elements_base().fold(init, f) + self.view().into_elements_base_any_order().fold(init, f) } } @@ -2295,9 +2293,7 @@ where match self.try_as_slice_memory_order_mut() { Ok(slc) => slc.iter_mut().for_each(f), Err(arr) => { - let mut v = arr.view_mut(); - move_min_stride_axis_to_last(&mut v.dim, &mut v.strides); - v.into_elements_base().for_each(f); + arr.view_mut().into_elements_base_any_order().for_each(f); } } } diff --git a/src/impl_views/conversions.rs b/src/impl_views/conversions.rs index cfd7f9aa0..f0e0b4dd2 100644 --- a/src/impl_views/conversions.rs +++ b/src/impl_views/conversions.rs @@ -10,9 +10,9 @@ use alloc::slice; use crate::imp_prelude::*; -use crate::{Baseiter, ElementsBase, ElementsBaseMut, Iter, IterMut}; - -use crate::iter::{self, AxisIter, AxisIterMut}; +use crate::iter::{self, Iter, IterMut, AxisIter, AxisIterMut}; +use crate::iterators::base::{Baseiter, ElementsBase, ElementsBaseMut, OrderOption, PreserveOrder, + ArbitraryOrder, NoOptimization}; use crate::math_cell::MathCell; use crate::IndexLonger; @@ -140,14 +140,25 @@ impl<'a, A, D> ArrayView<'a, A, D> where D: Dimension, { + /// Create a base iter fromt the view with the given order option + #[inline] + pub(crate) fn into_base_iter(self) -> Baseiter { + unsafe { Baseiter::new_with_order::(self.ptr.as_ptr(), self.dim, self.strides) } + } + + #[inline] + pub(crate) fn into_elements_base_keep_dims(self) -> ElementsBase<'a, A, D> { + ElementsBase::new::(self) + } + #[inline] - pub(crate) fn into_base_iter(self) -> Baseiter { - unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) } + pub(crate) fn into_elements_base_preserve_order(self) -> ElementsBase<'a, A, D> { + ElementsBase::new::(self) } #[inline] - pub(crate) fn into_elements_base(self) -> ElementsBase<'a, A, D> { - ElementsBase::new(self) + pub(crate) fn into_elements_base_any_order(self) -> ElementsBase<'a, A, D> { + ElementsBase::new::(self) } pub(crate) fn into_iter_(self) -> Iter<'a, A, D> { @@ -179,16 +190,28 @@ where unsafe { RawArrayViewMut::new(self.ptr, self.dim, self.strides) } } + /// Create a base iter fromt the view with the given order option #[inline] - pub(crate) fn into_base_iter(self) -> Baseiter { - unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) } + pub(crate) fn into_base_iter(self) -> Baseiter { + unsafe { Baseiter::new_with_order::(self.ptr.as_ptr(), self.dim, self.strides) } } #[inline] - pub(crate) fn into_elements_base(self) -> ElementsBaseMut<'a, A, D> { - ElementsBaseMut::new(self) + pub(crate) fn into_elements_base_keep_dims(self) -> ElementsBaseMut<'a, A, D> { + ElementsBaseMut::new::(self) } + #[inline] + pub(crate) fn into_elements_base_preserve_order(self) -> ElementsBaseMut<'a, A, D> { + ElementsBaseMut::new::(self) + } + + #[inline] + pub(crate) fn into_elements_base_any_order(self) -> ElementsBaseMut<'a, A, D> { + ElementsBaseMut::new::(self) + } + + /// Return the array’s data as a slice, if it is contiguous and in standard order. /// Otherwise return self in the Err branch of the result. pub(crate) fn try_into_slice(self) -> Result<&'a mut [A], Self> { diff --git a/src/iterators/base.rs b/src/iterators/base.rs index fbd98a2a2..441dd0f8d 100644 --- a/src/iterators/base.rs +++ b/src/iterators/base.rs @@ -57,16 +57,6 @@ pub(crate) struct Baseiter { index: Option, } -impl Baseiter { - /// Creating a Baseiter is unsafe because shape and stride parameters need - /// to be correct to avoid performing an unsafe pointer offset while - /// iterating. - #[inline] - pub unsafe fn new(ptr: *mut A, dim: D, strides: D) -> Baseiter { - Self::new_with_order::(ptr, dim, strides) - } -} - impl Baseiter { /// Creating a Baseiter is unsafe because shape and stride parameters need /// to be correct to avoid performing an unsafe pointer offset while @@ -246,9 +236,9 @@ clone_bounds!( ); impl<'a, A, D: Dimension> ElementsBase<'a, A, D> { - pub fn new(v: ArrayView<'a, A, D>) -> Self { + pub fn new(v: ArrayView<'a, A, D>) -> Self { ElementsBase { - inner: v.into_base_iter(), + inner: v.into_base_iter::(), life: PhantomData, } } @@ -332,7 +322,7 @@ where inner: if let Some(slc) = self_.to_slice() { ElementsRepr::Slice(slc.iter()) } else { - ElementsRepr::Counted(self_.into_elements_base()) + ElementsRepr::Counted(self_.into_elements_base_preserve_order()) }, } } @@ -346,7 +336,7 @@ where IterMut { inner: match self_.try_into_slice() { Ok(x) => ElementsRepr::Slice(x.iter_mut()), - Err(self_) => ElementsRepr::Counted(self_.into_elements_base()), + Err(self_) => ElementsRepr::Counted(self_.into_elements_base_preserve_order()), }, } } @@ -391,9 +381,9 @@ pub(crate) struct ElementsBaseMut<'a, A, D> { } impl<'a, A, D: Dimension> ElementsBaseMut<'a, A, D> { - pub fn new(v: ArrayViewMut<'a, A, D>) -> Self { + pub fn new(v: ArrayViewMut<'a, A, D>) -> Self { ElementsBaseMut { - inner: v.into_base_iter(), + inner: v.into_base_iter::(), life: PhantomData, } } diff --git a/src/iterators/chunks.rs b/src/iterators/chunks.rs index e41c1bf25..ba84cedc3 100644 --- a/src/iterators/chunks.rs +++ b/src/iterators/chunks.rs @@ -79,7 +79,7 @@ where type IntoIter = ExactChunksIter<'a, A, D>; fn into_iter(self) -> Self::IntoIter { ExactChunksIter { - iter: self.base.into_elements_base(), + iter: self.base.into_elements_base_any_order(), chunk: self.chunk, inner_strides: self.inner_strides, } @@ -169,7 +169,7 @@ where type IntoIter = ExactChunksIterMut<'a, A, D>; fn into_iter(self) -> Self::IntoIter { ExactChunksIterMut { - iter: self.base.into_elements_base(), + iter: self.base.into_elements_base_any_order(), chunk: self.chunk, inner_strides: self.inner_strides, } diff --git a/src/iterators/lanes.rs b/src/iterators/lanes.rs index a06ee906a..921ae9b1b 100644 --- a/src/iterators/lanes.rs +++ b/src/iterators/lanes.rs @@ -3,6 +3,7 @@ use std::marker::PhantomData; use crate::imp_prelude::*; use crate::{Layout, NdProducer}; use crate::iterators::Baseiter; +use crate::iterators::base::NoOptimization; impl_ndproducer! { ['a, A, D: Dimension] @@ -83,7 +84,7 @@ where type IntoIter = LanesIter<'a, A, D>; fn into_iter(self) -> Self::IntoIter { LanesIter { - iter: self.base.into_base_iter(), + iter: self.base.into_base_iter::(), inner_len: self.inner_len, inner_stride: self.inner_stride, life: PhantomData, @@ -134,7 +135,7 @@ where type IntoIter = LanesIterMut<'a, A, D>; fn into_iter(self) -> Self::IntoIter { LanesIterMut { - iter: self.base.into_base_iter(), + iter: self.base.into_base_iter::(), inner_len: self.inner_len, inner_stride: self.inner_stride, life: PhantomData, diff --git a/src/iterators/mod.rs b/src/iterators/mod.rs index 8e4ca52aa..7b38c6cc2 100644 --- a/src/iterators/mod.rs +++ b/src/iterators/mod.rs @@ -10,7 +10,7 @@ mod macros; mod axis; -mod base; +pub(crate) mod base; mod chunks; pub mod iter; mod lanes; diff --git a/src/iterators/windows.rs b/src/iterators/windows.rs index 4538f7abb..2933b5317 100644 --- a/src/iterators/windows.rs +++ b/src/iterators/windows.rs @@ -77,7 +77,7 @@ where type IntoIter = WindowsIter<'a, A, D>; fn into_iter(self) -> Self::IntoIter { WindowsIter { - iter: self.base.into_elements_base(), + iter: self.base.into_elements_base_any_order(), window: self.window, strides: self.strides, } diff --git a/src/lib.rs b/src/lib.rs index 7e218c54e..f35f42385 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -147,8 +147,7 @@ pub use crate::slice::{ MultiSliceArg, NewAxis, Slice, SliceArg, SliceInfo, SliceInfoElem, SliceNextDim, }; -use crate::iterators::Baseiter; -use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut}; +use crate::iterators::{ElementsBase, ElementsBaseMut}; pub use crate::arraytraits::AsArray; #[cfg(feature = "std")] diff --git a/tests/windows.rs b/tests/windows.rs index b0482e4bd..134f0ecb2 100644 --- a/tests/windows.rs +++ b/tests/windows.rs @@ -5,6 +5,9 @@ clippy::many_single_char_names )] +use std::collections::HashSet; +use std::hash::Hash; + use ndarray::prelude::*; use ndarray::Zip; @@ -117,6 +120,20 @@ fn test_window_zip() { } } +fn set(iter: impl IntoIterator) -> HashSet +where + T: Eq + Hash +{ + iter.into_iter().collect() +} + +/// Assert equal sets (same collection but order doesn't matter) +macro_rules! assert_set_eq { + ($a:expr, $b:expr) => { + assert_eq!(set($a), set($b)) + } +} + #[test] fn test_window_neg_stride() { let array = Array::from_iter(1..10).into_shape((3, 3)).unwrap(); @@ -131,24 +148,24 @@ fn test_window_neg_stride() { answer.invert_axis(Axis(1)); answer.map_inplace(|a| a.invert_axis(Axis(1))); - itertools::assert_equal( + assert_set_eq!( array.slice(s![.., ..;-1]).windows((2, 2)), - answer.iter() + answer.iter().map(Array::view) ); answer.invert_axis(Axis(0)); answer.map_inplace(|a| a.invert_axis(Axis(0))); - itertools::assert_equal( + assert_set_eq!( array.slice(s![..;-1, ..;-1]).windows((2, 2)), - answer.iter() + answer.iter().map(Array::view) ); answer.invert_axis(Axis(1)); answer.map_inplace(|a| a.invert_axis(Axis(1))); - itertools::assert_equal( + assert_set_eq!( array.slice(s![..;-1, ..]).windows((2, 2)), - answer.iter() + answer.iter().map(Array::view) ); }