Skip to content

Commit

Permalink
Add support for inserting new axes while slicing
Browse files Browse the repository at this point in the history
  • Loading branch information
jturner314 committed Dec 10, 2018
1 parent 9a99561 commit 1e978ca
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 67 deletions.
23 changes: 14 additions & 9 deletions src/dimension/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,11 @@ pub fn slices_intersect<D: Dimension>(
indices2: &impl CanSlice<D>,
) -> bool {
debug_assert_eq!(indices1.in_ndim(), indices2.in_ndim());
for (&axis_len, &si1, &si2) in izip!(dim.slice(), indices1.as_ref(), indices2.as_ref()) {
for (&axis_len, &si1, &si2) in izip!(
dim.slice(),
indices1.as_ref().iter().filter(|si| !si.is_new_axis()),
indices2.as_ref().iter().filter(|si| !si.is_new_axis()),
) {
// The slices do not intersect iff any pair of `AxisSliceInfo` does not intersect.
match (si1, si2) {
(
Expand Down Expand Up @@ -582,6 +586,7 @@ pub fn slices_intersect<D: Dimension>(
return false;
}
}
(AxisSliceInfo::NewAxis, _) | (_, AxisSliceInfo::NewAxis) => unreachable!(),
}
}
true
Expand Down Expand Up @@ -626,7 +631,7 @@ mod test {
use num_integer::gcd;
use quickcheck::{quickcheck, TestResult};
use slice::Slice;
use {Dim, Dimension, Ix0, Ix1, Ix2, Ix3, IxDyn};
use {Dim, Dimension, Ix0, Ix1, Ix2, Ix3, IxDyn, NewAxis};

#[test]
fn slice_indexing_uncommon_strides() {
Expand Down Expand Up @@ -882,17 +887,17 @@ mod test {

#[test]
fn slices_intersect_true() {
assert!(slices_intersect(&Dim([4, 5]), s![.., ..], s![.., ..]));
assert!(slices_intersect(&Dim([4, 5]), s![0, ..], s![0, ..]));
assert!(slices_intersect(&Dim([4, 5]), s![..;2, ..], s![..;3, ..]));
assert!(slices_intersect(&Dim([4, 5]), s![.., ..;2], s![.., 1..;3]));
assert!(slices_intersect(&Dim([4, 5]), s![NewAxis, .., NewAxis, ..], s![.., NewAxis, .., NewAxis]));
assert!(slices_intersect(&Dim([4, 5]), s![NewAxis, 0, ..], s![0, ..]));
assert!(slices_intersect(&Dim([4, 5]), s![..;2, ..], s![..;3, NewAxis, ..]));
assert!(slices_intersect(&Dim([4, 5]), s![.., ..;2], s![.., 1..;3, NewAxis]));
assert!(slices_intersect(&Dim([4, 10]), s![.., ..;9], s![.., 3..;6]));
}

#[test]
fn slices_intersect_false() {
assert!(!slices_intersect(&Dim([4, 5]), s![..;2, ..], s![1..;2, ..]));
assert!(!slices_intersect(&Dim([4, 5]), s![..;2, ..], s![1..;3, ..]));
assert!(!slices_intersect(&Dim([4, 5]), s![.., ..;9], s![.., 3..;6]));
assert!(!slices_intersect(&Dim([4, 5]), s![..;2, ..], s![NewAxis, 1..;2, ..]));
assert!(!slices_intersect(&Dim([4, 5]), s![..;2, NewAxis, ..], s![1..;3, ..]));
assert!(!slices_intersect(&Dim([4, 5]), s![.., ..;9], s![.., 3..;6, NewAxis]));
}
}
2 changes: 1 addition & 1 deletion src/doc/ndarray_for_numpy_users/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@
//! `a[:] = 3.` | [`a.fill(3.)`][.fill()] | set all array elements to the same scalar value
//! `a[:] = b` | [`a.assign(&b)`][.assign()] | copy the data from array `b` into array `a`
//! `np.concatenate((a,b), axis=1)` | [`stack![Axis(1), a, b]`][stack!] or [`stack(Axis(1), &[a.view(), b.view()])`][stack()] | concatenate arrays `a` and `b` along axis 1
//! `a[:,np.newaxis]` or `np.expand_dims(a, axis=1)` | [`a.insert_axis(Axis(1))`][.insert_axis()] | create an array from `a`, inserting a new axis 1
//! `a[:,np.newaxis]` or `np.expand_dims(a, axis=1)` | [`a.slice(s![.., NewAxis])`][.slice()] or [`a.insert_axis(Axis(1))`][.insert_axis()] | create an view of 1-D array `a`, inserting a new axis 1
//! `a.transpose()` or `a.T` | [`a.t()`][.t()] or [`a.reversed_axes()`][.reversed_axes()] | transpose of array `a` (view for `.t()` or by-move for `.reversed_axes()`)
//! `np.diag(a)` | [`a.diag()`][.diag()] | view the diagonal of `a`
//! `a.flatten()` | [`Array::from_iter(a.iter())`][::from_iter()] | create a 1-D array by flattening `a`
Expand Down
22 changes: 16 additions & 6 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,12 @@ where
// Skip the old axis since it should be removed.
old_axis += 1;
}
&AxisSliceInfo::NewAxis => {
// Set the dim and stride of the new axis.
new_dim[new_axis] = 1;
new_strides[new_axis] = 0;
new_axis += 1;
}
});
debug_assert_eq!(old_axis, self.ndim());
debug_assert_eq!(new_axis, out_ndim);
Expand All @@ -381,6 +387,8 @@ where

/// Slice the array in place without changing the number of dimensions.
///
/// Note that `NewAxis` elements in `info` are ignored.
///
/// See [*Slicing*](#slicing) for full documentation.
///
/// **Panics** if an index is out of bounds or step size is zero.<br>
Expand All @@ -394,18 +402,20 @@ where
self.ndim(),
"The input dimension of `info` must match the array to be sliced.",
);
info.as_ref()
.iter()
.enumerate()
.for_each(|(axis, ax_info)| match ax_info {
let mut axis = 0;
info.as_ref().iter().for_each(|ax_info| match ax_info {
&AxisSliceInfo::Slice { start, end, step } => {
self.slice_axis_inplace(Axis(axis), Slice { start, end, step })
self.slice_axis_inplace(Axis(axis), Slice { start, end, step });
axis += 1;
}
&AxisSliceInfo::Index(index) => {
let i_usize = abs_index(self.len_of(Axis(axis)), index);
self.collapse_axis(Axis(axis), i_usize)
self.collapse_axis(Axis(axis), i_usize);
axis += 1;
}
&AxisSliceInfo::NewAxis => {}
});
debug_assert_eq!(axis, self.ndim());
}

/// Slice the array in place without changing the number of dimensions.
Expand Down
23 changes: 13 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ pub use indexes::{indices, indices_of};
pub use error::{ShapeError, ErrorKind};
pub use slice::{
deref_raw_view_mut_into_view_with_life, deref_raw_view_mut_into_view_mut_with_life,
life_of_view_mut, AxisSliceInfo, Slice, SliceInfo, SliceNextInDim, SliceNextOutDim,
life_of_view_mut, AxisSliceInfo, NewAxis, Slice, SliceInfo, SliceNextInDim, SliceNextOutDim,
};

use iterators::Baseiter;
Expand Down Expand Up @@ -467,22 +467,24 @@ pub type Ixs = isize;
///
/// If a range is used, the axis is preserved. If an index is used, that index
/// is selected and the axis is removed; this selects a subview. See
/// [*Subviews*](#subviews) for more information about subviews. Note that
/// [`.slice_collapse()`] behaves like [`.collapse_axis()`] by preserving
/// the number of dimensions.
/// [*Subviews*](#subviews) for more information about subviews. If a
/// [`NewAxis`] instance is used, a new axis is inserted. Note that
/// [`.slice_collapse()`] ignores `NewAxis` elements and behaves like
/// [`.collapse_axis()`] by preserving the number of dimensions.
///
/// [`.slice()`]: #method.slice
/// [`.slice_mut()`]: #method.slice_mut
/// [`.slice_move()`]: #method.slice_move
/// [`.slice_collapse()`]: #method.slice_collapse
/// [`NewAxis`]: struct.NewAxis.html
///
/// It's possible to take multiple simultaneous *mutable* slices with the
/// [`multislice!()`](macro.multislice!.html) macro.
///
/// ```
/// extern crate ndarray;
///
/// use ndarray::{arr2, arr3, multislice, s};
/// use ndarray::{arr2, arr3, multislice, s, NewAxis};
///
/// fn main() {
///
Expand Down Expand Up @@ -519,16 +521,17 @@ pub type Ixs = isize;
/// assert_eq!(d, e);
/// assert_eq!(d.shape(), &[2, 1, 3]);
///
/// // Let’s create a slice while selecting a subview with
/// // Let’s create a slice while selecting a subview and inserting a new axis with
/// //
/// // - Both submatrices of the greatest dimension: `..`
/// // - The last row in each submatrix, removing that axis: `-1`
/// // - Row elements in reverse order: `..;-1`
/// let f = a.slice(s![.., -1, ..;-1]);
/// let g = arr2(&[[ 6, 5, 4],
/// [12, 11, 10]]);
/// // - A new axis at the end.
/// let f = a.slice(s![.., -1, ..;-1, NewAxis]);
/// let g = arr3(&[[ [6], [5], [4]],
/// [[12], [11], [10]]]);
/// assert_eq!(f, g);
/// assert_eq!(f.shape(), &[2, 3]);
/// assert_eq!(f.shape(), &[2, 3, 1]);
///
/// // Let's take two disjoint, mutable slices of a matrix with
/// //
Expand Down
5 changes: 5 additions & 0 deletions src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ pub use {
ShapeBuilder,
};

#[doc(no_inline)]
pub use {
NewAxis,
};

#[doc(no_inline)]
pub use {
NdFloat,
Expand Down
92 changes: 70 additions & 22 deletions src/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,12 @@ impl Slice {
}
}

/// A slice (range with step) or an index.
/// Token to represent a new axis in a slice description.
///
/// See also the [`s![]`](macro.s!.html) macro.
pub struct NewAxis;

/// A slice (range with step), an index, or a new axis token.
///
/// See also the [`s![]`](macro.s!.html) macro for a convenient way to create a
/// `&SliceInfo<[AxisSliceInfo; n], Di, Do>`.
Expand All @@ -91,6 +96,10 @@ impl Slice {
/// from `a` until the end, in reverse order. It can also be created with
/// `AxisSliceInfo::from(a..).step_by(-1)`. The Python equivalent is `[a::-1]`.
/// The macro equivalent is `s![a..;-1]`.
///
/// `AxisSliceInfo::NewAxis` is a new axis of length 1. It can also be created
/// with `AxisSliceInfo::from(NewAxis)`. The Python equivalent is
/// `[np.newaxis]`. The macro equivalent is `s![NewAxis]`.
#[derive(Debug, PartialEq, Eq, Hash)]
pub enum AxisSliceInfo {
/// A range with step size. `end` is an exclusive index. Negative `begin`
Expand All @@ -103,6 +112,8 @@ pub enum AxisSliceInfo {
},
/// A single index.
Index(isize),
/// A new axis of length 1.
NewAxis,
}

copy_and_clone!{AxisSliceInfo}
Expand All @@ -124,6 +135,14 @@ impl AxisSliceInfo {
}
}

/// Returns `true` if `self` is a `NewAxis` value.
pub fn is_new_axis(&self) -> bool {
match self {
&AxisSliceInfo::NewAxis => true,
_ => false,
}
}

/// Returns a new `AxisSliceInfo` with the given step size (multiplied with
/// the previous step size).
///
Expand All @@ -143,6 +162,7 @@ impl AxisSliceInfo {
step: orig_step * step,
},
AxisSliceInfo::Index(s) => AxisSliceInfo::Index(s),
AxisSliceInfo::NewAxis => AxisSliceInfo::NewAxis,
}
}
}
Expand All @@ -163,6 +183,7 @@ impl fmt::Display for AxisSliceInfo {
write!(f, ";{}", step)?;
}
}
AxisSliceInfo::NewAxis => write!(f, "NewAxis")?,
}
Ok(())
}
Expand Down Expand Up @@ -282,6 +303,13 @@ impl_sliceorindex_from_index!(isize);
impl_sliceorindex_from_index!(usize);
impl_sliceorindex_from_index!(i32);

