diff --git a/examples/axis.rs b/examples/axis.rs new file mode 100644 index 000000000..2982830d8 --- /dev/null +++ b/examples/axis.rs @@ -0,0 +1,16 @@ +extern crate ndarray; + +use ndarray::{ + OwnedArray, + Axis, + Axis0, Axis1, Axis2, +}; + +fn main() { + let mut a = OwnedArray::::linspace(0., 24., 25).into_shape((5, 5)).unwrap(); + println!("{:?}", a.subview(Axis0, 0)); + println!("{:?}", a.subview(Axis0, 1)); + println!("{:?}", a.subview(Axis1, 1)); + //println!("{:?}", a.subview(Axis2, 1)); + println!("{:?}", a.subview(Axis(0), 1)); +} diff --git a/src/dimension.rs b/src/dimension.rs index 1440ffdf5..5eed27c3a 100644 --- a/src/dimension.rs +++ b/src/dimension.rs @@ -1,4 +1,5 @@ use std::slice; +use std::marker::PhantomData; use super::{Si, Ix, Ixs}; use super::zipsl; @@ -715,3 +716,80 @@ mod test { assert!(super::dim_stride_overlap(&dim, &strides)); } } + +#[derive(Debug)] +pub struct AxisForDimension { + axis: usize, + dim: PhantomData, +} + +impl AxisForDimension { + pub fn axis(&self) -> usize { self.axis } +} + +impl Copy for AxisForDimension { } +impl Clone for AxisForDimension { + fn clone(&self) -> Self { *self } +} + +impl PartialEq for AxisForDimension { + fn eq(&self, rhs: &Self) -> bool { + self.axis == rhs.axis + } +} + +impl From for AxisForDimension { + fn from(x: Axis) -> Self { + AxisForDimension { + axis: x.0, + dim: PhantomData, + } + } +} + +#[cfg_attr(has_deprecated, deprecated(note="Usize arguments for `axis` are deprecated. Use `Axis` instead."))] +impl From for AxisForDimension { + fn from(x: usize) -> Self { + AxisForDimension { + axis: x, + dim: PhantomData, + } + } +} + +#[derive(Copy, Clone, Debug)] +pub struct Axis0; +#[derive(Copy, Clone, Debug)] +pub struct Axis1; +#[derive(Copy, Clone, Debug)] +pub struct Axis2; +#[derive(Copy, Clone, Debug)] +pub struct Axis3; + +#[derive(Copy, Clone, Debug)] +pub struct Axis(pub usize); + +impl Into for Axis0 { #[inline] fn into(self) -> usize { 0 } } +impl Into for Axis1 { #[inline] fn into(self) -> usize { 1 } } +impl Into for Axis2 { #[inline] fn into(self) -> usize { 2 } } +impl Into for Axis3 { #[inline] fn into(self) -> usize { 3 } } +impl Into for Axis { #[inline] fn into(self) -> usize { self.0 } } + +macro_rules! ax_for_dim { + ($dim:ty, $($ax:ident),*) => { + $( + impl From<$ax> for AxisForDimension<$dim> { + fn from(x: $ax) -> Self { + AxisForDimension { + axis: x.into(), + dim: PhantomData, + } + } + } + )* + } +} +ax_for_dim!{Ix, Axis0} +ax_for_dim!{(Ix, Ix), Axis0, Axis1} +ax_for_dim!{(Ix, Ix, Ix), Axis0, Axis1, Axis2} +ax_for_dim!{(Ix, Ix, Ix, Ix), Axis0, Axis1, Axis2, Axis3} diff --git a/src/lib.rs b/src/lib.rs index f840aaa5f..bd0691dce 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -92,6 +92,7 @@ use itertools::free::enumerate; pub use dimension::{ Dimension, RemoveAxis, + Axis0, Axis1, Axis2, Axis3, Axis, }; pub use dimension::NdIndex; @@ -112,6 +113,10 @@ pub use iterators::{ pub use linalg::LinalgScalar; +pub use dimension::{ + AxisForDimension, +}; + mod arraytraits; #[cfg(feature = "serde")] mod arrayserialize; @@ -1278,9 +1283,10 @@ impl ArrayBase where S: Data, D: Dimension /// a.subview(1, 1) == arr1(&[2., 4., 6.]) /// ); /// ``` - pub fn subview(&self, axis: usize, index: Ix) + pub fn subview(&self, axis: Ax, index: Ix) -> ArrayView::Smaller> - where D: RemoveAxis + where D: RemoveAxis, + Ax: Into>, { self.view().into_subview(axis, index) } @@ -1303,10 +1309,11 @@ impl ArrayBase where S: Data, D: Dimension /// [3., 14.]]) /// ); /// ``` - pub fn subview_mut(&mut self, axis: usize, index: Ix) + pub fn subview_mut(&mut self, axis: Ax, index: Ix) -> ArrayViewMut where S: DataMut, D: RemoveAxis, + Ax: Into>, { self.view_mut().into_subview(axis, index) } @@ -1315,7 +1322,10 @@ impl ArrayBase where S: Data, D: Dimension /// and select the subview of `index` along that axis. /// /// **Panics** if `index` is past the length of the axis. - pub fn isubview(&mut self, axis: usize, index: Ix) { + pub fn isubview(&mut self, axis: Ax, index: Ix) + where Ax: Into>, + { + let axis = axis.into().axis(); dimension::do_sub(&mut self.dim, &mut self.ptr, &self.strides, axis, index) } @@ -1323,11 +1333,14 @@ impl ArrayBase where S: Data, D: Dimension /// with that axis removed. /// /// See [`.subview()`](#method.subview) and [*Subviews*](#subviews) for full documentation. - pub fn into_subview(mut self, axis: usize, index: Ix) + pub fn into_subview(mut self, axis: Ax, index: Ix) -> ArrayBase::Smaller> - where D: RemoveAxis + where D: RemoveAxis, + Ax: Into>, { + let axis = axis.into(); self.isubview(axis, index); + let axis = axis.axis(); // don't use reshape -- we always know it will fit the size, // and we can use remove_axis on the strides as well ArrayBase { @@ -1418,10 +1431,11 @@ impl ArrayBase where S: Data, D: Dimension /// See [*Subviews*](#subviews) for full documentation. /// /// **Panics** if `axis` is out of bounds. - pub fn axis_iter(&self, axis: usize) -> OuterIter - where D: RemoveAxis + pub fn axis_iter(&self, axis: Ax) -> OuterIter + where D: RemoveAxis, + Ax: Into>, { - iterators::new_axis_iter(self.view(), axis) + iterators::new_axis_iter(self.view(), axis.into().axis()) } @@ -2242,11 +2256,13 @@ impl ArrayBase /// ``` /// /// **Panics** if `axis` is out of bounds. - pub fn sum(&self, axis: usize) -> OwnedArray::Smaller> + pub fn sum(&self, axis: Ax) -> OwnedArray::Smaller> where A: Clone + Add, D: RemoveAxis, + Ax: Into>, { - let n = self.shape()[axis]; + let axis = axis.into(); + let n = self.shape()[axis.axis()]; let mut res = self.subview(axis, 0).to_owned(); for i in 1..n { let view = self.subview(axis, i); @@ -2296,11 +2312,13 @@ impl ArrayBase /// /// /// **Panics** if `axis` is out of bounds. - pub fn mean(&self, axis: usize) -> OwnedArray::Smaller> + pub fn mean(&self, axis: Ax) -> OwnedArray::Smaller> where A: LinalgScalar, D: RemoveAxis, + Ax: Into>, { - let n = self.shape()[axis]; + let axis = axis.into(); + let n = self.shape()[axis.axis()]; let mut sum = self.sum(axis); let one = libnum::one::(); let mut cnt = one; @@ -2413,7 +2431,7 @@ impl ArrayBase /// **Panics** if `index` is out of bounds. pub fn row(&self, index: Ix) -> ArrayView { - self.subview(0, index) + self.subview(Axis0, index) } /// Return a mutable array view of row `index`. @@ -2422,7 +2440,7 @@ impl ArrayBase pub fn row_mut(&mut self, index: Ix) -> ArrayViewMut where S: DataMut { - self.subview_mut(0, index) + self.subview_mut(Axis0, index) } /// Return an array view of column `index`. @@ -2430,7 +2448,7 @@ impl ArrayBase /// **Panics** if `index` is out of bounds. pub fn column(&self, index: Ix) -> ArrayView { - self.subview(1, index) + self.subview(Axis1, index) } /// Return a mutable array view of column `index`. @@ -2439,7 +2457,7 @@ impl ArrayBase pub fn column_mut(&mut self, index: Ix) -> ArrayViewMut where S: DataMut { - self.subview_mut(1, index) + self.subview_mut(Axis1, index) } /// Perform matrix multiplication of rectangular arrays `self` and `rhs`. diff --git a/tests/array.rs b/tests/array.rs index 0274dd428..0e02cf7ad 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -14,6 +14,7 @@ use ndarray::{arr0, arr1, arr2, aview_mut1, }; use ndarray::Indexes; +use ndarray::{Axis, Axis0, Axis1, Axis2}; #[test] fn test_matmul_rcarray() @@ -284,12 +285,12 @@ fn assign() fn sum_mean() { let a = arr2(&[[1., 2.], [3., 4.]]); - assert_eq!(a.sum(0), arr1(&[4., 6.])); - assert_eq!(a.sum(1), arr1(&[3., 7.])); - assert_eq!(a.mean(0), arr1(&[2., 3.])); - assert_eq!(a.mean(1), arr1(&[1.5, 3.5])); - assert_eq!(a.sum(1).sum(0), arr0(10.)); - assert_eq!(a.view().mean(1), aview1(&[1.5, 3.5])); + assert_eq!(a.sum(Axis0), arr1(&[4., 6.])); + assert_eq!(a.sum(Axis1), arr1(&[3., 7.])); + assert_eq!(a.mean(Axis0), arr1(&[2., 3.])); + assert_eq!(a.mean(Axis1), arr1(&[1.5, 3.5])); + assert_eq!(a.sum(Axis1).sum(Axis0), arr0(10.)); + assert_eq!(a.view().mean(Axis1), aview1(&[1.5, 3.5])); assert_eq!(a.scalar_sum(), 10.); } diff --git a/tests/dimension.rs b/tests/dimension.rs index d590ce36e..b13e0236b 100644 --- a/tests/dimension.rs +++ b/tests/dimension.rs @@ -5,6 +5,8 @@ use ndarray::{ OwnedArray, RemoveAxis, arr2, + Axis, + Axis0, Axis1, Axis2, }; #[test] @@ -17,8 +19,11 @@ fn remove_axis() assert_eq!(vec![1,2].remove_axis(0), vec![2]); assert_eq!(vec![4, 5, 6].remove_axis(1), vec![4, 6]); + let a = RcArray::::zeros((4,5)); + a.subview(Axis1, 0); + let a = RcArray::::zeros(vec![4,5,6]); - let _b = a.into_subview(1, 0).reshape((4, 6)).reshape(vec![2, 3, 4]); + let _b = a.into_subview(Axis(1), 0).reshape((4, 6)).reshape(vec![2, 3, 4]); }