Skip to content

Commit

Permalink
WIP: Add dynamic borrow checking for dereferencing NumPy arrays.
Browse files Browse the repository at this point in the history
  • Loading branch information
adamreichold committed Feb 18, 2022
1 parent 61882e3 commit 67f8600
Show file tree
Hide file tree
Showing 7 changed files with 292 additions and 363 deletions.
8 changes: 5 additions & 3 deletions examples/simple-extension/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use numpy::ndarray::{ArrayD, ArrayViewD, ArrayViewMutD};
use numpy::{Complex64, IntoPyArray, PyArray1, PyArrayDyn, PyReadonlyArrayDyn};
use numpy::{
Complex64, IntoPyArray, PyArray1, PyArrayDyn, PyReadonlyArrayDyn, PyReadwriteArrayDyn,
};
use pyo3::{
pymodule,
types::{PyDict, PyModule},
Expand Down Expand Up @@ -41,8 +43,8 @@ fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
// wrapper of `mult`
#[pyfn(m)]
#[pyo3(name = "mult")]
fn mult_py(a: f64, x: &PyArrayDyn<f64>) {
let x = unsafe { x.as_array_mut() };
fn mult_py(a: f64, mut x: PyReadwriteArrayDyn<f64>) {
let x = x.as_array_mut();
mult(a, x);
}

Expand Down
61 changes: 26 additions & 35 deletions src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use pyo3::{
Python, ToPyObject,
};

use crate::borrow::{PyReadonlyArray, PyReadwriteArray};
use crate::convert::{ArrayExt, IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
use crate::dtype::Element;
use crate::error::{DimensionalityError, FromVecError, NotContiguousError, TypeError};
Expand Down Expand Up @@ -190,18 +191,8 @@ impl<T, D> PyArray<T, D> {
}

#[inline(always)]
fn check_flag(&self, flag: c_int) -> bool {
unsafe { *self.as_array_ptr() }.flags & flag == flag
}

#[inline(always)]
pub(crate) fn get_flag(&self) -> c_int {
unsafe { *self.as_array_ptr() }.flags
}

/// Returns a temporally unwriteable reference of the array.
pub fn readonly(&self) -> crate::PyReadonlyArray<T, D> {
self.into()
fn check_flags(&self, flags: c_int) -> bool {
unsafe { *self.as_array_ptr() }.flags & flags != 0
}

/// Returns `true` if the internal data of the array is C-style contiguous
Expand All @@ -223,18 +214,17 @@ impl<T, D> PyArray<T, D> {
/// });
/// ```
pub fn is_contiguous(&self) -> bool {
self.check_flag(npyffi::NPY_ARRAY_C_CONTIGUOUS)
| self.check_flag(npyffi::NPY_ARRAY_F_CONTIGUOUS)
self.check_flags(npyffi::NPY_ARRAY_C_CONTIGUOUS | npyffi::NPY_ARRAY_F_CONTIGUOUS)
}

/// Returns `true` if the internal data of the array is Fortran-style contiguous.
pub fn is_fortran_contiguous(&self) -> bool {
self.check_flag(npyffi::NPY_ARRAY_F_CONTIGUOUS)
self.check_flags(npyffi::NPY_ARRAY_F_CONTIGUOUS)
}

/// Returns `true` if the internal data of the array is C-style contiguous.
pub fn is_c_contiguous(&self) -> bool {
self.check_flag(npyffi::NPY_ARRAY_C_CONTIGUOUS)
self.check_flags(npyffi::NPY_ARRAY_C_CONTIGUOUS)
}

/// Get `Py<PyArray>` from `&PyArray`, which is the owned wrapper of PyObject.
Expand Down Expand Up @@ -823,28 +813,37 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
ToPyArray::to_pyarray(arr, py)
}

/// Get the immutable view of the internal data of `PyArray`, as
/// [`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html).
/// Get an immutable borrow of the NumPy array
pub fn readonly(&self) -> PyReadonlyArray<'_, T, D> {
PyReadonlyArray::try_new(self).expect("NumPy array already borrowed")
}

/// Get a mutable borrow of the NumPy array
pub fn readwrite(&self) -> PyReadwriteArray<'_, T, D> {
PyReadwriteArray::try_new(self).expect("NumPy array already borrowed")
}

/// Returns the internal array as [`ArrayView`].
///
/// Please consider the use of safe alternatives
/// ([`PyReadonlyArray::as_array`](../struct.PyReadonlyArray.html#method.as_array)
/// or [`to_array`](#method.to_array)) instead of this.
/// See also [`PyArrayRef::as_array`].
///
/// # Safety
/// If the internal array is not readonly and can be mutated from Python code,
/// holding the `ArrayView` might cause undefined behavior.
///
/// The existence of an exclusive reference to the internal data, e.g. `&mut [T]` or `ArrayViewMut`, implies undefined behavior.
pub unsafe fn as_array(&self) -> ArrayView<'_, T, D> {
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
let mut res = ArrayView::from_shape_ptr(shape, ptr);
inverted_axes.invert(&mut res);
res
}

/// Returns the internal array as [`ArrayViewMut`]. See also [`as_array`](#method.as_array).
/// Returns the internal array as [`ArrayViewMut`].
///
/// See also [`PyArrayRefMut::as_array_mut`].
///
/// # Safety
/// If another reference to the internal data exists(e.g., `&[T]` or `ArrayView`),
/// it might cause undefined behavior.
///
/// The existence of another reference to the internal data, e.g. `&[T]` or `ArrayView`, implies undefined behavior.
pub unsafe fn as_array_mut(&self) -> ArrayViewMut<'_, T, D> {
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
let mut res = ArrayViewMut::from_shape_ptr(shape, ptr);
Expand Down Expand Up @@ -920,7 +919,7 @@ impl<D: Dimension> PyArray<PyObject, D> {
///
/// let pyarray = PyArray::from_owned_object_array(py, array);
///
/// assert!(pyarray.readonly().get(0).unwrap().as_ref(py).is_instance::<CustomElement>().unwrap());
/// assert!(pyarray.readonly().as_array().get(0).unwrap().as_ref(py).is_instance::<CustomElement>().unwrap());
/// });
/// ```
pub fn from_owned_object_array<'py, T>(py: Python<'py>, arr: Array<Py<T>, D>) -> &'py Self {
Expand Down Expand Up @@ -1069,14 +1068,6 @@ impl<T: Element> PyArray<T, Ix1> {
self.resize_(self.py(), [new_elems], 1, NPY_ORDER::NPY_ANYORDER)
}

/// Iterates all elements of this array.
/// See [NpySingleIter](../npyiter/struct.NpySingleIter.html) for more.
pub fn iter<'py>(
&'py self,
) -> PyResult<crate::NpySingleIter<'py, T, crate::npyiter::ReadWrite>> {
crate::NpySingleIterBuilder::readwrite(self).build()
}

