diff --git a/src/array.rs b/src/array.rs index 3463b0c3d..fefb2e23f 100644 --- a/src/array.rs +++ b/src/array.rs @@ -127,8 +127,9 @@ impl IntoPy for PyArray { impl<'a, T: Element, D: Dimension> FromPyObject<'a> for &'a PyArray { // here we do type-check three times // 1. Checks if the object is PyArray - // 2. Checks if the data type of the array is T - // 3. Checks if the dimension is same as D + // 2. Checks if the dimension is same as D + // 3. Checks if the data type of the array is T + // 4. Optionally checks if the elements of the array match T fn extract(ob: &'a PyAny) -> PyResult { let array = unsafe { if npyffi::PyArray_Check(ob.as_ptr()) == 0 { @@ -137,12 +138,6 @@ impl<'a, T: Element, D: Dimension> FromPyObject<'a> for &'a PyArray { &*(ob as *const PyAny as *const PyArray) }; - let src_dtype = array.dtype(); - let dst_dtype = T::get_dtype(ob.py()); - if !src_dtype.is_equiv_to(dst_dtype) { - return Err(TypeError::new(src_dtype, dst_dtype).into()); - } - let src_ndim = array.shape().len(); if let Some(dst_ndim) = D::NDIM { if src_ndim != dst_ndim { @@ -150,6 +145,14 @@ impl<'a, T: Element, D: Dimension> FromPyObject<'a> for &'a PyArray { } } + let src_dtype = array.dtype(); + let dst_dtype = T::get_dtype(ob.py()); + if !src_dtype.is_equiv_to(dst_dtype) { + return Err(TypeError::new(src_dtype, dst_dtype).into()); + } + + T::check_element_types(ob.py(), array)?; + Ok(array) } } diff --git a/src/dtype.rs b/src/dtype.rs index e7b744917..5f1402d06 100644 --- a/src/dtype.rs +++ b/src/dtype.rs @@ -1,10 +1,17 @@ +use std::any::{type_name, TypeId}; use std::mem::size_of; use std::os::raw::{c_int, c_long, c_longlong, c_short, c_uint, c_ulong, c_ulonglong, c_ushort}; +use ndarray::Dimension; use num_traits::{Bounded, Zero}; -use pyo3::{ffi, prelude::*, pyobject_native_type_core, types::PyType, AsPyPointer, PyNativeType}; +use pyo3::{ + ffi, prelude::*, pyobject_native_type_core, type_object::PyTypeObject, types::PyType, + AsPyPointer, PyDowncastError, PyNativeType, PyTypeInfo, +}; +use crate::array::PyArray; use crate::npyffi::{NpyTypes, PyArray_Descr, NPY_TYPES, PY_ARRAY_API}; +use crate::NpySingleIterBuilder; pub use num_complex::{Complex32, Complex64}; @@ -132,26 +139,12 @@ impl PyArrayDescr { /// #[pyclass] /// pub struct CustomElement; /// -/// // The transparent wrapper is necessary as one cannot implement -/// // a foreign trait (`Element`) on a foreign type (`Py`) directly. -/// #[derive(Clone)] -/// #[repr(transparent)] -/// pub struct Wrapper(pub Py); -/// -/// unsafe impl Element for Wrapper { -/// const IS_COPY: bool = false; -/// -/// fn get_dtype(py: Python) -> &PyArrayDescr { -/// PyArrayDescr::object(py) -/// } -/// } -/// /// Python::with_gil(|py| { -/// let array = Array2::::from_shape_fn((2, 3), |(_i, _j)| { -/// Wrapper(Py::new(py, CustomElement).unwrap()) +/// let array = Array2::>::from_shape_fn((2, 3), |(_i, _j)| { +/// Py::new(py, CustomElement).unwrap() /// }); /// -/// let _array: &PyArray = array.to_pyarray(py); +/// let _array: &PyArray, _> = array.to_pyarray(py); /// }); /// ``` pub unsafe trait Element: Clone + Send { @@ -164,6 +157,9 @@ pub unsafe trait Element: Clone + Send { /// that contain object-type fields. const IS_COPY: bool; + /// TODO + fn check_element_types(py: Python, array: &PyArray) -> PyResult<()>; + /// Returns the associated array descriptor ("dtype") for the given type. fn get_dtype(py: Python) -> &PyArrayDescr; } @@ -218,6 +214,12 @@ macro_rules! impl_element_scalar { $(#[$meta])* unsafe impl Element for $ty { const IS_COPY: bool = true; + + fn check_element_types(_py: Python, _array: &PyArray) -> PyResult<()> { + // For scalar types, checking the dtype is sufficient. + Ok(()) + } + fn get_dtype(py: Python) -> &PyArrayDescr { PyArrayDescr::from_npy_type(py, $npy_type) } @@ -244,9 +246,30 @@ impl_element_scalar!(Complex64 => NPY_CDOUBLE, #[cfg(any(target_pointer_width = "32", target_pointer_width = "64"))] impl_element_scalar!(usize, isize); -unsafe impl Element for PyObject { +unsafe impl Element for Py +where + T: PyTypeInfo + 'static, +{ const IS_COPY: bool = false; + fn check_element_types(py: Python, array: &PyArray) -> PyResult<()> { + // `PyAny` can represent any Python object. + if TypeId::of::() == TypeId::of::() { + return Ok(()); + } + + let type_object = T::type_object(py); + let iterator = NpySingleIterBuilder::readwrite(array).build()?; + + for element in iterator { + if !type_object.is_instance(element)? { + return Err(PyDowncastError::new(array, type_name::()).into()); + } + } + + Ok(()) + } + fn get_dtype(py: Python) -> &PyArrayDescr { PyArrayDescr::object(py) }