impl From<NewAxis> for AxisSliceInfo {
#[inline]
fn from(_: NewAxis) -> AxisSliceInfo {
AxisSliceInfo::NewAxis
}
}

/// A type that can slice an array of dimension `D`.
///
/// This trait is unsafe to implement because the implementation must ensure
Expand Down Expand Up @@ -402,12 +430,12 @@ where
/// Errors if `Di` or `Do` is not consistent with `indices`.
pub fn new(indices: T) -> Result<SliceInfo<T, Di, Do>, ShapeError> {
if let Some(ndim) = Di::NDIM {
if ndim != indices.as_ref().len() {
if ndim != indices.as_ref().iter().filter(|s| !s.is_new_axis()).count() {
return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape));
}
}
if let Some(ndim) = Do::NDIM {
if ndim != indices.as_ref().iter().filter(|s| s.is_slice()).count() {
if ndim != indices.as_ref().iter().filter(|s| !s.is_index()).count() {
return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape));
}
}
Expand All @@ -427,8 +455,18 @@ where
{
/// Returns the number of dimensions of the input array for
/// [`.slice()`](struct.ArrayBase.html#method.slice).
///
/// If `Di` is a fixed-size dimension type, then this is equivalent to
/// `Di::NDIM.unwrap()`. Otherwise, the value is calculated by iterating
/// over the `AxisSliceInfo` elements.
pub fn in_ndim(&self) -> usize {
Di::NDIM.unwrap_or_else(|| self.indices.as_ref().len())
Di::NDIM.unwrap_or_else(|| {
self.indices
.as_ref()
.iter()
.filter(|s| !s.is_new_axis())
.count()
})
}

/// Returns the number of dimensions after calling
Expand All @@ -443,7 +481,7 @@ where
self.indices
.as_ref()
.iter()
.filter(|s| s.is_slice())
.filter(|s| !s.is_index())
.count()
})
}
Expand Down Expand Up @@ -506,6 +544,12 @@ pub trait SliceNextInDim<D1, D2> {
fn next_dim(&self, PhantomData<D1>) -> PhantomData<D2>;
}

