Skip to content

Array stacking #117

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Mar 5, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ pub use iterators::{

pub use arraytraits::AsArray;
pub use linalg::{LinalgScalar, NdFloat};
pub use stacking::stack;

mod arraytraits;
#[cfg(feature = "serde")]
Expand Down Expand Up @@ -129,6 +130,7 @@ mod linspace;
mod numeric_util;
mod si;
mod error;
pub mod stacking;

/// Implementation's prelude. Common types used everywhere.
mod imp_prelude {
Expand Down
53 changes: 53 additions & 0 deletions src/stacking.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@

use imp_prelude::*;
use error::{ShapeError, ErrorKind, from_kind};

/// Stack arrays along the given axis.
pub fn stack<'a, A, D>(axis: Axis, arrays: &[ArrayView<'a, A, D>])
-> Result<OwnedArray<A, D>, ShapeError>
where A: Copy,
D: Dimension + RemoveAxis
{
if arrays.len() == 0 {
return Err(from_kind(ErrorKind::Unsupported));
}
let mut res_dim = arrays[0].dim();
if axis.axis() >= res_dim.ndim() {
return Err(from_kind(ErrorKind::OutOfBounds));
}
let common_dim = res_dim.remove_axis(axis);
if arrays.iter().any(|a| a.dim().remove_axis(axis) != common_dim) {
return Err(from_kind(ErrorKind::IncompatibleShape));
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is in fact overly strict -- assign allows broadcasting.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I had forgotten about broadcasting, I'll have to think about it.


let stacked_dim = arrays.iter()
.fold(0, |acc, a| acc + a.dim().index(axis));
*res_dim.index_mut(axis) = stacked_dim;

// we can safely use uninitialized values here because they are Copy
// and we will only ever write to them
let size = res_dim.size();
let mut v = Vec::with_capacity(size);
unsafe {
v.set_len(size);
}
let mut res = try!(OwnedArray::from_vec_dim(res_dim, v));

{
let mut assign_view = res.view_mut();
for array in arrays {
let len = *array.dim().index(axis);
let (mut front, rest) = assign_view.split_at(axis, len);
front.assign(array);
assign_view = rest;
}
}
Ok(res)
}

#[macro_export]
macro_rules! stack {
($axis:expr, $( $a:expr ),+ ) => {
ndarray::stack($axis, &[ $(ndarray::ArrayView::from($a) ),* ]).unwrap()
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's a first attempt at a macro for stacking. I could not manage to use the AsArray trait though, I supposed I would call .into() here instead of view(), but that does not seem to work that easily. I haven't thought about it that much right now though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think using the new ArrayView::from works well. Shouldn't do much different than the .view() call. But .view() works well too? View accepts all kinds of arrays and references to arrays.

Should we make the syntax lighter? stack![Axis(0), a, b, c] would work too I think. And using + to enforce that there is at least one array.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think using the new ArrayView::from works well. Shouldn't do much
different than the .view() call. But .view() works well too? View accepts
all kinds of arrays and references to arrays.

Yes .view() works too. But maybe ArrayView::from is more generic ie it's
open for extension from the trait.

Should we make the syntax lighter? stack![Axis(0), a, b, c] would work too I
think. And using + to enforce that there is at least one array.

Yes looks nice that way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After trying, using ArrayView::from requires the user to import ArrayView at the call site, while using view() works without such requirement. So maybe it's better to use view()?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that part is easy to fix by using $crate::ArrayView::from i.e. the full path. Apart from that, I'm open to any feedback on how the from trait actually works/doesn't work for this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The full path resolves, alas there are still issues:

error: the trait `core::convert::From<ndarray::ArrayBase<collections::vec::Vec<_>, (usize, usize)>>` is not implemented for the type `ndarray::ArrayBase<ndarray::ViewRepr<&_>, _>`

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good error, since we should not consume an OwnedArray just to create a view from it. We can solve it later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, so there's no problem when passing a reference to the macro. Maybe that's better than .view()then, makes it clearer that the ownership isn't taken?

}
40 changes: 40 additions & 0 deletions tests/stacking.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@

#[macro_use(stack)]
extern crate ndarray;


use ndarray::{
arr2,
Axis,
Ix,
OwnedArray,
ErrorKind,
};

#[test]
fn stacking() {
let a = arr2(&[[2., 2.],
[3., 3.]]);
let b = ndarray::stack(Axis(0), &[a.view(), a.view()]).unwrap();
assert_eq!(b, arr2(&[[2., 2.],
[3., 3.],
[2., 2.],
[3., 3.]]));

let c = stack!(Axis(0), a.view(), &b);
assert_eq!(c, arr2(&[[2., 2.],
[3., 3.],
[2., 2.],
[3., 3.],
[2., 2.],
[3., 3.]]));

let res = ndarray::stack(Axis(1), &[a.view(), c.view()]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::IncompatibleShape);

let res = ndarray::stack(Axis(2), &[a.view(), c.view()]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds);

let res: Result<OwnedArray<f64, (Ix, Ix)>, _> = ndarray::stack(Axis(0), &[]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported);
}