Skip to content

Commit

Permalink
migration to pyo3 0.21 beta using the new Bound API
Browse files Browse the repository at this point in the history
This does still use GIL Refs in numpy's API but switches
our internals to use the Bound API where appropriate.
  • Loading branch information
Icxolu authored and adamreichold committed Mar 11, 2024
1 parent 32740b3 commit 456663d
Show file tree
Hide file tree
Showing 15 changed files with 84 additions and 72 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ num-complex = ">= 0.2, < 0.5"
num-integer = "0.1"
num-traits = "0.2"
ndarray = ">= 0.13, < 0.16"
pyo3 = { version = "0.20", default-features = false, features = ["macros"] }
pyo3 = { version = "0.21.0-beta", default-features = false, features = ["gil-refs", "macros"] }
rustc-hash = "1.1"

[dev-dependencies]
pyo3 = { version = "0.20", default-features = false, features = ["auto-initialize"] }
pyo3 = { version = "0.21.0-beta", default-features = false, features = ["auto-initialize", "gil-refs"] }
nalgebra = { version = "0.32", default-features = false, features = ["std"] }

[package.metadata.docs.rs]
Expand Down
2 changes: 1 addition & 1 deletion examples/linalg/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ name = "rust_linalg"
crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.20", features = ["extension-module"] }
pyo3 = { version = "0.21.0-beta", features = ["extension-module"] }
numpy = { path = "../.." }
ndarray-linalg = { version = "0.14.1", features = ["openblas-system"] }

Expand Down
2 changes: 1 addition & 1 deletion examples/parallel/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ name = "rust_parallel"
crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.20", features = ["extension-module", "multiple-pymethods"] }
pyo3 = { version = "0.21.0-beta", features = ["extension-module", "multiple-pymethods"] }
numpy = { path = "../.." }
ndarray = { version = "0.15", features = ["rayon", "blas"] }
blas-src = { version = "0.8", features = ["openblas"] }
Expand Down
2 changes: 1 addition & 1 deletion examples/simple/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ name = "rust_ext"
crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.20", features = ["extension-module", "abi3-py37"] }
pyo3 = { version = "0.21.0-beta", features = ["extension-module", "abi3-py37"] }
numpy = { path = "../.." }

[workspace]
44 changes: 25 additions & 19 deletions src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ use ndarray::{
};
use num_traits::AsPrimitive;
use pyo3::{
ffi, pyobject_native_type_base, types::PyModule, AsPyPointer, FromPyObject, IntoPy, Py, PyAny,
PyClassInitializer, PyDowncastError, PyErr, PyNativeType, PyObject, PyResult, PyTypeInfo,
Python, ToPyObject,
ffi, pyobject_native_type_base,
types::{DerefToPyAny, PyAnyMethods, PyModule},
AsPyPointer, Bound, DowncastError, FromPyObject, IntoPy, Py, PyAny, PyErr, PyNativeType,
PyObject, PyResult, PyTypeInfo, Python, ToPyObject,
};

use crate::borrow::{PyReadonlyArray, PyReadwriteArray};
Expand Down Expand Up @@ -118,21 +119,21 @@ pub type PyArray6<T> = PyArray<T, Ix6>;
pub type PyArrayDyn<T> = PyArray<T, IxDyn>;

/// Returns a handle to NumPy's multiarray module.
pub fn get_array_module<'py>(py: Python<'py>) -> PyResult<&PyModule> {
PyModule::import(py, npyffi::array::MOD_NAME)
pub fn get_array_module<'py>(py: Python<'py>) -> PyResult<Bound<'_, PyModule>> {
PyModule::import_bound(py, npyffi::array::MOD_NAME)
}

unsafe impl<T: Element, D: Dimension> PyTypeInfo for PyArray<T, D> {
type AsRefTarget = Self;
impl<T, D> DerefToPyAny for PyArray<T, D> {}

unsafe impl<T: Element, D: Dimension> PyTypeInfo for PyArray<T, D> {
const NAME: &'static str = "PyArray<T, D>";
const MODULE: Option<&'static str> = Some("numpy");

fn type_object_raw<'py>(py: Python<'py>) -> *mut ffi::PyTypeObject {
unsafe { npyffi::PY_ARRAY_API.get_type_object(py, npyffi::NpyTypes::PyArray_Type) }
}

fn is_type_of(ob: &PyAny) -> bool {
fn is_type_of_bound(ob: &Bound<'_, PyAny>) -> bool {
Self::extract::<IgnoreError>(ob).is_ok()
}
}
Expand Down Expand Up @@ -189,8 +190,11 @@ impl<T, D> IntoPy<PyObject> for PyArray<T, D> {
}

impl<'py, T: Element, D: Dimension> FromPyObject<'py> for &'py PyArray<T, D> {
fn extract(ob: &'py PyAny) -> PyResult<Self> {
fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
#[allow(clippy::map_clone)] // due to MSRV
PyArray::extract(ob)
.map(Clone::clone)
.map(Bound::into_gil_ref)
}
}

