Skip to content

Commit 9fb0101

Browse files
committed
Add dynamic borrow checking for dereferencing NumPy arrays.
1 parent 833896d commit 9fb0101

File tree

11 files changed

+1032
-398
lines changed

11 files changed

+1032
-398
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Changelog
22

33
- Unreleased
4+
- Add dynamic borrow checking to safely construct references into the interior of NumPy arrays. ([#274](https://github.com/PyO3/rust-numpy/pull/274))
45
- Deprecate `PyArray::from_exact_iter` after optimizing `PyArray::from_iter`. ([#292](https://github.com/PyO3/rust-numpy/pull/292))
56

67
- v0.16.2

benches/borrow.rs

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#![feature(test)]
2+
3+
extern crate test;
4+
use test::{black_box, Bencher};
5+
6+
use numpy::PyArray;
7+
use pyo3::Python;
8+
9+
#[bench]
10+
fn initial_shared_borrow(bencher: &mut Bencher) {
11+
Python::with_gil(|py| {
12+
let array = PyArray::<f64, _>::zeros(py, (1, 2, 3), false);
13+
14+
bencher.iter(|| {
15+
let array = black_box(array);
16+
17+
let _shared = array.readonly();
18+
});
19+
});
20+
}
21+
22+
#[bench]
23+
fn additional_shared_borrow(bencher: &mut Bencher) {
24+
Python::with_gil(|py| {
25+
let array = PyArray::<f64, _>::zeros(py, (1, 2, 3), false);
26+
27+
let _shared = (0..128).map(|_| array.readonly()).collect::<Vec<_>>();
28+
29+
bencher.iter(|| {
30+
let array = black_box(array);
31+
32+
let _shared = array.readonly();
33+
});
34+
});
35+
}
36+
37+
#[bench]
38+
fn exclusive_borrow(bencher: &mut Bencher) {
39+
Python::with_gil(|py| {
40+
let array = PyArray::<f64, _>::zeros(py, (1, 2, 3), false);
41+
42+
bencher.iter(|| {
43+
let array = black_box(array);
44+
45+
let _exclusive = array.readwrite();
46+
});
47+
});
48+
}

examples/simple/src/lib.rs

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
use numpy::ndarray::{ArrayD, ArrayViewD, ArrayViewMutD};
2-
use numpy::{Complex64, IntoPyArray, PyArray1, PyArrayDyn, PyReadonlyArrayDyn};
2+
use numpy::{
3+
Complex64, IntoPyArray, PyArray1, PyArrayDyn, PyReadonlyArrayDyn, PyReadwriteArrayDyn,
4+
};
35
use pyo3::{
46
pymodule,
57
types::{PyDict, PyModule},
@@ -41,8 +43,8 @@ fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
4143
// wrapper of `mult`
4244
#[pyfn(m)]
4345
#[pyo3(name = "mult")]
44-
fn mult_py(a: f64, x: &PyArrayDyn<f64>) {
45-
let x = unsafe { x.as_array_mut() };
46+
fn mult_py(a: f64, mut x: PyReadwriteArrayDyn<f64>) {
47+
let x = x.as_array_mut();
4648
mult(a, x);
4749
}
4850

src/array.rs

+67-52
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,12 @@ use pyo3::{
1919
Python, ToPyObject,
2020
};
2121

22+
use crate::borrow::{PyReadonlyArray, PyReadwriteArray};
2223
use crate::cold;
2324
use crate::convert::{ArrayExt, IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
2425
use crate::dtype::{Element, PyArrayDescr};
2526
use crate::error::{DimensionalityError, FromVecError, NotContiguousError, TypeError};
2627
use crate::npyffi::{self, npy_intp, NPY_ORDER, PY_ARRAY_API};
27-
#[allow(deprecated)]
28-
use crate::npyiter::{NpySingleIter, NpySingleIterBuilder, ReadWrite};
29-
use crate::readonly::PyReadonlyArray;
3028
use crate::slice_container::PySliceContainer;
3129

3230
/// A safe, static-typed interface for
@@ -195,18 +193,8 @@ impl<T, D> PyArray<T, D> {
195193
}
196194

197195
#[inline(always)]
198-
fn check_flag(&self, flag: c_int) -> bool {
199-
unsafe { *self.as_array_ptr() }.flags & flag == flag
200-
}
201-
202-
#[inline(always)]
203-
pub(crate) fn get_flag(&self) -> c_int {
204-
unsafe { *self.as_array_ptr() }.flags
205-
}
206-
207-
/// Returns a temporally unwriteable reference of the array.
208-
pub fn readonly(&self) -> PyReadonlyArray<T, D> {
209-
self.into()
196+
pub(crate) fn check_flags(&self, flags: c_int) -> bool {
197+
unsafe { *self.as_array_ptr() }.flags & flags != 0
210198
}
211199

212200
/// Returns `true` if the internal data of the array is C-style contiguous
@@ -228,18 +216,17 @@ impl<T, D> PyArray<T, D> {
228216
/// });
229217
/// ```
230218
pub fn is_contiguous(&self) -> bool {
231-
self.check_flag(npyffi::NPY_ARRAY_C_CONTIGUOUS)
232-
| self.check_flag(npyffi::NPY_ARRAY_F_CONTIGUOUS)
219+
self.check_flags(npyffi::NPY_ARRAY_C_CONTIGUOUS | npyffi::NPY_ARRAY_F_CONTIGUOUS)
233220
}
234221

235222
/// Returns `true` if the internal data of the array is Fortran-style contiguous.
236223
pub fn is_fortran_contiguous(&self) -> bool {
237-
self.check_flag(npyffi::NPY_ARRAY_F_CONTIGUOUS)
224+
self.check_flags(npyffi::NPY_ARRAY_F_CONTIGUOUS)
238225
}
239226

240227
/// Returns `true` if the internal data of the array is C-style contiguous.
241228
pub fn is_c_contiguous(&self) -> bool {
242-
self.check_flag(npyffi::NPY_ARRAY_C_CONTIGUOUS)
229+
self.check_flags(npyffi::NPY_ARRAY_C_CONTIGUOUS)
243230
}
244231

245232
/// Get `Py<PyArray>` from `&PyArray`, which is the owned wrapper of PyObject.
@@ -681,27 +668,61 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
681668

682669
/// Get the immutable reference of the specified element, with checking the passed index is valid.
683670
///
684-
/// Please consider the use of safe alternatives
685-
/// ([`PyReadonlyArray::get`](../struct.PyReadonlyArray.html#method.get)
686-
/// or [`get_owned`](#method.get_owned)) instead of this.
671+
/// Consider using safe alternatives like [`PyReadonlyArray::get`].
672+
///
687673
/// # Example
674+
///
688675
/// ```
689676
/// use numpy::PyArray;
690-
/// pyo3::Python::with_gil(|py| {
677+
/// use pyo3::Python;
678+
///
679+
/// Python::with_gil(|py| {
691680
/// let arr = PyArray::arange(py, 0, 16, 1).reshape([2, 2, 4]).unwrap();
692-
/// assert_eq!(*unsafe { arr.get([1, 0, 3]) }.unwrap(), 11);
681+
/// assert_eq!(unsafe { *arr.get([1, 0, 3]).unwrap() }, 11);
693682
/// });
694683
/// ```
695684
///
696685
/// # Safety
697-
/// If the internal array is not readonly and can be mutated from Python code,
698-
/// holding the slice might cause undefined behavior.
686+
///
687+
/// Calling this method is undefined behaviour if the underlying array
688+
/// is aliased mutably by other instances of `PyArray`
689+
/// or concurrently modified by Python or other native code.
699690
#[inline(always)]
700691
pub unsafe fn get(&self, index: impl NpyIndex<Dim = D>) -> Option<&T> {
701692
let offset = index.get_checked::<T>(self.shape(), self.strides())?;
702693
Some(&*self.data().offset(offset))
703694
}
704695

696+
/// Same as [`get`][Self::get], but returns `Option<&mut T>`.
697+
///
698+
/// Consider using safe alternatives like [`PyReadwriteArray::get_mut`].
699+
///
700+
/// # Example
701+
///
702+
/// ```
703+
/// use numpy::PyArray;
704+
/// use pyo3::Python;
705+
///
706+
/// Python::with_gil(|py| {
707+
/// let arr = PyArray::arange(py, 0, 16, 1).reshape([2, 2, 4]).unwrap();
708+
/// unsafe {
709+
/// *arr.get_mut([1, 0, 3]).unwrap() = 42;
710+
/// }
711+
/// assert_eq!(unsafe { *arr.get([1, 0, 3]).unwrap() }, 42);
712+
/// });
713+
/// ```
714+
///
715+
/// # Safety
716+
///
717+
/// Calling this method is undefined behaviour if the underlying array
718+
/// is aliased immutably by mutably by other instances of `PyArray`
719+
/// or concurrently modified by Python or other native code.
720+
#[inline(always)]
721+
pub unsafe fn get_mut(&self, index: impl NpyIndex<Dim = D>) -> Option<&mut T> {
722+
let offset = index.get_checked::<T>(self.shape(), self.strides())?;
723+
Some(&mut *self.data().offset(offset))
724+
}
725+
705726
/// Get the immutable reference of the specified element, without checking the
706727
/// passed index is valid.
707728
///
@@ -824,28 +845,37 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
824845
ToPyArray::to_pyarray(arr, py)
825846
}
826847

827-
/// Get the immutable view of the internal data of `PyArray`, as
828-
/// [`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html).
848+
/// Get an immutable borrow of the NumPy array
849+
pub fn readonly(&self) -> PyReadonlyArray<'_, T, D> {
850+
PyReadonlyArray::try_new(self).unwrap()
851+
}
852+
853+
/// Get a mutable borrow of the NumPy array
854+
pub fn readwrite(&self) -> PyReadwriteArray<'_, T, D> {
855+
PyReadwriteArray::try_new(self).unwrap()
856+
}
857+
858+
/// Returns the internal array as [`ArrayView`].
829859
///
830-
/// Please consider the use of safe alternatives
831-
/// ([`PyReadonlyArray::as_array`](../struct.PyReadonlyArray.html#method.as_array)
832-
/// or [`to_array`](#method.to_array)) instead of this.
860+
/// See also [`PyReadonlyArray::as_array`].
833861
///
834862
/// # Safety
835-
/// If the internal array is not readonly and can be mutated from Python code,
836-
/// holding the `ArrayView` might cause undefined behavior.
863+
///
864+
/// The existence of an exclusive reference to the internal data, e.g. `&mut [T]` or `ArrayViewMut`, implies undefined behavior.
837865
pub unsafe fn as_array(&self) -> ArrayView<'_, T, D> {
838866
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
839867
let mut res = ArrayView::from_shape_ptr(shape, ptr);
840868
inverted_axes.invert(&mut res);
841869
res
842870
}
843871

844-
/// Returns the internal array as [`ArrayViewMut`]. See also [`as_array`](#method.as_array).
872+
/// Returns the internal array as [`ArrayViewMut`].
873+
///
874+
/// See also [`PyReadwriteArray::as_array_mut`].
845875
///
846876
/// # Safety
847-
/// If another reference to the internal data exists(e.g., `&[T]` or `ArrayView`),
848-
/// it might cause undefined behavior.
877+
///
878+
/// The existence of another reference to the internal data, e.g. `&[T]` or `ArrayView`, implies undefined behavior.
849879
pub unsafe fn as_array_mut(&self) -> ArrayViewMut<'_, T, D> {
850880
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
851881
let mut res = ArrayViewMut::from_shape_ptr(shape, ptr);
@@ -921,7 +951,7 @@ impl<D: Dimension> PyArray<PyObject, D> {
921951
///
922952
/// let pyarray = PyArray::from_owned_object_array(py, array);
923953
///
924-
/// assert!(pyarray.readonly().get(0).unwrap().as_ref(py).is_instance_of::<CustomElement>().unwrap());
954+
/// assert!(pyarray.readonly().as_array().get(0).unwrap().as_ref(py).is_instance_of::<CustomElement>().unwrap());
925955
/// });
926956
/// ```
927957
pub fn from_owned_object_array<'py, T>(py: Python<'py>, arr: Array<Py<T>, D>) -> &'py Self {
@@ -1043,21 +1073,6 @@ impl<T: Element> PyArray<T, Ix1> {
10431073
self.resize_(self.py(), [new_elems], 1, NPY_ORDER::NPY_ANYORDER)
10441074
}
10451075

1046-
/// Iterates all elements of this array.
1047-
/// See [NpySingleIter](../npyiter/struct.NpySingleIter.html) for more.
1048-
///
1049-
/// # Safety
1050-
///
1051-
/// The iterator will produce mutable references into the array which must not be
1052-
/// aliased by other references for the life time of the iterator.
1053-
#[deprecated(
1054-
note = "The wrappers of the array iterator API are deprecated, please use ndarray's `ArrayBase::iter_mut` instead."
1055-
)]
1056-
#[allow(deprecated)]
1057-
pub unsafe fn iter<'py>(&'py self) -> PyResult<NpySingleIter<'py, T, ReadWrite>> {
1058-
NpySingleIterBuilder::readwrite(self).build()
1059-
}
1060-
10611076
fn resize_<D: IntoDimension>(
10621077
&self,
10631078
py: Python,

0 commit comments

Comments
 (0)