Skip to content

Commit

Permalink
Add new untype PyAnyArray base type of PyArray.
Browse files Browse the repository at this point in the history
  • Loading branch information
adamreichold committed Feb 28, 2023
1 parent d58add3 commit c4adf9a
Show file tree
Hide file tree
Showing 4 changed files with 299 additions and 168 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Changelog

- Unreleased
- Add `PyAnyArray` as an untyped base type for `PyArray` which can be used to inspect arguments before more targeted downcasts. ([#369](https://github.com/PyO3/rust-numpy/pull/369))
- Drop deprecated `PyArray::from_exact_iter` as it does not provide any benefits over `PyArray::from_iter`. ([#370](https://github.com/PyO3/rust-numpy/pull/370))

- v0.18.0
Expand Down
254 changes: 254 additions & 0 deletions src/any_array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
//! Safe, untyped interface for NumPy's [N-dimensional arrays][ndarray]
//!
//! [ndarray]: https://numpy.org/doc/stable/reference/arrays.ndarray.html
use std::{os::raw::c_int, slice};

use pyo3::{
ffi, pyobject_native_type_extract, pyobject_native_type_named, AsPyPointer, IntoPy, PyAny,
PyNativeType, PyObject, PyTypeInfo, Python,
};

use crate::cold;
use crate::dtype::PyArrayDescr;
use crate::npyffi;

/// A safe, untyped wrapper for NumPy's [`ndarray`][ndarray] class.
///
/// Unlike [`PyArray<T,D>`][`PyArray`], this type does not constrain either element type `T` nor the dimensionality `D`.
/// This can be useful to inspect function arguments, but it prevents operating on the elements without further downcasts.
///
/// # Example
///
/// Taking `PyAnyArray` can be helpful to implement polymorphic entry points:
///
/// ```
/// # use pyo3::prelude::*;
/// use pyo3::exceptions::PyTypeError;
/// use numpy::{Element, PyAnyArray, PyArray1, dtype};
///
/// #[pyfunction]
/// fn entry_point(py: Python, array: &PyAnyArray) -> PyResult<()> {
/// fn implementation<T: Element>(array: &PyAnyArray) -> PyResult<()> {
/// let array: &PyArray1<T> = array.downcast()?;
///
/// /* .. */
///
/// Ok(())
/// }
///
/// let element_type = array.dtype();
///
/// if element_type.is_equiv_to(dtype::<f32>(py)) {
/// implementation::<f32>(array)
/// } else if element_type.is_equiv_to(dtype::<f64>(py)) {
/// implementation::<f64>(array)
/// } else {
/// Err(PyTypeError::new_err(format!("Unsupported element type: {}", element_type)))
/// }
/// }
/// #
/// # Python::with_gil(|py| {
/// # let array = PyArray1::<f64>::zeros(py, 42, false);
/// # entry_point(py, array)
/// # }).unwrap();
/// ```
#[repr(transparent)]
pub struct PyAnyArray(PyAny);

unsafe impl PyTypeInfo for PyAnyArray {
type AsRefTarget = Self;

const NAME: &'static str = "PyAnyArray";
const MODULE: Option<&'static str> = Some("numpy");

fn type_object_raw(py: Python) -> *mut ffi::PyTypeObject {
unsafe { npyffi::PY_ARRAY_API.get_type_object(py, npyffi::NpyTypes::PyArray_Type) }
}

fn is_type_of(ob: &PyAny) -> bool {
unsafe { npyffi::PyArray_Check(ob.py(), ob.as_ptr()) != 0 }
}
}

pyobject_native_type_named!(PyAnyArray);

impl IntoPy<PyObject> for PyAnyArray {
fn into_py(self, py: Python<'_>) -> PyObject {
unsafe { PyObject::from_borrowed_ptr(py, self.as_ptr()) }
}
}

pyobject_native_type_extract!(PyAnyArray);

impl PyAnyArray {
/// Returns a raw pointer to the underlying [`PyArrayObject`][npyffi::PyArrayObject].
#[inline]
pub fn as_array_ptr(&self) -> *mut npyffi::PyArrayObject {
self.as_ptr() as _
}

/// Returns the `dtype` of the array.
///
/// See also [`ndarray.dtype`][ndarray-dtype] and [`PyArray_DTYPE`][PyArray_DTYPE].
///
/// # Example
///
/// ```
/// use numpy::{dtype, PyArray};
/// use pyo3::Python;
///
/// Python::with_gil(|py| {
/// let array = PyArray::from_vec(py, vec![1_i32, 2, 3]);
///
/// assert!(array.dtype().is_equiv_to(dtype::<i32>(py)));
/// });
/// ```
///
/// [ndarray-dtype]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.dtype.html
/// [PyArray_DTYPE]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_DTYPE
pub fn dtype(&self) -> &PyArrayDescr {
unsafe {
let descr_ptr = (*self.as_array_ptr()).descr;
self.py().from_borrowed_ptr(descr_ptr as _)
}
}

#[inline(always)]
pub(crate) fn check_flags(&self, flags: c_int) -> bool {
unsafe { (*self.as_array_ptr()).flags & flags != 0 }
}

/// Returns `true` if the internal data of the array is contiguous,
/// indepedently of whether C-style/row-major or Fortran-style/column-major.
///
/// # Example
///
/// ```
/// use numpy::PyArray1;
/// use pyo3::{types::IntoPyDict, Python};
///
/// Python::with_gil(|py| {
/// let array = PyArray1::arange(py, 0, 10, 1);
/// assert!(array.is_contiguous());
///
/// let view = py
/// .eval("array[::2]", None, Some([("array", array)].into_py_dict(py)))
/// .unwrap()
/// .downcast::<PyArray1<i32>>()
/// .unwrap();
/// assert!(!view.is_contiguous());
/// });
/// ```
pub fn is_contiguous(&self) -> bool {
self.check_flags(npyffi::NPY_ARRAY_C_CONTIGUOUS | npyffi::NPY_ARRAY_F_CONTIGUOUS)
}

/// Returns `true` if the internal data of the array is Fortran-style/column-major contiguous.
pub fn is_fortran_contiguous(&self) -> bool {
self.check_flags(npyffi::NPY_ARRAY_F_CONTIGUOUS)
}

/// Returns `true` if the internal data of the array is C-style/row-major contiguous.
pub fn is_c_contiguous(&self) -> bool {
self.check_flags(npyffi::NPY_ARRAY_C_CONTIGUOUS)
}

/// Returns the number of dimensions of the array.
///
/// See also [`ndarray.ndim`][ndarray-ndim] and [`PyArray_NDIM`][PyArray_NDIM].
///
/// # Example
///
/// ```
/// use numpy::PyArray3;
/// use pyo3::Python;
///
/// Python::with_gil(|py| {
/// let arr = PyArray3::<f64>::zeros(py, [4, 5, 6], false);
///
/// assert_eq!(arr.ndim(), 3);
/// });
/// ```
///
/// [ndarray-ndim]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.ndim.html
/// [PyArray_NDIM]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_NDIM
#[inline]
pub fn ndim(&self) -> usize {
unsafe { (*self.as_array_ptr()).nd as usize }
}

/// Returns a slice indicating how many bytes to advance when iterating along each axis.
///
/// See also [`ndarray.strides`][ndarray-strides] and [`PyArray_STRIDES`][PyArray_STRIDES].
///
/// # Example
///
/// ```
/// use numpy::PyArray3;
/// use pyo3::Python;
///
/// Python::with_gil(|py| {
/// let arr = PyArray3::<f64>::zeros(py, [4, 5, 6], false);
///
/// assert_eq!(arr.strides(), &[240, 48, 8]);
/// });
/// ```
/// [ndarray-strides]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.strides.html
/// [PyArray_STRIDES]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_STRIDES
#[inline]
pub fn strides(&self) -> &[isize] {
let n = self.ndim();
if n == 0 {
cold();
return &[];
}
let ptr = self.as_array_ptr();
unsafe {
let p = (*ptr).strides;
slice::from_raw_parts(p, n)
}
}

/// Returns a slice which contains dimmensions of the array.
///
/// See also [`ndarray.shape`][ndaray-shape] and [`PyArray_DIMS`][PyArray_DIMS].
///
/// # Example
///
/// ```
/// use numpy::PyArray3;
/// use pyo3::Python;
///
/// Python::with_gil(|py| {
/// let arr = PyArray3::<f64>::zeros(py, [4, 5, 6], false);
///
/// assert_eq!(arr.shape(), &[4, 5, 6]);
/// });
/// ```
///
/// [ndarray-shape]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.shape.html
/// [PyArray_DIMS]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_DIMS
#[inline]
pub fn shape(&self) -> &[usize] {
let n = self.ndim();
if n == 0 {
cold();
return &[];
}
let ptr = self.as_array_ptr();
unsafe {
let p = (*ptr).dimensions as *mut usize;
slice::from_raw_parts(p, n)
}
}

/// Calculates the total number of elements in the array.
pub fn len(&self) -> usize {
self.shape().iter().product()
}

/// Returns `true` if the there are no elements in the array.
pub fn is_empty(&self) -> bool {
self.shape().iter().any(|dim| *dim == 0)
}
}
Loading

0 comments on commit c4adf9a

Please sign in to comment.