Expand Down Expand Up @@ -251,28 +255,30 @@ impl<T, D> PyArray<T, D> {
}

impl<T: Element, D: Dimension> PyArray<T, D> {
fn extract<'py, E>(ob: &'py PyAny) -> Result<&'py Self, E>
fn extract<'a, 'py, E>(ob: &'a Bound<'py, PyAny>) -> Result<&'a Bound<'py, Self>, E>
where
E: From<PyDowncastError<'py>> + From<DimensionalityError> + From<TypeError<'py>>,
E: From<DowncastError<'a, 'py>> + From<DimensionalityError> + From<TypeError<'a>>,
{
// Check if the object is an array.
let array = unsafe {
if npyffi::PyArray_Check(ob.py(), ob.as_ptr()) == 0 {
return Err(PyDowncastError::new(ob, Self::NAME).into());
return Err(DowncastError::new(ob, <Self as PyTypeInfo>::NAME).into());
}
&*(ob as *const PyAny as *const Self)
ob.downcast_unchecked()
};

let arr_gil_ref: &PyArray<T, D> = array.as_gil_ref();

// Check if the dimensionality matches `D`.
let src_ndim = array.ndim();
let src_ndim = arr_gil_ref.ndim();
if let Some(dst_ndim) = D::NDIM {
if src_ndim != dst_ndim {
return Err(DimensionalityError::new(src_ndim, dst_ndim).into());
}
}

// Check if the element type matches `T`.
let src_dtype = array.dtype();
let src_dtype = arr_gil_ref.dtype();
let dst_dtype = T::get_dtype(ob.py());
if !src_dtype.is_equiv_to(dst_dtype) {
return Err(TypeError::new(src_dtype, dst_dtype).into());
Expand Down Expand Up @@ -399,11 +405,11 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
data_ptr: *const T,
container: PySliceContainer,
) -> &'py Self {
let container = PyClassInitializer::from(container)
.create_cell(py)
.expect("Failed to create slice container");
let container = Bound::new(py, container)
.expect("Failed to create slice container")
.into_ptr();

Self::new_with_data(py, dims, strides, data_ptr, container as *mut PyAny)
Self::new_with_data(py, dims, strides, data_ptr, container.cast())
}

/// Creates a NumPy array backed by `array` and ties its ownership to the Python object `container`.
Expand Down
7 changes: 6 additions & 1 deletion src/array_like.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@ use std::marker::PhantomData;
use std::ops::Deref;

use ndarray::{Array1, Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
use pyo3::{intern, sync::GILOnceCell, types::PyDict, FromPyObject, Py, PyAny, PyResult};
use pyo3::{
intern,
sync::GILOnceCell,
types::{PyAnyMethods, PyDict},
FromPyObject, Py, PyAny, PyResult,
};

use crate::sealed::Sealed;
use crate::{get_array_module, Element, IntoPyArray, PyArray, PyReadonlyArray};
Expand Down
22 changes: 11 additions & 11 deletions src/borrow/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@ use std::os::raw::{c_char, c_int};
use std::slice::from_raw_parts;

use num_integer::gcd;
use pyo3::{
exceptions::PyTypeError, once_cell::GILOnceCell, types::PyCapsule, Py, PyResult, PyTryInto,
Python,
};
use pyo3::types::{PyAnyMethods, PyCapsuleMethods};
use pyo3::{exceptions::PyTypeError, sync::GILOnceCell, types::PyCapsule, PyResult, Python};
use rustc_hash::FxHashMap;

use crate::array::get_array_module;
Expand Down Expand Up @@ -124,8 +122,8 @@ fn get_or_insert_shared<'py>(py: Python<'py>) -> PyResult<&'py Shared> {
fn insert_shared<'py>(py: Python<'py>) -> PyResult<*const Shared> {
let module = get_array_module(py)?;

let capsule: &PyCapsule = match module.getattr("_RUST_NUMPY_BORROW_CHECKING_API") {
Ok(capsule) => PyTryInto::try_into(capsule)?,
let capsule = match module.getattr("_RUST_NUMPY_BORROW_CHECKING_API") {
Ok(capsule) => capsule.downcast_into::<PyCapsule>()?,
Err(_err) => {
let flags: *mut BorrowFlags = Box::into_raw(Box::default());

Expand All @@ -138,7 +136,7 @@ fn insert_shared<'py>(py: Python<'py>) -> PyResult<*const Shared> {
release_mut: release_mut_shared,
};

let capsule = PyCapsule::new_with_destructor(
let capsule = PyCapsule::new_bound_with_destructor(
py,
shared,
Some(CString::new("_RUST_NUMPY_BORROW_CHECKING_API").unwrap()),
Expand All @@ -147,25 +145,27 @@ fn insert_shared<'py>(py: Python<'py>) -> PyResult<*const Shared> {
let _ = unsafe { Box::from_raw(shared.flags as *mut BorrowFlags) };
},
)?;
module.setattr("_RUST_NUMPY_BORROW_CHECKING_API", capsule)?;
module.setattr("_RUST_NUMPY_BORROW_CHECKING_API", &capsule)?;
capsule
}
};

// SAFETY: All versions of the shared borrow checking API start with a version field.
let version = unsafe { *(capsule.pointer() as *mut u64) };
let version = unsafe { *capsule.pointer().cast::<u64>() };
if version < 1 {
return Err(PyTypeError::new_err(format!(
"Version {} of borrow checking API is not supported by this version of rust-numpy",
version
)));
}

let ptr = capsule.pointer();

// Intentionally leak a reference to the capsule
// so we can safely cache a pointer into its interior.
forget(Py::<PyCapsule>::from(capsule));
forget(capsule);

Ok(capsule.pointer() as *const Shared)
Ok(ptr.cast())
}

