diff --git a/src/gil.rs b/src/gil.rs index db8311ab2fc..555e45e9105 100644 --- a/src/gil.rs +++ b/src/gil.rs @@ -9,6 +9,7 @@ use std::cell::Cell; use std::cell::RefCell; #[cfg(not(debug_assertions))] use std::cell::UnsafeCell; +use std::sync::atomic::AtomicBool; use std::{mem, ptr::NonNull, sync}; static START: sync::Once = sync::Once::new(); @@ -215,6 +216,7 @@ impl GILGuard { let gstate = unsafe { ffi::PyGILState_Ensure() }; // acquire GIL #[allow(deprecated)] let pool = unsafe { mem::ManuallyDrop::new(GILPool::new()) }; + update_deferred_reference_counts(pool.python()); Some(GILGuard { gstate, pool }) } @@ -238,6 +240,7 @@ type PyObjVec = Vec>; #[cfg(not(pyo3_disable_reference_pool))] /// Thread-safe storage for objects which were dec_ref while the GIL was not held. struct ReferencePool { + ever_used: AtomicBool, pending_decrefs: sync::Mutex, } @@ -245,15 +248,28 @@ struct ReferencePool { impl ReferencePool { const fn new() -> Self { Self { + ever_used: AtomicBool::new(false), pending_decrefs: sync::Mutex::new(Vec::new()), } } fn register_decref(&self, obj: NonNull) { + self.ever_used.store(true, sync::atomic::Ordering::Relaxed); self.pending_decrefs.lock().unwrap().push(obj); } - fn update_counts(&self, _py: Python<'_>) { + #[inline] + fn update_counts(&self, py: Python<'_>) { + // Justification for relaxed: worst case this causes already deferred drops to be + // delayed slightly later, and this is also a one-time flag, so if the program is + // using deferred drops it is highly likely that branch prediction will always + // assume this is true and we don't need the atomic overhead. + if self.ever_used.load(sync::atomic::Ordering::Relaxed) { + self.update_counts_impl(py) + } + } + + fn update_counts_impl(&self, _py: Python<'_>) { let mut pending_decrefs = self.pending_decrefs.lock().unwrap(); if pending_decrefs.is_empty() { return; @@ -268,6 +284,12 @@ impl ReferencePool { } } +#[inline] +#[cfg(not(pyo3_disable_reference_pool))] +pub(crate) fn update_deferred_reference_counts(py: Python<'_>) { + POOL.update_counts(py); +} + #[cfg(not(pyo3_disable_reference_pool))] unsafe impl Sync for ReferencePool {} @@ -370,9 +392,6 @@ impl GILPool { #[inline] pub unsafe fn new() -> GILPool { increment_gil_count(); - // Update counts of PyObjects / Py that have been cloned or dropped since last acquisition - #[cfg(not(pyo3_disable_reference_pool))] - POOL.update_counts(Python::assume_gil_acquired()); GILPool { start: OWNED_OBJECTS .try_with(|owned_objects| { diff --git a/src/impl_/trampoline.rs b/src/impl_/trampoline.rs index db493817cba..f4db0c92d83 100644 --- a/src/impl_/trampoline.rs +++ b/src/impl_/trampoline.rs @@ -12,8 +12,9 @@ use std::{ #[allow(deprecated)] use crate::gil::GILPool; use crate::{ - callback::PyCallbackOutput, ffi, ffi_ptr_ext::FfiPtrExt, impl_::panic::PanicTrap, - methods::IPowModulo, panic::PanicException, types::PyModule, Py, PyResult, Python, + callback::PyCallbackOutput, ffi, ffi_ptr_ext::FfiPtrExt, gil::update_deferred_reference_counts, + impl_::panic::PanicTrap, methods::IPowModulo, panic::PanicException, types::PyModule, Py, + PyResult, Python, }; #[inline] @@ -182,6 +183,7 @@ where #[allow(deprecated)] let pool = unsafe { GILPool::new() }; let py = pool.python(); + update_deferred_reference_counts(py); let out = panic_result_into_callback_output( py, panic::catch_unwind(move || -> PyResult<_> { body(py) }), @@ -229,6 +231,7 @@ where #[allow(deprecated)] let pool = GILPool::new(); let py = pool.python(); + update_deferred_reference_counts(py); if let Err(py_err) = panic::catch_unwind(move || body(py)) .unwrap_or_else(|payload| Err(PanicException::from_panic_payload(payload))) {