Skip to content

Commit

Permalink
Merge pull request #265 from PyO3/example-downcast
Browse files Browse the repository at this point in the history
Fix type confusion during downcastsing and add a test case showing how to extract an array from a dictionary.
  • Loading branch information
adamreichold authored Jan 24, 2022
2 parents ad49760 + c8390e3 commit 7cc945c
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 20 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
- Unreleased
- Support object arrays ([#216](https://github.com/PyO3/rust-numpy/pull/216))
- Support borrowing arrays that are part of other Python objects via `PyArray::borrow_from_array` ([#230](https://github.com/PyO3/rust-numpy/pull/216))
- Fixed downcasting ignoring element type and dimensionality ([#265](https://github.com/PyO3/rust-numpy/pull/265))
- `PyArray::new` is now `unsafe`, as it produces uninitialized arrays ([#220](https://github.com/PyO3/rust-numpy/pull/220))
- `PyArray::from_exact_iter` does not unsoundly trust `ExactSizeIterator::len` any more ([#262](https://github.com/PyO3/rust-numpy/pull/262))
- `PyArray::as_cell_slice` was removed as it unsoundly interacts with `PyReadonlyArray` allowing safe code to violate aliasing rules ([#260](https://github.com/PyO3/rust-numpy/pull/260))
Expand Down
20 changes: 18 additions & 2 deletions examples/simple-extension/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use numpy::ndarray::{ArrayD, ArrayViewD, ArrayViewMutD};
use numpy::{Complex64, IntoPyArray, PyArrayDyn, PyReadonlyArrayDyn};
use pyo3::{pymodule, types::PyModule, PyResult, Python};
use numpy::{Complex64, IntoPyArray, PyArray1, PyArrayDyn, PyReadonlyArrayDyn};
use pyo3::{
pymodule,
types::{PyDict, PyModule},
PyResult, Python,
};

#[pymodule]
fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
Expand Down Expand Up @@ -52,5 +56,17 @@ fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
conj(x.as_array()).into_pyarray(py)
}

#[pyfn(m)]
#[pyo3(name = "extract")]
fn extract(d: &PyDict) -> f64 {
let x = d
.get_item("x")
.unwrap()
.downcast::<PyArray1<f64>>()
.unwrap();

x.readonly().as_array().sum()
}

Ok(())
}
8 changes: 7 additions & 1 deletion examples/simple-extension/tests/test_ext.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from rust_ext import axpy, conj, mult
from rust_ext import axpy, conj, mult, extract


def test_axpy():
Expand All @@ -22,3 +22,9 @@ def test_mult():
def test_conj():
x = np.array([1.0 + 2j, 2.0 + 3j, 3.0 + 4j])
np.testing.assert_array_almost_equal(conj(x), np.conj(x))


def test_extract():
x = np.arange(5.0)
d = { "x": x }
np.testing.assert_almost_equal(extract(d), 10.0)
36 changes: 22 additions & 14 deletions src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ use ndarray::{
};
use num_traits::AsPrimitive;
use pyo3::{
ffi, pyobject_native_type_info, pyobject_native_type_named, type_object, types::PyModule,
AsPyPointer, FromPyObject, IntoPy, Py, PyAny, PyDowncastError, PyErr, PyNativeType, PyObject,
PyResult, Python, ToPyObject,
ffi, pyobject_native_type_named, type_object, types::PyModule, AsPyPointer, FromPyObject,
IntoPy, Py, PyAny, PyDowncastError, PyErr, PyNativeType, PyObject, PyResult, PyTypeInfo,
Python, ToPyObject,
};

use crate::convert::{ArrayExt, IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
Expand Down Expand Up @@ -110,16 +110,24 @@ pub fn get_array_module(py: Python<'_>) -> PyResult<&PyModule> {
}

unsafe impl<T, D> type_object::PyLayout<PyArray<T, D>> for npyffi::PyArrayObject {}

impl<T, D> type_object::PySizedLayout<PyArray<T, D>> for npyffi::PyArrayObject {}

pyobject_native_type_info!(
PyArray<T, D>,
*npyffi::PY_ARRAY_API.get_type_object(npyffi::NpyTypes::PyArray_Type),
Some("numpy"),
#checkfunction=npyffi::PyArray_Check
; T
; D
);
unsafe impl<T: Element, D: Dimension> PyTypeInfo for PyArray<T, D> {
type AsRefTarget = Self;

const NAME: &'static str = "PyArray<T, D>";
const MODULE: ::std::option::Option<&'static str> = Some("numpy");

#[inline]
fn type_object_raw(_py: Python) -> *mut ffi::PyTypeObject {
unsafe { npyffi::PY_ARRAY_API.get_type_object(npyffi::NpyTypes::PyArray_Type) }
}

fn is_type_of(ob: &PyAny) -> bool {
<&Self>::extract(ob).is_ok()
}
}

pyobject_native_type_named!(PyArray<T, D> ; T ; D);

Expand All @@ -129,12 +137,12 @@ impl<T, D> IntoPy<PyObject> for PyArray<T, D> {
}
}

impl<'a, T: Element, D: Dimension> FromPyObject<'a> for &'a PyArray<T, D> {
impl<'py, T: Element, D: Dimension> FromPyObject<'py> for &'py PyArray<T, D> {
// here we do type-check three times
// 1. Checks if the object is PyArray
// 2. Checks if the data type of the array is T
// 3. Checks if the dimension is same as D
fn extract(ob: &'a PyAny) -> PyResult<Self> {
fn extract(ob: &'py PyAny) -> PyResult<Self> {
let array = unsafe {
if npyffi::PyArray_Check(ob.as_ptr()) == 0 {
return Err(PyDowncastError::new(ob, "PyArray<T, D>").into());
Expand Down Expand Up @@ -207,7 +215,7 @@ impl<T, D> PyArray<T, D> {
/// assert!(array.is_contiguous());
/// let locals = [("np", numpy::get_array_module(py).unwrap())].into_py_dict(py);
/// let not_contiguous: &numpy::PyArray1<f32> = py
/// .eval("np.zeros((3, 5))[::2, 4]", Some(locals), None)
/// .eval("np.zeros((3, 5), dtype='float32')[::2, 4]", Some(locals), None)
/// .unwrap()
/// .downcast()
/// .unwrap();
Expand Down
2 changes: 1 addition & 1 deletion src/readonly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ impl<'py, T: Element, D: Dimension> PyReadonlyArray<'py, T, D> {
/// assert_eq!(readonly.as_slice().unwrap(), &[0, 1, 2, 3]);
/// let locals = [("np", numpy::get_array_module(py).unwrap())].into_py_dict(py);
/// let not_contiguous: &PyArray1<i32> = py
/// .eval("np.arange(10)[::2]", Some(locals), None)
/// .eval("np.arange(10, dtype='int32')[::2]", Some(locals), None)
/// .unwrap()
/// .downcast()
/// .unwrap();
Expand Down
30 changes: 28 additions & 2 deletions tests/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ use ndarray::*;
use numpy::*;
use pyo3::{
prelude::*,
types::PyList,
types::{IntoPyDict, PyDict},
types::{IntoPyDict, PyDict, PyList},
};

fn get_np_locals(py: Python) -> &PyDict {
Expand Down Expand Up @@ -300,3 +299,30 @@ fn borrow_from_array() {
py_run!(py, array, "assert array.shape == (10,)");
});
}

#[test]
fn downcasting_works() {
Python::with_gil(|py| {
let ob: &PyAny = PyArray::from_slice(py, &[1_i32, 2, 3]);

assert!(ob.downcast::<PyArray1<i32>>().is_ok());
})
}

#[test]
fn downcasting_respects_element_type() {
Python::with_gil(|py| {
let ob: &PyAny = PyArray::from_slice(py, &[1_i32, 2, 3]);

assert!(ob.downcast::<PyArray1<f64>>().is_err());
})
}

#[test]
fn downcasting_respects_dimensionality() {
Python::with_gil(|py| {
let ob: &PyAny = PyArray::from_slice(py, &[1_i32, 2, 3]);

assert!(ob.downcast::<PyArray2<i32>>().is_err());
})
}

0 comments on commit 7cc945c

Please sign in to comment.