diff --git a/Cargo.toml b/Cargo.toml index 9f7d2262e..dbeccc51f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ license = "BSD-2-Clause" [dependencies] libc = "0.2" num-complex = ">= 0.2, < 0.5" +num-integer = "0.1" num-traits = "0.2" ndarray = ">= 0.13, < 0.16" pyo3 = { version = "0.16", default-features = false, features = ["macros"] } diff --git a/src/borrow.rs b/src/borrow.rs index 1680c9d1b..2f6317c4b 100644 --- a/src/borrow.rs +++ b/src/borrow.rs @@ -59,12 +59,9 @@ //! }); //! ``` //! -//! The second example shows that non-overlapping and interleaved views which do not alias -//! are currently not supported due to over-approximating which borrows are in conflict. +//! The second example shows that non-overlapping and interleaved views are also supported. //! //! ```rust -//! # use std::panic::{catch_unwind, AssertUnwindSafe}; -//! # //! use numpy::PyArray1; //! use pyo3::{types::IntoPyDict, Python}; //! @@ -77,19 +74,37 @@ //! let view3 = py.eval("array[::2]", None, Some(locals)).unwrap().downcast::>().unwrap(); //! let view4 = py.eval("array[1::2]", None, Some(locals)).unwrap().downcast::>().unwrap(); //! -//! // Will fail at runtime even though `view1` and `view2` -//! // do not overlap as they are based on the same array. -//! let res = catch_unwind(AssertUnwindSafe(|| { +//! { //! let _view1 = view1.readwrite(); //! let _view2 = view2.readwrite(); -//! })); -//! assert!(res.is_err()); +//! } //! -//! // Will fail at runtime even though `view3` and `view4` -//! // interleave as they are based on the same array. -//! let res = catch_unwind(AssertUnwindSafe(|| { +//! { //! let _view3 = view3.readwrite(); //! let _view4 = view4.readwrite(); +//! } +//! }); +//! ``` +//! +//! The third example shows that some views are incorrectly rejected since the borrows are over-approximated. +//! +//! ```rust +//! # use std::panic::{catch_unwind, AssertUnwindSafe}; +//! # +//! use numpy::PyArray2; +//! use pyo3::{types::IntoPyDict, Python}; +//! +//! Python::with_gil(|py| { +//! let array = PyArray2::::zeros(py, (10, 10), false); +//! let locals = [("array", array)].into_py_dict(py); +//! +//! let view1 = py.eval("array[:, ::3]", None, Some(locals)).unwrap().downcast::>().unwrap(); +//! let view2 = py.eval("array[:, 1::3]", None, Some(locals)).unwrap().downcast::>().unwrap(); +//! +//! // A false conflict as the views do not actually share any elements. +//! let res = catch_unwind(AssertUnwindSafe(|| { +//! let _view1 = view1.readwrite(); +//! let _view2 = view2.readwrite(); //! })); //! assert!(res.is_err()); //! }); @@ -129,25 +144,30 @@ //! //! # Limitations //! -//! Note that the current implementation of this is an over-approximation: It will consider all borrows potentially conflicting -//! if the initial arrays have the same object at the end of their [base object chain][base]. -//! For example, creating two views of the same underlying array by slicing will always yield potentially conflicting borrows -//! even if the slice indices are chosen so that the two views do not actually share any elements by splitting the array into -//! non-overlapping parts of by interleaving along one of its axes. +//! Note that the current implementation of this is an over-approximation: It will consider borrows +//! potentially conflicting if the initial arrays have the same object at the end of their [base object chain][base]. +//! Then, multiple conditions which are sufficient but not necessary to show the absence of conflicts are checked. +//! +//! While this is sufficient to handle common situations like slicing an array with a non-unit step size which divides +//! the dimension along that axis, there are also cases which it does not handle. For example, if the step size does +//! not divide the dimension along the sliced axis. Under such conditions, borrows are rejected even though the arrays +//! do not actually share any elements. //! //! This does limit the set of programs that can be written using safe Rust in way similar to rustc itself //! which ensures that all accepted programs are memory safe but does not necessarily accept all memory safe programs. -//! The plan is to refine this checking to correctly handle more involved cases like non-overlapping and interleaved -//! views into the same array and until then the unsafe method [`PyArray::as_array_mut`] can be used as an escape hatch. +//! However, the unsafe method [`PyArray::as_array_mut`] can be used as an escape hatch. +//! More involved cases like the example from above may be supported in the future. //! //! [base]: https://numpy.org/doc/stable/reference/c-api/types-and-structures.html#c.NPY_AO.base #![deny(missing_docs)] use std::cell::UnsafeCell; use std::collections::hash_map::{Entry, HashMap}; +use std::mem::size_of; use std::ops::Deref; use ndarray::{ArrayView, ArrayViewMut, Dimension, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn}; +use num_integer::gcd; use pyo3::{FromPyObject, PyAny, PyResult}; use crate::array::PyArray; @@ -157,7 +177,65 @@ use crate::dtype::Element; use crate::error::{BorrowError, NotContiguousError}; use crate::npyffi::{self, PyArrayObject, NPY_ARRAY_WRITEABLE}; -struct BorrowFlags(UnsafeCell>>); +#[derive(PartialEq, Eq, Hash)] +struct BorrowKey { + /// exclusive range of lowest and highest address covered by array + range: (usize, usize), + /// the data address on which address computations are based + data_ptr: usize, + /// the greatest common divisor of the strides of the array + gcd_strides: isize, +} + +impl BorrowKey { + fn from_array(array: &PyArray) -> Self + where + T: Element, + D: Dimension, + { + let range = data_range(array); + + let data_ptr = array.data() as usize; + let gcd_strides = reduce(array.strides().iter().copied(), gcd).unwrap_or(1); + + Self { + range, + data_ptr, + gcd_strides, + } + } + + fn conflicts(&self, other: &Self) -> bool { + debug_assert!(self.range.0 <= self.range.1); + debug_assert!(other.range.0 <= other.range.1); + + if other.range.0 >= self.range.1 || self.range.0 >= other.range.1 { + return false; + } + + // The Diophantine equation which describes whether any integers can combine the data pointers and strides of the two arrays s.t. + // they yield the same element has a solution if and only if the GCD of all strides divides the difference of the data pointers. + // + // That solution could be out of bounds which mean that this is still an over-approximation. + // It appears sufficient to handle typical cases like the color channels of an image, + // but fails when slicing an array with a step size that does not divide the dimension along that axis. + // + // https://users.rust-lang.org/t/math-for-borrow-checking-numpy-arrays/73303 + let ptr_diff = abs_diff(self.data_ptr, other.data_ptr) as isize; + let gcd_strides = gcd(self.gcd_strides, other.gcd_strides); + + if ptr_diff % gcd_strides != 0 { + return false; + } + + // By default, a conflict is assumed as it is the safe choice without actually solving the aliasing equation. + true + } +} + +type BorrowFlagsInner = HashMap>; + +struct BorrowFlags(UnsafeCell>); unsafe impl Sync for BorrowFlags {} @@ -167,12 +245,17 @@ impl BorrowFlags { } #[allow(clippy::mut_from_ref)] - unsafe fn get(&self) -> &mut HashMap { + unsafe fn get(&self) -> &mut BorrowFlagsInner { (*self.0.get()).get_or_insert_with(HashMap::new) } - fn acquire(&self, array: &PyArray) -> Result<(), BorrowError> { + fn acquire(&self, array: &PyArray) -> Result<(), BorrowError> + where + T: Element, + D: Dimension, + { let address = base_address(array); + let key = BorrowKey::from_array(array); // SAFETY: Access to `&PyArray` implies holding the GIL // and we are not calling into user code which might re-enter this function. @@ -180,43 +263,76 @@ impl BorrowFlags { match borrow_flags.entry(address) { Entry::Occupied(entry) => { - let readers = entry.into_mut(); - - let new_readers = readers.wrapping_add(1); - - if new_readers <= 0 { - cold(); - return Err(BorrowError::AlreadyBorrowed); + let same_base_arrays = entry.into_mut(); + + if let Some(readers) = same_base_arrays.get_mut(&key) { + // Zero flags are removed during release. + assert_ne!(*readers, 0); + + let new_readers = readers.wrapping_add(1); + + if new_readers <= 0 { + cold(); + return Err(BorrowError::AlreadyBorrowed); + } + + *readers = new_readers; + } else { + if same_base_arrays + .iter() + .any(|(other, readers)| key.conflicts(other) && *readers < 0) + { + cold(); + return Err(BorrowError::AlreadyBorrowed); + } + + same_base_arrays.insert(key, 1); } - - *readers = new_readers; } Entry::Vacant(entry) => { - entry.insert(1); + let mut same_base_arrays = HashMap::with_capacity(1); + same_base_arrays.insert(key, 1); + entry.insert(same_base_arrays); } } Ok(()) } - fn release(&self, array: &PyArray) { + fn release(&self, array: &PyArray) + where + T: Element, + D: Dimension, + { let address = base_address(array); + let key = BorrowKey::from_array(array); // SAFETY: Access to `&PyArray` implies holding the GIL // and we are not calling into user code which might re-enter this function. let borrow_flags = unsafe { BORROW_FLAGS.get() }; - let readers = borrow_flags.get_mut(&address).unwrap(); + let same_base_arrays = borrow_flags.get_mut(&address).unwrap(); + + let readers = same_base_arrays.get_mut(&key).unwrap(); *readers -= 1; if *readers == 0 { - borrow_flags.remove(&address).unwrap(); + if same_base_arrays.len() > 1 { + same_base_arrays.remove(&key).unwrap(); + } else { + borrow_flags.remove(&address).unwrap(); + } } } - fn acquire_mut(&self, array: &PyArray) -> Result<(), BorrowError> { + fn acquire_mut(&self, array: &PyArray) -> Result<(), BorrowError> + where + T: Element, + D: Dimension, + { let address = base_address(array); + let key = BorrowKey::from_array(array); // SAFETY: Access to `&PyArray` implies holding the GIL // and we are not calling into user code which might re-enter this function. @@ -224,31 +340,55 @@ impl BorrowFlags { match borrow_flags.entry(address) { Entry::Occupied(entry) => { - let writers = entry.into_mut(); + let same_base_arrays = entry.into_mut(); + + if let Some(writers) = same_base_arrays.get_mut(&key) { + // Zero flags are removed during release. + assert_ne!(*writers, 0); - if *writers != 0 { cold(); return Err(BorrowError::AlreadyBorrowed); + } else { + if same_base_arrays + .iter() + .any(|(other, writers)| key.conflicts(other) && *writers != 0) + { + cold(); + return Err(BorrowError::AlreadyBorrowed); + } + + same_base_arrays.insert(key, -1); } - - *writers = -1; } Entry::Vacant(entry) => { - entry.insert(-1); + let mut same_base_arrays = HashMap::with_capacity(1); + same_base_arrays.insert(key, -1); + entry.insert(same_base_arrays); } } Ok(()) } - fn release_mut(&self, array: &PyArray) { + fn release_mut(&self, array: &PyArray) + where + T: Element, + D: Dimension, + { let address = base_address(array); + let key = BorrowKey::from_array(array); // SAFETY: Access to `&PyArray` implies holding the GIL // and we are not calling into user code which might re-enter this function. - let borrow_flags = unsafe { self.get() }; + let borrow_flags = unsafe { BORROW_FLAGS.get() }; + + let same_base_arrays = borrow_flags.get_mut(&address).unwrap(); - borrow_flags.remove(&address).unwrap(); + if same_base_arrays.len() > 1 { + same_base_arrays.remove(&key).unwrap(); + } else { + borrow_flags.remove(&address); + } } } @@ -260,7 +400,10 @@ static BORROW_FLAGS: BorrowFlags = BorrowFlags::new(); /// i.e. that only shared references into the interior of the array can be created safely. /// /// See the [module-level documentation](self) for more. -pub struct PyReadonlyArray<'py, T, D>(&'py PyArray); +pub struct PyReadonlyArray<'py, T, D>(&'py PyArray) +where + T: Element, + D: Dimension; /// Read-only borrow of a one-dimensional array. pub type PyReadonlyArray1<'py, T> = PyReadonlyArray<'py, T, Ix1>; @@ -283,7 +426,11 @@ pub type PyReadonlyArray6<'py, T> = PyReadonlyArray<'py, T, Ix6>; /// Read-only borrow of an array whose dimensionality is determined at runtime. pub type PyReadonlyArrayDyn<'py, T> = PyReadonlyArray<'py, T, IxDyn>; -impl<'py, T, D> Deref for PyReadonlyArray<'py, T, D> { +impl<'py, T, D> Deref for PyReadonlyArray<'py, T, D> +where + T: Element, + D: Dimension, +{ type Target = PyArray; fn deref(&self) -> &Self::Target { @@ -343,7 +490,11 @@ where } } -impl<'a, T, D> Drop for PyReadonlyArray<'a, T, D> { +impl<'a, T, D> Drop for PyReadonlyArray<'a, T, D> +where + T: Element, + D: Dimension, +{ fn drop(&mut self) { BORROW_FLAGS.release(self.0); } @@ -355,7 +506,10 @@ impl<'a, T, D> Drop for PyReadonlyArray<'a, T, D> { /// i.e. that only a single exclusive reference into the interior of the array can be created safely. /// /// See the [module-level documentation](self) for more. -pub struct PyReadwriteArray<'py, T, D>(&'py PyArray); +pub struct PyReadwriteArray<'py, T, D>(&'py PyArray) +where + T: Element, + D: Dimension; /// Read-write borrow of a one-dimensional array. pub type PyReadwriteArray1<'py, T> = PyReadwriteArray<'py, T, Ix1>; @@ -378,7 +532,11 @@ pub type PyReadwriteArray6<'py, T> = PyReadwriteArray<'py, T, Ix6>; /// Read-write borrow of an array whose dimensionality is determined at runtime. pub type PyReadwriteArrayDyn<'py, T> = PyReadwriteArray<'py, T, IxDyn>; -impl<'py, T, D> Deref for PyReadwriteArray<'py, T, D> { +impl<'py, T, D> Deref for PyReadwriteArray<'py, T, D> +where + T: Element, + D: Dimension, +{ type Target = PyReadonlyArray<'py, T, D>; fn deref(&self) -> &Self::Target { @@ -468,14 +626,16 @@ where } } -impl<'a, T, D> Drop for PyReadwriteArray<'a, T, D> { +impl<'a, T, D> Drop for PyReadwriteArray<'a, T, D> +where + T: Element, + D: Dimension, +{ fn drop(&mut self) { BORROW_FLAGS.release_mut(self.0); } } -// FIXME(adamreichold): This is a coarse approximation and needs to be refined, -// i.e. borrows of non-overlapping views into the same base should not be considered conflicting. fn base_address(array: &PyArray) -> usize { let py = array.py(); let mut array = array.as_array_ptr(); @@ -493,6 +653,57 @@ fn base_address(array: &PyArray) -> usize { } } +fn data_range(array: &PyArray) -> (usize, usize) +where + T: Element, + D: Dimension, +{ + let shape = array.shape(); + let strides = array.strides(); + + let mut start = 0; + let mut end = 0; + + if shape.iter().all(|dim| *dim != 0) { + for (&dim, &stride) in shape.iter().zip(strides) { + let offset = (dim - 1) as isize * stride; + + if offset >= 0 { + end += offset; + } else { + start += offset; + } + } + + end += size_of::() as isize; + } + + let data = unsafe { (*array.as_array_ptr()).data }; + let start = unsafe { data.offset(start) } as usize; + let end = unsafe { data.offset(end) } as usize; + + (start, end) +} + +// FIXME(adamreichold): Use `usize::abs_diff` from std when that becomes stable. +fn abs_diff(lhs: usize, rhs: usize) -> usize { + if lhs >= rhs { + lhs - rhs + } else { + rhs - lhs + } +} + +// FIXME(adamreichold): Use `Iterator::reduce` from std when our MSRV reaches 1.51. +fn reduce(mut iter: I, f: F) -> Option +where + I: Iterator, + F: FnMut(I::Item, I::Item) -> I::Item, +{ + let first = iter.next()?; + Some(iter.fold(first, f)) +} + #[cfg(test)] mod tests { use super::*; @@ -500,7 +711,7 @@ mod tests { use ndarray::Array; use pyo3::{types::IntoPyDict, Python}; - use crate::array::{PyArray1, PyArray2}; + use crate::array::{PyArray1, PyArray2, PyArray3}; use crate::convert::IntoPyArray; #[test] @@ -513,6 +724,10 @@ mod tests { let base_address = base_address(array); assert_eq!(base_address, array as *const _ as usize); + + let data_range = data_range(array); + assert_eq!(data_range.0, array.data() as usize); + assert_eq!(data_range.1, unsafe { array.data().add(6) } as usize); }); } @@ -527,6 +742,10 @@ mod tests { let base_address = base_address(array); assert_ne!(base_address, array as *const _ as usize); assert_eq!(base_address, base as usize); + + let data_range = data_range(array); + assert_eq!(data_range.0, array.data() as usize); + assert_eq!(data_range.1, unsafe { array.data().add(6) } as usize); }); } @@ -549,6 +768,10 @@ mod tests { let base_address = base_address(view); assert_ne!(base_address, view as *const _ as usize); assert_eq!(base_address, base as usize); + + let data_range = data_range(view); + assert_eq!(data_range.0, array.data() as usize); + assert_eq!(data_range.1, unsafe { array.data().add(4) } as usize); }); } @@ -575,6 +798,10 @@ mod tests { assert_ne!(base_address, view as *const _ as usize); assert_ne!(base_address, array as *const _ as usize); assert_eq!(base_address, base as usize); + + let data_range = data_range(view); + assert_eq!(data_range.0, array.data() as usize); + assert_eq!(data_range.1, unsafe { array.data().add(4) } as usize); }); } @@ -610,6 +837,10 @@ mod tests { assert_ne!(base_address, view2 as *const _ as usize); assert_ne!(base_address, view1 as *const _ as usize); assert_eq!(base_address, base as usize); + + let data_range = data_range(view2); + assert_eq!(data_range.0, array.data() as usize); + assert_eq!(data_range.1, unsafe { array.data().add(1) } as usize); }); } @@ -649,6 +880,334 @@ mod tests { assert_ne!(base_address, view1 as *const _ as usize); assert_ne!(base_address, array as *const _ as usize); assert_eq!(base_address, base as usize); + + let data_range = data_range(view2); + assert_eq!(data_range.0, array.data() as usize); + assert_eq!(data_range.1, unsafe { array.data().add(1) } as usize); + }); + } + + #[test] + fn view_with_negative_strides() { + Python::with_gil(|py| { + let array = PyArray::::zeros(py, (1, 2, 3), false); + + let locals = [("array", array)].into_py_dict(py); + let view = py + .eval("array[::-1,:,::-1]", None, Some(locals)) + .unwrap() + .downcast::>() + .unwrap(); + assert_ne!(view as *const _ as usize, array as *const _ as usize); + + let base = unsafe { (*view.as_array_ptr()).base }; + assert_eq!(base as usize, array as *const _ as usize); + + let base_address = base_address(view); + assert_ne!(base_address, view as *const _ as usize); + assert_eq!(base_address, base as usize); + + let data_range = data_range(view); + assert_eq!(view.data(), unsafe { array.data().offset(2) }); + assert_eq!(data_range.0, unsafe { view.data().offset(-2) } as usize); + assert_eq!(data_range.1, unsafe { view.data().offset(4) } as usize); + }); + } + + #[test] + fn array_with_zero_dimensions() { + Python::with_gil(|py| { + let array = PyArray::::zeros(py, (1, 0, 3), false); + + let base = unsafe { (*array.as_array_ptr()).base }; + assert!(base.is_null()); + + let base_address = base_address(array); + assert_eq!(base_address, array as *const _ as usize); + + let data_range = data_range(array); + assert_eq!(data_range.0, array.data() as usize); + assert_eq!(data_range.1, array.data() as usize); + }); + } + + #[test] + fn view_with_non_dividing_strides() { + Python::with_gil(|py| { + let array = PyArray::::zeros(py, (10, 10), false); + let locals = [("array", array)].into_py_dict(py); + + let view1 = py + .eval("array[:,::3]", None, Some(locals)) + .unwrap() + .downcast::>() + .unwrap(); + + let key1 = BorrowKey::from_array(view1); + + assert_eq!(view1.strides(), &[80, 24]); + assert_eq!(key1.gcd_strides, 8); + + let view2 = py + .eval("array[:,1::3]", None, Some(locals)) + .unwrap() + .downcast::>() + .unwrap(); + + let key2 = BorrowKey::from_array(view2); + + assert_eq!(view2.strides(), &[80, 24]); + assert_eq!(key2.gcd_strides, 8); + + let view3 = py + .eval("array[:,::2]", None, Some(locals)) + .unwrap() + .downcast::>() + .unwrap(); + + let key3 = BorrowKey::from_array(view3); + + assert_eq!(view3.strides(), &[80, 16]); + assert_eq!(key3.gcd_strides, 16); + + let view4 = py + .eval("array[:,1::2]", None, Some(locals)) + .unwrap() + .downcast::>() + .unwrap(); + + let key4 = BorrowKey::from_array(view4); + + assert_eq!(view4.strides(), &[80, 16]); + assert_eq!(key4.gcd_strides, 16); + + assert!(!key3.conflicts(&key4)); + assert!(key1.conflicts(&key3)); + assert!(key2.conflicts(&key4)); + + // This is a false conflict where all aliasing indices like (0,7) and (2,0) are out of bounds. + assert!(key1.conflicts(&key2)); + }); + } + + #[test] + fn borrow_multiple_arrays() { + Python::with_gil(|py| { + let array1 = PyArray::::zeros(py, 10, false); + let array2 = PyArray::::zeros(py, 10, false); + + let base1 = base_address(array1); + let base2 = base_address(array2); + + let key1 = BorrowKey::from_array(array1); + let _exclusive1 = array1.readwrite(); + + { + let borrow_flags = unsafe { BORROW_FLAGS.get() }; + assert_eq!(borrow_flags.len(), 1); + + let same_base_arrays = &borrow_flags[&base1]; + assert_eq!(same_base_arrays.len(), 1); + + let flag = same_base_arrays[&key1]; + assert_eq!(flag, -1); + } + + let key2 = BorrowKey::from_array(array2); + let _shared2 = array2.readonly(); + + { + let borrow_flags = unsafe { BORROW_FLAGS.get() }; + assert_eq!(borrow_flags.len(), 2); + + let same_base_arrays = &borrow_flags[&base1]; + assert_eq!(same_base_arrays.len(), 1); + + let flag = same_base_arrays[&key1]; + assert_eq!(flag, -1); + + let same_base_arrays = &borrow_flags[&base2]; + assert_eq!(same_base_arrays.len(), 1); + + let flag = same_base_arrays[&key2]; + assert_eq!(flag, 1); + } + }); + } + + #[test] + fn borrow_multiple_views() { + Python::with_gil(|py| { + let array = PyArray::::zeros(py, 10, false); + let base = base_address(array); + + let locals = [("array", array)].into_py_dict(py); + + let view1 = py + .eval("array[:5]", None, Some(locals)) + .unwrap() + .downcast::>() + .unwrap(); + + let key1 = BorrowKey::from_array(view1); + let exclusive1 = view1.readwrite(); + + { + let borrow_flags = unsafe { BORROW_FLAGS.get() }; + assert_eq!(borrow_flags.len(), 1); + + let same_base_arrays = &borrow_flags[&base]; + assert_eq!(same_base_arrays.len(), 1); + + let flag = same_base_arrays[&key1]; + assert_eq!(flag, -1); + } + + let view2 = py + .eval("array[5:]", None, Some(locals)) + .unwrap() + .downcast::>() + .unwrap(); + + let key2 = BorrowKey::from_array(view2); + let shared2 = view2.readonly(); + + { + let borrow_flags = unsafe { BORROW_FLAGS.get() }; + assert_eq!(borrow_flags.len(), 1); + + let same_base_arrays = &borrow_flags[&base]; + assert_eq!(same_base_arrays.len(), 2); + + let flag = same_base_arrays[&key1]; + assert_eq!(flag, -1); + + let flag = same_base_arrays[&key2]; + assert_eq!(flag, 1); + } + + let view3 = py + .eval("array[5:]", None, Some(locals)) + .unwrap() + .downcast::>() + .unwrap(); + + let key3 = BorrowKey::from_array(view3); + let shared3 = view3.readonly(); + + { + let borrow_flags = unsafe { BORROW_FLAGS.get() }; + assert_eq!(borrow_flags.len(), 1); + + let same_base_arrays = &borrow_flags[&base]; + assert_eq!(same_base_arrays.len(), 2); + + let flag = same_base_arrays[&key1]; + assert_eq!(flag, -1); + + let flag = same_base_arrays[&key2]; + assert_eq!(flag, 2); + + let flag = same_base_arrays[&key3]; + assert_eq!(flag, 2); + } + + let view4 = py + .eval("array[7:]", None, Some(locals)) + .unwrap() + .downcast::>() + .unwrap(); + + let key4 = BorrowKey::from_array(view4); + let shared4 = view4.readonly(); + + { + let borrow_flags = unsafe { BORROW_FLAGS.get() }; + assert_eq!(borrow_flags.len(), 1); + + let same_base_arrays = &borrow_flags[&base]; + assert_eq!(same_base_arrays.len(), 3); + + let flag = same_base_arrays[&key1]; + assert_eq!(flag, -1); + + let flag = same_base_arrays[&key2]; + assert_eq!(flag, 2); + + let flag = same_base_arrays[&key3]; + assert_eq!(flag, 2); + + let flag = same_base_arrays[&key4]; + assert_eq!(flag, 1); + } + + drop(shared2); + + { + let borrow_flags = unsafe { BORROW_FLAGS.get() }; + assert_eq!(borrow_flags.len(), 1); + + let same_base_arrays = &borrow_flags[&base]; + assert_eq!(same_base_arrays.len(), 3); + + let flag = same_base_arrays[&key1]; + assert_eq!(flag, -1); + + let flag = same_base_arrays[&key2]; + assert_eq!(flag, 1); + + let flag = same_base_arrays[&key3]; + assert_eq!(flag, 1); + + let flag = same_base_arrays[&key4]; + assert_eq!(flag, 1); + } + + drop(shared3); + + { + let borrow_flags = unsafe { BORROW_FLAGS.get() }; + assert_eq!(borrow_flags.len(), 1); + + let same_base_arrays = &borrow_flags[&base]; + assert_eq!(same_base_arrays.len(), 2); + + let flag = same_base_arrays[&key1]; + assert_eq!(flag, -1); + + assert!(!same_base_arrays.contains_key(&key2)); + + assert!(!same_base_arrays.contains_key(&key3)); + + let flag = same_base_arrays[&key4]; + assert_eq!(flag, 1); + } + + drop(exclusive1); + + { + let borrow_flags = unsafe { BORROW_FLAGS.get() }; + assert_eq!(borrow_flags.len(), 1); + + let same_base_arrays = &borrow_flags[&base]; + assert_eq!(same_base_arrays.len(), 1); + + assert!(!same_base_arrays.contains_key(&key1)); + + assert!(!same_base_arrays.contains_key(&key2)); + + assert!(!same_base_arrays.contains_key(&key3)); + + let flag = same_base_arrays[&key4]; + assert_eq!(flag, 1); + } + + drop(shared4); + + { + let borrow_flags = unsafe { BORROW_FLAGS.get() }; + assert_eq!(borrow_flags.len(), 0); + } }); } } diff --git a/tests/borrow.rs b/tests/borrow.rs index dac90b87d..cbbac2fc8 100644 --- a/tests/borrow.rs +++ b/tests/borrow.rs @@ -43,6 +43,17 @@ fn exclusive_and_shared_borrows() { }); } +#[test] +#[should_panic(expected = "AlreadyBorrowed")] +fn shared_and_exclusive_borrows() { + Python::with_gil(|py| { + let array = PyArray::::zeros(py, (1, 2, 3), false); + + let _shared = array.readonly(); + let _exclusive = array.readwrite(); + }); +} + #[test] #[should_panic(expected = "AlreadyBorrowed")] fn multiple_exclusive_borrows() { @@ -155,8 +166,7 @@ fn overlapping_views_conflict() { } #[test] -#[should_panic(expected = "AlreadyBorrowed")] -fn non_overlapping_views_conflict() { +fn non_overlapping_views_do_not_conflict() { Python::with_gil(|py| { let array = PyArray::::zeros(py, (1, 2, 3), false); let locals = [("array", array)].into_py_dict(py); @@ -175,34 +185,103 @@ fn non_overlapping_views_conflict() { .unwrap(); assert_eq!(view2.shape(), [1]); + let exclusive1 = view1.readwrite(); + let exclusive2 = view2.readwrite(); + + assert_eq!(exclusive2.len(), 1); + assert_eq!(exclusive1.len(), 1); + }); +} + +#[test] +#[should_panic(expected = "AlreadyBorrowed")] +fn conflict_due_to_overlapping_views() { + Python::with_gil(|py| { + let array = PyArray::::zeros(py, 3, false); + let locals = [("array", array)].into_py_dict(py); + + let view1 = py + .eval("array[0:2]", None, Some(locals)) + .unwrap() + .downcast::>() + .unwrap(); + assert_eq!(view1.shape(), [2]); + + let view2 = py + .eval("array[1:3]", None, Some(locals)) + .unwrap() + .downcast::>() + .unwrap(); + assert_eq!(view2.shape(), [2]); + let _exclusive1 = view1.readwrite(); - let _exclusive2 = view2.readwrite(); + let _shared2 = view2.readonly(); }); } #[test] #[should_panic(expected = "AlreadyBorrowed")] -fn interleaved_views_conflict() { +fn conflict_due_to_reborrow_of_overlapping_views() { Python::with_gil(|py| { - let array = PyArray::::zeros(py, (1, 2, 3), false); + let array = PyArray::::zeros(py, 3, false); let locals = [("array", array)].into_py_dict(py); let view1 = py - .eval("array[:,:,1]", None, Some(locals)) + .eval("array[0:2]", None, Some(locals)) + .unwrap() + .downcast::>() + .unwrap(); + assert_eq!(view1.shape(), [2]); + + let view2 = py + .eval("array[1:3]", None, Some(locals)) + .unwrap() + .downcast::>() + .unwrap(); + assert_eq!(view2.shape(), [2]); + + let shared1 = view1.readonly(); + let _shared2 = view2.readonly(); + + drop(shared1); + let _exclusive1 = view1.readwrite(); + }); +} + +#[test] +fn interleaved_views_do_not_conflict() { + Python::with_gil(|py| { + let array = PyArray::::zeros(py, (23, 42, 3), false); + let locals = [("array", array)].into_py_dict(py); + + let view1 = py + .eval("array[:,:,0]", None, Some(locals)) .unwrap() .downcast::>() .unwrap(); - assert_eq!(view1.shape(), [1, 2]); + assert_eq!(view1.shape(), [23, 42]); let view2 = py + .eval("array[:,:,1]", None, Some(locals)) + .unwrap() + .downcast::>() + .unwrap(); + assert_eq!(view2.shape(), [23, 42]); + + let view3 = py .eval("array[:,:,2]", None, Some(locals)) .unwrap() .downcast::>() .unwrap(); - assert_eq!(view2.shape(), [1, 2]); + assert_eq!(view2.shape(), [23, 42]); - let _exclusive1 = view1.readwrite(); - let _exclusive2 = view2.readwrite(); + let exclusive1 = view1.readwrite(); + let exclusive2 = view2.readwrite(); + let exclusive3 = view3.readwrite(); + + assert_eq!(exclusive3.len(), 23 * 42); + assert_eq!(exclusive2.len(), 23 * 42); + assert_eq!(exclusive1.len(), 23 * 42); }); }