diff --git a/src/lib.rs b/src/lib.rs index 53c82e7a2..c882865cc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -100,6 +100,7 @@ pub use iterators::{ pub use arraytraits::AsArray; pub use linalg::{LinalgScalar, NdFloat}; +pub use stacking::stack; mod arraytraits; #[cfg(feature = "serde")] @@ -129,6 +130,7 @@ mod linspace; mod numeric_util; mod si; mod error; +pub mod stacking; /// Implementation's prelude. Common types used everywhere. mod imp_prelude { diff --git a/src/stacking.rs b/src/stacking.rs new file mode 100644 index 000000000..22010b135 --- /dev/null +++ b/src/stacking.rs @@ -0,0 +1,53 @@ + +use imp_prelude::*; +use error::{ShapeError, ErrorKind, from_kind}; + +/// Stack arrays along the given axis. +pub fn stack<'a, A, D>(axis: Axis, arrays: &[ArrayView<'a, A, D>]) + -> Result, ShapeError> + where A: Copy, + D: Dimension + RemoveAxis +{ + if arrays.len() == 0 { + return Err(from_kind(ErrorKind::Unsupported)); + } + let mut res_dim = arrays[0].dim(); + if axis.axis() >= res_dim.ndim() { + return Err(from_kind(ErrorKind::OutOfBounds)); + } + let common_dim = res_dim.remove_axis(axis); + if arrays.iter().any(|a| a.dim().remove_axis(axis) != common_dim) { + return Err(from_kind(ErrorKind::IncompatibleShape)); + } + + let stacked_dim = arrays.iter() + .fold(0, |acc, a| acc + a.dim().index(axis)); + *res_dim.index_mut(axis) = stacked_dim; + + // we can safely use uninitialized values here because they are Copy + // and we will only ever write to them + let size = res_dim.size(); + let mut v = Vec::with_capacity(size); + unsafe { + v.set_len(size); + } + let mut res = try!(OwnedArray::from_vec_dim(res_dim, v)); + + { + let mut assign_view = res.view_mut(); + for array in arrays { + let len = *array.dim().index(axis); + let (mut front, rest) = assign_view.split_at(axis, len); + front.assign(array); + assign_view = rest; + } + } + Ok(res) +} + +#[macro_export] +macro_rules! stack { + ($axis:expr, $( $a:expr ),+ ) => { + ndarray::stack($axis, &[ $(ndarray::ArrayView::from($a) ),* ]).unwrap() + } +} diff --git a/tests/stacking.rs b/tests/stacking.rs new file mode 100644 index 000000000..7f9f9bfe0 --- /dev/null +++ b/tests/stacking.rs @@ -0,0 +1,40 @@ + +#[macro_use(stack)] +extern crate ndarray; + + +use ndarray::{ + arr2, + Axis, + Ix, + OwnedArray, + ErrorKind, +}; + +#[test] +fn stacking() { + let a = arr2(&[[2., 2.], + [3., 3.]]); + let b = ndarray::stack(Axis(0), &[a.view(), a.view()]).unwrap(); + assert_eq!(b, arr2(&[[2., 2.], + [3., 3.], + [2., 2.], + [3., 3.]])); + + let c = stack!(Axis(0), a.view(), &b); + assert_eq!(c, arr2(&[[2., 2.], + [3., 3.], + [2., 2.], + [3., 3.], + [2., 2.], + [3., 3.]])); + + let res = ndarray::stack(Axis(1), &[a.view(), c.view()]); + assert_eq!(res.unwrap_err().kind(), ErrorKind::IncompatibleShape); + + let res = ndarray::stack(Axis(2), &[a.view(), c.view()]); + assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds); + + let res: Result, _> = ndarray::stack(Axis(0), &[]); + assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported); +}