Skip to content

Commit

Permalink
Merge pull request #3689 from PyO3/unsendable-threadsafe-traverse
Browse files Browse the repository at this point in the history
Turn calls of __traverse__ into no-ops for unsendable pyclass if on the wrong thread
  • Loading branch information
adamreichold authored Dec 23, 2023
2 parents 65f25d4 + 4dc6c16 commit 8bef6e3
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 2 deletions.
2 changes: 1 addition & 1 deletion guide/pyclass_parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
| `set_all` | Generates setters for all fields of the pyclass. |
| `subclass` | Allows other Python classes and `#[pyclass]` to inherit from this class. Enums cannot be subclassed. |
| <span style="white-space: pre">`text_signature = "(arg1, arg2, ...)"`</span> | Sets the text signature for the Python class' `__new__` method. |
| `unsendable` | Required if your struct is not [`Send`][params-3]. Rather than using `unsendable`, consider implementing your struct in a threadsafe way by e.g. substituting [`Rc`][params-4] with [`Arc`][params-5]. By using `unsendable`, your class will panic when accessed by another thread.|
| `unsendable` | Required if your struct is not [`Send`][params-3]. Rather than using `unsendable`, consider implementing your struct in a threadsafe way by e.g. substituting [`Rc`][params-4] with [`Arc`][params-5]. By using `unsendable`, your class will panic when accessed by another thread. Also note the Python's GC is multi-threaded and while unsendable classes will not be traversed on foreign threads to avoid UB, this can lead to memory leaks. |
| `weakref` | Allows this class to be [weakly referenceable][params-6]. |

