Skip to content

Commit

Permalink
feat: allow async methods to accept &self/&mut self
Browse files Browse the repository at this point in the history
  • Loading branch information
wyfo committed Dec 4, 2023
1 parent 2ca9f59 commit 17182bb
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 8 deletions.
3 changes: 1 addition & 2 deletions guide/src/async-await.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ Resulting future of an `async fn` decorated by `#[pyfunction]` must be `Send + '

As a consequence, `async fn` parameters and return types must also be `Send + 'static`, so it is not possible to have a signature like `async fn does_not_compile(arg: &PyAny, py: Python<'_>) -> &PyAny`.

It also means that methods cannot use `&self`/`&mut self`, *but this restriction should be dropped in the future.*

However, there is an exception for method receiver, so async methods can accept `&self`/`&mut self`

## Implicit GIL holding

Expand Down
1 change: 1 addition & 0 deletions newsfragments/3609.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Allow async methods to accept `&self`/`&mut self`
27 changes: 22 additions & 5 deletions pyo3-macros-backend/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -473,8 +473,7 @@ impl<'a> FnSpec<'a> {
}

let rust_call = |args: Vec<TokenStream>| {
let mut call = quote! { function(#self_arg #(#args),*) };
if self.asyncness.is_some() {
let call = if self.asyncness.is_some() {
let throw_callback = if cancel_handle.is_some() {
quote! { Some(__throw_callback) }
} else {
Expand All @@ -485,8 +484,23 @@ impl<'a> FnSpec<'a> {
Some(cls) => quote!(Some(<#cls as _pyo3::PyTypeInfo>::NAME)),
None => quote!(None),
};
call = quote! {{
let future = #call;
let future = match self.tp {
FnType::Fn(SelfType::Receiver { mutable: false, .. }) => quote! {
_pyo3::impl_::coroutine::ref_method_future(
py.from_borrowed_ptr::<_pyo3::types::PyAny>(_slf),
move |__self| function(__self, #(#args),*)
)?
},
FnType::Fn(SelfType::Receiver { mutable: true, .. }) => quote! {
_pyo3::impl_::coroutine::mut_method_future(
py.from_borrowed_ptr::<_pyo3::types::PyAny>(_slf),
move |__self| function(__self, #(#args),*)
)?
},
_ => quote! { function(#self_arg #(#args),*) },
};
let mut call = quote! {{
let future = #future;
_pyo3::impl_::coroutine::new_coroutine(
_pyo3::intern!(py, stringify!(#python_name)),
#qualname_prefix,
Expand All @@ -501,7 +515,10 @@ impl<'a> FnSpec<'a> {
#call
}};
}
}
call
} else {
quote! { function(#self_arg #(#args),*) }
};
quotes::map_result_into_ptr(quotes::ok_wrap(call))
};

Expand Down
50 changes: 49 additions & 1 deletion src/impl_/coroutine.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
use std::future::Future;
use std::mem;

use crate::coroutine::cancel::ThrowCallback;
use crate::{coroutine::Coroutine, types::PyString, IntoPy, PyErr, PyObject};
use crate::pyclass::boolean_struct::False;
use crate::{
coroutine::Coroutine, types::PyString, IntoPy, Py, PyAny, PyCell, PyClass, PyErr, PyObject,
PyRef, PyRefMut, PyResult, Python,
};

pub fn new_coroutine<F, T, E>(
name: &PyString,
Expand All @@ -16,3 +21,46 @@ where
{
Coroutine::new(Some(name.into()), qualname_prefix, throw_callback, future)
}

fn get_ptr<T: PyClass>(obj: &Py<T>) -> *mut T {
// SAFETY: Py<T> can be casted as *const PyCell<T>
unsafe { &*(obj.as_ptr() as *const PyCell<T>) }.get_ptr()
}

struct RefGuard<T: PyClass>(Py<T>);

impl<T: PyClass> Drop for RefGuard<T> {
fn drop(&mut self) {
Python::with_gil(|gil| self.0.as_ref(gil).release_ref())
}
}

pub unsafe fn ref_method_future<'a, T: PyClass, F: Future + 'a>(
self_: &PyAny,
fut: impl FnOnce(&'a T) -> F,
) -> PyResult<impl Future<Output = F::Output>> {
let ref_: PyRef<'_, T> = self_.extract()?;
// SAFETY: `PyRef::as_ptr` returns a borrowed reference
let guard = RefGuard(unsafe { Py::<T>::from_borrowed_ptr(self_.py(), ref_.as_ptr()) });
mem::forget(ref_);
Ok(async move { fut(unsafe { &*get_ptr(&guard.0) }).await })
}

struct RefMutGuard<T: PyClass>(Py<T>);

impl<T: PyClass> Drop for RefMutGuard<T> {
fn drop(&mut self) {
Python::with_gil(|gil| self.0.as_ref(gil).release_mut())
}
}

pub fn mut_method_future<'a, T: PyClass<Frozen = False>, F: Future + 'a>(
self_: &PyAny,
fut: impl FnOnce(&'a mut T) -> F,
) -> PyResult<impl Future<Output = F::Output>> {
let mut_: PyRefMut<'_, T> = self_.extract()?;
// SAFETY: `PyRefMut::as_ptr` returns a borrowed reference
let guard = RefMutGuard(unsafe { Py::<T>::from_borrowed_ptr(self_.py(), mut_.as_ptr()) });
mem::forget(mut_);
Ok(async move { fut(unsafe { &mut *get_ptr(&guard.0) }).await })
}
10 changes: 10 additions & 0 deletions src/pycell.rs
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,16 @@ impl<T: PyClass> PyCell<T> {
#[allow(clippy::useless_conversion)]
offset.try_into().expect("offset should fit in Py_ssize_t")
}

#[cfg(feature = "macros")]
pub(crate) fn release_ref(&self) {
self.borrow_checker().release_borrow();
}

#[cfg(feature = "macros")]
pub(crate) fn release_mut(&self) {
self.borrow_checker().release_borrow_mut();
}
}

impl<T: PyClassImpl> PyCell<T> {
Expand Down
53 changes: 53 additions & 0 deletions tests/test_coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,56 @@ fn coroutine_panic() {
py_run!(gil, panic, &handle_windows(test));
})
}

#[test]
fn test_async_method_receiver() {
#[pyclass]
struct Counter(usize);
#[pymethods]
impl Counter {
#[new]
fn new() -> Self {
Self(0)
}
async fn get(&self) -> usize {
self.0
}
async fn incr(&mut self) -> usize {
self.0 += 1;
self.0
}
}
Python::with_gil(|gil| {
let test = r#"
import asyncio
obj = Counter()
coro1 = obj.get()
coro2 = obj.get()
try:
obj.incr() # borrow checking should fail
except RuntimeError as err:
pass
else:
assert False
assert asyncio.run(coro1) == 0
coro2.close()
coro3 = obj.incr()
try:
obj.incr() # borrow checking should fail
except RuntimeError as err:
pass
else:
assert False
try:
obj.get() == 42 # borrow checking should fail
except RuntimeError as err:
pass
else:
assert False
assert asyncio.run(coro3) == 1
"#;
let locals = [("Counter", gil.get_type::<Counter>())].into_py_dict(gil);
py_run!(gil, *locals, test);
})
}

0 comments on commit 17182bb

Please sign in to comment.