diff --git a/guide/pyclass_parameters.md b/guide/pyclass_parameters.md index edc69eb3cf5..35c54147df5 100644 --- a/guide/pyclass_parameters.md +++ b/guide/pyclass_parameters.md @@ -16,7 +16,7 @@ | `set_all` | Generates setters for all fields of the pyclass. | | `subclass` | Allows other Python classes and `#[pyclass]` to inherit from this class. Enums cannot be subclassed. | | `text_signature = "(arg1, arg2, ...)"` | Sets the text signature for the Python class' `__new__` method. | -| `unsendable` | Required if your struct is not [`Send`][params-3]. Rather than using `unsendable`, consider implementing your struct in a threadsafe way by e.g. substituting [`Rc`][params-4] with [`Arc`][params-5]. By using `unsendable`, your class will panic when accessed by another thread.| +| `unsendable` | Required if your struct is not [`Send`][params-3]. Rather than using `unsendable`, consider implementing your struct in a threadsafe way by e.g. substituting [`Rc`][params-4] with [`Arc`][params-5]. By using `unsendable`, your class will panic when accessed by another thread. Also note the Python's GC is multi-threaded and while unsendable classes will not be traversed on foreign threads to avoid UB, this can lead to memory leaks. | | `weakref` | Allows this class to be [weakly referenceable][params-6]. | All of these parameters can either be passed directly on the `#[pyclass(...)]` annotation, or as one or diff --git a/newsfragments/3689.changed.md b/newsfragments/3689.changed.md new file mode 100644 index 00000000000..30928e82f64 --- /dev/null +++ b/newsfragments/3689.changed.md @@ -0,0 +1 @@ +Calls to `__traverse__` become no-ops for unsendable pyclasses if on the wrong thread, thereby avoiding hard aborts at the cost of potential leakage. diff --git a/src/impl_/pyclass.rs b/src/impl_/pyclass.rs index 3941dfcb3e7..5ee67dc998d 100644 --- a/src/impl_/pyclass.rs +++ b/src/impl_/pyclass.rs @@ -1013,6 +1013,7 @@ impl PyClassNewTextSignature for &'_ PyClassImplCollector { #[doc(hidden)] pub trait PyClassThreadChecker: Sized { fn ensure(&self); + fn check(&self) -> bool; fn can_drop(&self, py: Python<'_>) -> bool; fn new() -> Self; private_decl! {} @@ -1028,6 +1029,9 @@ pub struct SendablePyClass(PhantomData); impl PyClassThreadChecker for SendablePyClass { fn ensure(&self) {} + fn check(&self) -> bool { + true + } fn can_drop(&self, _py: Python<'_>) -> bool { true } @@ -1053,6 +1057,10 @@ impl ThreadCheckerImpl { ); } + fn check(&self) -> bool { + thread::current().id() == self.0 + } + fn can_drop(&self, py: Python<'_>, type_name: &'static str) -> bool { if thread::current().id() != self.0 { PyRuntimeError::new_err(format!( @@ -1071,6 +1079,9 @@ impl PyClassThreadChecker for ThreadCheckerImpl { fn ensure(&self) { self.ensure(std::any::type_name::()); } + fn check(&self) -> bool { + self.check() + } fn can_drop(&self, py: Python<'_>) -> bool { self.can_drop(py, std::any::type_name::()) } diff --git a/src/impl_/pymethods.rs b/src/impl_/pymethods.rs index f2d816bba8d..e403aa23c79 100644 --- a/src/impl_/pymethods.rs +++ b/src/impl_/pymethods.rs @@ -269,7 +269,7 @@ where let py = Python::assume_gil_acquired(); let slf = py.from_borrowed_ptr::>(slf); - let borrow = slf.try_borrow(); + let borrow = slf.try_borrow_threadsafe(); let visit = PyVisit::from_raw(visit, arg, py); let retval = if let Ok(borrow) = borrow { diff --git a/src/pycell.rs b/src/pycell.rs index 3bc80a7eb07..bde95ad8313 100644 --- a/src/pycell.rs +++ b/src/pycell.rs @@ -351,6 +351,14 @@ impl PyCell { .map(|_| PyRef { inner: self }) } + /// Variant of [`try_borrow`][Self::try_borrow] which fails instead of panicking if called from the wrong thread + pub(crate) fn try_borrow_threadsafe(&self) -> Result, PyBorrowError> { + self.check_threadsafe()?; + self.borrow_checker() + .try_borrow() + .map(|_| PyRef { inner: self }) + } + /// Mutably borrows the value `T`, returning an error if the value is currently borrowed. /// This borrow lasts as long as the returned `PyRefMut` exists. /// @@ -975,6 +983,7 @@ impl From for PyErr { #[doc(hidden)] pub trait PyCellLayout: PyLayout { fn ensure_threadsafe(&self); + fn check_threadsafe(&self) -> Result<(), PyBorrowError>; /// Implementation of tp_dealloc. /// # Safety /// - slf must be a valid pointer to an instance of a T or a subclass. @@ -988,6 +997,9 @@ where T: PyTypeInfo, { fn ensure_threadsafe(&self) {} + fn check_threadsafe(&self) -> Result<(), PyBorrowError> { + Ok(()) + } unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject) { let type_obj = T::type_object_raw(py); // For `#[pyclass]` types which inherit from PyAny, we can just call tp_free @@ -1025,6 +1037,12 @@ where self.contents.thread_checker.ensure(); self.ob_base.ensure_threadsafe(); } + fn check_threadsafe(&self) -> Result<(), PyBorrowError> { + if !self.contents.thread_checker.check() { + return Err(PyBorrowError { _private: () }); + } + self.ob_base.check_threadsafe() + } unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject) { // Safety: Python only calls tp_dealloc when no references to the object remain. let cell = &mut *(slf as *mut PyCell); diff --git a/tests/test_gc.rs b/tests/test_gc.rs index 8fd4622f65e..54c3e1a100c 100644 --- a/tests/test_gc.rs +++ b/tests/test_gc.rs @@ -512,6 +512,55 @@ fn drop_during_traversal_without_gil() { assert!(drop_called.load(Ordering::Relaxed)); } +#[pyclass(unsendable)] +struct UnsendableTraversal { + traversed: Cell, +} + +#[pymethods] +impl UnsendableTraversal { + fn __clear__(&mut self) {} + + #[allow(clippy::unnecessary_wraps)] + fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> { + self.traversed.set(true); + Ok(()) + } +} + +#[test] +#[cfg(not(target_arch = "wasm32"))] // We are building wasm Python with pthreads disabled +fn unsendable_are_not_traversed_on_foreign_thread() { + Python::with_gil(|py| unsafe { + let ty = py.get_type::().as_type_ptr(); + let traverse = get_type_traverse(ty).unwrap(); + + let obj = Py::new( + py, + UnsendableTraversal { + traversed: Cell::new(false), + }, + ) + .unwrap(); + + let ptr = SendablePtr(obj.as_ptr()); + + std::thread::spawn(move || { + // traversal on foreign thread is a no-op + assert_eq!(traverse({ ptr }.0, novisit, std::ptr::null_mut()), 0); + }) + .join() + .unwrap(); + + assert!(!obj.borrow(py).traversed.get()); + + // traversal on home thread still works + assert_eq!(traverse({ ptr }.0, novisit, std::ptr::null_mut()), 0); + + assert!(obj.borrow(py).traversed.get()); + }); +} + // Manual traversal utilities unsafe fn get_type_traverse(tp: *mut pyo3::ffi::PyTypeObject) -> Option { @@ -533,3 +582,8 @@ extern "C" fn visit_error( ) -> std::os::raw::c_int { -1 } + +#[derive(Clone, Copy)] +struct SendablePtr(*mut pyo3::ffi::PyObject); + +unsafe impl Send for SendablePtr {}