Skip to content

Commit

Permalink
avoid creating PyRef inside __traverse__ handler (#4479)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Sep 15, 2024
1 parent acbe5d5 commit c2f8114
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 43 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0.61"
rayon = "1.6.1"
futures = "0.3.28"
static_assertions = "1.1.0"

[build-dependencies]
pyo3-build-config = { path = "pyo3-build-config", version = "=0.22.2", features = ["resolve-config"] }
Expand Down
1 change: 1 addition & 0 deletions newsfragments/4479.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Remove illegal reference counting op inside implementation of `__traverse__` handlers.
81 changes: 68 additions & 13 deletions src/impl_/pymethods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,20 @@ use crate::callback::IntoPyCallbackOutput;
use crate::exceptions::PyStopAsyncIteration;
use crate::gil::LockGIL;
use crate::impl_::panic::PanicTrap;
use crate::impl_::pycell::{PyClassObject, PyClassObjectLayout};
use crate::pycell::impl_::PyClassBorrowChecker as _;
use crate::pycell::{PyBorrowError, PyBorrowMutError};
use crate::pyclass::boolean_struct::False;
use crate::types::any::PyAnyMethods;
#[cfg(feature = "gil-refs")]
use crate::types::{PyModule, PyType};
use crate::{
ffi, Borrowed, Bound, DowncastError, Py, PyAny, PyClass, PyClassInitializer, PyErr, PyObject,
PyRef, PyRefMut, PyResult, PyTraverseError, PyTypeCheck, PyVisit, Python,
ffi, Bound, DowncastError, Py, PyAny, PyClass, PyClassInitializer, PyErr, PyObject, PyRef,
PyRefMut, PyResult, PyTraverseError, PyTypeCheck, PyVisit, Python,
};
use std::ffi::CStr;
use std::fmt;
use std::marker::PhantomData;
use std::os::raw::{c_int, c_void};
use std::panic::{catch_unwind, AssertUnwindSafe};
use std::ptr::null_mut;
Expand Down Expand Up @@ -234,6 +237,40 @@ impl PySetterDef {
}

/// Calls an implementation of __traverse__ for tp_traverse
///
/// NB cannot accept `'static` visitor, this is a sanity check below:
///
/// ```rust,compile_fail
/// use pyo3::prelude::*;
/// use pyo3::pyclass::{PyTraverseError, PyVisit};
///
/// #[pyclass]
/// struct Foo;
///
/// #[pymethods]
/// impl Foo {
/// fn __traverse__(&self, _visit: PyVisit<'static>) -> Result<(), PyTraverseError> {
/// Ok(())
/// }
/// }
/// ```
///
/// Elided lifetime should compile ok:
///
/// ```rust
/// use pyo3::prelude::*;
/// use pyo3::pyclass::{PyTraverseError, PyVisit};
///
/// #[pyclass]
/// struct Foo;
///
/// #[pymethods]
/// impl Foo {
/// fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
/// Ok(())
/// }
/// }
/// ```
#[doc(hidden)]
pub unsafe fn _call_traverse<T>(
slf: *mut ffi::PyObject,
Expand All @@ -252,25 +289,43 @@ where
// Since we do not create a `GILPool` at all, it is important that our usage of the GIL
// token does not produce any owned objects thereby calling into `register_owned`.
let trap = PanicTrap::new("uncaught panic inside __traverse__ handler");
let lock = LockGIL::during_traverse();

// SAFETY: `slf` is a valid Python object pointer to a class object of type T, and
// traversal is running so no mutations can occur.
let class_object: &PyClassObject<T> = &*slf.cast();

let retval =
// `#[pyclass(unsendable)]` types can only be deallocated by their own thread, so
// do not traverse them if not on their owning thread :(
if class_object.check_threadsafe().is_ok()
// ... and we cannot traverse a type which might be being mutated by a Rust thread
&& class_object.borrow_checker().try_borrow().is_ok() {
struct TraverseGuard<'a, T: PyClass>(&'a PyClassObject<T>);
impl<'a, T: PyClass> Drop for TraverseGuard<'a, T> {
fn drop(&mut self) {
self.0.borrow_checker().release_borrow()
}
}

let py = Python::assume_gil_acquired();
let slf = Borrowed::from_ptr_unchecked(py, slf).downcast_unchecked::<T>();
let borrow = PyRef::try_borrow_threadsafe(&slf);
let visit = PyVisit::from_raw(visit, arg, py);
// `.try_borrow()` above created a borrow, we need to release it when we're done
// traversing the object. This allows us to read `instance` safely.
let _guard = TraverseGuard(class_object);
let instance = &*class_object.contents.value.get();

let retval = if let Ok(borrow) = borrow {
let _lock = LockGIL::during_traverse();
let visit = PyVisit { visit, arg, _guard: PhantomData };

match catch_unwind(AssertUnwindSafe(move || impl_(&*borrow, visit))) {
Ok(res) => match res {
Ok(()) => 0,
Err(PyTraverseError(value)) => value,
},
match catch_unwind(AssertUnwindSafe(move || impl_(instance, visit))) {
Ok(Ok(())) => 0,
Ok(Err(traverse_error)) => traverse_error.into_inner(),
Err(_err) => -1,
}
} else {
0
};

// Drop lock before trap just in case dropping lock panics
drop(lock);
trap.disarm();
retval
}
Expand Down
8 changes: 0 additions & 8 deletions src/pycell.rs
Original file line number Diff line number Diff line change
Expand Up @@ -673,14 +673,6 @@ impl<'py, T: PyClass> PyRef<'py, T> {
.try_borrow()
.map(|_| Self { inner: obj.clone() })
}

pub(crate) fn try_borrow_threadsafe(obj: &Bound<'py, T>) -> Result<Self, PyBorrowError> {
let cell = obj.get_class_object();
cell.check_threadsafe()?;
cell.borrow_checker()
.try_borrow()
.map(|_| Self { inner: obj.clone() })
}
}

impl<'p, T, U> PyRef<'p, T>
Expand Down
1 change: 1 addition & 0 deletions src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ mod create_type_object;
mod gc;

pub(crate) use self::create_type_object::{create_type_object, PyClassTypeObject};

pub use self::gc::{PyTraverseError, PyVisit};

/// Types that can be used as Python classes.
Expand Down
71 changes: 49 additions & 22 deletions src/pyclass/gc.rs
Original file line number Diff line number Diff line change
@@ -1,48 +1,75 @@
use std::os::raw::{c_int, c_void};
use std::{
marker::PhantomData,
os::raw::{c_int, c_void},
};

use crate::{ffi, AsPyPointer, Python};
use crate::{ffi, AsPyPointer};

/// Error returned by a `__traverse__` visitor implementation.
#[repr(transparent)]
pub struct PyTraverseError(pub(crate) c_int);
pub struct PyTraverseError(NonZeroCInt);

impl PyTraverseError {
/// Returns the error code.
pub(crate) fn into_inner(self) -> c_int {
self.0.into()
}
}

/// Object visitor for GC.
#[derive(Clone)]
pub struct PyVisit<'p> {
pub struct PyVisit<'a> {
pub(crate) visit: ffi::visitproc,
pub(crate) arg: *mut c_void,
/// VisitProc contains a Python instance to ensure that
/// 1) it is cannot be moved out of the traverse() call
/// 2) it cannot be sent to other threads
pub(crate) _py: Python<'p>,
/// Prevents the `PyVisit` from outliving the `__traverse__` call.
pub(crate) _guard: PhantomData<&'a ()>,
}

