diff --git a/README.md b/README.md index 271c06335..ba879bebc 100644 --- a/README.md +++ b/README.md @@ -44,20 +44,20 @@ numpy = "0.15" ``` ```rust -use numpy::ndarray::{ArrayD, ArrayViewD, ArrayViewMutD}; -use numpy::{IntoPyArray, PyArrayDyn, PyReadonlyArrayDyn}; +use numpy::ndarray::{ArrayD, ArrayViewD, ArrayViewMutD, IxDyn}; +use numpy::{IntoPyArray, PyArrayDyn, PyArrayRef, PyArrayRefMut}; use pyo3::prelude::{pymodule, PyModule, PyResult, Python}; #[pymodule] fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> { // immutable example - fn axpy(a: f64, x: ArrayViewD<'_, f64>, y: ArrayViewD<'_, f64>) -> ArrayD { - a * &x + &y + fn axpy(a: f64, x: &ArrayViewD<'_, f64>, y: &ArrayViewD<'_, f64>) -> ArrayD { + a * x + y } // mutable example (no return) - fn mult(a: f64, mut x: ArrayViewMutD<'_, f64>) { - x *= a; + fn mult(a: f64, x: &mut ArrayViewMutD<'_, f64>) { + *x *= a; } // wrapper of `axpy` @@ -65,19 +65,16 @@ fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> { fn axpy_py<'py>( py: Python<'py>, a: f64, - x: PyReadonlyArrayDyn, - y: PyReadonlyArrayDyn, + x: PyArrayRef, + y: PyArrayRef, ) -> &'py PyArrayDyn { - let x = x.as_array(); - let y = y.as_array(); - axpy(a, x, y).into_pyarray(py) + axpy(a, &x, &y).into_pyarray(py) } // wrapper of `mult` #[pyfn(m, "mult")] - fn mult_py(_py: Python<'_>, a: f64, x: &PyArrayDyn) -> PyResult<()> { - let x = unsafe { x.as_array_mut() }; - mult(a, x); + fn mult_py(_py: Python<'_>, a: f64, mut x: PyArrayRefMut) -> PyResult<()> { + mult(a, &mut x); Ok(()) } diff --git a/examples/linalg/src/lib.rs b/examples/linalg/src/lib.rs index e648d3e4c..ba8b273d7 100755 --- a/examples/linalg/src/lib.rs +++ b/examples/linalg/src/lib.rs @@ -1,12 +1,11 @@ use ndarray_linalg::solve::Inverse; -use numpy::{IntoPyArray, PyArray2, PyReadonlyArray2}; +use numpy::{IntoPyArray, Ix2, PyArray2, PyArrayRef}; use pyo3::{exceptions::PyRuntimeError, pymodule, types::PyModule, PyErr, PyResult, Python}; #[pymodule] fn rust_linalg(_py: Python<'_>, m: &PyModule) -> PyResult<()> { #[pyfn(m)] - fn inv<'py>(py: Python<'py>, x: PyReadonlyArray2<'py, f64>) -> PyResult<&'py PyArray2> { - let x = x.as_array(); + fn inv<'py>(py: Python<'py>, x: PyArrayRef<'py, f64, Ix2>) -> PyResult<&'py PyArray2> { let y = x .inv() .map_err(|e| PyErr::new::(format!("[rust_linalg] {}", e)))?; diff --git a/examples/parallel/src/lib.rs b/examples/parallel/src/lib.rs index cca4aa63a..795c89dd6 100755 --- a/examples/parallel/src/lib.rs +++ b/examples/parallel/src/lib.rs @@ -1,8 +1,8 @@ // We need to link `blas_src` directly, c.f. https://github.com/rust-ndarray/ndarray#how-to-enable-blas-integration extern crate blas_src; -use ndarray::Zip; -use numpy::{IntoPyArray, PyArray1, PyReadonlyArray1, PyReadonlyArray2}; +use numpy::ndarray::{ArrayView1, Zip}; +use numpy::{IntoPyArray, Ix1, Ix2, PyArray1, PyArrayRef}; use pyo3::{pymodule, types::PyModule, PyResult, Python}; #[pymodule] @@ -10,12 +10,11 @@ fn rust_parallel(_py: Python<'_>, m: &PyModule) -> PyResult<()> { #[pyfn(m)] fn rows_dot<'py>( py: Python<'py>, - x: PyReadonlyArray2<'py, f64>, - y: PyReadonlyArray1<'py, f64>, + x: PyArrayRef<'py, f64, Ix2>, + y: PyArrayRef<'py, f64, Ix1>, ) -> &'py PyArray1 { - let x = x.as_array(); - let y = y.as_array(); - let z = Zip::from(x.rows()).par_map_collect(|row| row.dot(&y)); + let y: &ArrayView1 = &y; + let z = Zip::from(x.rows()).par_map_collect(|row| row.dot(y)); z.into_pyarray(py) } Ok(()) diff --git a/examples/simple-extension/src/lib.rs b/examples/simple-extension/src/lib.rs index 2c1eceecd..37bea4f2c 100644 --- a/examples/simple-extension/src/lib.rs +++ b/examples/simple-extension/src/lib.rs @@ -9,17 +9,17 @@ use pyo3::{ #[pymodule] fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> { // immutable example - fn axpy(a: f64, x: ArrayViewD<'_, f64>, y: ArrayViewD<'_, f64>) -> ArrayD { - a * &x + &y + fn axpy(a: f64, x: &ArrayViewD<'_, f64>, y: &ArrayViewD<'_, f64>) -> ArrayD { + a * x + y } // mutable example (no return) - fn mult(a: f64, mut x: ArrayViewMutD<'_, f64>) { - x *= a; + fn mult(a: f64, x: &mut ArrayViewMutD<'_, f64>) { + *x *= a; } // complex example - fn conj(x: ArrayViewD<'_, Complex64>) -> ArrayD { + fn conj(x: &ArrayViewD<'_, Complex64>) -> ArrayD { x.map(|c| c.conj()) } @@ -34,7 +34,7 @@ fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> { ) -> &'py PyArrayDyn { let x = x.as_array(); let y = y.as_array(); - let z = axpy(a, x, y); + let z = axpy(a, &x, &y); z.into_pyarray(py) } @@ -42,8 +42,8 @@ fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> { #[pyfn(m)] #[pyo3(name = "mult")] fn mult_py(a: f64, x: &PyArrayDyn) { - let x = unsafe { x.as_array_mut() }; - mult(a, x); + let mut x = x.as_array_mut(); + mult(a, &mut x); } // wrapper of `conj` @@ -53,7 +53,7 @@ fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> { py: Python<'py>, x: PyReadonlyArrayDyn<'_, Complex64>, ) -> &'py PyArrayDyn { - conj(x.as_array()).into_pyarray(py) + conj(&x.as_array()).into_pyarray(py) } #[pyfn(m)] diff --git a/src/array.rs b/src/array.rs index 5f41b041c..99974e11b 100644 --- a/src/array.rs +++ b/src/array.rs @@ -18,6 +18,7 @@ use pyo3::{ Python, ToPyObject, }; +use crate::borrow::{PyArrayRef, PyArrayRefMut}; use crate::convert::{ArrayExt, IntoPyArray, NpyIndex, ToNpyDims, ToPyArray}; use crate::dtype::Element; use crate::error::{DimensionalityError, FromVecError, NotContiguousError, TypeError}; @@ -825,27 +826,34 @@ impl PyArray { /// Get the immutable view of the internal data of `PyArray`, as /// [`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html). - /// - /// Please consider the use of safe alternatives - /// ([`PyReadonlyArray::as_array`](../struct.PyReadonlyArray.html#method.as_array) - /// or [`to_array`](#method.to_array)) instead of this. + pub fn as_array(&self) -> PyArrayRef<'_, T, D> { + PyArrayRef::try_new(self).expect("NumPy array already borrowed") + } + + /// Get the immutable view of the internal data of `PyArray`, as + /// [`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html). + pub fn as_array_mut(&self) -> PyArrayRefMut<'_, T, D> { + PyArrayRefMut::try_new(self).expect("NumPy array already borrowed") + } + + /// Returns the internal array as [`ArrayView`]. See also [`as_array_unchecked`]. /// /// # Safety - /// If the internal array is not readonly and can be mutated from Python code, - /// holding the `ArrayView` might cause undefined behavior. - pub unsafe fn as_array(&self) -> ArrayView<'_, T, D> { + /// + /// The existence of an exclusive reference to the internal data, e.g. `&mut [T]` or `ArrayViewMut`, implies undefined behavior. + pub unsafe fn as_array_unchecked(&self) -> ArrayView<'_, T, D> { let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr(); let mut res = ArrayView::from_shape_ptr(shape, ptr); inverted_axes.invert(&mut res); res } - /// Returns the internal array as [`ArrayViewMut`]. See also [`as_array`](#method.as_array). + /// Returns the internal array as [`ArrayViewMut`]. See also [`as_array_unchecked`]. /// /// # Safety - /// If another reference to the internal data exists(e.g., `&[T]` or `ArrayView`), - /// it might cause undefined behavior. - pub unsafe fn as_array_mut(&self) -> ArrayViewMut<'_, T, D> { + /// + /// The existence of another reference to the internal data, e.g. `&[T]` or `ArrayView`, implies undefined behavior. + pub unsafe fn as_array_mut_unchecked(&self) -> ArrayViewMut<'_, T, D> { let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr(); let mut res = ArrayViewMut::from_shape_ptr(shape, ptr); inverted_axes.invert(&mut res); @@ -884,7 +892,7 @@ impl PyArray { /// }); /// ``` pub fn to_owned_array(&self) -> Array { - unsafe { self.as_array() }.to_owned() + unsafe { self.as_array_unchecked() }.to_owned() } } diff --git a/src/borrow.rs b/src/borrow.rs new file mode 100644 index 000000000..4d2a7d6fe --- /dev/null +++ b/src/borrow.rs @@ -0,0 +1,170 @@ +use std::cell::UnsafeCell; +use std::collections::hash_map::{Entry, HashMap}; +use std::ops::{Deref, DerefMut}; + +use ndarray::{ArrayView, ArrayViewMut, Dimension}; +use pyo3::{FromPyObject, PyAny, PyResult}; + +use crate::array::PyArray; +use crate::dtype::Element; + +thread_local! { + static BORROW_FLAGS: UnsafeCell> = UnsafeCell::new(HashMap::new()); +} + +pub struct PyArrayRef<'a, T, D> { + array: &'a PyArray, + view: ArrayView<'a, T, D>, +} + +impl<'a, T, D> Deref for PyArrayRef<'a, T, D> { + type Target = ArrayView<'a, T, D>; + + fn deref(&self) -> &Self::Target { + &self.view + } +} + +impl<'py, T: Element, D: Dimension> FromPyObject<'py> for PyArrayRef<'py, T, D> { + fn extract(obj: &'py PyAny) -> PyResult { + let array: &'py PyArray = obj.extract()?; + Ok(array.as_array()) + } +} + +impl<'a, T, D> PyArrayRef<'a, T, D> +where + T: Element, + D: Dimension, +{ + pub(crate) fn try_new(array: &'a PyArray) -> Option { + let address = array as *const PyArray as usize; + + BORROW_FLAGS.with(|borrow_flags| { + // SAFETY: Called on a thread local variable in a leaf function. + let borrow_flags = unsafe { &mut *borrow_flags.get() }; + + match borrow_flags.entry(address) { + Entry::Occupied(entry) => { + let readers = entry.into_mut(); + + let new_readers = readers.wrapping_add(1); + + if new_readers <= 0 { + cold(); + return None; + } + + *readers = new_readers; + } + Entry::Vacant(entry) => { + entry.insert(1); + } + } + + // SAFETY: Thread-local borrow flags ensure aliasing discipline on this thread, + // and `PyArray` is neither `Send` nor `Sync` + let view = unsafe { array.as_array_unchecked() }; + + Some(Self { array, view }) + }) + } +} + +impl<'a, T, D> Drop for PyArrayRef<'a, T, D> { + fn drop(&mut self) { + let address = self.array as *const PyArray as usize; + + BORROW_FLAGS.with(|borrow_flags| { + // SAFETY: Called on a thread local variable in a leaf function. + let borrow_flags = unsafe { &mut *borrow_flags.get() }; + + let readers = borrow_flags.get_mut(&address).unwrap(); + + *readers -= 1; + + if *readers == 0 { + borrow_flags.remove(&address).unwrap(); + } + }); + } +} + +pub struct PyArrayRefMut<'a, T, D> { + array: &'a PyArray, + view: ArrayViewMut<'a, T, D>, +} + +impl<'a, T, D> Deref for PyArrayRefMut<'a, T, D> { + type Target = ArrayViewMut<'a, T, D>; + + fn deref(&self) -> &Self::Target { + &self.view + } +} + +impl<'a, T, D> DerefMut for PyArrayRefMut<'a, T, D> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.view + } +} + +impl<'py, T: Element, D: Dimension> FromPyObject<'py> for PyArrayRefMut<'py, T, D> { + fn extract(obj: &'py PyAny) -> PyResult { + let array: &'py PyArray = obj.extract()?; + Ok(array.as_array_mut()) + } +} + +impl<'a, T, D> PyArrayRefMut<'a, T, D> +where + T: Element, + D: Dimension, +{ + pub(crate) fn try_new(array: &'a PyArray) -> Option { + let address = array as *const PyArray as usize; + + BORROW_FLAGS.with(|borrow_flags| { + // SAFETY: Called on a thread local variable in a leaf function. + let borrow_flags = unsafe { &mut *borrow_flags.get() }; + + match borrow_flags.entry(address) { + Entry::Occupied(entry) => { + let writers = entry.into_mut(); + + if *writers != 0 { + cold(); + return None; + } + + *writers = -1; + } + Entry::Vacant(entry) => { + entry.insert(-1); + } + } + + // SAFETY: Thread-local borrow flags ensure aliasing discipline on this thread, + // and `PyArray` is neither `Send` nor `Sync` + let view = unsafe { array.as_array_mut_unchecked() }; + + Some(Self { array, view }) + }) + } +} + +impl<'a, T, D> Drop for PyArrayRefMut<'a, T, D> { + fn drop(&mut self) { + let address = self.array as *const PyArray as usize; + + BORROW_FLAGS.with(|borrow_flags| { + // SAFETY: Called on a thread local variable in a leaf function. + let borrow_flags = unsafe { &mut *borrow_flags.get() }; + + borrow_flags.remove(&address).unwrap(); + }); + } +} +#[cold] +#[inline(always)] +fn cold() {} diff --git a/src/lib.rs b/src/lib.rs index 3b049ceff..291a251d3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -30,6 +30,7 @@ #![allow(clippy::needless_lifetimes)] // We often want to make the GIL lifetime explicit. pub mod array; +mod borrow; pub mod convert; mod dtype; mod error; @@ -46,6 +47,7 @@ pub use crate::array::{ get_array_module, PyArray, PyArray0, PyArray1, PyArray2, PyArray3, PyArray4, PyArray5, PyArray6, PyArrayDyn, }; +pub use crate::borrow::{PyArrayRef, PyArrayRefMut}; pub use crate::convert::{IntoPyArray, NpyIndex, ToNpyDims, ToPyArray}; pub use crate::dtype::{dtype, Complex32, Complex64, Element, PyArrayDescr}; pub use crate::error::{DimensionalityError, FromVecError, NotContiguousError, TypeError}; diff --git a/src/readonly.rs b/src/readonly.rs index 7a25a0eb6..15a384d89 100644 --- a/src/readonly.rs +++ b/src/readonly.rs @@ -91,7 +91,7 @@ impl<'py, T: Element, D: Dimension> PyReadonlyArray<'py, T, D> { /// }); /// ``` pub fn as_array(&self) -> ArrayView<'_, T, D> { - unsafe { self.array.as_array() } + unsafe { self.array.as_array_unchecked() } } /// Get an immutable reference of the specified element, with checking the passed index is valid.