From c2f8114d8679d023f2fb5f1f734e169252f5033c Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Wed, 28 Aug 2024 22:37:18 +0100 Subject: [PATCH] avoid creating `PyRef` inside `__traverse__` handler (#4479) --- Cargo.toml | 1 + newsfragments/4479.fixed.md | 1 + src/impl_/pymethods.rs | 81 +++++++++++++++++++++++++++++++------ src/pycell.rs | 8 ---- src/pyclass.rs | 1 + src/pyclass/gc.rs | 71 ++++++++++++++++++++++---------- 6 files changed, 120 insertions(+), 43 deletions(-) create mode 100644 newsfragments/4479.fixed.md diff --git a/Cargo.toml b/Cargo.toml index c7446c7da5d..29e131dd827 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,6 +61,7 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0.61" rayon = "1.6.1" futures = "0.3.28" +static_assertions = "1.1.0" [build-dependencies] pyo3-build-config = { path = "pyo3-build-config", version = "=0.22.2", features = ["resolve-config"] } diff --git a/newsfragments/4479.fixed.md b/newsfragments/4479.fixed.md new file mode 100644 index 00000000000..15d634543af --- /dev/null +++ b/newsfragments/4479.fixed.md @@ -0,0 +1 @@ +Remove illegal reference counting op inside implementation of `__traverse__` handlers. diff --git a/src/impl_/pymethods.rs b/src/impl_/pymethods.rs index 60b655e5647..76b71a3e188 100644 --- a/src/impl_/pymethods.rs +++ b/src/impl_/pymethods.rs @@ -2,17 +2,20 @@ use crate::callback::IntoPyCallbackOutput; use crate::exceptions::PyStopAsyncIteration; use crate::gil::LockGIL; use crate::impl_::panic::PanicTrap; +use crate::impl_::pycell::{PyClassObject, PyClassObjectLayout}; +use crate::pycell::impl_::PyClassBorrowChecker as _; use crate::pycell::{PyBorrowError, PyBorrowMutError}; use crate::pyclass::boolean_struct::False; use crate::types::any::PyAnyMethods; #[cfg(feature = "gil-refs")] use crate::types::{PyModule, PyType}; use crate::{ - ffi, Borrowed, Bound, DowncastError, Py, PyAny, PyClass, PyClassInitializer, PyErr, PyObject, - PyRef, PyRefMut, PyResult, PyTraverseError, PyTypeCheck, PyVisit, Python, + ffi, Bound, DowncastError, Py, PyAny, PyClass, PyClassInitializer, PyErr, PyObject, PyRef, + PyRefMut, PyResult, PyTraverseError, PyTypeCheck, PyVisit, Python, }; use std::ffi::CStr; use std::fmt; +use std::marker::PhantomData; use std::os::raw::{c_int, c_void}; use std::panic::{catch_unwind, AssertUnwindSafe}; use std::ptr::null_mut; @@ -234,6 +237,40 @@ impl PySetterDef { } /// Calls an implementation of __traverse__ for tp_traverse +/// +/// NB cannot accept `'static` visitor, this is a sanity check below: +/// +/// ```rust,compile_fail +/// use pyo3::prelude::*; +/// use pyo3::pyclass::{PyTraverseError, PyVisit}; +/// +/// #[pyclass] +/// struct Foo; +/// +/// #[pymethods] +/// impl Foo { +/// fn __traverse__(&self, _visit: PyVisit<'static>) -> Result<(), PyTraverseError> { +/// Ok(()) +/// } +/// } +/// ``` +/// +/// Elided lifetime should compile ok: +/// +/// ```rust +/// use pyo3::prelude::*; +/// use pyo3::pyclass::{PyTraverseError, PyVisit}; +/// +/// #[pyclass] +/// struct Foo; +/// +/// #[pymethods] +/// impl Foo { +/// fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> { +/// Ok(()) +/// } +/// } +/// ``` #[doc(hidden)] pub unsafe fn _call_traverse( slf: *mut ffi::PyObject, @@ -252,25 +289,43 @@ where // Since we do not create a `GILPool` at all, it is important that our usage of the GIL // token does not produce any owned objects thereby calling into `register_owned`. let trap = PanicTrap::new("uncaught panic inside __traverse__ handler"); + let lock = LockGIL::during_traverse(); + + // SAFETY: `slf` is a valid Python object pointer to a class object of type T, and + // traversal is running so no mutations can occur. + let class_object: &PyClassObject = &*slf.cast(); + + let retval = + // `#[pyclass(unsendable)]` types can only be deallocated by their own thread, so + // do not traverse them if not on their owning thread :( + if class_object.check_threadsafe().is_ok() + // ... and we cannot traverse a type which might be being mutated by a Rust thread + && class_object.borrow_checker().try_borrow().is_ok() { + struct TraverseGuard<'a, T: PyClass>(&'a PyClassObject); + impl<'a, T: PyClass> Drop for TraverseGuard<'a, T> { + fn drop(&mut self) { + self.0.borrow_checker().release_borrow() + } + } - let py = Python::assume_gil_acquired(); - let slf = Borrowed::from_ptr_unchecked(py, slf).downcast_unchecked::(); - let borrow = PyRef::try_borrow_threadsafe(&slf); - let visit = PyVisit::from_raw(visit, arg, py); + // `.try_borrow()` above created a borrow, we need to release it when we're done + // traversing the object. This allows us to read `instance` safely. + let _guard = TraverseGuard(class_object); + let instance = &*class_object.contents.value.get(); - let retval = if let Ok(borrow) = borrow { - let _lock = LockGIL::during_traverse(); + let visit = PyVisit { visit, arg, _guard: PhantomData }; - match catch_unwind(AssertUnwindSafe(move || impl_(&*borrow, visit))) { - Ok(res) => match res { - Ok(()) => 0, - Err(PyTraverseError(value)) => value, - }, + match catch_unwind(AssertUnwindSafe(move || impl_(instance, visit))) { + Ok(Ok(())) => 0, + Ok(Err(traverse_error)) => traverse_error.into_inner(), Err(_err) => -1, } } else { 0 }; + + // Drop lock before trap just in case dropping lock panics + drop(lock); trap.disarm(); retval } diff --git a/src/pycell.rs b/src/pycell.rs index 77d174cb9e1..c9fe9aad4ed 100644 --- a/src/pycell.rs +++ b/src/pycell.rs @@ -673,14 +673,6 @@ impl<'py, T: PyClass> PyRef<'py, T> { .try_borrow() .map(|_| Self { inner: obj.clone() }) } - - pub(crate) fn try_borrow_threadsafe(obj: &Bound<'py, T>) -> Result { - let cell = obj.get_class_object(); - cell.check_threadsafe()?; - cell.borrow_checker() - .try_borrow() - .map(|_| Self { inner: obj.clone() }) - } } impl<'p, T, U> PyRef<'p, T> diff --git a/src/pyclass.rs b/src/pyclass.rs index 29cd1251974..4ba30b5bbaf 100644 --- a/src/pyclass.rs +++ b/src/pyclass.rs @@ -9,6 +9,7 @@ mod create_type_object; mod gc; pub(crate) use self::create_type_object::{create_type_object, PyClassTypeObject}; + pub use self::gc::{PyTraverseError, PyVisit}; /// Types that can be used as Python classes. diff --git a/src/pyclass/gc.rs b/src/pyclass/gc.rs index 7878ccf5ca8..b6747a63f89 100644 --- a/src/pyclass/gc.rs +++ b/src/pyclass/gc.rs @@ -1,23 +1,31 @@ -use std::os::raw::{c_int, c_void}; +use std::{ + marker::PhantomData, + os::raw::{c_int, c_void}, +}; -use crate::{ffi, AsPyPointer, Python}; +use crate::{ffi, AsPyPointer}; /// Error returned by a `__traverse__` visitor implementation. #[repr(transparent)] -pub struct PyTraverseError(pub(crate) c_int); +pub struct PyTraverseError(NonZeroCInt); + +impl PyTraverseError { + /// Returns the error code. + pub(crate) fn into_inner(self) -> c_int { + self.0.into() + } +} /// Object visitor for GC. #[derive(Clone)] -pub struct PyVisit<'p> { +pub struct PyVisit<'a> { pub(crate) visit: ffi::visitproc, pub(crate) arg: *mut c_void, - /// VisitProc contains a Python instance to ensure that - /// 1) it is cannot be moved out of the traverse() call - /// 2) it cannot be sent to other threads - pub(crate) _py: Python<'p>, + /// Prevents the `PyVisit` from outliving the `__traverse__` call. + pub(crate) _guard: PhantomData<&'a ()>, } -impl<'p> PyVisit<'p> { +impl<'a> PyVisit<'a> { /// Visit `obj`. pub fn call(&self, obj: &T) -> Result<(), PyTraverseError> where @@ -25,24 +33,43 @@ impl<'p> PyVisit<'p> { { let ptr = obj.as_ptr(); if !ptr.is_null() { - let r = unsafe { (self.visit)(ptr, self.arg) }; - if r == 0 { - Ok(()) - } else { - Err(PyTraverseError(r)) + match NonZeroCInt::new(unsafe { (self.visit)(ptr, self.arg) }) { + None => Ok(()), + Some(r) => Err(PyTraverseError(r)), } } else { Ok(()) } } +} - /// Creates the PyVisit from the arguments to tp_traverse - #[doc(hidden)] - pub unsafe fn from_raw(visit: ffi::visitproc, arg: *mut c_void, py: Python<'p>) -> Self { - Self { - visit, - arg, - _py: py, - } +/// Workaround for `NonZero` not being available until MSRV 1.79 +mod get_nonzero_c_int { + pub struct GetNonZeroCInt(); + + pub trait NonZeroCIntType { + type Type; + } + impl NonZeroCIntType for GetNonZeroCInt<16> { + type Type = std::num::NonZeroI16; + } + impl NonZeroCIntType for GetNonZeroCInt<32> { + type Type = std::num::NonZeroI32; + } + + pub type Type = + () * 8 }> as NonZeroCIntType>::Type; +} + +use get_nonzero_c_int::Type as NonZeroCInt; + +#[cfg(test)] +mod tests { + use super::PyVisit; + use static_assertions::assert_not_impl_any; + + #[test] + fn py_visit_not_send_sync() { + assert_not_impl_any!(PyVisit<'_>: Send, Sync); } }