5959//! });
6060//! ```
6161//!
62- //! The second example shows that while non-overlapping views are supported,
63- //! interleaved views which do not touch are currently not supported
64- //! due to over-approximating which borrows are in conflict.
62+ //! The second example shows that non-overlapping and interleaved views are also supported.
6563//!
6664//! ```rust
67- //! # use std::panic::{catch_unwind, AssertUnwindSafe};
68- //! #
6965//! use numpy::PyArray1;
7066//! use pyo3::{types::IntoPyDict, Python};
7167//!
7874//! let view3 = py.eval("array[::2]", None, Some(locals)).unwrap().downcast::<PyArray1<f64>>().unwrap();
7975//! let view4 = py.eval("array[1::2]", None, Some(locals)).unwrap().downcast::<PyArray1<f64>>().unwrap();
8076//!
81- //! let _view1 = view1.readwrite();
82- //! let _view2 = view2.readwrite();
77+ //! {
78+ //! let _view1 = view1.readwrite();
79+ //! let _view2 = view2.readwrite();
80+ //! }
8381//!
84- //! // Will fail at runtime even though `view3` and `view4`
85- //! // interleave as they are based on the same array.
86- //! let res = catch_unwind(AssertUnwindSafe(|| {
82+ //! {
8783//! let _view3 = view3.readwrite();
8884//! let _view4 = view4.readwrite();
89- //! }));
90- //! assert!(res.is_err());
85+ //! }
9186//! });
9287//! ```
9388//!
125120//!
126121//! # Limitations
127122//!
123+ //! TODO: We only leave the case of aliasing, but only out of bounds. Can this actually happen for array views?
124+ //!
128125//! Note that the current implementation of this is an over-approximation: It will consider overlapping borrows
129126//! potentially conflicting if the initial arrays have the same object at the end of their [base object chain][base].
130127//! For example, creating two views of the same underlying array by slicing can yield potentially conflicting borrows
@@ -143,6 +140,7 @@ use std::collections::hash_map::{Entry, HashMap};
143140use std:: ops:: { Deref , Range } ;
144141
145142use ndarray:: { ArrayView , ArrayViewMut , Dimension , Ix1 , Ix2 , Ix3 , Ix4 , Ix5 , Ix6 , IxDyn } ;
143+ use num_integer:: gcd;
146144use pyo3:: { FromPyObject , PyAny , PyResult } ;
147145
148146use crate :: array:: PyArray ;
@@ -155,9 +153,28 @@ use crate::npyffi::{self, PyArrayObject, NPY_ARRAY_WRITEABLE};
155153#[ derive( PartialEq , Eq , Hash ) ]
156154struct BorrowKey {
157155 range : Range < usize > ,
156+ data_ptr : usize ,
157+ gcd_strides : isize ,
158158}
159159
160160impl BorrowKey {
161+ fn from_array < T , D > ( array : & PyArray < T , D > ) -> Self
162+ where
163+ T : Element ,
164+ D : Dimension ,
165+ {
166+ let range = data_range ( array) ;
167+
168+ let data_ptr = array. data ( ) as usize ;
169+ let gcd_strides = reduce ( array. strides ( ) . iter ( ) . copied ( ) , gcd) . unwrap_or ( 1 ) ;
170+
171+ Self {
172+ range,
173+ data_ptr,
174+ gcd_strides,
175+ }
176+ }
177+
161178 fn conflicts ( & self , other : & Self ) -> bool {
162179 debug_assert ! ( self . range. start <= self . range. end) ;
163180 debug_assert ! ( other. range. start <= other. range. end) ;
@@ -166,6 +183,20 @@ impl BorrowKey {
166183 return false ;
167184 }
168185
186+ // The Diophantine equation which describes whether any integers can combine the data pointers and strides of the two arrays s.t.
187+ // they yield the same element has a solution if and only if the GCD of all strides divides the difference of the data pointers.
188+ //
189+ // That solution could be out of bounds which mean that this is still an approximation,
190+ // but it seems sufficient to handle typical cases like the color channels of an image.
191+ //
192+ // https://users.rust-lang.org/t/math-for-borrow-checking-numpy-arrays/73303
193+ let ptr_diff = abs_diff ( self . data_ptr , other. data_ptr ) as isize ;
194+ let gcd_strides = gcd ( self . gcd_strides , other. gcd_strides ) ;
195+
196+ if ptr_diff % gcd_strides != 0 {
197+ return false ;
198+ }
199+
169200 true
170201 }
171202}
@@ -192,10 +223,7 @@ impl BorrowFlags {
192223 D : Dimension ,
193224 {
194225 let address = base_address ( array) ;
195-
196- let key = BorrowKey {
197- range : data_range ( array) ,
198- } ;
226+ let key = BorrowKey :: from_array ( array) ;
199227
200228 // SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
201229 // and we are not calling into user code which might re-enter this function.
@@ -242,10 +270,7 @@ impl BorrowFlags {
242270 D : Dimension ,
243271 {
244272 let address = base_address ( array) ;
245-
246- let key = BorrowKey {
247- range : data_range ( array) ,
248- } ;
273+ let key = BorrowKey :: from_array ( array) ;
249274
250275 // SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
251276 // and we are not calling into user code which might re-enter this function.
@@ -272,10 +297,7 @@ impl BorrowFlags {
272297 D : Dimension ,
273298 {
274299 let address = base_address ( array) ;
275-
276- let key = BorrowKey {
277- range : data_range ( array) ,
278- } ;
300+ let key = BorrowKey :: from_array ( array) ;
279301
280302 // SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
281303 // and we are not calling into user code which might re-enter this function.
@@ -320,10 +342,7 @@ impl BorrowFlags {
320342 D : Dimension ,
321343 {
322344 let address = base_address ( array) ;
323-
324- let key = BorrowKey {
325- range : data_range ( array) ,
326- } ;
345+ let key = BorrowKey :: from_array ( array) ;
327346
328347 // SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
329348 // and we are not calling into user code which might re-enter this function.
@@ -628,6 +647,25 @@ where
628647 Range { start, end }
629648}
630649
650+ // FIXME(adamreichold): Use `usize::abs_diff` from std when that becomes stable.
651+ fn abs_diff ( lhs : usize , rhs : usize ) -> usize {
652+ if lhs >= rhs {
653+ lhs - rhs
654+ } else {
655+ rhs - lhs
656+ }
657+ }
658+
659+ // FIXME(adamreichold): Use `Iterator::reduce` from std when our MSRV reaches 1.51.
660+ fn reduce < I , F > ( mut iter : I , f : F ) -> Option < I :: Item >
661+ where
662+ I : Iterator ,
663+ F : FnMut ( I :: Item , I :: Item ) -> I :: Item ,
664+ {
665+ let first = iter. next ( ) ?;
666+ Some ( iter. fold ( first, f) )
667+ }
668+
631669#[ cfg( test) ]
632670mod tests {
633671 use super :: * ;
@@ -650,7 +688,7 @@ mod tests {
650688 assert_eq ! ( base_address, array as * const _ as usize ) ;
651689
652690 let data_range = data_range ( array) ;
653- assert_eq ! ( data_range. start, unsafe { array. data( ) } as usize ) ;
691+ assert_eq ! ( data_range. start, array. data( ) as usize ) ;
654692 assert_eq ! ( data_range. end, unsafe { array. data( ) . add( 15 ) } as usize ) ;
655693 } ) ;
656694 }
@@ -668,7 +706,7 @@ mod tests {
668706 assert_eq ! ( base_address, base as usize ) ;
669707
670708 let data_range = data_range ( array) ;
671- assert_eq ! ( data_range. start, unsafe { array. data( ) } as usize ) ;
709+ assert_eq ! ( data_range. start, array. data( ) as usize ) ;
672710 assert_eq ! ( data_range. end, unsafe { array. data( ) . add( 15 ) } as usize ) ;
673711 } ) ;
674712 }
@@ -694,7 +732,7 @@ mod tests {
694732 assert_eq ! ( base_address, base as usize ) ;
695733
696734 let data_range = data_range ( view) ;
697- assert_eq ! ( data_range. start, unsafe { view. data( ) } as usize ) ;
735+ assert_eq ! ( data_range. start, view. data( ) as usize ) ;
698736 assert_eq ! ( data_range. end, unsafe { view. data( ) . add( 12 ) } as usize ) ;
699737 } ) ;
700738 }
@@ -724,7 +762,7 @@ mod tests {
724762 assert_eq ! ( base_address, base as usize ) ;
725763
726764 let data_range = data_range ( view) ;
727- assert_eq ! ( data_range. start, unsafe { view. data( ) } as usize ) ;
765+ assert_eq ! ( data_range. start, view. data( ) as usize ) ;
728766 assert_eq ! ( data_range. end, unsafe { view. data( ) . add( 12 ) } as usize ) ;
729767 } ) ;
730768 }
@@ -763,7 +801,7 @@ mod tests {
763801 assert_eq ! ( base_address, base as usize ) ;
764802
765803 let data_range = data_range ( view2) ;
766- assert_eq ! ( data_range. start, unsafe { view2. data( ) } as usize ) ;
804+ assert_eq ! ( data_range. start, view2. data( ) as usize ) ;
767805 assert_eq ! ( data_range. end, unsafe { view2. data( ) . add( 6 ) } as usize ) ;
768806 } ) ;
769807 }
@@ -806,7 +844,7 @@ mod tests {
806844 assert_eq ! ( base_address, base as usize ) ;
807845
808846 let data_range = data_range ( view2) ;
809- assert_eq ! ( data_range. start, unsafe { view2. data( ) } as usize ) ;
847+ assert_eq ! ( data_range. start, view2. data( ) as usize ) ;
810848 assert_eq ! ( data_range. end, unsafe { view2. data( ) . add( 6 ) } as usize ) ;
811849 } ) ;
812850 }
0 commit comments