diff --git a/guide/pyclass_parameters.md b/guide/pyclass_parameters.md
index edc69eb3cf5..35c54147df5 100644
--- a/guide/pyclass_parameters.md
+++ b/guide/pyclass_parameters.md
@@ -16,7 +16,7 @@
| `set_all` | Generates setters for all fields of the pyclass. |
| `subclass` | Allows other Python classes and `#[pyclass]` to inherit from this class. Enums cannot be subclassed. |
| `text_signature = "(arg1, arg2, ...)"` | Sets the text signature for the Python class' `__new__` method. |
-| `unsendable` | Required if your struct is not [`Send`][params-3]. Rather than using `unsendable`, consider implementing your struct in a threadsafe way by e.g. substituting [`Rc`][params-4] with [`Arc`][params-5]. By using `unsendable`, your class will panic when accessed by another thread.|
+| `unsendable` | Required if your struct is not [`Send`][params-3]. Rather than using `unsendable`, consider implementing your struct in a threadsafe way by e.g. substituting [`Rc`][params-4] with [`Arc`][params-5]. By using `unsendable`, your class will panic when accessed by another thread. Also note the Python's GC is multi-threaded and while unsendable classes will not be traversed on foreign threads to avoid UB, this can lead to memory leaks. |
| `weakref` | Allows this class to be [weakly referenceable][params-6]. |
All of these parameters can either be passed directly on the `#[pyclass(...)]` annotation, or as one or
diff --git a/newsfragments/3689.changed.md b/newsfragments/3689.changed.md
new file mode 100644
index 00000000000..30928e82f64
--- /dev/null
+++ b/newsfragments/3689.changed.md
@@ -0,0 +1 @@
+Calls to `__traverse__` become no-ops for unsendable pyclasses if on the wrong thread, thereby avoiding hard aborts at the cost of potential leakage.
diff --git a/src/impl_/pyclass.rs b/src/impl_/pyclass.rs
index 3941dfcb3e7..5ee67dc998d 100644
--- a/src/impl_/pyclass.rs
+++ b/src/impl_/pyclass.rs
@@ -1013,6 +1013,7 @@ impl PyClassNewTextSignature for &'_ PyClassImplCollector {
#[doc(hidden)]
pub trait PyClassThreadChecker: Sized {
fn ensure(&self);
+ fn check(&self) -> bool;
fn can_drop(&self, py: Python<'_>) -> bool;
fn new() -> Self;
private_decl! {}
@@ -1028,6 +1029,9 @@ pub struct SendablePyClass(PhantomData);
impl PyClassThreadChecker for SendablePyClass {
fn ensure(&self) {}
+ fn check(&self) -> bool {
+ true
+ }
fn can_drop(&self, _py: Python<'_>) -> bool {
true
}
@@ -1053,6 +1057,10 @@ impl ThreadCheckerImpl {
);
}
+ fn check(&self) -> bool {
+ thread::current().id() == self.0
+ }
+
fn can_drop(&self, py: Python<'_>, type_name: &'static str) -> bool {
if thread::current().id() != self.0 {
PyRuntimeError::new_err(format!(
@@ -1071,6 +1079,9 @@ impl PyClassThreadChecker for ThreadCheckerImpl {
fn ensure(&self) {
self.ensure(std::any::type_name::());
}
+ fn check(&self) -> bool {
+ self.check()
+ }
fn can_drop(&self, py: Python<'_>) -> bool {
self.can_drop(py, std::any::type_name::())
}
diff --git a/src/impl_/pymethods.rs b/src/impl_/pymethods.rs
index f2d816bba8d..e403aa23c79 100644
--- a/src/impl_/pymethods.rs
+++ b/src/impl_/pymethods.rs
@@ -269,7 +269,7 @@ where
let py = Python::assume_gil_acquired();
let slf = py.from_borrowed_ptr::>(slf);
- let borrow = slf.try_borrow();
+ let borrow = slf.try_borrow_threadsafe();
let visit = PyVisit::from_raw(visit, arg, py);
let retval = if let Ok(borrow) = borrow {
diff --git a/src/pycell.rs b/src/pycell.rs
index 3bc80a7eb07..bde95ad8313 100644
--- a/src/pycell.rs
+++ b/src/pycell.rs
@@ -351,6 +351,14 @@ impl PyCell {
.map(|_| PyRef { inner: self })
}
+ /// Variant of [`try_borrow`][Self::try_borrow] which fails instead of panicking if called from the wrong thread
+ pub(crate) fn try_borrow_threadsafe(&self) -> Result, PyBorrowError> {
+ self.check_threadsafe()?;
+ self.borrow_checker()
+ .try_borrow()
+ .map(|_| PyRef { inner: self })
+ }
+
/// Mutably borrows the value `T`, returning an error if the value is currently borrowed.
/// This borrow lasts as long as the returned `PyRefMut` exists.
///
@@ -975,6 +983,7 @@ impl From for PyErr {
#[doc(hidden)]
pub trait PyCellLayout: PyLayout {
fn ensure_threadsafe(&self);
+ fn check_threadsafe(&self) -> Result<(), PyBorrowError>;
/// Implementation of tp_dealloc.
/// # Safety
/// - slf must be a valid pointer to an instance of a T or a subclass.
@@ -988,6 +997,9 @@ where
T: PyTypeInfo,
{
fn ensure_threadsafe(&self) {}
+ fn check_threadsafe(&self) -> Result<(), PyBorrowError> {
+ Ok(())
+ }
unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject) {
let type_obj = T::type_object_raw(py);
// For `#[pyclass]` types which inherit from PyAny, we can just call tp_free
@@ -1025,6 +1037,12 @@ where
self.contents.thread_checker.ensure();
self.ob_base.ensure_threadsafe();
}
+ fn check_threadsafe(&self) -> Result<(), PyBorrowError> {
+ if !self.contents.thread_checker.check() {
+ return Err(PyBorrowError { _private: () });
+ }
+ self.ob_base.check_threadsafe()
+ }
unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject) {
// Safety: Python only calls tp_dealloc when no references to the object remain.
let cell = &mut *(slf as *mut PyCell);
diff --git a/tests/test_gc.rs b/tests/test_gc.rs
index 8fd4622f65e..54c3e1a100c 100644
--- a/tests/test_gc.rs
+++ b/tests/test_gc.rs
@@ -512,6 +512,55 @@ fn drop_during_traversal_without_gil() {
assert!(drop_called.load(Ordering::Relaxed));
}
+#[pyclass(unsendable)]
+struct UnsendableTraversal {
+ traversed: Cell,
+}
+
+#[pymethods]
+impl UnsendableTraversal {
+ fn __clear__(&mut self) {}
+
+ #[allow(clippy::unnecessary_wraps)]
+ fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
+ self.traversed.set(true);
+ Ok(())
+ }
+}
+
+#[test]
+#[cfg(not(target_arch = "wasm32"))] // We are building wasm Python with pthreads disabled
+fn unsendable_are_not_traversed_on_foreign_thread() {
+ Python::with_gil(|py| unsafe {
+ let ty = py.get_type::().as_type_ptr();
+ let traverse = get_type_traverse(ty).unwrap();
+
+ let obj = Py::new(
+ py,
+ UnsendableTraversal {
+ traversed: Cell::new(false),
+ },
+ )
+ .unwrap();
+
+ let ptr = SendablePtr(obj.as_ptr());
+
+ std::thread::spawn(move || {
+ // traversal on foreign thread is a no-op
+ assert_eq!(traverse({ ptr }.0, novisit, std::ptr::null_mut()), 0);
+ })
+ .join()
+ .unwrap();
+
+ assert!(!obj.borrow(py).traversed.get());
+
+ // traversal on home thread still works
+ assert_eq!(traverse({ ptr }.0, novisit, std::ptr::null_mut()), 0);
+
+ assert!(obj.borrow(py).traversed.get());
+ });
+}
+
// Manual traversal utilities
unsafe fn get_type_traverse(tp: *mut pyo3::ffi::PyTypeObject) -> Option {
@@ -533,3 +582,8 @@ extern "C" fn visit_error(
) -> std::os::raw::c_int {
-1
}
+
+#[derive(Clone, Copy)]
+struct SendablePtr(*mut pyo3::ffi::PyObject);
+
+unsafe impl Send for SendablePtr {}