All of these parameters can either be passed directly on the `#[pyclass(...)]` annotation, or as one or
Expand Down
1 change: 1 addition & 0 deletions newsfragments/3689.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Calls to `__traverse__` become no-ops for unsendable pyclasses if on the wrong thread, thereby avoiding hard aborts at the cost of potential leakage.
11 changes: 11 additions & 0 deletions src/impl_/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1013,6 +1013,7 @@ impl<T> PyClassNewTextSignature<T> for &'_ PyClassImplCollector<T> {
#[doc(hidden)]
pub trait PyClassThreadChecker<T>: Sized {
fn ensure(&self);
fn check(&self) -> bool;
fn can_drop(&self, py: Python<'_>) -> bool;
fn new() -> Self;
private_decl! {}
Expand All @@ -1028,6 +1029,9 @@ pub struct SendablePyClass<T: Send>(PhantomData<T>);

impl<T: Send> PyClassThreadChecker<T> for SendablePyClass<T> {
fn ensure(&self) {}
fn check(&self) -> bool {
true
}
fn can_drop(&self, _py: Python<'_>) -> bool {
true
}
Expand All @@ -1053,6 +1057,10 @@ impl ThreadCheckerImpl {
);
}

fn check(&self) -> bool {
thread::current().id() == self.0
}

fn can_drop(&self, py: Python<'_>, type_name: &'static str) -> bool {
if thread::current().id() != self.0 {
PyRuntimeError::new_err(format!(
Expand All @@ -1071,6 +1079,9 @@ impl<T> PyClassThreadChecker<T> for ThreadCheckerImpl {
fn ensure(&self) {
self.ensure(std::any::type_name::<T>());
}
fn check(&self) -> bool {
self.check()
}
fn can_drop(&self, py: Python<'_>) -> bool {
self.can_drop(py, std::any::type_name::<T>())
}
Expand Down
2 changes: 1 addition & 1 deletion src/impl_/pymethods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ where

let py = Python::assume_gil_acquired();
let slf = py.from_borrowed_ptr::<PyCell<T>>(slf);
let borrow = slf.try_borrow();
let borrow = slf.try_borrow_threadsafe();
let visit = PyVisit::from_raw(visit, arg, py);

let retval = if let Ok(borrow) = borrow {
Expand Down
18 changes: 18 additions & 0 deletions src/pycell.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,14 @@ impl<T: PyClass> PyCell<T> {
.map(|_| PyRef { inner: self })
}

/// Variant of [`try_borrow`][Self::try_borrow] which fails instead of panicking if called from the wrong thread
pub(crate) fn try_borrow_threadsafe(&self) -> Result<PyRef<'_, T>, PyBorrowError> {
self.check_threadsafe()?;
self.borrow_checker()
.try_borrow()
.map(|_| PyRef { inner: self })
}

/// Mutably borrows the value `T`, returning an error if the value is currently borrowed.
/// This borrow lasts as long as the returned `PyRefMut` exists.
///
Expand Down Expand Up @@ -975,6 +983,7 @@ impl From<PyBorrowMutError> for PyErr {
#[doc(hidden)]
pub trait PyCellLayout<T>: PyLayout<T> {
fn ensure_threadsafe(&self);
fn check_threadsafe(&self) -> Result<(), PyBorrowError>;
/// Implementation of tp_dealloc.
/// # Safety
/// - slf must be a valid pointer to an instance of a T or a subclass.
Expand All @@ -988,6 +997,9 @@ where
T: PyTypeInfo,
{
fn ensure_threadsafe(&self) {}
fn check_threadsafe(&self) -> Result<(), PyBorrowError> {
Ok(())
}
unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject) {
let type_obj = T::type_object_raw(py);
// For `#[pyclass]` types which inherit from PyAny, we can just call tp_free
Expand Down Expand Up @@ -1025,6 +1037,12 @@ where
self.contents.thread_checker.ensure();
self.ob_base.ensure_threadsafe();
}
fn check_threadsafe(&self) -> Result<(), PyBorrowError> {
if !self.contents.thread_checker.check() {
return Err(PyBorrowError { _private: () });
}
self.ob_base.check_threadsafe()
}
unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject) {
// Safety: Python only calls tp_dealloc when no references to the object remain.
let cell = &mut *(slf as *mut PyCell<T>);
Expand Down
54 changes: 54 additions & 0 deletions tests/test_gc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,55 @@ fn drop_during_traversal_without_gil() {
assert!(drop_called.load(Ordering::Relaxed));
}

#[pyclass(unsendable)]
struct UnsendableTraversal {
traversed: Cell<bool>,
}

#[pymethods]
impl UnsendableTraversal {
fn __clear__(&mut self) {}

#[allow(clippy::unnecessary_wraps)]
fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
self.traversed.set(true);
Ok(())
}
}

#[test]
#[cfg(not(target_arch = "wasm32"))] // We are building wasm Python with pthreads disabled
fn unsendable_are_not_traversed_on_foreign_thread() {
Python::with_gil(|py| unsafe {
let ty = py.get_type::<UnsendableTraversal>().as_type_ptr();
let traverse = get_type_traverse(ty).unwrap();

let obj = Py::new(
py,
UnsendableTraversal {
traversed: Cell::new(false),
},
)
.unwrap();

let ptr = SendablePtr(obj.as_ptr());

std::thread::spawn(move || {
// traversal on foreign thread is a no-op
assert_eq!(traverse({ ptr }.0, novisit, std::ptr::null_mut()), 0);
})
.join()
.unwrap();

assert!(!obj.borrow(py).traversed.get());

// traversal on home thread still works
assert_eq!(traverse({ ptr }.0, novisit, std::ptr::null_mut()), 0);

assert!(obj.borrow(py).traversed.get());
});
}

// Manual traversal utilities

unsafe fn get_type_traverse(tp: *mut pyo3::ffi::PyTypeObject) -> Option<pyo3::ffi::traverseproc> {
Expand All @@ -533,3 +582,8 @@ extern "C" fn visit_error(
) -> std::os::raw::c_int {
-1
}

#[derive(Clone, Copy)]
struct SendablePtr(*mut pyo3::ffi::PyObject);

unsafe impl Send for SendablePtr {}

0 comments on commit 8bef6e3

Please sign in to comment.