From 8cabaaf333f8e25b05f1090a63b5f0a1d97d222e Mon Sep 17 00:00:00 2001 From: bluss Date: Thu, 17 Dec 2015 04:25:40 +0100 Subject: [PATCH 1/4] Axis-typenum experiments: Integers for axis numbers and optional checked axis param --- src/dimension.rs | 99 ++++++++++++++++++++++++++++++++++++++++++++++ tests/dimension.rs | 6 ++- 2 files changed, 104 insertions(+), 1 deletion(-) diff --git a/src/dimension.rs b/src/dimension.rs index 1440ffdf5..d656566dc 100644 --- a/src/dimension.rs +++ b/src/dimension.rs @@ -141,6 +141,7 @@ pub unsafe trait Dimension : Clone + Eq { /// The easiest way to create a `&SliceArg` is using the macro /// [`s![]`](macro.s!.html). type SliceArg: ?Sized + AsRef<[Si]>; + type AxisArg: Copy + Axis; #[doc(hidden)] fn ndim(&self) -> usize; #[doc(hidden)] @@ -355,6 +356,7 @@ pub fn do_sub(dims: &mut D, ptr: &mut *mut A, strides: &D, unsafe impl Dimension for () { type SliceArg = [Si; 0]; + type AxisArg = VoidAxis; // empty product is 1 -> size is 1 #[inline] fn ndim(&self) -> usize { 0 } @@ -364,6 +366,7 @@ unsafe impl Dimension for () { unsafe impl Dimension for Ix { type SliceArg = [Si; 1]; + type AxisArg = Axes0; #[inline] fn ndim(&self) -> usize { 1 } #[inline] @@ -411,6 +414,7 @@ unsafe impl Dimension for Ix { unsafe impl Dimension for (Ix, Ix) { type SliceArg = [Si; 2]; + type AxisArg = Axes1; #[inline] fn ndim(&self) -> usize { 2 } @@ -479,6 +483,7 @@ unsafe impl Dimension for (Ix, Ix) { unsafe impl Dimension for (Ix, Ix, Ix) { type SliceArg = [Si; 3]; + type AxisArg = Axes2; #[inline] fn ndim(&self) -> usize { 3 } #[inline] @@ -515,6 +520,7 @@ macro_rules! large_dim { ($n:expr, $($ix:ident),+) => ( unsafe impl Dimension for ($($ix),+) { type SliceArg = [Si; $n]; + type AxisArg = usize; #[inline] fn ndim(&self) -> usize { $n } } @@ -536,6 +542,7 @@ large_dim!(12, Ix, Ix, Ix, Ix, Ix, Ix, Ix, Ix, Ix, Ix, Ix, Ix); unsafe impl Dimension for Vec { type SliceArg = [Si]; + type AxisArg = usize; fn ndim(&self) -> usize { self.len() } fn slice(&self) -> &[Ix] { self } fn slice_mut(&mut self) -> &mut [Ix] { self } @@ -715,3 +722,95 @@ mod test { assert!(super::dim_stride_overlap(&dim, &strides)); } } + +pub trait Axis { + fn axis(&self) -> usize; +} + +#[derive(Copy, Clone, Debug)] +pub enum VoidAxis { } +#[derive(Copy, Clone, Debug)] +pub struct Axes0(usize); +#[derive(Copy, Clone, Debug)] +pub struct Axes1(usize); +#[derive(Copy, Clone, Debug)] +pub struct Axes2(usize); +#[derive(Copy, Clone, Debug)] +pub struct Axes3(usize); + +impl Axis for VoidAxis { fn axis(&self) -> usize { match *self { } } } +impl Axis for Axes0 { fn axis(&self) -> usize { self.0 } } +impl Axis for Axes1 { fn axis(&self) -> usize { self.0 } } +impl Axis for Axes2 { fn axis(&self) -> usize { self.0 } } +impl Axis for Axes3 { fn axis(&self) -> usize { self.0 } } + +#[derive(Copy, Clone, Debug)] +pub struct Ax0; +#[derive(Copy, Clone, Debug)] +pub struct Ax1; +#[derive(Copy, Clone, Debug)] +pub struct Ax2; +#[derive(Copy, Clone, Debug)] +pub struct Ax3; + +impl Axis for Ax0 { fn axis(&self) -> usize { 0 } } +impl Axis for Ax1 { fn axis(&self) -> usize { 1 } } +impl Axis for Ax2 { fn axis(&self) -> usize { 2 } } +impl Axis for Ax3 { fn axis(&self) -> usize { 3 } } +impl Axis for usize { fn axis(&self) -> usize { *self } } + +impl Into for usize { + fn into(self) -> VoidAxis { + panic!("VoidAxis: zero-dimensional arrays have no axes") + } +} + +impl Into for Ax0 { fn into(self) -> usize { 0 } } +impl Into for Ax1 { fn into(self) -> usize { 1 } } +impl Into for Ax2 { fn into(self) -> usize { 2 } } +impl Into for Ax3 { fn into(self) -> usize { 3 } } + +impl Into for Ax0 { + fn into(self) -> Axes0 { Axes0(0) } +} + +impl Into for usize { + fn into(self) -> Axes0 { + assert!(self == 0); + Axes0(self) + } +} + +impl Into for Ax0 { + fn into(self) -> Axes1 { Axes1(self.axis()) } +} + +impl Into for Ax1 { + fn into(self) -> Axes1 { Axes1(self.axis()) } +} + +impl Into for usize { + fn into(self) -> Axes1 { + assert!(self <= 1); + Axes1(self) + } +} + +impl Into for Ax0 { + fn into(self) -> Axes2 { Axes2(self.axis()) } +} + +impl Into for Ax1 { + fn into(self) -> Axes2 { Axes2(self.axis()) } +} + +impl Into for Ax2 { + fn into(self) -> Axes2 { Axes2(self.axis()) } +} + +impl Into for usize { + fn into(self) -> Axes2 { + assert!(self <= 2); + Axes2(self) + } +} diff --git a/tests/dimension.rs b/tests/dimension.rs index d590ce36e..a9625d597 100644 --- a/tests/dimension.rs +++ b/tests/dimension.rs @@ -5,6 +5,7 @@ use ndarray::{ OwnedArray, RemoveAxis, arr2, + Ax0, Ax1, Ax2, }; #[test] @@ -17,8 +18,11 @@ fn remove_axis() assert_eq!(vec![1,2].remove_axis(0), vec![2]); assert_eq!(vec![4, 5, 6].remove_axis(1), vec![4, 6]); + let a = RcArray::::zeros((4,5)); + a.subview(Ax1, 0); + let a = RcArray::::zeros(vec![4,5,6]); - let _b = a.into_subview(1, 0).reshape((4, 6)).reshape(vec![2, 3, 4]); + let _b = a.subview(Ax1, 0).reshape((4, 6)).reshape(vec![2, 3, 4]); } From 9c90145469c74047701372584bdaababbd70dcc2 Mon Sep 17 00:00:00 2001 From: bluss Date: Sun, 28 Feb 2016 22:27:44 +0100 Subject: [PATCH 2/4] Use a simpler Axis typenum scheme --- examples/axis.rs | 23 +++++++++++++++++++ src/dimension.rs | 57 +++++++++++++++++++++++++++++++++++++++++++++++- src/lib.rs | 46 ++++++++++++++++++++++++++------------ 3 files changed, 111 insertions(+), 15 deletions(-) create mode 100644 examples/axis.rs diff --git a/examples/axis.rs b/examples/axis.rs new file mode 100644 index 000000000..6a61bc4a2 --- /dev/null +++ b/examples/axis.rs @@ -0,0 +1,23 @@ +extern crate ndarray; + +use ndarray::{ + OwnedArray, + Ax, + Ax0, Ax1, Ax2, + AxisForDimension, + Ix, +}; +use ndarray::{Ax as Axis, Ax0 as Axis0, Ax1 as Axis1}; + +fn main() { + let mut a = OwnedArray::::linspace(0., 24., 25).into_shape((5, 5)).unwrap(); + let x: AxisForDimension<(Ix, Ix)> = Ax(2).into(); + let x: AxisForDimension<(Ix, Ix)> = Ax0.into(); + let x: AxisForDimension<(Ix, Ix)> = Ax1.into(); + //let x: AxisForDimension<(Ix, Ix)> = Ax2.into(); + println!("{:?}", x); + println!("{:?}", a.subview(Axis0, 0)); + println!("{:?}", a.subview(Axis0, 1)); + println!("{:?}", a.subview(Axis1, 1)); + println!("{:?}", a.subview(Axis(0), 1)); +} diff --git a/src/dimension.rs b/src/dimension.rs index d656566dc..44347c4e9 100644 --- a/src/dimension.rs +++ b/src/dimension.rs @@ -1,4 +1,5 @@ use std::slice; +use std::marker::PhantomData; use super::{Si, Ix, Ixs}; use super::zipsl; @@ -723,7 +724,7 @@ mod test { } } -pub trait Axis { +pub trait Axis : Copy { fn axis(&self) -> usize; } @@ -738,6 +739,36 @@ pub struct Axes2(usize); #[derive(Copy, Clone, Debug)] pub struct Axes3(usize); +#[derive(Debug)] +pub struct AxisForDimension { + axis: usize, + dim: PhantomData, +} + +impl AxisForDimension { + pub fn axis(&self) -> usize { self.axis } +} + +impl Copy for AxisForDimension { } +impl Clone for AxisForDimension { + fn clone(&self) -> Self { *self } +} + +impl PartialEq for AxisForDimension { + fn eq(&self, rhs: &Self) -> bool { + self.axis == rhs.axis + } +} + +impl From for AxisForDimension { + fn from(x: Ax) -> Self { + AxisForDimension { + axis: x.0, + dim: PhantomData, + } + } +} + impl Axis for VoidAxis { fn axis(&self) -> usize { match *self { } } } impl Axis for Axes0 { fn axis(&self) -> usize { self.0 } } impl Axis for Axes1 { fn axis(&self) -> usize { self.0 } } @@ -753,11 +784,15 @@ pub struct Ax2; #[derive(Copy, Clone, Debug)] pub struct Ax3; +#[derive(Copy, Clone, Debug)] +pub struct Ax(pub usize); + impl Axis for Ax0 { fn axis(&self) -> usize { 0 } } impl Axis for Ax1 { fn axis(&self) -> usize { 1 } } impl Axis for Ax2 { fn axis(&self) -> usize { 2 } } impl Axis for Ax3 { fn axis(&self) -> usize { 3 } } impl Axis for usize { fn axis(&self) -> usize { *self } } +impl Axis for Ax { fn axis(&self) -> usize { self.0 } } impl Into for usize { fn into(self) -> VoidAxis { @@ -769,11 +804,31 @@ impl Into for Ax0 { fn into(self) -> usize { 0 } } impl Into for Ax1 { fn into(self) -> usize { 1 } } impl Into for Ax2 { fn into(self) -> usize { 2 } } impl Into for Ax3 { fn into(self) -> usize { 3 } } +impl Into for Ax { fn into(self) -> usize { self.0 } } impl Into for Ax0 { fn into(self) -> Axes0 { Axes0(0) } } +macro_rules! ax_for_dim { + ($dim:ty, $($ax:ident),*) => { + $( + impl From<$ax> for AxisForDimension<$dim> { + fn from(x: $ax) -> Self { + AxisForDimension { + axis: x.into(), + dim: PhantomData, + } + } + } + )* + } +} +ax_for_dim!{Ix, Ax0} +ax_for_dim!{(Ix, Ix), Ax0, Ax1} +ax_for_dim!{(Ix, Ix, Ix), Ax0, Ax1, Ax2} +ax_for_dim!{(Ix, Ix, Ix, Ix), Ax0, Ax1, Ax2, Ax3} + impl Into for usize { fn into(self) -> Axes0 { assert!(self == 0); diff --git a/src/lib.rs b/src/lib.rs index f840aaa5f..0074244fc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -92,6 +92,7 @@ use itertools::free::enumerate; pub use dimension::{ Dimension, RemoveAxis, + Ax0, Ax1, Ax2, Ax3, Ax, }; pub use dimension::NdIndex; @@ -112,6 +113,11 @@ pub use iterators::{ pub use linalg::LinalgScalar; +pub use dimension::{ + AxisForDimension, + Axis, +}; + mod arraytraits; #[cfg(feature = "serde")] mod arrayserialize; @@ -1278,9 +1284,10 @@ impl ArrayBase where S: Data, D: Dimension /// a.subview(1, 1) == arr1(&[2., 4., 6.]) /// ); /// ``` - pub fn subview(&self, axis: usize, index: Ix) + pub fn subview(&self, axis: Ax, index: Ix) -> ArrayView::Smaller> - where D: RemoveAxis + where D: RemoveAxis, + Ax: Into>, { self.view().into_subview(axis, index) } @@ -1303,10 +1310,11 @@ impl ArrayBase where S: Data, D: Dimension /// [3., 14.]]) /// ); /// ``` - pub fn subview_mut(&mut self, axis: usize, index: Ix) + pub fn subview_mut(&mut self, axis: Ax, index: Ix) -> ArrayViewMut where S: DataMut, D: RemoveAxis, + Ax: Into>, { self.view_mut().into_subview(axis, index) } @@ -1315,7 +1323,10 @@ impl ArrayBase where S: Data, D: Dimension /// and select the subview of `index` along that axis. /// /// **Panics** if `index` is past the length of the axis. - pub fn isubview(&mut self, axis: usize, index: Ix) { + pub fn isubview(&mut self, axis: Ax, index: Ix) + where Ax: Into>, + { + let axis = axis.into().axis(); dimension::do_sub(&mut self.dim, &mut self.ptr, &self.strides, axis, index) } @@ -1323,11 +1334,14 @@ impl ArrayBase where S: Data, D: Dimension /// with that axis removed. /// /// See [`.subview()`](#method.subview) and [*Subviews*](#subviews) for full documentation. - pub fn into_subview(mut self, axis: usize, index: Ix) + pub fn into_subview(mut self, axis: Ax, index: Ix) -> ArrayBase::Smaller> - where D: RemoveAxis + where D: RemoveAxis, + Ax: Into>, { + let axis = axis.into(); self.isubview(axis, index); + let axis = axis.axis(); // don't use reshape -- we always know it will fit the size, // and we can use remove_axis on the strides as well ArrayBase { @@ -2242,11 +2256,13 @@ impl ArrayBase /// ``` /// /// **Panics** if `axis` is out of bounds. - pub fn sum(&self, axis: usize) -> OwnedArray::Smaller> + pub fn sum(&self, axis: Ax) -> OwnedArray::Smaller> where A: Clone + Add, D: RemoveAxis, + Ax: Into>, { - let n = self.shape()[axis]; + let axis = axis.into(); + let n = self.shape()[axis.axis()]; let mut res = self.subview(axis, 0).to_owned(); for i in 1..n { let view = self.subview(axis, i); @@ -2296,11 +2312,13 @@ impl ArrayBase /// /// /// **Panics** if `axis` is out of bounds. - pub fn mean(&self, axis: usize) -> OwnedArray::Smaller> + pub fn mean(&self, axis: Ax) -> OwnedArray::Smaller> where A: LinalgScalar, D: RemoveAxis, + Ax: Into>, { - let n = self.shape()[axis]; + let axis = axis.into(); + let n = self.shape()[axis.axis()]; let mut sum = self.sum(axis); let one = libnum::one::(); let mut cnt = one; @@ -2413,7 +2431,7 @@ impl ArrayBase /// **Panics** if `index` is out of bounds. pub fn row(&self, index: Ix) -> ArrayView { - self.subview(0, index) + self.subview(Ax0, index) } /// Return a mutable array view of row `index`. @@ -2422,7 +2440,7 @@ impl ArrayBase pub fn row_mut(&mut self, index: Ix) -> ArrayViewMut where S: DataMut { - self.subview_mut(0, index) + self.subview_mut(Ax0, index) } /// Return an array view of column `index`. @@ -2430,7 +2448,7 @@ impl ArrayBase /// **Panics** if `index` is out of bounds. pub fn column(&self, index: Ix) -> ArrayView { - self.subview(1, index) + self.subview(Ax1, index) } /// Return a mutable array view of column `index`. @@ -2439,7 +2457,7 @@ impl ArrayBase pub fn column_mut(&mut self, index: Ix) -> ArrayViewMut where S: DataMut { - self.subview_mut(1, index) + self.subview_mut(Ax1, index) } /// Perform matrix multiplication of rectangular arrays `self` and `rhs`. From f12ee302151a0fd9bf2b78c7b46893d455b218db Mon Sep 17 00:00:00 2001 From: bluss Date: Sun, 28 Feb 2016 22:34:52 +0100 Subject: [PATCH 3/4] Simple axis typenums --- examples/axis.rs | 13 ++--- src/dimension.rs | 118 ++++++--------------------------------------- src/lib.rs | 11 ++--- tests/dimension.rs | 6 +-- 4 files changed, 27 insertions(+), 121 deletions(-) diff --git a/examples/axis.rs b/examples/axis.rs index 6a61bc4a2..2982830d8 100644 --- a/examples/axis.rs +++ b/examples/axis.rs @@ -2,22 +2,15 @@ extern crate ndarray; use ndarray::{ OwnedArray, - Ax, - Ax0, Ax1, Ax2, - AxisForDimension, - Ix, + Axis, + Axis0, Axis1, Axis2, }; -use ndarray::{Ax as Axis, Ax0 as Axis0, Ax1 as Axis1}; fn main() { let mut a = OwnedArray::::linspace(0., 24., 25).into_shape((5, 5)).unwrap(); - let x: AxisForDimension<(Ix, Ix)> = Ax(2).into(); - let x: AxisForDimension<(Ix, Ix)> = Ax0.into(); - let x: AxisForDimension<(Ix, Ix)> = Ax1.into(); - //let x: AxisForDimension<(Ix, Ix)> = Ax2.into(); - println!("{:?}", x); println!("{:?}", a.subview(Axis0, 0)); println!("{:?}", a.subview(Axis0, 1)); println!("{:?}", a.subview(Axis1, 1)); + //println!("{:?}", a.subview(Axis2, 1)); println!("{:?}", a.subview(Axis(0), 1)); } diff --git a/src/dimension.rs b/src/dimension.rs index 44347c4e9..c1e7285d8 100644 --- a/src/dimension.rs +++ b/src/dimension.rs @@ -142,7 +142,6 @@ pub unsafe trait Dimension : Clone + Eq { /// The easiest way to create a `&SliceArg` is using the macro /// [`s![]`](macro.s!.html). type SliceArg: ?Sized + AsRef<[Si]>; - type AxisArg: Copy + Axis; #[doc(hidden)] fn ndim(&self) -> usize; #[doc(hidden)] @@ -357,7 +356,6 @@ pub fn do_sub(dims: &mut D, ptr: &mut *mut A, strides: &D, unsafe impl Dimension for () { type SliceArg = [Si; 0]; - type AxisArg = VoidAxis; // empty product is 1 -> size is 1 #[inline] fn ndim(&self) -> usize { 0 } @@ -367,7 +365,6 @@ unsafe impl Dimension for () { unsafe impl Dimension for Ix { type SliceArg = [Si; 1]; - type AxisArg = Axes0; #[inline] fn ndim(&self) -> usize { 1 } #[inline] @@ -415,7 +412,6 @@ unsafe impl Dimension for Ix { unsafe impl Dimension for (Ix, Ix) { type SliceArg = [Si; 2]; - type AxisArg = Axes1; #[inline] fn ndim(&self) -> usize { 2 } @@ -484,7 +480,6 @@ unsafe impl Dimension for (Ix, Ix) { unsafe impl Dimension for (Ix, Ix, Ix) { type SliceArg = [Si; 3]; - type AxisArg = Axes2; #[inline] fn ndim(&self) -> usize { 3 } #[inline] @@ -521,7 +516,6 @@ macro_rules! large_dim { ($n:expr, $($ix:ident),+) => ( unsafe impl Dimension for ($($ix),+) { type SliceArg = [Si; $n]; - type AxisArg = usize; #[inline] fn ndim(&self) -> usize { $n } } @@ -543,7 +537,6 @@ large_dim!(12, Ix, Ix, Ix, Ix, Ix, Ix, Ix, Ix, Ix, Ix, Ix, Ix); unsafe impl Dimension for Vec { type SliceArg = [Si]; - type AxisArg = usize; fn ndim(&self) -> usize { self.len() } fn slice(&self) -> &[Ix] { self } fn slice_mut(&mut self) -> &mut [Ix] { self } @@ -724,21 +717,6 @@ mod test { } } -pub trait Axis : Copy { - fn axis(&self) -> usize; -} - -#[derive(Copy, Clone, Debug)] -pub enum VoidAxis { } -#[derive(Copy, Clone, Debug)] -pub struct Axes0(usize); -#[derive(Copy, Clone, Debug)] -pub struct Axes1(usize); -#[derive(Copy, Clone, Debug)] -pub struct Axes2(usize); -#[derive(Copy, Clone, Debug)] -pub struct Axes3(usize); - #[derive(Debug)] pub struct AxisForDimension { axis: usize, @@ -760,8 +738,8 @@ impl PartialEq for AxisForDimension { } } -impl From for AxisForDimension { - fn from(x: Ax) -> Self { +impl From for AxisForDimension { + fn from(x: Axis) -> Self { AxisForDimension { axis: x.0, dim: PhantomData, @@ -769,46 +747,23 @@ impl From for AxisForDimension { } } -impl Axis for VoidAxis { fn axis(&self) -> usize { match *self { } } } -impl Axis for Axes0 { fn axis(&self) -> usize { self.0 } } -impl Axis for Axes1 { fn axis(&self) -> usize { self.0 } } -impl Axis for Axes2 { fn axis(&self) -> usize { self.0 } } -impl Axis for Axes3 { fn axis(&self) -> usize { self.0 } } - #[derive(Copy, Clone, Debug)] -pub struct Ax0; +pub struct Axis0; #[derive(Copy, Clone, Debug)] -pub struct Ax1; +pub struct Axis1; #[derive(Copy, Clone, Debug)] -pub struct Ax2; +pub struct Axis2; #[derive(Copy, Clone, Debug)] -pub struct Ax3; +pub struct Axis3; #[derive(Copy, Clone, Debug)] -pub struct Ax(pub usize); +pub struct Axis(pub usize); -impl Axis for Ax0 { fn axis(&self) -> usize { 0 } } -impl Axis for Ax1 { fn axis(&self) -> usize { 1 } } -impl Axis for Ax2 { fn axis(&self) -> usize { 2 } } -impl Axis for Ax3 { fn axis(&self) -> usize { 3 } } -impl Axis for usize { fn axis(&self) -> usize { *self } } -impl Axis for Ax { fn axis(&self) -> usize { self.0 } } - -impl Into for usize { - fn into(self) -> VoidAxis { - panic!("VoidAxis: zero-dimensional arrays have no axes") - } -} - -impl Into for Ax0 { fn into(self) -> usize { 0 } } -impl Into for Ax1 { fn into(self) -> usize { 1 } } -impl Into for Ax2 { fn into(self) -> usize { 2 } } -impl Into for Ax3 { fn into(self) -> usize { 3 } } -impl Into for Ax { fn into(self) -> usize { self.0 } } - -impl Into for Ax0 { - fn into(self) -> Axes0 { Axes0(0) } -} +impl Into for Axis0 { #[inline] fn into(self) -> usize { 0 } } +impl Into for Axis1 { #[inline] fn into(self) -> usize { 1 } } +impl Into for Axis2 { #[inline] fn into(self) -> usize { 2 } } +impl Into for Axis3 { #[inline] fn into(self) -> usize { 3 } } +impl Into for Axis { #[inline] fn into(self) -> usize { self.0 } } macro_rules! ax_for_dim { ($dim:ty, $($ax:ident),*) => { @@ -824,48 +779,7 @@ macro_rules! ax_for_dim { )* } } -ax_for_dim!{Ix, Ax0} -ax_for_dim!{(Ix, Ix), Ax0, Ax1} -ax_for_dim!{(Ix, Ix, Ix), Ax0, Ax1, Ax2} -ax_for_dim!{(Ix, Ix, Ix, Ix), Ax0, Ax1, Ax2, Ax3} - -impl Into for usize { - fn into(self) -> Axes0 { - assert!(self == 0); - Axes0(self) - } -} - -impl Into for Ax0 { - fn into(self) -> Axes1 { Axes1(self.axis()) } -} - -impl Into for Ax1 { - fn into(self) -> Axes1 { Axes1(self.axis()) } -} - -impl Into for usize { - fn into(self) -> Axes1 { - assert!(self <= 1); - Axes1(self) - } -} - -impl Into for Ax0 { - fn into(self) -> Axes2 { Axes2(self.axis()) } -} - -impl Into for Ax1 { - fn into(self) -> Axes2 { Axes2(self.axis()) } -} - -impl Into for Ax2 { - fn into(self) -> Axes2 { Axes2(self.axis()) } -} - -impl Into for usize { - fn into(self) -> Axes2 { - assert!(self <= 2); - Axes2(self) - } -} +ax_for_dim!{Ix, Axis0} +ax_for_dim!{(Ix, Ix), Axis0, Axis1} +ax_for_dim!{(Ix, Ix, Ix), Axis0, Axis1, Axis2} +ax_for_dim!{(Ix, Ix, Ix, Ix), Axis0, Axis1, Axis2, Axis3} diff --git a/src/lib.rs b/src/lib.rs index 0074244fc..73ed7339f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -92,7 +92,7 @@ use itertools::free::enumerate; pub use dimension::{ Dimension, RemoveAxis, - Ax0, Ax1, Ax2, Ax3, Ax, + Axis0, Axis1, Axis2, Axis3, Axis, }; pub use dimension::NdIndex; @@ -115,7 +115,6 @@ pub use linalg::LinalgScalar; pub use dimension::{ AxisForDimension, - Axis, }; mod arraytraits; @@ -2431,7 +2430,7 @@ impl ArrayBase /// **Panics** if `index` is out of bounds. pub fn row(&self, index: Ix) -> ArrayView { - self.subview(Ax0, index) + self.subview(Axis0, index) } /// Return a mutable array view of row `index`. @@ -2440,7 +2439,7 @@ impl ArrayBase pub fn row_mut(&mut self, index: Ix) -> ArrayViewMut where S: DataMut { - self.subview_mut(Ax0, index) + self.subview_mut(Axis0, index) } /// Return an array view of column `index`. @@ -2448,7 +2447,7 @@ impl ArrayBase /// **Panics** if `index` is out of bounds. pub fn column(&self, index: Ix) -> ArrayView { - self.subview(Ax1, index) + self.subview(Axis1, index) } /// Return a mutable array view of column `index`. @@ -2457,7 +2456,7 @@ impl ArrayBase pub fn column_mut(&mut self, index: Ix) -> ArrayViewMut where S: DataMut { - self.subview_mut(Ax1, index) + self.subview_mut(Axis1, index) } /// Perform matrix multiplication of rectangular arrays `self` and `rhs`. diff --git a/tests/dimension.rs b/tests/dimension.rs index a9625d597..84b825eb8 100644 --- a/tests/dimension.rs +++ b/tests/dimension.rs @@ -5,7 +5,7 @@ use ndarray::{ OwnedArray, RemoveAxis, arr2, - Ax0, Ax1, Ax2, + Axis0, Axis1, Axis2, }; #[test] @@ -19,10 +19,10 @@ fn remove_axis() assert_eq!(vec![4, 5, 6].remove_axis(1), vec![4, 6]); let a = RcArray::::zeros((4,5)); - a.subview(Ax1, 0); + a.subview(Axis1, 0); let a = RcArray::::zeros(vec![4,5,6]); - let _b = a.subview(Ax1, 0).reshape((4, 6)).reshape(vec![2, 3, 4]); + let _b = a.subview(Axis1, 0).reshape((4, 6)).reshape(vec![2, 3, 4]); } From a47fdc58b92170a81611a4440121a953d5022e9e Mon Sep 17 00:00:00 2001 From: bluss Date: Sun, 28 Feb 2016 22:57:21 +0100 Subject: [PATCH 4/4] Add usize -> axis for dimension conversion --- src/dimension.rs | 10 ++++++++++ src/lib.rs | 7 ++++--- tests/array.rs | 13 +++++++------ tests/dimension.rs | 3 ++- 4 files changed, 23 insertions(+), 10 deletions(-) diff --git a/src/dimension.rs b/src/dimension.rs index c1e7285d8..5eed27c3a 100644 --- a/src/dimension.rs +++ b/src/dimension.rs @@ -747,6 +747,16 @@ impl From for AxisForDimension { } } +#[cfg_attr(has_deprecated, deprecated(note="Usize arguments for `axis` are deprecated. Use `Axis` instead."))] +impl From for AxisForDimension { + fn from(x: usize) -> Self { + AxisForDimension { + axis: x, + dim: PhantomData, + } + } +} + #[derive(Copy, Clone, Debug)] pub struct Axis0; #[derive(Copy, Clone, Debug)] diff --git a/src/lib.rs b/src/lib.rs index 73ed7339f..bd0691dce 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1431,10 +1431,11 @@ impl ArrayBase where S: Data, D: Dimension /// See [*Subviews*](#subviews) for full documentation. /// /// **Panics** if `axis` is out of bounds. - pub fn axis_iter(&self, axis: usize) -> OuterIter - where D: RemoveAxis + pub fn axis_iter(&self, axis: Ax) -> OuterIter + where D: RemoveAxis, + Ax: Into>, { - iterators::new_axis_iter(self.view(), axis) + iterators::new_axis_iter(self.view(), axis.into().axis()) } diff --git a/tests/array.rs b/tests/array.rs index 0274dd428..0e02cf7ad 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -14,6 +14,7 @@ use ndarray::{arr0, arr1, arr2, aview_mut1, }; use ndarray::Indexes; +use ndarray::{Axis, Axis0, Axis1, Axis2}; #[test] fn test_matmul_rcarray() @@ -284,12 +285,12 @@ fn assign() fn sum_mean() { let a = arr2(&[[1., 2.], [3., 4.]]); - assert_eq!(a.sum(0), arr1(&[4., 6.])); - assert_eq!(a.sum(1), arr1(&[3., 7.])); - assert_eq!(a.mean(0), arr1(&[2., 3.])); - assert_eq!(a.mean(1), arr1(&[1.5, 3.5])); - assert_eq!(a.sum(1).sum(0), arr0(10.)); - assert_eq!(a.view().mean(1), aview1(&[1.5, 3.5])); + assert_eq!(a.sum(Axis0), arr1(&[4., 6.])); + assert_eq!(a.sum(Axis1), arr1(&[3., 7.])); + assert_eq!(a.mean(Axis0), arr1(&[2., 3.])); + assert_eq!(a.mean(Axis1), arr1(&[1.5, 3.5])); + assert_eq!(a.sum(Axis1).sum(Axis0), arr0(10.)); + assert_eq!(a.view().mean(Axis1), aview1(&[1.5, 3.5])); assert_eq!(a.scalar_sum(), 10.); } diff --git a/tests/dimension.rs b/tests/dimension.rs index 84b825eb8..b13e0236b 100644 --- a/tests/dimension.rs +++ b/tests/dimension.rs @@ -5,6 +5,7 @@ use ndarray::{ OwnedArray, RemoveAxis, arr2, + Axis, Axis0, Axis1, Axis2, }; @@ -22,7 +23,7 @@ fn remove_axis() a.subview(Axis1, 0); let a = RcArray::::zeros(vec![4,5,6]); - let _b = a.subview(Axis1, 0).reshape((4, 6)).reshape(vec![2, 3, 4]); + let _b = a.into_subview(Axis(1), 0).reshape((4, 6)).reshape(vec![2, 3, 4]); }