Skip to content

Not going to be merged: Statically checked axis numbers #96

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions examples/axis.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
extern crate ndarray;

use ndarray::{
OwnedArray,
Axis,
Axis0, Axis1, Axis2,
};

fn main() {
let mut a = OwnedArray::<f32, _>::linspace(0., 24., 25).into_shape((5, 5)).unwrap();
println!("{:?}", a.subview(Axis0, 0));
println!("{:?}", a.subview(Axis0, 1));
println!("{:?}", a.subview(Axis1, 1));
//println!("{:?}", a.subview(Axis2, 1));
println!("{:?}", a.subview(Axis(0), 1));
}
78 changes: 78 additions & 0 deletions src/dimension.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::slice;
use std::marker::PhantomData;

use super::{Si, Ix, Ixs};
use super::zipsl;
Expand Down Expand Up @@ -715,3 +716,80 @@ mod test {
assert!(super::dim_stride_overlap(&dim, &strides));
}
}

#[derive(Debug)]
pub struct AxisForDimension<D> {
axis: usize,
dim: PhantomData<D>,
}

impl<D> AxisForDimension<D> {
pub fn axis(&self) -> usize { self.axis }
}

impl<D> Copy for AxisForDimension<D> { }
impl<D> Clone for AxisForDimension<D> {
fn clone(&self) -> Self { *self }
}

impl<D> PartialEq for AxisForDimension<D> {
fn eq(&self, rhs: &Self) -> bool {
self.axis == rhs.axis
}
}

impl<D> From<Axis> for AxisForDimension<D> {
fn from(x: Axis) -> Self {
AxisForDimension {
axis: x.0,
dim: PhantomData,
}
}
}

#[cfg_attr(has_deprecated, deprecated(note="Usize arguments for `axis` are deprecated. Use `Axis` instead."))]
impl<D> From<usize> for AxisForDimension<D> {
fn from(x: usize) -> Self {
AxisForDimension {
axis: x,
dim: PhantomData,
}
}
}

#[derive(Copy, Clone, Debug)]
pub struct Axis0;
#[derive(Copy, Clone, Debug)]
pub struct Axis1;
#[derive(Copy, Clone, Debug)]
pub struct Axis2;
#[derive(Copy, Clone, Debug)]
pub struct Axis3;

#[derive(Copy, Clone, Debug)]
pub struct Axis(pub usize);