fn resize_<D: IntoDimension>(
&self,
py: Python,
Expand Down
195 changes: 195 additions & 0 deletions src/borrow.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
use std::cell::UnsafeCell;
use std::collections::hash_map::{Entry, HashMap};
use std::ops::Deref;

use ndarray::{ArrayView, ArrayViewMut, Dimension, Ix1, Ix2, IxDyn};
use pyo3::{FromPyObject, PyAny, PyResult};

use crate::array::PyArray;
use crate::dtype::Element;
use crate::error::NotContiguousError;

thread_local! {
static BORROW_FLAGS: UnsafeCell<HashMap<usize, isize>> = UnsafeCell::new(HashMap::new());
}

pub struct PyReadonlyArray<'py, T, D>(&'py PyArray<T, D>);

pub type PyReadonlyArray1<'py, T> = PyReadonlyArray<'py, T, Ix1>;

pub type PyReadonlyArray2<'py, T> = PyReadonlyArray<'py, T, Ix2>;

pub type PyReadonlyArrayDyn<'py, T> = PyReadonlyArray<'py, T, IxDyn>;

impl<'py, T, D> Deref for PyReadonlyArray<'py, T, D> {
type Target = PyArray<T, D>;

fn deref(&self) -> &Self::Target {
self.0
}
}

impl<'py, T: Element, D: Dimension> FromPyObject<'py> for PyReadonlyArray<'py, T, D> {
fn extract(obj: &'py PyAny) -> PyResult<Self> {
let array: &'py PyArray<T, D> = obj.extract()?;
Ok(array.readonly())
}
}

impl<'py, T, D> PyReadonlyArray<'py, T, D>
where
T: Element,
D: Dimension,
{
pub(crate) fn try_new(array: &'py PyArray<T, D>) -> Option<Self> {
let address = array as *const PyArray<T, D> as usize;

BORROW_FLAGS.with(|borrow_flags| {
// SAFETY: Called on a thread local variable in a leaf function.
let borrow_flags = unsafe { &mut *borrow_flags.get() };

match borrow_flags.entry(address) {
Entry::Occupied(entry) => {
let readers = entry.into_mut();

let new_readers = readers.wrapping_add(1);

if new_readers <= 0 {
cold();
return None;
}

*readers = new_readers;
}
Entry::Vacant(entry) => {
entry.insert(1);
}
}

Some(Self(array))
})
}

pub fn as_array(&self) -> ArrayView<T, D> {
// SAFETY: Thread-local borrow flags ensure aliasing discipline on this thread,
// and `PyArray` is neither `Send` nor `Sync`
unsafe { self.0.as_array() }
}

pub fn as_slice(&self) -> Result<&[T], NotContiguousError> {
// SAFETY: Thread-local borrow flags ensure aliasing discipline on this thread,
// and `PyArray` is neither `Send` nor `Sync`
unsafe { self.0.as_slice() }
}
}

