Skip to content

RFC: Check element types #257

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,9 @@ impl<T, D> IntoPy<PyObject> for PyArray<T, D> {
impl<'a, T: Element, D: Dimension> FromPyObject<'a> for &'a PyArray<T, D> {
// here we do type-check three times
// 1. Checks if the object is PyArray
// 2. Checks if the data type of the array is T
// 3. Checks if the dimension is same as D
// 2. Checks if the dimension is same as D
// 3. Checks if the data type of the array is T
// 4. Optionally checks if the elements of the array match T
fn extract(ob: &'a PyAny) -> PyResult<Self> {
let array = unsafe {
if npyffi::PyArray_Check(ob.as_ptr()) == 0 {
Expand All @@ -137,19 +138,21 @@ impl<'a, T: Element, D: Dimension> FromPyObject<'a> for &'a PyArray<T, D> {
&*(ob as *const PyAny as *const PyArray<T, D>)
};

let src_dtype = array.dtype();
let dst_dtype = T::get_dtype(ob.py());
if !src_dtype.is_equiv_to(dst_dtype) {
return Err(TypeError::new(src_dtype, dst_dtype).into());
}

let src_ndim = array.shape().len();
if let Some(dst_ndim) = D::NDIM {
if src_ndim != dst_ndim {
return Err(DimensionalityError::new(src_ndim, dst_ndim).into());
}
}

let src_dtype = array.dtype();
let dst_dtype = T::get_dtype(ob.py());
if !src_dtype.is_equiv_to(dst_dtype) {
return Err(TypeError::new(src_dtype, dst_dtype).into());
}

T::check_element_types(ob.py(), array)?;

Ok(array)
}
}
Expand Down
61 changes: 42 additions & 19 deletions src/dtype.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
use std::any::{type_name, TypeId};
use std::mem::size_of;
use std::os::raw::{c_int, c_long, c_longlong, c_short, c_uint, c_ulong, c_ulonglong, c_ushort};

use ndarray::Dimension;
use num_traits::{Bounded, Zero};
use pyo3::{ffi, prelude::*, pyobject_native_type_core, types::PyType, AsPyPointer, PyNativeType};
use pyo3::{
ffi, prelude::*, pyobject_native_type_core, type_object::PyTypeObject, types::PyType,
AsPyPointer, PyDowncastError, PyNativeType, PyTypeInfo,
};

use crate::array::PyArray;
use crate::npyffi::{NpyTypes, PyArray_Descr, NPY_TYPES, PY_ARRAY_API};
use crate::NpySingleIterBuilder;

pub use num_complex::{Complex32, Complex64};

Expand Down Expand Up @@ -132,26 +139,12 @@ impl PyArrayDescr {
/// #[pyclass]
/// pub struct CustomElement;
///
/// // The transparent wrapper is necessary as one cannot implement
/// // a foreign trait (`Element`) on a foreign type (`Py`) directly.
/// #[derive(Clone)]
/// #[repr(transparent)]
/// pub struct Wrapper(pub Py<CustomElement>);
///
/// unsafe impl Element for Wrapper {
/// const IS_COPY: bool = false;
///
/// fn get_dtype(py: Python) -> &PyArrayDescr {
/// PyArrayDescr::object(py)
/// }
/// }
///
/// Python::with_gil(|py| {
/// let array = Array2::<Wrapper>::from_shape_fn((2, 3), |(_i, _j)| {
/// Wrapper(Py::new(py, CustomElement).unwrap())
/// let array = Array2::<Py<CustomElement>>::from_shape_fn((2, 3), |(_i, _j)| {
/// Py::new(py, CustomElement).unwrap()
/// });
///
/// let _array: &PyArray<Wrapper, _> = array.to_pyarray(py);
/// let _array: &PyArray<Py<CustomElement>, _> = array.to_pyarray(py);
/// });
/// ```
pub unsafe trait Element: Clone + Send {
Expand All @@ -164,6 +157,9 @@ pub unsafe trait Element: Clone + Send {
/// that contain object-type fields.
const IS_COPY: bool;

/// TODO
fn check_element_types<D: Dimension>(py: Python, array: &PyArray<Self, D>) -> PyResult<()>;

/// Returns the associated array descriptor ("dtype") for the given type.
fn get_dtype(py: Python) -> &PyArrayDescr;
}
Expand Down Expand Up @@ -218,6 +214,12 @@ macro_rules! impl_element_scalar {
$(#[$meta])*
unsafe impl Element for $ty {
const IS_COPY: bool = true;

fn check_element_types<D: Dimension>(_py: Python, _array: &PyArray<Self, D>) -> PyResult<()> {
// For scalar types, checking the dtype is sufficient.
Ok(())
}

fn get_dtype(py: Python) -> &PyArrayDescr {
PyArrayDescr::from_npy_type(py, $npy_type)
}
Expand All @@ -244,9 +246,30 @@ impl_element_scalar!(Complex64 => NPY_CDOUBLE,
#[cfg(any(target_pointer_width = "32", target_pointer_width = "64"))]
impl_element_scalar!(usize, isize);

unsafe impl Element for PyObject {
unsafe impl<T> Element for Py<T>
where
T: PyTypeInfo + 'static,
{
const IS_COPY: bool = false;

fn check_element_types<D: Dimension>(py: Python, array: &PyArray<Self, D>) -> PyResult<()> {
Copy link
Contributor

@aldanor aldanor Jan 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I personally don't like is that there's again logic being added into Element, so it goes from being a descriptive trait to an 'actionable' trait once again.

Perhaps something like this might be cleaner? (and then the element-type checking logic can be moved out of here somewhere into array)

pub unsafe trait Element {
    const IS_COPY: bool;
    fn py_type(py: Python) -> Option<&PyType>; // None for anything with IS_COPY=true or PyAny
    fn get_dtype(py: Python) -> &PyArrayDescr;
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I personally don't like is that there's again logic being added into Element, so it goes from being a descriptive trait to an 'actionable' trait once again.

The point of this was to make the policy changeable by user code, i.e. I really known I'll get PyArray<T, D> and do not want to pay the cost of checking the types, I can implement this unsoundly myself. But I guess due to the general issues, using PyArray<PyObject, D> together with unchecked_downcast would be the way to make this work. We just need to document this...

Copy link
Contributor

@aldanor aldanor Jan 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess my point was, it doesn't feel like it should be stuffed in here. Do you have an example in mind where you would implement something custom for your type that doesn't just go through the elements and checks instance types? Like, what else can you reasonably do?

Another disadantage here being, if I do want to do something custom downstream but similar to what's being done here, I would basically have to literally copy-paste rust-numpy source code that iterates over numpy array items - that's not very nice.

And finally, the "don't check this" part doesn't sound like it belongs to the type itself, but rather to the call point. If you bind it to the type, you can't do "here I know it's safe, don't check, but here it's arbitrary user input, need to check" which would be a fairly reasonable use case, I believe. This you could probably achieve by having a separate trait on PyArray<T, D> (pulling this stuff out of T: Element) with default impl that always checks instance types for any Py<T> that's not PyAny, and then add a newtype wrapper like Unchecked<PyArray<T, D>> that would override this and impl a FromPyObject for it. Just a thought. (and, Unchecked is probably too generic a name here)

(Note that you could probably also achieve the above by having an Unchecked<T> but this would be much worse from the ergonomics standpoint as now you have an array of ugly wrappers instead of a single wrapper of an array.)

// `PyAny` can represent any Python object.
if TypeId::of::<PyAny>() == TypeId::of::<T>() {
return Ok(());
}

let type_object = T::type_object(py);
let iterator = NpySingleIterBuilder::readwrite(array).build()?;

for element in iterator {
if !type_object.is_instance(element)? {
return Err(PyDowncastError::new(array, type_name::<T>()).into());
}
}
Comment on lines +261 to +268
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be probably written a bit shorter like

if NpySingleIterBuiler::readwrite(array).build()?.any(|el| !type_obj.is_instance(el)) {
    return Err(...);
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly, but I am not sure we can use the iterators at all as they produce references to T, but we are still checking whether we actually have T or something else entirely.


Ok(())
}

fn get_dtype(py: Python) -> &PyArrayDescr {
PyArrayDescr::object(py)
}
Expand Down