Skip to content

Commit

Permalink
Merge pull request #3595 from davidhewitt/ok-wrap
Browse files Browse the repository at this point in the history
refactor `OkWrap` to not call `.into_py(py)`
  • Loading branch information
davidhewitt authored Nov 25, 2023
2 parents cbd0630 + c814078 commit 9f66846
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 40 deletions.
12 changes: 7 additions & 5 deletions pyo3-macros-backend/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -455,11 +455,13 @@ impl<'a> FnSpec<'a> {
let func_name = &self.name;

let rust_call = |args: Vec<TokenStream>| {
let mut call = quote! { function(#self_arg #(#args),*) };
if self.asyncness.is_some() {
call = quote! { _pyo3::impl_::coroutine::wrap_future(#call) };
}
quotes::map_result_into_ptr(quotes::ok_wrap(call))
let call = quote! { function(#self_arg #(#args),*) };
let wrapped_call = if self.asyncness.is_some() {
quote! { _pyo3::PyResult::Ok(_pyo3::impl_::wrap::wrap_future(#call)) }
} else {
quotes::ok_wrap(call)
};
quotes::map_result_into_ptr(wrapped_call)
};

let rust_name = if let Some(cls) = cls {
Expand Down
2 changes: 1 addition & 1 deletion pyo3-macros-backend/src/pymethod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ fn impl_py_class_attribute(cls: &syn::Type, spec: &FnSpec<'_>) -> syn::Result<Me
let associated_method = quote! {
fn #wrapper_ident(py: _pyo3::Python<'_>) -> _pyo3::PyResult<_pyo3::PyObject> {
let function = #cls::#name; // Shadow the method name to avoid #3017
#body
_pyo3::impl_::wrap::map_result_into_py(py, #body)
}
};

Expand Down
4 changes: 2 additions & 2 deletions pyo3-macros-backend/src/quotes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ pub(crate) fn some_wrap(obj: TokenStream) -> TokenStream {

pub(crate) fn ok_wrap(obj: TokenStream) -> TokenStream {
quote! {
_pyo3::impl_::wrap::OkWrap::wrap(#obj, py)
_pyo3::impl_::wrap::OkWrap::wrap(#obj)
.map_err(::core::convert::Into::<_pyo3::PyErr>::into)
}
}

pub(crate) fn map_result_into_ptr(result: TokenStream) -> TokenStream {
quote! {
#result.map(_pyo3::PyObject::into_ptr)
_pyo3::impl_::wrap::map_result_into_ptr(py, #result)
}
}
2 changes: 0 additions & 2 deletions src/impl_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
//! APIs may may change at any time without documentation in the CHANGELOG and without
//! breaking semver guarantees.
#[cfg(feature = "macros")]
pub mod coroutine;
pub mod deprecations;
pub mod extract_argument;
pub mod freelist;
Expand Down
19 changes: 0 additions & 19 deletions src/impl_/coroutine.rs

This file was deleted.

68 changes: 58 additions & 10 deletions src/impl_/wrap.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
use crate::{IntoPy, Py, PyAny, PyErr, PyObject, PyResult, Python};
use std::convert::Infallible;

use crate::{ffi, IntoPy, PyObject, PyResult, Python};

/// Used to wrap values in `Option<T>` for default arguments.
pub trait SomeWrap<T> {
fn wrap(self) -> T;
fn wrap(self) -> Option<T>;
}

impl<T> SomeWrap<Option<T>> for T {
impl<T> SomeWrap<T> for T {
fn wrap(self) -> Option<T> {
Some(self)
}
}

impl<T> SomeWrap<Option<T>> for Option<T> {
impl<T> SomeWrap<T> for Option<T> {
fn wrap(self) -> Self {
self
}
Expand All @@ -20,7 +22,7 @@ impl<T> SomeWrap<Option<T>> for Option<T> {
/// Used to wrap the result of `#[pyfunction]` and `#[pymethods]`.
pub trait OkWrap<T> {
type Error;
fn wrap(self, py: Python<'_>) -> Result<Py<PyAny>, Self::Error>;
fn wrap(self) -> Result<T, Self::Error>;
}

// The T: IntoPy<PyObject> bound here is necessary to prevent the
Expand All @@ -29,9 +31,10 @@ impl<T> OkWrap<T> for T
where
T: IntoPy<PyObject>,
{
type Error = PyErr;
fn wrap(self, py: Python<'_>) -> PyResult<Py<PyAny>> {
Ok(self.into_py(py))
type Error = Infallible;
#[inline]
fn wrap(self) -> Result<T, Infallible> {
Ok(self)
}
}

Expand All @@ -40,11 +43,44 @@ where
T: IntoPy<PyObject>,
{
type Error = E;
fn wrap(self, py: Python<'_>) -> Result<Py<PyAny>, Self::Error> {
self.map(|o| o.into_py(py))
#[inline]
fn wrap(self) -> Result<T, Self::Error> {
self
}
}

/// This is a follow-up function to `OkWrap::wrap` that converts the result into
/// a `*mut ffi::PyObject` pointer.
pub fn map_result_into_ptr<T: IntoPy<PyObject>>(
py: Python<'_>,
result: PyResult<T>,
) -> PyResult<*mut ffi::PyObject> {
result.map(|obj| obj.into_py(py).into_ptr())
}

/// This is a follow-up function to `OkWrap::wrap` that converts the result into
/// a safe wrapper.
pub fn map_result_into_py<T: IntoPy<PyObject>>(
py: Python<'_>,
result: PyResult<T>,
) -> PyResult<PyObject> {
result.map(|err| err.into_py(py))
}

/// Used to wrap the result of async `#[pyfunction]` and `#[pymethods]`.
#[cfg(feature = "macros")]
pub fn wrap_future<F, R, T>(future: F) -> crate::coroutine::Coroutine
where
F: std::future::Future<Output = R> + Send + 'static,
R: OkWrap<T>,
T: IntoPy<PyObject>,
crate::PyErr: From<R::Error>,
{
crate::coroutine::Coroutine::from_future::<_, T, crate::PyErr>(async move {
OkWrap::wrap(future.await).map_err(Into::into)
})
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -57,4 +93,16 @@ mod tests {
let b: Option<u8> = SomeWrap::wrap(None);
assert_eq!(b, None);
}

#[test]
fn wrap_result() {
let a: Result<u8, _> = OkWrap::wrap(42u8);
assert!(matches!(a, Ok(42)));

let b: PyResult<u8> = OkWrap::wrap(Ok(42u8));
assert!(matches!(b, Ok(42)));

let c: Result<u8, &str> = OkWrap::wrap(Err("error"));
assert_eq!(c, Err("error"));
}
}
2 changes: 1 addition & 1 deletion tests/test_compile_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ fn test_compile_errors() {
t.compile_fail("tests/ui/invalid_pymethod_receiver.rs");
t.compile_fail("tests/ui/missing_intopy.rs");
// adding extra error conversion impls changes the output
#[cfg(all(target_os = "linux", not(any(feature = "eyre", feature = "anyhow"))))]
#[cfg(not(any(windows, feature = "eyre", feature = "anyhow")))]
t.compile_fail("tests/ui/invalid_result_conversion.rs");
t.compile_fail("tests/ui/not_send.rs");
t.compile_fail("tests/ui/not_send2.rs");
Expand Down

0 comments on commit 9f66846

Please sign in to comment.