diff --git a/guide/src/async-await.md b/guide/src/async-await.md index 3649a9a0fed..847fed2f47d 100644 --- a/guide/src/async-await.md +++ b/guide/src/async-await.md @@ -30,8 +30,7 @@ Resulting future of an `async fn` decorated by `#[pyfunction]` must be `Send + ' As a consequence, `async fn` parameters and return types must also be `Send + 'static`, so it is not possible to have a signature like `async fn does_not_compile(arg: &PyAny, py: Python<'_>) -> &PyAny`. -It also means that methods cannot use `&self`/`&mut self`, *but this restriction should be dropped in the future.* - +However, there is an exception for method receiver, so async methods can accept `&self`/`&mut self` ## Implicit GIL holding diff --git a/newsfragments/3609.changed.md b/newsfragments/3609.changed.md new file mode 100644 index 00000000000..7979ea71960 --- /dev/null +++ b/newsfragments/3609.changed.md @@ -0,0 +1 @@ +Allow async methods to accept `&self`/`&mut self` \ No newline at end of file diff --git a/pyo3-macros-backend/src/method.rs b/pyo3-macros-backend/src/method.rs index 1cbb9304cda..9b8aaaea656 100644 --- a/pyo3-macros-backend/src/method.rs +++ b/pyo3-macros-backend/src/method.rs @@ -473,8 +473,7 @@ impl<'a> FnSpec<'a> { } let rust_call = |args: Vec| { - let mut call = quote! { function(#self_arg #(#args),*) }; - if self.asyncness.is_some() { + let call = if self.asyncness.is_some() { let throw_callback = if cancel_handle.is_some() { quote! { Some(__throw_callback) } } else { @@ -485,8 +484,23 @@ impl<'a> FnSpec<'a> { Some(cls) => quote!(Some(<#cls as _pyo3::PyTypeInfo>::NAME)), None => quote!(None), }; - call = quote! {{ - let future = #call; + let future = match self.tp { + FnType::Fn(SelfType::Receiver { mutable: false, .. }) => quote! { + _pyo3::impl_::coroutine::ref_method_future( + py.from_borrowed_ptr::<_pyo3::types::PyAny>(_slf), + move |__self| function(__self, #(#args),*) + )? + }, + FnType::Fn(SelfType::Receiver { mutable: true, .. }) => quote! { + _pyo3::impl_::coroutine::mut_method_future( + py.from_borrowed_ptr::<_pyo3::types::PyAny>(_slf), + move |__self| function(__self, #(#args),*) + )? + }, + _ => quote! { function(#self_arg #(#args),*) }, + }; + let mut call = quote! {{ + let future = #future; _pyo3::impl_::coroutine::new_coroutine( _pyo3::intern!(py, stringify!(#python_name)), #qualname_prefix, @@ -501,7 +515,10 @@ impl<'a> FnSpec<'a> { #call }}; } - } + call + } else { + quote! { function(#self_arg #(#args),*) } + }; quotes::map_result_into_ptr(quotes::ok_wrap(call)) }; diff --git a/src/impl_/coroutine.rs b/src/impl_/coroutine.rs index c8b2cdcce49..49a04fde828 100644 --- a/src/impl_/coroutine.rs +++ b/src/impl_/coroutine.rs @@ -1,7 +1,12 @@ use std::future::Future; +use std::mem; use crate::coroutine::cancel::ThrowCallback; -use crate::{coroutine::Coroutine, types::PyString, IntoPy, PyErr, PyObject}; +use crate::pyclass::boolean_struct::False; +use crate::{ + coroutine::Coroutine, types::PyString, IntoPy, Py, PyAny, PyCell, PyClass, PyErr, PyObject, + PyRef, PyRefMut, PyResult, Python, +}; pub fn new_coroutine( name: &PyString, @@ -16,3 +21,46 @@ where { Coroutine::new(Some(name.into()), qualname_prefix, throw_callback, future) } + +fn get_ptr(obj: &Py) -> *mut T { + // SAFETY: Py can be casted as *const PyCell + unsafe { &*(obj.as_ptr() as *const PyCell) }.get_ptr() +} + +struct RefGuard(Py); + +impl Drop for RefGuard { + fn drop(&mut self) { + Python::with_gil(|gil| self.0.as_ref(gil).release_ref()) + } +} + +pub unsafe fn ref_method_future<'a, T: PyClass, F: Future + 'a>( + self_: &PyAny, + fut: impl FnOnce(&'a T) -> F, +) -> PyResult> { + let ref_: PyRef<'_, T> = self_.extract()?; + // SAFETY: `PyRef::as_ptr` returns a borrowed reference + let guard = RefGuard(unsafe { Py::::from_borrowed_ptr(self_.py(), ref_.as_ptr()) }); + mem::forget(ref_); + Ok(async move { fut(unsafe { &*get_ptr(&guard.0) }).await }) +} + +struct RefMutGuard(Py); + +impl Drop for RefMutGuard { + fn drop(&mut self) { + Python::with_gil(|gil| self.0.as_ref(gil).release_mut()) + } +} + +pub fn mut_method_future<'a, T: PyClass, F: Future + 'a>( + self_: &PyAny, + fut: impl FnOnce(&'a mut T) -> F, +) -> PyResult> { + let mut_: PyRefMut<'_, T> = self_.extract()?; + // SAFETY: `PyRefMut::as_ptr` returns a borrowed reference + let guard = RefMutGuard(unsafe { Py::::from_borrowed_ptr(self_.py(), mut_.as_ptr()) }); + mem::forget(mut_); + Ok(async move { fut(unsafe { &mut *get_ptr(&guard.0) }).await }) +} diff --git a/src/pycell.rs b/src/pycell.rs index 8a4ceb6b374..8b85bfec8e2 100644 --- a/src/pycell.rs +++ b/src/pycell.rs @@ -516,6 +516,16 @@ impl PyCell { #[allow(clippy::useless_conversion)] offset.try_into().expect("offset should fit in Py_ssize_t") } + + #[cfg(feature = "macros")] + pub(crate) fn release_ref(&self) { + self.borrow_checker().release_borrow(); + } + + #[cfg(feature = "macros")] + pub(crate) fn release_mut(&self) { + self.borrow_checker().release_borrow_mut(); + } } impl PyCell { diff --git a/tests/test_coroutine.rs b/tests/test_coroutine.rs index cf975423c25..01b84ca8e94 100644 --- a/tests/test_coroutine.rs +++ b/tests/test_coroutine.rs @@ -234,3 +234,56 @@ fn coroutine_panic() { py_run!(gil, panic, &handle_windows(test)); }) } + +#[test] +fn test_async_method_receiver() { + #[pyclass] + struct Counter(usize); + #[pymethods] + impl Counter { + #[new] + fn new() -> Self { + Self(0) + } + async fn get(&self) -> usize { + self.0 + } + async fn incr(&mut self) -> usize { + self.0 += 1; + self.0 + } + } + Python::with_gil(|gil| { + let test = r#" + import asyncio + + obj = Counter() + coro1 = obj.get() + coro2 = obj.get() + try: + obj.incr() # borrow checking should fail + except RuntimeError as err: + pass + else: + assert False + assert asyncio.run(coro1) == 0 + coro2.close() + coro3 = obj.incr() + try: + obj.incr() # borrow checking should fail + except RuntimeError as err: + pass + else: + assert False + try: + obj.get() == 42 # borrow checking should fail + except RuntimeError as err: + pass + else: + assert False + assert asyncio.run(coro3) == 1 + "#; + let locals = [("Counter", gil.get_type::())].into_py_dict(gil); + py_run!(gil, *locals, test); + }) +}