// These entry points will be used to access the shared borrow checking API from this extension:
Expand Down
46 changes: 24 additions & 22 deletions src/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@ use pyo3::{
exceptions::{PyIndexError, PyValueError},
ffi::{self, PyTuple_Size},
pyobject_native_type_extract, pyobject_native_type_named,
types::{PyDict, PyTuple, PyType},
AsPyPointer, FromPyObject, FromPyPointer, PyAny, PyNativeType, PyObject, PyResult, PyTypeInfo,
Python, ToPyObject,
types::{PyAnyMethods, PyDict, PyDictMethods, PyTuple, PyType},
AsPyPointer, Borrowed, PyAny, PyNativeType, PyObject, PyResult, PyTypeInfo, Python, ToPyObject,
};
#[cfg(feature = "half")]
use pyo3::{sync::GILOnceCell, IntoPy, Py};
Expand Down Expand Up @@ -53,8 +52,6 @@ pub struct PyArrayDescr(PyAny);
pyobject_native_type_named!(PyArrayDescr);

unsafe impl PyTypeInfo for PyArrayDescr {
type AsRefTarget = Self;

const NAME: &'static str = "PyArrayDescr";
const MODULE: Option<&'static str> = Some("numpy");

Expand Down Expand Up @@ -249,7 +246,9 @@ impl PyArrayDescr {
if !self.has_subarray() {
self
} else {
#[allow(deprecated)]
unsafe {
use pyo3::FromPyPointer;
Self::from_borrowed_ptr(self.py(), (*(*self.as_dtype_ptr()).subarray).base as _)
}
}
Expand All @@ -267,11 +266,9 @@ impl PyArrayDescr {
Vec::new()
} else {
// NumPy guarantees that shape is a tuple of non-negative integers so this should never panic.
unsafe {
PyTuple::from_borrowed_ptr(self.py(), (*(*self.as_dtype_ptr()).subarray).shape)
}
.extract()
.unwrap()
unsafe { Borrowed::from_ptr(self.py(), (*(*self.as_dtype_ptr()).subarray).shape) }
.extract()
.unwrap()
}
}

Expand Down Expand Up @@ -329,8 +326,8 @@ impl PyArrayDescr {
if !self.has_fields() {
return None;
}
let names = unsafe { PyTuple::from_borrowed_ptr(self.py(), (*self.as_dtype_ptr()).names) };
FromPyObject::extract(names).ok()
let names = unsafe { Borrowed::from_ptr(self.py(), (*self.as_dtype_ptr()).names) };
names.extract().ok()
}

/// Returns the type descriptor and offset of the field with the given name.
Expand All @@ -349,17 +346,22 @@ impl PyArrayDescr {
"cannot get field information: type descriptor has no fields",
));
}
let dict = unsafe { PyDict::from_borrowed_ptr(self.py(), (*self.as_dtype_ptr()).fields) };
let dict = unsafe { Borrowed::from_ptr(self.py(), (*self.as_dtype_ptr()).fields) };
let dict = unsafe { dict.downcast_unchecked::<PyDict>() };
// NumPy guarantees that fields are tuples of proper size and type, so this should never panic.
let tuple = dict
.get_item(name)?
.ok_or_else(|| PyIndexError::new_err(name.to_owned()))?
.downcast::<PyTuple>()
.downcast_into::<PyTuple>()
.unwrap();
// Note that we cannot just extract the entire tuple since the third element can be a title.
let dtype = FromPyObject::extract(tuple.as_ref().get_item(0).unwrap()).unwrap();
let offset = FromPyObject::extract(tuple.as_ref().get_item(1).unwrap()).unwrap();
Ok((dtype, offset))
let dtype = tuple
.get_item(0)
.unwrap()
.downcast_into::<PyArrayDescr>()
.unwrap();
let offset = tuple.get_item(1).unwrap().extract().unwrap();
Ok((dtype.into_gil_ref(), offset))
}
}