impl<'p> PyVisit<'p> {
impl<'a> PyVisit<'a> {
/// Visit `obj`.
pub fn call<T>(&self, obj: &T) -> Result<(), PyTraverseError>
where
T: AsPyPointer,
{
let ptr = obj.as_ptr();
if !ptr.is_null() {
let r = unsafe { (self.visit)(ptr, self.arg) };
if r == 0 {
Ok(())
} else {
Err(PyTraverseError(r))
match NonZeroCInt::new(unsafe { (self.visit)(ptr, self.arg) }) {
None => Ok(()),
Some(r) => Err(PyTraverseError(r)),
}
} else {
Ok(())
}
}
}

/// Creates the PyVisit from the arguments to tp_traverse
#[doc(hidden)]
pub unsafe fn from_raw(visit: ffi::visitproc, arg: *mut c_void, py: Python<'p>) -> Self {
Self {
visit,
arg,
_py: py,
}
/// Workaround for `NonZero<c_int>` not being available until MSRV 1.79
mod get_nonzero_c_int {
pub struct GetNonZeroCInt<const WIDTH: usize>();

pub trait NonZeroCIntType {
type Type;
}
impl NonZeroCIntType for GetNonZeroCInt<16> {
type Type = std::num::NonZeroI16;
}
impl NonZeroCIntType for GetNonZeroCInt<32> {
type Type = std::num::NonZeroI32;
}

pub type Type =
<GetNonZeroCInt<{ std::mem::size_of::<std::os::raw::c_int>() * 8 }> as NonZeroCIntType>::Type;
}

use get_nonzero_c_int::Type as NonZeroCInt;

#[cfg(test)]
mod tests {
use super::PyVisit;
use static_assertions::assert_not_impl_any;

#[test]
fn py_visit_not_send_sync() {
assert_not_impl_any!(PyVisit<'_>: Send, Sync);
}
}

0 comments on commit c2f8114

Please sign in to comment.