impl Into<usize> for Axis0 { #[inline] fn into(self) -> usize { 0 } }
impl Into<usize> for Axis1 { #[inline] fn into(self) -> usize { 1 } }
impl Into<usize> for Axis2 { #[inline] fn into(self) -> usize { 2 } }
impl Into<usize> for Axis3 { #[inline] fn into(self) -> usize { 3 } }
impl Into<usize> for Axis { #[inline] fn into(self) -> usize { self.0 } }

macro_rules! ax_for_dim {
($dim:ty, $($ax:ident),*) => {
$(
impl From<$ax> for AxisForDimension<$dim> {
fn from(x: $ax) -> Self {
AxisForDimension {
axis: x.into(),
dim: PhantomData,
}
}
}
)*
}
}
ax_for_dim!{Ix, Axis0}
ax_for_dim!{(Ix, Ix), Axis0, Axis1}
ax_for_dim!{(Ix, Ix, Ix), Axis0, Axis1, Axis2}
ax_for_dim!{(Ix, Ix, Ix, Ix), Axis0, Axis1, Axis2, Axis3}
52 changes: 35 additions & 17 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ use itertools::free::enumerate;
pub use dimension::{
Dimension,
RemoveAxis,
Axis0, Axis1, Axis2, Axis3, Axis,
};

pub use dimension::NdIndex;
Expand All @@ -112,6 +113,10 @@ pub use iterators::{

pub use linalg::LinalgScalar;

pub use dimension::{
AxisForDimension,
};

mod arraytraits;
#[cfg(feature = "serde")]
mod arrayserialize;
Expand Down Expand Up @@ -1278,9 +1283,10 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
/// a.subview(1, 1) == arr1(&[2., 4., 6.])
/// );
/// ```
pub fn subview(&self, axis: usize, index: Ix)
pub fn subview<Ax>(&self, axis: Ax, index: Ix)
-> ArrayView<A, <D as RemoveAxis>::Smaller>
where D: RemoveAxis
where D: RemoveAxis,
Ax: Into<AxisForDimension<D>>,
{
self.view().into_subview(axis, index)
}
Expand All @@ -1303,10 +1309,11 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
/// [3., 14.]])
/// );
/// ```
pub fn subview_mut(&mut self, axis: usize, index: Ix)
pub fn subview_mut<Ax>(&mut self, axis: Ax, index: Ix)
-> ArrayViewMut<A, D::Smaller>
where S: DataMut,
D: RemoveAxis,
Ax: Into<AxisForDimension<D>>,
{
self.view_mut().into_subview(axis, index)
}
Expand All @@ -1315,19 +1322,25 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
/// and select the subview of `index` along that axis.
///
/// **Panics** if `index` is past the length of the axis.
pub fn isubview(&mut self, axis: usize, index: Ix) {
pub fn isubview<Ax>(&mut self, axis: Ax, index: Ix)
where Ax: Into<AxisForDimension<D>>,
{
let axis = axis.into().axis();
dimension::do_sub(&mut self.dim, &mut self.ptr, &self.strides, axis, index)
}

/// Along `axis`, select the subview `index` and return `self`
/// with that axis removed.
///
/// See [`.subview()`](#method.subview) and [*Subviews*](#subviews) for full documentation.
pub fn into_subview(mut self, axis: usize, index: Ix)
pub fn into_subview<Ax>(mut self, axis: Ax, index: Ix)
-> ArrayBase<S, <D as RemoveAxis>::Smaller>
where D: RemoveAxis
where D: RemoveAxis,
Ax: Into<AxisForDimension<D>>,
{
let axis = axis.into();
self.isubview(axis, index);
let axis = axis.axis();
// don't use reshape -- we always know it will fit the size,
// and we can use remove_axis on the strides as well
ArrayBase {
Expand Down Expand Up @@ -1418,10 +1431,11 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
/// See [*Subviews*](#subviews) for full documentation.
///
/// **Panics** if `axis` is out of bounds.
pub fn axis_iter(&self, axis: usize) -> OuterIter<A, D::Smaller>
where D: RemoveAxis
pub fn axis_iter<Ax>(&self, axis: Ax) -> OuterIter<A, D::Smaller>
where D: RemoveAxis,
Ax: Into<AxisForDimension<D>>,
{
iterators::new_axis_iter(self.view(), axis)
iterators::new_axis_iter(self.view(), axis.into().axis())
}


Expand Down Expand Up @@ -2242,11 +2256,13 @@ impl<A, S, D> ArrayBase<S, D>
/// ```
///
/// **Panics** if `axis` is out of bounds.
pub fn sum(&self, axis: usize) -> OwnedArray<A, <D as RemoveAxis>::Smaller>
pub fn sum<Ax>(&self, axis: Ax) -> OwnedArray<A, <D as RemoveAxis>::Smaller>
where A: Clone + Add<Output=A>,
D: RemoveAxis,
Ax: Into<AxisForDimension<D>>,
{
let n = self.shape()[axis];
let axis = axis.into();
let n = self.shape()[axis.axis()];
let mut res = self.subview(axis, 0).to_owned();
for i in 1..n {
let view = self.subview(axis, i);
Expand Down Expand Up @@ -2296,11 +2312,13 @@ impl<A, S, D> ArrayBase<S, D>
///
///
/// **Panics** if `axis` is out of bounds.
pub fn mean(&self, axis: usize) -> OwnedArray<A, <D as RemoveAxis>::Smaller>
pub fn mean<Ax>(&self, axis: Ax) -> OwnedArray<A, <D as RemoveAxis>::Smaller>
where A: LinalgScalar,
D: RemoveAxis,
Ax: Into<AxisForDimension<D>>,
{
let n = self.shape()[axis];
let axis = axis.into();
let n = self.shape()[axis.axis()];
let mut sum = self.sum(axis);
let one = libnum::one::<A>();
let mut cnt = one;
Expand Down Expand Up @@ -2413,7 +2431,7 @@ impl<A, S> ArrayBase<S, (Ix, Ix)>
/// **Panics** if `index` is out of bounds.
pub fn row(&self, index: Ix) -> ArrayView<A, Ix>
{
self.subview(0, index)
self.subview(Axis0, index)
}

/// Return a mutable array view of row `index`.
Expand All @@ -2422,15 +2440,15 @@ impl<A, S> ArrayBase<S, (Ix, Ix)>
pub fn row_mut(&mut self, index: Ix) -> ArrayViewMut<A, Ix>
where S: DataMut
{
self.subview_mut(0, index)
self.subview_mut(Axis0, index)
}

/// Return an array view of column `index`.
///
/// **Panics** if `index` is out of bounds.
pub fn column(&self, index: Ix) -> ArrayView<A, Ix>
{
self.subview(1, index)
self.subview(Axis1, index)
}

/// Return a mutable array view of column `index`.
Expand All @@ -2439,7 +2457,7 @@ impl<A, S> ArrayBase<S, (Ix, Ix)>
pub fn column_mut(&mut self, index: Ix) -> ArrayViewMut<A, Ix>
where S: DataMut
{
self.subview_mut(1, index)
self.subview_mut(Axis1, index)
}

/// Perform matrix multiplication of rectangular arrays `self` and `rhs`.
Expand Down
13 changes: 7 additions & 6 deletions tests/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use ndarray::{arr0, arr1, arr2,
aview_mut1,
};
use ndarray::Indexes;
use ndarray::{Axis, Axis0, Axis1, Axis2};

#[test]
fn test_matmul_rcarray()
Expand Down Expand Up @@ -284,12 +285,12 @@ fn assign()
fn sum_mean()
{
let a = arr2(&[[1., 2.], [3., 4.]]);
assert_eq!(a.sum(0), arr1(&[4., 6.]));
assert_eq!(a.sum(1), arr1(&[3., 7.]));
assert_eq!(a.mean(0), arr1(&[2., 3.]));
assert_eq!(a.mean(1), arr1(&[1.5, 3.5]));
assert_eq!(a.sum(1).sum(0), arr0(10.));
assert_eq!(a.view().mean(1), aview1(&[1.5, 3.5]));
assert_eq!(a.sum(Axis0), arr1(&[4., 6.]));
assert_eq!(a.sum(Axis1), arr1(&[3., 7.]));
assert_eq!(a.mean(Axis0), arr1(&[2., 3.]));
assert_eq!(a.mean(Axis1), arr1(&[1.5, 3.5]));
assert_eq!(a.sum(Axis1).sum(Axis0), arr0(10.));
assert_eq!(a.view().mean(Axis1), aview1(&[1.5, 3.5]));
assert_eq!(a.scalar_sum(), 10.);
}

Expand Down
7 changes: 6 additions & 1 deletion tests/dimension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use ndarray::{
OwnedArray,
RemoveAxis,
arr2,
Axis,
Axis0, Axis1, Axis2,
};

#[test]
Expand All @@ -17,8 +19,11 @@ fn remove_axis()
assert_eq!(vec![1,2].remove_axis(0), vec![2]);
assert_eq!(vec![4, 5, 6].remove_axis(1), vec![4, 6]);

let a = RcArray::<f32, _>::zeros((4,5));
a.subview(Axis1, 0);

let a = RcArray::<f32, _>::zeros(vec![4,5,6]);
let _b = a.into_subview(1, 0).reshape((4, 6)).reshape(vec![2, 3, 4]);
let _b = a.into_subview(Axis(1), 0).reshape((4, 6)).reshape(vec![2, 3, 4]);

}

Expand Down