Skip to content

Properly handle negative strides #156

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 44 additions & 17 deletions src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,15 +298,6 @@ impl<T, D> PyArray<T, D> {
self.len() == 0
}

fn strides_usize(&self) -> &[usize] {
let n = self.ndim();
let ptr = self.as_array_ptr();
unsafe {
let p = (*ptr).strides;
slice::from_raw_parts(p as *const _, n)
}
}

/// Returns the pointer to the first element of the inner array.
pub(crate) unsafe fn data(&self) -> *mut T {
let ptr = self.as_array_ptr();
Expand All @@ -318,20 +309,50 @@ impl<T, D> PyArray<T, D> {
}
}

struct InvertedAxises(Vec<Axis>);

impl InvertedAxises {
fn invert<S: RawData, D: Dimension>(self, array: &mut ArrayBase<S, D>) {
for axis in self.0 {
array.invert_axis(axis);
}
}
}

impl<T: Element, D: Dimension> PyArray<T, D> {
/// Same as [shape](#method.shape), but returns `D`
#[inline(always)]
pub fn dims(&self) -> D {
D::from_dimension(&Dim(self.shape())).expect("PyArray::dims different dimension")
}

fn ndarray_shape(&self) -> StrideShape<D> {
fn ndarray_shape_ptr(&self) -> (StrideShape<D>, *mut T, InvertedAxises) {
const ERR_MSG: &str = "PyArray::ndarray_shape: dimension mismatching";
let shape_slice = self.shape();
let shape: Shape<_> = Dim(self.dims()).into();
let size = mem::size_of::<T>();
let mut st = D::from_dimension(&Dim(self.strides_usize()))
.expect("PyArray::ndarray_shape: dimension mismatching");
st.slice_mut().iter_mut().for_each(|e| *e /= size);
shape.strides(st)
let sizeof_t = mem::size_of::<T>();
let strides = self.strides();
let mut new_strides = D::zeros(strides.len());
let mut data_ptr = unsafe { self.data() };
let mut inverted_axises = vec![];
for i in 0..strides.len() {
// TODO(kngwyu): Replace this hacky negative strides support with
// a proper constructor, when it's implemented.
// See https://github.com/rust-ndarray/ndarray/issues/842 for more.
if strides[i] < 0 {
// Move the pointer to the start position
let offset = strides[i] * (shape_slice[i] as isize - 1) / sizeof_t as isize;
unsafe {
data_ptr = data_ptr.offset(offset);
}
new_strides[i] = (-strides[i]) as usize / sizeof_t;
inverted_axises.push(Axis(i));
} else {
new_strides[i] = strides[i] as usize / sizeof_t;
}
}
let st = D::from_dimension(&Dim(new_strides)).expect(ERR_MSG);
(shape.strides(st), data_ptr, InvertedAxises(inverted_axises))
}

/// Creates a new uninitialized PyArray in python heap.
Expand Down Expand Up @@ -632,7 +653,10 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
/// If the internal array is not readonly and can be mutated from Python code,
/// holding the `ArrayView` might cause undefined behavior.
pub unsafe fn as_array(&self) -> ArrayView<'_, T, D> {
ArrayView::from_shape_ptr(self.ndarray_shape(), self.data())
let (shape, ptr, inverted_axises) = self.ndarray_shape_ptr();
let mut res = ArrayView::from_shape_ptr(shape, ptr);
inverted_axises.invert(&mut res);
res
}

/// Returns the internal array as `ArrayViewMut`. See also [`as_array`](#method.as_array).
Expand All @@ -641,7 +665,10 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
/// If another reference to the internal data exists(e.g., `&[T]` or `ArrayView`),
/// it might cause undefined behavior.
pub unsafe fn as_array_mut(&self) -> ArrayViewMut<'_, T, D> {
ArrayViewMut::from_shape_ptr(self.ndarray_shape(), self.data())
let (shape, ptr, inverted_axises) = self.ndarray_shape_ptr();
let mut res = ArrayViewMut::from_shape_ptr(shape, ptr);
inverted_axises.invert(&mut res);
res
}

/// Get a copy of `PyArray` as
Expand Down
14 changes: 14 additions & 0 deletions tests/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,17 @@ fn array_cast() {
let arr_i32: &PyArray2<i32> = arr_f64.cast(false).unwrap();
assert_eq!(arr_i32.readonly().as_array(), array![[1, 2, 3], [1, 2, 3]]);
}

#[test]
fn handle_negative_strides() {
let gil = pyo3::Python::acquire_gil();
let py = gil.python();
let arr = array![[2, 3], [4, 5u32]];
let pyarr = arr.to_pyarray(py);
let negstr_pyarr: &numpy::PyArray2<u32> = py
.eval("a[::-1]", Some([("a", pyarr)].into_py_dict(py)), None)
.unwrap()
.downcast()
.unwrap();
assert_eq!(negstr_pyarr.to_owned_array(), arr.slice(s![..;-1, ..]));
}