From c814078866b9f3b5bc712f37274c65b5b03d32c3 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Fri, 24 Nov 2023 04:29:24 +0000 Subject: [PATCH] refactor `OkWrap` to not call `.into_py(py)` --- pyo3-macros-backend/src/method.rs | 12 ++--- pyo3-macros-backend/src/pymethod.rs | 2 +- pyo3-macros-backend/src/quotes.rs | 4 +- src/impl_.rs | 2 - src/impl_/coroutine.rs | 19 -------- src/impl_/wrap.rs | 68 ++++++++++++++++++++++++----- tests/test_compile_error.rs | 2 +- 7 files changed, 69 insertions(+), 40 deletions(-) delete mode 100644 src/impl_/coroutine.rs diff --git a/pyo3-macros-backend/src/method.rs b/pyo3-macros-backend/src/method.rs index 034d079bb46..79d83f0451f 100644 --- a/pyo3-macros-backend/src/method.rs +++ b/pyo3-macros-backend/src/method.rs @@ -449,11 +449,13 @@ impl<'a> FnSpec<'a> { let func_name = &self.name; let rust_call = |args: Vec| { - let mut call = quote! { function(#self_arg #(#args),*) }; - if self.asyncness.is_some() { - call = quote! { _pyo3::impl_::coroutine::wrap_future(#call) }; - } - quotes::map_result_into_ptr(quotes::ok_wrap(call)) + let call = quote! { function(#self_arg #(#args),*) }; + let wrapped_call = if self.asyncness.is_some() { + quote! { _pyo3::PyResult::Ok(_pyo3::impl_::wrap::wrap_future(#call)) } + } else { + quotes::ok_wrap(call) + }; + quotes::map_result_into_ptr(wrapped_call) }; let rust_name = if let Some(cls) = cls { diff --git a/pyo3-macros-backend/src/pymethod.rs b/pyo3-macros-backend/src/pymethod.rs index a8fd3b41a18..9c5f2601f9c 100644 --- a/pyo3-macros-backend/src/pymethod.rs +++ b/pyo3-macros-backend/src/pymethod.rs @@ -458,7 +458,7 @@ fn impl_py_class_attribute(cls: &syn::Type, spec: &FnSpec<'_>) -> syn::Result) -> _pyo3::PyResult<_pyo3::PyObject> { let function = #cls::#name; // Shadow the method name to avoid #3017 - #body + _pyo3::impl_::wrap::map_result_into_py(py, #body) } }; diff --git a/pyo3-macros-backend/src/quotes.rs b/pyo3-macros-backend/src/quotes.rs index 966564b1d98..239036ef3ca 100644 --- a/pyo3-macros-backend/src/quotes.rs +++ b/pyo3-macros-backend/src/quotes.rs @@ -9,13 +9,13 @@ pub(crate) fn some_wrap(obj: TokenStream) -> TokenStream { pub(crate) fn ok_wrap(obj: TokenStream) -> TokenStream { quote! { - _pyo3::impl_::wrap::OkWrap::wrap(#obj, py) + _pyo3::impl_::wrap::OkWrap::wrap(#obj) .map_err(::core::convert::Into::<_pyo3::PyErr>::into) } } pub(crate) fn map_result_into_ptr(result: TokenStream) -> TokenStream { quote! { - #result.map(_pyo3::PyObject::into_ptr) + _pyo3::impl_::wrap::map_result_into_ptr(py, #result) } } diff --git a/src/impl_.rs b/src/impl_.rs index 77f9ff4ea1f..118d62d9dbc 100644 --- a/src/impl_.rs +++ b/src/impl_.rs @@ -6,8 +6,6 @@ //! APIs may may change at any time without documentation in the CHANGELOG and without //! breaking semver guarantees. -#[cfg(feature = "macros")] -pub mod coroutine; pub mod deprecations; pub mod extract_argument; pub mod freelist; diff --git a/src/impl_/coroutine.rs b/src/impl_/coroutine.rs deleted file mode 100644 index 843c42f169a..00000000000 --- a/src/impl_/coroutine.rs +++ /dev/null @@ -1,19 +0,0 @@ -use crate::coroutine::Coroutine; -use crate::impl_::wrap::OkWrap; -use crate::{IntoPy, PyErr, PyObject, Python}; -use std::future::Future; - -/// Used to wrap the result of async `#[pyfunction]` and `#[pymethods]`. -pub fn wrap_future(future: F) -> Coroutine -where - F: Future + Send + 'static, - R: OkWrap, - T: IntoPy, - PyErr: From, -{ - let future = async move { - // SAFETY: GIL is acquired when future is polled (see `Coroutine::poll`) - future.await.wrap(unsafe { Python::assume_gil_acquired() }) - }; - Coroutine::from_future(future) -} diff --git a/src/impl_/wrap.rs b/src/impl_/wrap.rs index a73e3597302..b41055b2863 100644 --- a/src/impl_/wrap.rs +++ b/src/impl_/wrap.rs @@ -1,17 +1,19 @@ -use crate::{IntoPy, Py, PyAny, PyErr, PyObject, PyResult, Python}; +use std::convert::Infallible; + +use crate::{ffi, IntoPy, PyObject, PyResult, Python}; /// Used to wrap values in `Option` for default arguments. pub trait SomeWrap { - fn wrap(self) -> T; + fn wrap(self) -> Option; } -impl SomeWrap> for T { +impl SomeWrap for T { fn wrap(self) -> Option { Some(self) } } -impl SomeWrap> for Option { +impl SomeWrap for Option { fn wrap(self) -> Self { self } @@ -20,7 +22,7 @@ impl SomeWrap> for Option { /// Used to wrap the result of `#[pyfunction]` and `#[pymethods]`. pub trait OkWrap { type Error; - fn wrap(self, py: Python<'_>) -> Result, Self::Error>; + fn wrap(self) -> Result; } // The T: IntoPy bound here is necessary to prevent the @@ -29,9 +31,10 @@ impl OkWrap for T where T: IntoPy, { - type Error = PyErr; - fn wrap(self, py: Python<'_>) -> PyResult> { - Ok(self.into_py(py)) + type Error = Infallible; + #[inline] + fn wrap(self) -> Result { + Ok(self) } } @@ -40,11 +43,44 @@ where T: IntoPy, { type Error = E; - fn wrap(self, py: Python<'_>) -> Result, Self::Error> { - self.map(|o| o.into_py(py)) + #[inline] + fn wrap(self) -> Result { + self } } +/// This is a follow-up function to `OkWrap::wrap` that converts the result into +/// a `*mut ffi::PyObject` pointer. +pub fn map_result_into_ptr>( + py: Python<'_>, + result: PyResult, +) -> PyResult<*mut ffi::PyObject> { + result.map(|obj| obj.into_py(py).into_ptr()) +} + +/// This is a follow-up function to `OkWrap::wrap` that converts the result into +/// a safe wrapper. +pub fn map_result_into_py>( + py: Python<'_>, + result: PyResult, +) -> PyResult { + result.map(|err| err.into_py(py)) +} + +/// Used to wrap the result of async `#[pyfunction]` and `#[pymethods]`. +#[cfg(feature = "macros")] +pub fn wrap_future(future: F) -> crate::coroutine::Coroutine +where + F: std::future::Future + Send + 'static, + R: OkWrap, + T: IntoPy, + crate::PyErr: From, +{ + crate::coroutine::Coroutine::from_future::<_, T, crate::PyErr>(async move { + OkWrap::wrap(future.await).map_err(Into::into) + }) +} + #[cfg(test)] mod tests { use super::*; @@ -57,4 +93,16 @@ mod tests { let b: Option = SomeWrap::wrap(None); assert_eq!(b, None); } + + #[test] + fn wrap_result() { + let a: Result = OkWrap::wrap(42u8); + assert!(matches!(a, Ok(42))); + + let b: PyResult = OkWrap::wrap(Ok(42u8)); + assert!(matches!(b, Ok(42))); + + let c: Result = OkWrap::wrap(Err("error")); + assert_eq!(c, Err("error")); + } } diff --git a/tests/test_compile_error.rs b/tests/test_compile_error.rs index 3919886f8a5..3e01ad81ebe 100644 --- a/tests/test_compile_error.rs +++ b/tests/test_compile_error.rs @@ -33,7 +33,7 @@ fn test_compile_errors() { t.compile_fail("tests/ui/invalid_pymethod_receiver.rs"); t.compile_fail("tests/ui/missing_intopy.rs"); // adding extra error conversion impls changes the output - #[cfg(all(target_os = "linux", not(any(feature = "eyre", feature = "anyhow"))))] + #[cfg(not(any(windows, feature = "eyre", feature = "anyhow")))] t.compile_fail("tests/ui/invalid_result_conversion.rs"); t.compile_fail("tests/ui/not_send.rs"); t.compile_fail("tests/ui/not_send2.rs");