-
Notifications
You must be signed in to change notification settings - Fork 324
Array stacking #117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Array stacking #117
Changes from all commits
3c7de6e
18606d6
ab0527c
fbeb262
645d0f2
1a709b7
a2283e9
86cb801
ce4fc93
b34a0c8
3a3e278
ddcc85c
99c0145
4762141
806b2bd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
|
||
use imp_prelude::*; | ||
use error::{ShapeError, ErrorKind, from_kind}; | ||
|
||
/// Stack arrays along the given axis. | ||
pub fn stack<'a, A, D>(axis: Axis, arrays: &[ArrayView<'a, A, D>]) | ||
-> Result<OwnedArray<A, D>, ShapeError> | ||
where A: Copy, | ||
D: Dimension + RemoveAxis | ||
{ | ||
if arrays.len() == 0 { | ||
return Err(from_kind(ErrorKind::Unsupported)); | ||
} | ||
let mut res_dim = arrays[0].dim(); | ||
if axis.axis() >= res_dim.ndim() { | ||
return Err(from_kind(ErrorKind::OutOfBounds)); | ||
} | ||
let common_dim = res_dim.remove_axis(axis); | ||
if arrays.iter().any(|a| a.dim().remove_axis(axis) != common_dim) { | ||
return Err(from_kind(ErrorKind::IncompatibleShape)); | ||
} | ||
|
||
let stacked_dim = arrays.iter() | ||
.fold(0, |acc, a| acc + a.dim().index(axis)); | ||
*res_dim.index_mut(axis) = stacked_dim; | ||
|
||
// we can safely use uninitialized values here because they are Copy | ||
// and we will only ever write to them | ||
let size = res_dim.size(); | ||
let mut v = Vec::with_capacity(size); | ||
unsafe { | ||
v.set_len(size); | ||
} | ||
let mut res = try!(OwnedArray::from_vec_dim(res_dim, v)); | ||
|
||
{ | ||
let mut assign_view = res.view_mut(); | ||
for array in arrays { | ||
let len = *array.dim().index(axis); | ||
let (mut front, rest) = assign_view.split_at(axis, len); | ||
front.assign(array); | ||
assign_view = rest; | ||
} | ||
} | ||
Ok(res) | ||
} | ||
|
||
#[macro_export] | ||
macro_rules! stack { | ||
($axis:expr, $( $a:expr ),+ ) => { | ||
ndarray::stack($axis, &[ $(ndarray::ArrayView::from($a) ),* ]).unwrap() | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here's a first attempt at a macro for stacking. I could not manage to use the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think using the new Should we make the syntax lighter? stack![Axis(0), a, b, c] would work too I think. And using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes
Yes looks nice that way. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After trying, using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that part is easy to fix by using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The full path resolves, alas there are still issues:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a good error, since we should not consume an OwnedArray just to create a view from it. We can solve it later. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right, so there's no problem when passing a reference to the macro. Maybe that's better than |
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
|
||
#[macro_use(stack)] | ||
extern crate ndarray; | ||
|
||
|
||
use ndarray::{ | ||
arr2, | ||
Axis, | ||
Ix, | ||
OwnedArray, | ||
ErrorKind, | ||
}; | ||
|
||
#[test] | ||
fn stacking() { | ||
let a = arr2(&[[2., 2.], | ||
[3., 3.]]); | ||
let b = ndarray::stack(Axis(0), &[a.view(), a.view()]).unwrap(); | ||
assert_eq!(b, arr2(&[[2., 2.], | ||
[3., 3.], | ||
[2., 2.], | ||
[3., 3.]])); | ||
|
||
let c = stack!(Axis(0), a.view(), &b); | ||
assert_eq!(c, arr2(&[[2., 2.], | ||
[3., 3.], | ||
[2., 2.], | ||
[3., 3.], | ||
[2., 2.], | ||
[3., 3.]])); | ||
|
||
let res = ndarray::stack(Axis(1), &[a.view(), c.view()]); | ||
assert_eq!(res.unwrap_err().kind(), ErrorKind::IncompatibleShape); | ||
|
||
let res = ndarray::stack(Axis(2), &[a.view(), c.view()]); | ||
assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds); | ||
|
||
let res: Result<OwnedArray<f64, (Ix, Ix)>, _> = ndarray::stack(Axis(0), &[]); | ||
assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is in fact overly strict -- assign allows broadcasting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, I had forgotten about broadcasting, I'll have to think about it.