From ac9fcf9ffd386cb4a0213f4bbc9f58726d9c6436 Mon Sep 17 00:00:00 2001 From: Bas Schoenmaeckers Date: Sun, 29 Sep 2024 22:06:33 +0200 Subject: [PATCH] Use critical section wrapper --- src/types/dict.rs | 267 ++++++++++++++++++++-------------------------- 1 file changed, 118 insertions(+), 149 deletions(-) diff --git a/src/types/dict.rs b/src/types/dict.rs index 5446008945e..c424ddfab70 100644 --- a/src/types/dict.rs +++ b/src/types/dict.rs @@ -366,23 +366,22 @@ impl<'py> PyDictMethods<'py> for Bound<'py, PyDict> { BoundDictIterator::new(self.clone()) } - #[cfg(Py_GIL_DISABLED)] - fn locked_for_each(&self, closure: F) -> PyResult<()> + #[cfg(all(Py_GIL_DISABLED, not(feature = "nightly")))] + fn locked_for_each(&self, f: F) -> PyResult<()> where F: Fn(Bound<'py, PyAny>, Bound<'py, PyAny>) -> PyResult<()>, { - let mut section = unsafe { std::mem::zeroed() }; - unsafe { ffi::PyCriticalSection_Begin(&mut section, self.as_ptr()) }; - - for (key, value) in self { - if let Err(err) = closure(key, value) { - unsafe { ffi::PyCriticalSection_End(&mut section) }; - return Err(err); - } + #[cfg(feature = "nightly")] + { + self.iter().try_for_each(|(key, value)| f(key, value)) } - unsafe { ffi::PyCriticalSection_End(&mut section) }; - Ok(()) + #[cfg(not(feature = "nightly"))] + { + crate::sync::with_critical_section(self, || { + self.iter().try_for_each(|(key, value)| f(key, value)) + }) + } } fn as_mapping(&self) -> &Bound<'py, PyMapping> { @@ -452,10 +451,13 @@ impl<'py> Iterator for BoundDictIterator<'py> { #[inline] fn next(&mut self) -> Option { - match self { - BoundDictIterator::ItemIter { iter, remaining } => { + self.with_critical_section(|iter| match iter { + BoundDictIterator::ItemIter { + iter: ref mut py_iter, + ref mut remaining, + } => { *remaining = remaining.saturating_sub(1); - iter.next().map(Result::unwrap).map(|tuple| { + py_iter.next().map(Result::unwrap).map(|tuple| { let tuple = tuple.downcast::().unwrap(); let key = tuple.get_item(0).unwrap(); let value = tuple.get_item(1).unwrap(); @@ -463,18 +465,11 @@ impl<'py> Iterator for BoundDictIterator<'py> { }) } BoundDictIterator::DictIter { - dict, - ppos, - di_used, - remaining, + ref mut dict, + ref mut ppos, + ref mut di_used, + ref mut remaining, } => { - #[cfg(Py_GIL_DISABLED)] - let mut section = unsafe { std::mem::zeroed() }; - #[cfg(Py_GIL_DISABLED)] - unsafe { - ffi::PyCriticalSection_Begin(&mut section, dict.as_ptr()); - }; - let ma_used = dict_len(dict); // These checks are similar to what CPython does. @@ -504,10 +499,7 @@ impl<'py> Iterator for BoundDictIterator<'py> { let mut key: *mut ffi::PyObject = std::ptr::null_mut(); let mut value: *mut ffi::PyObject = std::ptr::null_mut(); - let result = if unsafe { - ffi::PyDict_Next(dict.as_ptr(), ppos, &mut key, &mut value) - } != 0 - { + if unsafe { ffi::PyDict_Next(dict.as_ptr(), ppos, &mut key, &mut value) } != 0 { *remaining -= 1; let py = dict.py(); // Safety: @@ -519,16 +511,9 @@ impl<'py> Iterator for BoundDictIterator<'py> { )) } else { None - }; - - #[cfg(Py_GIL_DISABLED)] - unsafe { - ffi::PyCriticalSection_End(&mut section); } - - result } - } + }) } #[inline] @@ -544,15 +529,13 @@ impl<'py> Iterator for BoundDictIterator<'py> { Self: Sized, F: FnMut(B, Self::Item) -> B, { - let mut section = unsafe { std::mem::zeroed() }; - unsafe { ffi::PyCriticalSection_Begin(&mut section, self.as_ptr()) }; - - let mut accum = init; - for x in &mut self { - accum = f(accum, x); - } - unsafe { ffi::PyCriticalSection_End(&mut section) }; - accum + self.with_critical_section(|mut iter| { + let mut accum = init; + for x in &mut iter { + accum = f(accum, x); + } + accum + }) } #[inline] @@ -563,22 +546,13 @@ impl<'py> Iterator for BoundDictIterator<'py> { F: FnMut(B, Self::Item) -> R, R: std::ops::Try, { - let mut section = unsafe { std::mem::zeroed() }; - unsafe { ffi::PyCriticalSection_Begin(&mut section, self.as_ptr()) }; - - let mut accum = init; - - for x in &mut self { - match f(accum, x).branch() { - ControlFlow::Continue(a) => accum = a, - ControlFlow::Break(err) => { - unsafe { ffi::PyCriticalSection_End(&mut section) } - return R::from_residual(err); - } + self.with_critical_section(|mut iter| { + let mut accum = init; + for x in &mut iter { + accum = f(accum, x)? } - } - unsafe { ffi::PyCriticalSection_End(&mut section) }; - R::from_output(accum) + R::from_output(accum) + }) } #[inline] @@ -588,22 +562,19 @@ impl<'py> Iterator for BoundDictIterator<'py> { Self: Sized, F: FnMut(Self::Item) -> bool, { - let mut section = unsafe { std::mem::zeroed() }; - unsafe { ffi::PyCriticalSection_Begin(&mut section, self.as_ptr()) }; - - #[inline] - fn check(mut f: impl FnMut(T) -> bool) -> impl FnMut((), T) -> ControlFlow<()> { - move |(), x| { - if f(x) { - ControlFlow::Continue(()) - } else { - ControlFlow::Break(()) + self.with_critical_section(|iter| { + #[inline] + fn check(mut f: impl FnMut(T) -> bool) -> impl FnMut((), T) -> ControlFlow<()> { + move |(), x| { + if f(x) { + ControlFlow::Continue(()) + } else { + ControlFlow::Break(()) + } } } - } - let result = self.try_fold((), check(f)) == ControlFlow::Continue(()); - unsafe { ffi::PyCriticalSection_End(&mut section) }; - result + iter.try_fold((), check(f)) == ControlFlow::Continue(()) + }) } #[inline] @@ -613,23 +584,20 @@ impl<'py> Iterator for BoundDictIterator<'py> { Self: Sized, F: FnMut(Self::Item) -> bool, { - let mut section = unsafe { std::mem::zeroed() }; - unsafe { ffi::PyCriticalSection_Begin(&mut section, self.as_ptr()) }; - - #[inline] - fn check(mut f: impl FnMut(T) -> bool) -> impl FnMut((), T) -> ControlFlow<()> { - move |(), x| { - if f(x) { - ControlFlow::Break(()) - } else { - ControlFlow::Continue(()) + self.with_critical_section(|iter| { + #[inline] + fn check(mut f: impl FnMut(T) -> bool) -> impl FnMut((), T) -> ControlFlow<()> { + move |(), x| { + if f(x) { + ControlFlow::Break(()) + } else { + ControlFlow::Continue(()) + } } } - } - let result = self.try_fold((), check(f)) == ControlFlow::Break(()); - unsafe { ffi::PyCriticalSection_End(&mut section) }; - result + iter.try_fold((), check(f)) == ControlFlow::Break(()) + }) } #[inline] @@ -639,26 +607,25 @@ impl<'py> Iterator for BoundDictIterator<'py> { Self: Sized, P: FnMut(&Self::Item) -> bool, { - let mut section = unsafe { std::mem::zeroed() }; - unsafe { ffi::PyCriticalSection_Begin(&mut section, self.as_ptr()) }; - - #[inline] - fn check(mut predicate: impl FnMut(&T) -> bool) -> impl FnMut((), T) -> ControlFlow { - move |(), x| { - if predicate(&x) { - ControlFlow::Break(x) - } else { - ControlFlow::Continue(()) + self.with_critical_section(|iter| { + #[inline] + fn check( + mut predicate: impl FnMut(&T) -> bool, + ) -> impl FnMut((), T) -> ControlFlow { + move |(), x| { + if predicate(&x) { + ControlFlow::Break(x) + } else { + ControlFlow::Continue(()) + } } } - } - let result = match self.try_fold((), check(predicate)) { - ControlFlow::Continue(_) => None, - ControlFlow::Break(x) => Some(x), - }; - unsafe { ffi::PyCriticalSection_End(&mut section) }; - result + match iter.try_fold((), check(predicate)) { + ControlFlow::Continue(_) => None, + ControlFlow::Break(x) => Some(x), + } + }) } #[inline] @@ -668,23 +635,22 @@ impl<'py> Iterator for BoundDictIterator<'py> { Self: Sized, F: FnMut(Self::Item) -> Option, { - let mut section = unsafe { std::mem::zeroed() }; - unsafe { ffi::PyCriticalSection_Begin(&mut section, self.as_ptr()) }; - - #[inline] - fn check(mut f: impl FnMut(T) -> Option) -> impl FnMut((), T) -> ControlFlow { - move |(), x| match f(x) { - Some(x) => ControlFlow::Break(x), - None => ControlFlow::Continue(()), + self.with_critical_section(|iter| { + #[inline] + fn check( + mut f: impl FnMut(T) -> Option, + ) -> impl FnMut((), T) -> ControlFlow { + move |(), x| match f(x) { + Some(x) => ControlFlow::Break(x), + None => ControlFlow::Continue(()), + } } - } - let result = match self.try_fold((), check(f)) { - ControlFlow::Continue(_) => None, - ControlFlow::Break(x) => Some(x), - }; - unsafe { ffi::PyCriticalSection_End(&mut section) }; - result + match iter.try_fold((), check(f)) { + ControlFlow::Continue(_) => None, + ControlFlow::Break(x) => Some(x), + } + }) } #[inline] @@ -694,32 +660,28 @@ impl<'py> Iterator for BoundDictIterator<'py> { Self: Sized, P: FnMut(Self::Item) -> bool, { - let mut section = unsafe { std::mem::zeroed() }; - unsafe { ffi::PyCriticalSection_Begin(&mut section, self.as_ptr()) }; - - #[inline] - fn check<'a, T>( - mut predicate: impl FnMut(T) -> bool + 'a, - acc: &'a mut usize, - ) -> impl FnMut((), T) -> ControlFlow + 'a { - move |_, x| { - if predicate(x) { - ControlFlow::Break(*acc) - } else { - *acc += 1; - ControlFlow::Continue(()) + self.with_critical_section(|iter| { + #[inline] + fn check<'a, T>( + mut predicate: impl FnMut(T) -> bool + 'a, + acc: &'a mut usize, + ) -> impl FnMut((), T) -> ControlFlow + 'a { + move |_, x| { + if predicate(x) { + ControlFlow::Break(*acc) + } else { + *acc += 1; + ControlFlow::Continue(()) + } } } - } - let mut acc = 0; - let result = match self.try_fold((), check(predicate, &mut acc)) { - ControlFlow::Continue(_) => None, - ControlFlow::Break(x) => Some(x), - }; - - unsafe { ffi::PyCriticalSection_End(&mut section) }; - result + let mut acc = 0; + match iter.try_fold((), check(predicate, &mut acc)) { + ControlFlow::Continue(_) => None, + ControlFlow::Break(x) => Some(x), + } + }) } } @@ -751,11 +713,18 @@ impl<'py> BoundDictIterator<'py> { } #[inline] - #[cfg(Py_GIL_DISABLED)] - fn as_ptr(&self) -> *mut ffi::PyObject { + fn with_critical_section(&mut self, f: F) -> R + where + F: FnOnce(&mut Self) -> R, + { match self { - BoundDictIterator::ItemIter { ref iter, .. } => iter.as_ptr(), - BoundDictIterator::DictIter { ref dict, .. } => dict.as_ptr(), + BoundDictIterator::ItemIter { .. } => f(self), + #[cfg(not(Py_GIL_DISABLED))] + BoundDictIterator::DictIter { .. } => f(self), + #[cfg(Py_GIL_DISABLED)] + BoundDictIterator::DictIter { ref dict, .. } => { + crate::sync::with_critical_section(dict.clone().as_ref(), || f(self)) + } } } }