-
Notifications
You must be signed in to change notification settings - Fork 118
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
WIP: Add dynamic borrow checking for dereferencing NumPy arrays.
- Loading branch information
1 parent
61882e3
commit 67f8600
Showing
7 changed files
with
292 additions
and
363 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
use std::cell::UnsafeCell; | ||
use std::collections::hash_map::{Entry, HashMap}; | ||
use std::ops::Deref; | ||
|
||
use ndarray::{ArrayView, ArrayViewMut, Dimension, Ix1, Ix2, IxDyn}; | ||
use pyo3::{FromPyObject, PyAny, PyResult}; | ||
|
||
use crate::array::PyArray; | ||
use crate::dtype::Element; | ||
use crate::error::NotContiguousError; | ||
|
||
thread_local! { | ||
static BORROW_FLAGS: UnsafeCell<HashMap<usize, isize>> = UnsafeCell::new(HashMap::new()); | ||
} | ||
|
||
pub struct PyReadonlyArray<'py, T, D>(&'py PyArray<T, D>); | ||
|
||
pub type PyReadonlyArray1<'py, T> = PyReadonlyArray<'py, T, Ix1>; | ||
|
||
pub type PyReadonlyArray2<'py, T> = PyReadonlyArray<'py, T, Ix2>; | ||
|
||
pub type PyReadonlyArrayDyn<'py, T> = PyReadonlyArray<'py, T, IxDyn>; | ||
|
||
impl<'py, T, D> Deref for PyReadonlyArray<'py, T, D> { | ||
type Target = PyArray<T, D>; | ||
|
||
fn deref(&self) -> &Self::Target { | ||
self.0 | ||
} | ||
} | ||
|
||
impl<'py, T: Element, D: Dimension> FromPyObject<'py> for PyReadonlyArray<'py, T, D> { | ||
fn extract(obj: &'py PyAny) -> PyResult<Self> { | ||
let array: &'py PyArray<T, D> = obj.extract()?; | ||
Ok(array.readonly()) | ||
} | ||
} | ||
|
||
impl<'py, T, D> PyReadonlyArray<'py, T, D> | ||
where | ||
T: Element, | ||
D: Dimension, | ||
{ | ||
pub(crate) fn try_new(array: &'py PyArray<T, D>) -> Option<Self> { | ||
let address = array as *const PyArray<T, D> as usize; | ||
|
||
BORROW_FLAGS.with(|borrow_flags| { | ||
// SAFETY: Called on a thread local variable in a leaf function. | ||
let borrow_flags = unsafe { &mut *borrow_flags.get() }; | ||
|
||
match borrow_flags.entry(address) { | ||
Entry::Occupied(entry) => { | ||
let readers = entry.into_mut(); | ||
|
||
let new_readers = readers.wrapping_add(1); | ||
|
||
if new_readers <= 0 { | ||
cold(); | ||
return None; | ||
} | ||
|
||
*readers = new_readers; | ||
} | ||
Entry::Vacant(entry) => { | ||
entry.insert(1); | ||
} | ||
} | ||
|
||
Some(Self(array)) | ||
}) | ||
} | ||
|
||
pub fn as_array(&self) -> ArrayView<T, D> { | ||
// SAFETY: Thread-local borrow flags ensure aliasing discipline on this thread, | ||
// and `PyArray` is neither `Send` nor `Sync` | ||
unsafe { self.0.as_array() } | ||
} | ||
|
||
pub fn as_slice(&self) -> Result<&[T], NotContiguousError> { | ||
// SAFETY: Thread-local borrow flags ensure aliasing discipline on this thread, | ||
// and `PyArray` is neither `Send` nor `Sync` | ||
unsafe { self.0.as_slice() } | ||
} | ||
} | ||
|
||
impl<'a, T, D> Drop for PyReadonlyArray<'a, T, D> { | ||
fn drop(&mut self) { | ||
let address = self.0 as *const PyArray<T, D> as usize; | ||
|
||
BORROW_FLAGS.with(|borrow_flags| { | ||
// SAFETY: Called on a thread local variable in a leaf function. | ||
let borrow_flags = unsafe { &mut *borrow_flags.get() }; | ||
|
||
let readers = borrow_flags.get_mut(&address).unwrap(); | ||
|
||
*readers -= 1; | ||
|
||
if *readers == 0 { | ||
borrow_flags.remove(&address).unwrap(); | ||
} | ||
}); | ||
} | ||
} | ||
|
||
pub struct PyReadwriteArray<'py, T, D>(&'py PyArray<T, D>); | ||
|
||
pub type PyReadwriteArrayDyn<'py, T> = PyReadwriteArray<'py, T, IxDyn>; | ||
|
||
impl<'py, T, D> Deref for PyReadwriteArray<'py, T, D> { | ||
type Target = PyArray<T, D>; | ||
|
||
fn deref(&self) -> &Self::Target { | ||
self.0 | ||
} | ||
} | ||
|
||
impl<'py, T: Element, D: Dimension> FromPyObject<'py> for PyReadwriteArray<'py, T, D> { | ||
fn extract(obj: &'py PyAny) -> PyResult<Self> { | ||
let array: &'py PyArray<T, D> = obj.extract()?; | ||
Ok(array.readwrite()) | ||
} | ||
} | ||
|
||
impl<'py, T, D> PyReadwriteArray<'py, T, D> | ||
where | ||
T: Element, | ||
D: Dimension, | ||
{ | ||
pub(crate) fn try_new(array: &'py PyArray<T, D>) -> Option<Self> { | ||
let address = array as *const PyArray<T, D> as usize; | ||
|
||
BORROW_FLAGS.with(|borrow_flags| { | ||
// SAFETY: Called on a thread local variable in a leaf function. | ||
let borrow_flags = unsafe { &mut *borrow_flags.get() }; | ||
|
||
match borrow_flags.entry(address) { | ||
Entry::Occupied(entry) => { | ||
let writers = entry.into_mut(); | ||
|
||
if *writers != 0 { | ||
cold(); | ||
return None; | ||
} | ||
|
||
*writers = -1; | ||
} | ||
Entry::Vacant(entry) => { | ||
entry.insert(-1); | ||
} | ||
} | ||
|
||
Some(Self(array)) | ||
}) | ||
} | ||
|
||
pub fn as_array(&self) -> ArrayView<T, D> { | ||
// SAFETY: Thread-local borrow flags ensure aliasing discipline on this thread, | ||
// and `PyArray` is neither `Send` nor `Sync` | ||
unsafe { self.0.as_array() } | ||
} | ||
|
||
pub fn as_slice(&self) -> Result<&[T], NotContiguousError> { | ||
// SAFETY: Thread-local borrow flags ensure aliasing discipline on this thread, | ||
// and `PyArray` is neither `Send` nor `Sync` | ||
unsafe { self.0.as_slice() } | ||
} | ||
|
||
pub fn as_array_mut(&mut self) -> ArrayViewMut<T, D> { | ||
// SAFETY: Thread-local borrow flags ensure aliasing discipline on this thread, | ||
// and `PyArray` is neither `Send` nor `Sync` | ||
unsafe { self.0.as_array_mut() } | ||
} | ||
|
||
pub fn as_slice_mut(&self) -> Result<&mut [T], NotContiguousError> { | ||
// SAFETY: Thread-local borrow flags ensure aliasing discipline on this thread, | ||
// and `PyArray` is neither `Send` nor `Sync` | ||
unsafe { self.0.as_slice_mut() } | ||
} | ||
} | ||
|
||
impl<'a, T, D> Drop for PyReadwriteArray<'a, T, D> { | ||
fn drop(&mut self) { | ||
let address = self.0 as *const PyArray<T, D> as usize; | ||
|
||
BORROW_FLAGS.with(|borrow_flags| { | ||
// SAFETY: Called on a thread local variable in a leaf function. | ||
let borrow_flags = unsafe { &mut *borrow_flags.get() }; | ||
|
||
borrow_flags.remove(&address).unwrap(); | ||
}); | ||
} | ||
} | ||
#[cold] | ||
#[inline(always)] | ||
fn cold() {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.