From 7cb1ad049f81a825ed8c98946bff8b927b39e3e4 Mon Sep 17 00:00:00 2001 From: Kang Seonghoon Date: Mon, 22 May 2023 10:57:58 +0900 Subject: [PATCH 1/2] Rename a misleading test name and add an actual panicking test. --- tests/test_gc.rs | 95 +++++++++++++++++++++++++++++++----------------- 1 file changed, 61 insertions(+), 34 deletions(-) diff --git a/tests/test_gc.rs b/tests/test_gc.rs index e8cb65168ca..3a369c6e4a0 100644 --- a/tests/test_gc.rs +++ b/tests/test_gc.rs @@ -248,22 +248,10 @@ impl TraversableClass { } } -unsafe fn get_type_traverse(tp: *mut pyo3::ffi::PyTypeObject) -> Option { - std::mem::transmute(pyo3::ffi::PyType_GetSlot(tp, pyo3::ffi::Py_tp_traverse)) -} - #[test] fn gc_during_borrow() { Python::with_gil(|py| { unsafe { - // declare a dummy visitor function - extern "C" fn novisit( - _object: *mut pyo3::ffi::PyObject, - _arg: *mut core::ffi::c_void, - ) -> std::os::raw::c_int { - 0 - } - // get the traverse function let ty = py.get_type::().as_type_ptr(); let traverse = get_type_traverse(ty).unwrap(); @@ -290,18 +278,18 @@ fn gc_during_borrow() { } #[pyclass] -struct PanickyTraverse { +struct PartialTraverse { member: PyObject, } -impl PanickyTraverse { +impl PartialTraverse { fn new(py: Python<'_>) -> Self { Self { member: py.None() } } } #[pymethods] -impl PanickyTraverse { +impl PartialTraverse { fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> { visit.call(&self.member)?; // In the test, we expect this to never be hit @@ -310,22 +298,14 @@ impl PanickyTraverse { } #[test] -fn traverse_error() { +fn traverse_partial() { Python::with_gil(|py| unsafe { - // declare a visitor function which errors (returns nonzero code) - extern "C" fn visit_error( - _object: *mut pyo3::ffi::PyObject, - _arg: *mut core::ffi::c_void, - ) -> std::os::raw::c_int { - -1 - } - // get the traverse function - let ty = py.get_type::().as_type_ptr(); + let ty = py.get_type::().as_type_ptr(); let traverse = get_type_traverse(ty).unwrap(); // confirm that traversing errors - let obj = Py::new(py, PanickyTraverse::new(py)).unwrap(); + let obj = Py::new(py, PartialTraverse::new(py)).unwrap(); assert_eq!( traverse(obj.as_ptr(), visit_error, std::ptr::null_mut()), -1 @@ -333,6 +313,38 @@ fn traverse_error() { }) } +#[pyclass] +struct PanickyTraverse { + member: PyObject, +} + +impl PanickyTraverse { + fn new(py: Python<'_>) -> Self { + Self { member: py.None() } + } +} + +#[pymethods] +impl PanickyTraverse { + fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> { + visit.call(&self.member)?; + panic!("at the disco"); + } +} + +#[test] +fn traverse_panic() { + Python::with_gil(|py| unsafe { + // get the traverse function + let ty = py.get_type::().as_type_ptr(); + let traverse = get_type_traverse(ty).unwrap(); + + // confirm that traversing errors + let obj = Py::new(py, PanickyTraverse::new(py)).unwrap(); + assert_eq!(traverse(obj.as_ptr(), novisit, std::ptr::null_mut()), -1); + }) +} + #[pyclass] struct TriesGILInTraverse {} @@ -346,14 +358,6 @@ impl TriesGILInTraverse { #[test] fn tries_gil_in_traverse() { Python::with_gil(|py| unsafe { - // declare a visitor function which errors (returns nonzero code) - extern "C" fn novisit( - _object: *mut pyo3::ffi::PyObject, - _arg: *mut core::ffi::c_void, - ) -> std::os::raw::c_int { - 0 - } - // get the traverse function let ty = py.get_type::().as_type_ptr(); let traverse = get_type_traverse(ty).unwrap(); @@ -363,3 +367,26 @@ fn tries_gil_in_traverse() { assert_eq!(traverse(obj.as_ptr(), novisit, std::ptr::null_mut()), -1); }) } + +// Manual traversal utilities + +unsafe fn get_type_traverse(tp: *mut pyo3::ffi::PyTypeObject) -> Option { + std::mem::transmute(pyo3::ffi::PyType_GetSlot(tp, pyo3::ffi::Py_tp_traverse)) +} + +// a dummy visitor function +extern "C" fn novisit( + _object: *mut pyo3::ffi::PyObject, + _arg: *mut core::ffi::c_void, +) -> std::os::raw::c_int { + 0 +} + +// a visitor function which errors (returns nonzero code) +extern "C" fn visit_error( + _object: *mut pyo3::ffi::PyObject, + _arg: *mut core::ffi::c_void, +) -> std::os::raw::c_int { + -1 +} + From e8843276757d4b728d2effd43eba578770ae2bb4 Mon Sep 17 00:00:00 2001 From: Kang Seonghoon Date: Mon, 22 May 2023 11:09:04 +0900 Subject: [PATCH 2/2] Add additional tests for #3165. --- tests/test_gc.rs | 140 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 139 insertions(+), 1 deletion(-) diff --git a/tests/test_gc.rs b/tests/test_gc.rs index 3a369c6e4a0..c84d6784633 100644 --- a/tests/test_gc.rs +++ b/tests/test_gc.rs @@ -4,6 +4,7 @@ use pyo3::class::PyTraverseError; use pyo3::class::PyVisit; use pyo3::prelude::*; use pyo3::{py_run, AsPyPointer, PyCell, PyTryInto}; +use std::cell::Cell; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; @@ -368,6 +369,144 @@ fn tries_gil_in_traverse() { }) } +#[pyclass] +struct HijackedTraverse { + traversed: Cell, + hijacked: Cell, +} + +impl HijackedTraverse { + fn new() -> Self { + Self { + traversed: Cell::new(false), + hijacked: Cell::new(false), + } + } + + fn traversed_and_hijacked(&self) -> (bool, bool) { + (self.traversed.get(), self.hijacked.get()) + } +} + +#[pymethods] +impl HijackedTraverse { + #[allow(clippy::unnecessary_wraps)] + fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> { + self.traversed.set(true); + Ok(()) + } +} + +trait Traversable { + fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError>; +} + +impl<'a> Traversable for PyRef<'a, HijackedTraverse> { + fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> { + self.hijacked.set(true); + Ok(()) + } +} + +#[test] +fn traverse_cannot_be_hijacked() { + Python::with_gil(|py| unsafe { + // get the traverse function + let ty = py.get_type::().as_type_ptr(); + let traverse = get_type_traverse(ty).unwrap(); + + let cell = PyCell::new(py, HijackedTraverse::new()).unwrap(); + let obj = cell.to_object(py); + assert_eq!(cell.borrow().traversed_and_hijacked(), (false, false)); + traverse(obj.as_ptr(), novisit, std::ptr::null_mut()); + assert_eq!(cell.borrow().traversed_and_hijacked(), (true, false)); + }) +} + +#[allow(dead_code)] +#[pyclass] +struct DropDuringTraversal { + cycle: Cell>>, + dropped: TestDropCall, +} + +#[pymethods] +impl DropDuringTraversal { + #[allow(clippy::unnecessary_wraps)] + fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> { + self.cycle.take(); + Ok(()) + } + + fn __clear__(&mut self) { + self.cycle.take(); + } +} + +#[test] +fn drop_during_traversal_with_gil() { + let drop_called = Arc::new(AtomicBool::new(false)); + + Python::with_gil(|py| { + let inst = Py::new( + py, + DropDuringTraversal { + cycle: Cell::new(None), + dropped: TestDropCall { + drop_called: Arc::clone(&drop_called), + }, + }, + ) + .unwrap(); + + inst.borrow_mut(py).cycle.set(Some(inst.clone_ref(py))); + + drop(inst); + }); + + // due to the internal GC mechanism, we may need multiple + // (but not too many) collections to get `inst` actually dropped. + for _ in 0..10 { + Python::with_gil(|py| { + py.run("import gc; gc.collect()", None, None).unwrap(); + }); + } + assert!(drop_called.load(Ordering::Relaxed)); +} + +#[test] +fn drop_during_traversal_without_gil() { + let drop_called = Arc::new(AtomicBool::new(false)); + + let inst = Python::with_gil(|py| { + let inst = Py::new( + py, + DropDuringTraversal { + cycle: Cell::new(None), + dropped: TestDropCall { + drop_called: Arc::clone(&drop_called), + }, + }, + ) + .unwrap(); + + inst.borrow_mut(py).cycle.set(Some(inst.clone_ref(py))); + + inst + }); + + drop(inst); + + // due to the internal GC mechanism, we may need multiple + // (but not too many) collections to get `inst` actually dropped. + for _ in 0..10 { + Python::with_gil(|py| { + py.run("import gc; gc.collect()", None, None).unwrap(); + }); + } + assert!(drop_called.load(Ordering::Relaxed)); +} + // Manual traversal utilities unsafe fn get_type_traverse(tp: *mut pyo3::ffi::PyTypeObject) -> Option { @@ -389,4 +528,3 @@ extern "C" fn visit_error( ) -> std::os::raw::c_int { -1 } -