Expand Down Expand Up @@ -548,8 +550,8 @@ mod tests {

#[test]
fn test_dtype_names() {
fn type_name<'py, T: Element>(py: Python<'py>) -> &str {
dtype::<T>(py).typeobj().name().unwrap()
fn type_name<'py, T: Element>(py: Python<'py>) -> String {
dtype::<T>(py).typeobj().qualname().unwrap()
}
Python::with_gil(|py| {
assert_eq!(type_name::<bool>(py), "bool_");
Expand Down Expand Up @@ -589,7 +591,7 @@ mod tests {

assert_eq!(dt.num(), NPY_TYPES::NPY_DOUBLE as c_int);
assert_eq!(dt.flags(), 0);
assert_eq!(dt.typeobj().name().unwrap(), "float64");
assert_eq!(dt.typeobj().qualname().unwrap(), "float64");
assert_eq!(dt.char(), b'd');
assert_eq!(dt.kind(), b'f');
assert_eq!(dt.byteorder(), b'=');
Expand Down Expand Up @@ -625,7 +627,7 @@ mod tests {

assert_eq!(dt.num(), NPY_TYPES::NPY_VOID as c_int);
assert_eq!(dt.flags(), 0);
assert_eq!(dt.typeobj().name().unwrap(), "void");
assert_eq!(dt.typeobj().qualname().unwrap(), "void");
assert_eq!(dt.char(), b'V');
assert_eq!(dt.kind(), b'V');
assert_eq!(dt.byteorder(), b'|');
Expand Down Expand Up @@ -663,7 +665,7 @@ mod tests {
assert_ne!(dt.flags() & NPY_ITEM_HASOBJECT, 0);
assert_ne!(dt.flags() & NPY_NEEDS_PYAPI, 0);
assert_ne!(dt.flags() & NPY_ALIGNED_STRUCT, 0);
assert_eq!(dt.typeobj().name().unwrap(), "void");
assert_eq!(dt.typeobj().qualname().unwrap(), "void");
assert_eq!(dt.char(), b'V');
assert_eq!(dt.kind(), b'V');
assert_eq!(dt.byteorder(), b'|');
Expand Down
2 changes: 1 addition & 1 deletion src/npyffi/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::os::raw::*;
use libc::FILE;
use pyo3::{
ffi::{self, PyObject, PyTypeObject},
once_cell::GILOnceCell,
sync::GILOnceCell,
};

use crate::npyffi::*;
Expand Down
10 changes: 5 additions & 5 deletions src/npyffi/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,23 @@ use std::mem::forget;
use std::os::raw::c_void;

use pyo3::{
types::{PyCapsule, PyModule},
Py, PyResult, PyTryInto, Python,
types::{PyAnyMethods, PyCapsule, PyCapsuleMethods, PyModule},
PyResult, Python,
};

fn get_numpy_api<'py>(
py: Python<'py>,
module: &str,
capsule: &str,
) -> PyResult<*const *const c_void> {
let module = PyModule::import(py, module)?;
let capsule: &PyCapsule = PyTryInto::try_into(module.getattr(capsule)?)?;
let module = PyModule::import_bound(py, module)?;
let capsule = module.getattr(capsule)?.downcast_into::<PyCapsule>()?;

let api = capsule.pointer() as *const *const c_void;

// Intentionally leak a reference to the capsule
// so we can safely cache a pointer into its interior.
forget(Py::<PyCapsule>::from(capsule));
forget(capsule);

Ok(api)
}
Expand Down
2 changes: 1 addition & 1 deletion src/npyffi/ufunc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use std::os::raw::*;

use pyo3::{ffi::PyObject, once_cell::GILOnceCell};
use pyo3::{ffi::PyObject, sync::GILOnceCell};

use crate::npyffi::*;

Expand Down
Loading

0 comments on commit 456663d

Please sign in to comment.