impl<D1: Dimension> SliceNextInDim<D1, D1> for NewAxis {
fn next_dim(&self, _: PhantomData<D1>) -> PhantomData<D1> {
PhantomData
}
}

macro_rules! impl_slicenextindim_larger {
(($($generics:tt)*), $self:ty) => {
impl<D1: Dimension, $($generics),*> SliceNextInDim<D1, D1::Larger> for $self {
Expand Down Expand Up @@ -560,12 +604,13 @@ impl_slicenextoutdim_larger!((T), RangeTo<T>);
impl_slicenextoutdim_larger!((T), RangeToInclusive<T>);
impl_slicenextoutdim_larger!((), RangeFull);
impl_slicenextoutdim_larger!((), Slice);
impl_slicenextoutdim_larger!((), NewAxis);

/// Slice argument constructor.
///
/// `s![]` takes a list of ranges/slices/indices, separated by comma, with
/// optional step sizes that are separated from the range by a semicolon. It is
/// converted into a [`&SliceInfo`] instance.
/// `s![]` takes a list of ranges/slices/indices/new-axes, separated by comma,
/// with optional step sizes that are separated from the range by a semicolon.
/// It is converted into a [`&SliceInfo`] instance.
///
/// [`&SliceInfo`]: struct.SliceInfo.html
///
Expand All @@ -584,22 +629,25 @@ impl_slicenextoutdim_larger!((), Slice);
/// * *slice*: a [`Slice`] instance to use for slicing that axis.
/// * *slice* `;` *step*: a range constructed from the start and end of a [`Slice`]
/// instance, with new step size *step*, to use for slicing that axis.
/// * *new-axis*: a [`NewAxis`] instance that represents the creation of a new axis.
///
/// [`Slice`]: struct.Slice.html
/// [`NewAxis`]: struct.NewAxis.html
///
/// The number of *axis-slice-info* must match the number of axes in the array.
/// *index*, *range*, *slice*, and *step* can be expressions. *index* must be
/// of type `isize`, `usize`, or `i32`. *range* must be of type `Range<I>`,
/// `RangeTo<I>`, `RangeFrom<I>`, or `RangeFull` where `I` is `isize`, `usize`,
/// or `i32`. *step* must be a type that can be converted to `isize` with the
/// `as` keyword.
/// The number of *axis-slice-info*, not including *new-axis*, must match the
/// number of axes in the array. *index*, *range*, *slice*, *step*, and
/// *new-axis* can be expressions. *index* must be of type `isize`, `usize`, or
/// `i32`. *range* must be of type `Range<I>`, `RangeTo<I>`, `RangeFrom<I>`, or
/// `RangeFull` where `I` is `isize`, `usize`, or `i32`. *step* must be a type
/// that can be converted to `isize` with the `as` keyword.
///
/// For example `s![0..4;2, 6, 1..5]` is a slice of the first axis for 0..4
/// with step size 2, a subview of the second axis at index 6, and a slice of
/// the third axis for 1..5 with default step size 1. The input array must have
/// 3 dimensions. The resulting slice would have shape `[2, 4]` for
/// [`.slice()`], [`.slice_mut()`], and [`.slice_move()`], and shape
/// `[2, 1, 4]` for [`.slice_collapse()`].
/// For example `s![0..4;2, 6, 1..5, NewAxis]` is a slice of the first axis for
/// 0..4 with step size 2, a subview of the second axis at index 6, a slice of
/// the third axis for 1..5 with default step size 1, and a new axis of length
/// 1 at the end of the shape. The input array must have 3 dimensions. The
/// resulting slice would have shape `[2, 4, 1]` for [`.slice()`],
/// [`.slice_mut()`], and [`.slice_move()`], and shape `[2, 1, 4]` for
/// [`.slice_collapse()`].
///
/// [`.slice()`]: struct.ArrayBase.html#method.slice
/// [`.slice_mut()`]: struct.ArrayBase.html#method.slice_mut
Expand Down Expand Up @@ -726,11 +774,11 @@ macro_rules! s(
}
}
};
// convert range/index into AxisSliceInfo
// convert range/index/new-axis into AxisSliceInfo
(@convert $r:expr) => {
<$crate::AxisSliceInfo as ::std::convert::From<_>>::from($r)
};
// convert range/index and step into AxisSliceInfo
// convert range/index/new-axis and step into AxisSliceInfo
(@convert $r:expr, $s:expr) => {
<$crate::AxisSliceInfo as ::std::convert::From<_>>::from($r).step_by($s as isize)
};
Expand Down
Loading

0 comments on commit 1e978ca

Please sign in to comment.