impl<'a, T, D> Drop for PyReadonlyArray<'a, T, D> {
fn drop(&mut self) {
let address = self.0 as *const PyArray<T, D> as usize;

BORROW_FLAGS.with(|borrow_flags| {
// SAFETY: Called on a thread local variable in a leaf function.
let borrow_flags = unsafe { &mut *borrow_flags.get() };

let readers = borrow_flags.get_mut(&address).unwrap();

*readers -= 1;

if *readers == 0 {
borrow_flags.remove(&address).unwrap();
}
});
}
}

pub struct PyReadwriteArray<'py, T, D>(&'py PyArray<T, D>);

pub type PyReadwriteArrayDyn<'py, T> = PyReadwriteArray<'py, T, IxDyn>;

impl<'py, T, D> Deref for PyReadwriteArray<'py, T, D> {
type Target = PyArray<T, D>;

fn deref(&self) -> &Self::Target {
self.0
}
}

impl<'py, T: Element, D: Dimension> FromPyObject<'py> for PyReadwriteArray<'py, T, D> {
fn extract(obj: &'py PyAny) -> PyResult<Self> {
let array: &'py PyArray<T, D> = obj.extract()?;
Ok(array.readwrite())
}
}

impl<'py, T, D> PyReadwriteArray<'py, T, D>
where
T: Element,
D: Dimension,
{
pub(crate) fn try_new(array: &'py PyArray<T, D>) -> Option<Self> {
let address = array as *const PyArray<T, D> as usize;

BORROW_FLAGS.with(|borrow_flags| {
// SAFETY: Called on a thread local variable in a leaf function.
let borrow_flags = unsafe { &mut *borrow_flags.get() };

match borrow_flags.entry(address) {
Entry::Occupied(entry) => {
let writers = entry.into_mut();

if *writers != 0 {
cold();
return None;
}

*writers = -1;
}
Entry::Vacant(entry) => {
entry.insert(-1);
}
}

Some(Self(array))
})
}

pub fn as_array(&self) -> ArrayView<T, D> {
// SAFETY: Thread-local borrow flags ensure aliasing discipline on this thread,
// and `PyArray` is neither `Send` nor `Sync`
unsafe { self.0.as_array() }
}

pub fn as_slice(&self) -> Result<&[T], NotContiguousError> {
// SAFETY: Thread-local borrow flags ensure aliasing discipline on this thread,
// and `PyArray` is neither `Send` nor `Sync`
unsafe { self.0.as_slice() }
}

pub fn as_array_mut(&mut self) -> ArrayViewMut<T, D> {
// SAFETY: Thread-local borrow flags ensure aliasing discipline on this thread,
// and `PyArray` is neither `Send` nor `Sync`
unsafe { self.0.as_array_mut() }
}

pub fn as_slice_mut(&self) -> Result<&mut [T], NotContiguousError> {
// SAFETY: Thread-local borrow flags ensure aliasing discipline on this thread,
// and `PyArray` is neither `Send` nor `Sync`
unsafe { self.0.as_slice_mut() }
}
}

impl<'a, T, D> Drop for PyReadwriteArray<'a, T, D> {
fn drop(&mut self) {
let address = self.0 as *const PyArray<T, D> as usize;

BORROW_FLAGS.with(|borrow_flags| {
// SAFETY: Called on a thread local variable in a leaf function.
let borrow_flags = unsafe { &mut *borrow_flags.get() };

borrow_flags.remove(&address).unwrap();
});
}
}
#[cold]
#[inline(always)]
fn cold() {}
10 changes: 5 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@
#![allow(clippy::needless_lifetimes)] // We often want to make the GIL lifetime explicit.

pub mod array;
mod borrow;
pub mod convert;
mod dtype;
mod error;
pub mod npyffi;
pub mod npyiter;
mod readonly;
mod slice_container;
mod sum_products;

Expand All @@ -46,17 +46,17 @@ pub use crate::array::{
get_array_module, PyArray, PyArray0, PyArray1, PyArray2, PyArray3, PyArray4, PyArray5,
PyArray6, PyArrayDyn,
};
pub use crate::borrow::{
PyReadonlyArray, PyReadonlyArray1, PyReadonlyArray2, PyReadonlyArrayDyn, PyReadwriteArray,
PyReadwriteArrayDyn,
};
pub use crate::convert::{IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
pub use crate::dtype::{dtype, Complex32, Complex64, Element, PyArrayDescr};
pub use crate::error::{DimensionalityError, FromVecError, NotContiguousError, TypeError};
pub use crate::npyffi::{PY_ARRAY_API, PY_UFUNC_API};
pub use crate::npyiter::{
IterMode, NpyIterFlag, NpyMultiIter, NpyMultiIterBuilder, NpySingleIter, NpySingleIterBuilder,
};
pub use crate::readonly::{
PyReadonlyArray, PyReadonlyArray1, PyReadonlyArray2, PyReadonlyArray3, PyReadonlyArray4,
PyReadonlyArray5, PyReadonlyArray6, PyReadonlyArrayDyn,
};
pub use crate::sum_products::{dot, einsum_impl, inner};
pub use ndarray::{array, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};

Expand Down
Loading

0 comments on commit 67f8600

Please sign in to comment.