Skip to content

Commit

Permalink
feat: add coroutine::CancelHandle
Browse files Browse the repository at this point in the history
  • Loading branch information
wyfo committed Nov 30, 2023
1 parent 81ad2e8 commit d66be8b
Show file tree
Hide file tree
Showing 14 changed files with 286 additions and 21 deletions.
21 changes: 19 additions & 2 deletions guide/src/async-await.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,27 @@ where

## Cancellation

*To be implemented*
Cancellation on the Python side can be caught using [`CancelHandle`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.CancelHandle.html) type, by annotating a function parameter with `#[pyo3(cancel_handle)].

```rust
# #![allow(dead_code)]
use futures::FutureExt;
use pyo3::prelude::*;
use pyo3::coroutine::CancelHandle;

#[pyfunction]
async fn cancellable(#[pyo3(cancel_handle)]mut cancel: CancelHandle) {
futures::select! {
/* _ = ... => println!("done"), */
_ = cancel.cancelled().fuse() => println!("cancelled"),
}
}
```

## The `Coroutine` type

To make a Rust future awaitable in Python, PyO3 defines a [`Coroutine`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.Coroutine.html) type, which implements the Python [coroutine protocol](https://docs.python.org/3/library/collections.abc.html#collections.abc.Coroutine). Each `coroutine.send` call is translated to `Future::poll` call, while `coroutine.throw` call reraise the exception *(this behavior will be configurable with cancellation support)*.
To make a Rust future awaitable in Python, PyO3 defines a [`Coroutine`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.Coroutine.html) type, which implements the Python [coroutine protocol](https://docs.python.org/3/library/collections.abc.html#collections.abc.Coroutine).

Each `coroutine.send` call is translated to `Future::poll` call. If a [`CancelHandle`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.CancelHandle.html) parameter is declared, the exception passed to `coroutine.throw` call is stored in it and can be retrieved with [`CancelHandle::cancelled`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.CancelHandle.html#method.cancelled); otherwise, it cancels the Rust future, and the exception is reraised;

*The type does not yet have a public constructor until the design is finalized.*
1 change: 1 addition & 0 deletions newsfragments/3599.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `coroutine::CancelHandle` to catch coroutine cancellation
1 change: 1 addition & 0 deletions pyo3-macros-backend/src/attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use syn::{
pub mod kw {
syn::custom_keyword!(annotation);
syn::custom_keyword!(attribute);
syn::custom_keyword!(cancel_handle);
syn::custom_keyword!(dict);
syn::custom_keyword!(extends);
syn::custom_keyword!(freelist);
Expand Down
52 changes: 45 additions & 7 deletions pyo3-macros-backend/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pub struct FnArg<'a> {
pub attrs: PyFunctionArgPyO3Attributes,
pub is_varargs: bool,
pub is_kwargs: bool,
pub is_cancel_handle: bool,
}

impl<'a> FnArg<'a> {
Expand All @@ -44,6 +45,7 @@ impl<'a> FnArg<'a> {
other => return Err(handle_argument_error(other)),
};

let is_cancel_handle = arg_attrs.cancel_handle.is_some();
Ok(FnArg {
name: ident,
ty: &cap.ty,
Expand All @@ -53,6 +55,7 @@ impl<'a> FnArg<'a> {
attrs: arg_attrs,
is_varargs: false,
is_kwargs: false,
is_cancel_handle,
})
}
}
Expand Down Expand Up @@ -455,9 +458,27 @@ impl<'a> FnSpec<'a> {
let self_arg = self.tp.self_arg(cls, ExtractErrorMode::Raise);
let func_name = &self.name;

let mut cancel_handle_iter = self
.signature
.arguments
.iter()
.filter(|arg| arg.is_cancel_handle);
let cancel_handle = cancel_handle_iter.next();
if let Some(arg) = cancel_handle {
ensure_spanned!(self.asyncness.is_some(), arg.name.span() => "`cancel_handle` attribute can only be used with `async fn`");
if let Some(arg2) = cancel_handle_iter.next() {
bail_spanned!(arg2.name.span() => "`cancel_handle` may only be specified once");
}
}

let rust_call = |args: Vec<TokenStream>| {
let mut call = quote! { function(#self_arg #(#args),*) };
if self.asyncness.is_some() {
let throw_callback = if cancel_handle.is_some() {
quote! { Some(__throw_callback) }
} else {
quote! { None }
};
let python_name = &self.python_name;
let qualname_prefix = match cls {
Some(cls) => quote!(Some(<#cls as _pyo3::PyTypeInfo>::NAME)),
Expand All @@ -468,9 +489,17 @@ impl<'a> FnSpec<'a> {
_pyo3::impl_::coroutine::new_coroutine(
_pyo3::intern!(py, stringify!(#python_name)),
#qualname_prefix,
async move { _pyo3::impl_::wrap::OkWrap::wrap(future.await) }
#throw_callback,
async move { _pyo3::impl_::wrap::OkWrap::wrap(future.await) },
)
}};
if cancel_handle.is_some() {
call = quote! {{
let __cancel_handle = _pyo3::coroutine::CancelHandle::new();
let __throw_callback = __cancel_handle.throw_callback();
#call
}};
}
}
quotes::map_result_into_ptr(quotes::ok_wrap(call))
};
Expand All @@ -483,12 +512,21 @@ impl<'a> FnSpec<'a> {

Ok(match self.convention {
CallingConvention::Noargs => {
let call = if !self.signature.arguments.is_empty() {
// Only `py` arg can be here
rust_call(vec![quote!(py)])
} else {
rust_call(vec![])
};
let args = self
.signature
.arguments
.iter()
.map(|arg| {
if arg.py {
quote!(py)
} else if arg.is_cancel_handle {
quote!(__cancel_handle)
} else {
unreachable!()
}
})
.collect();
let call = rust_call(args);

quote! {
unsafe fn #ident<'py>(
Expand Down
4 changes: 4 additions & 0 deletions pyo3-macros-backend/src/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ fn impl_arg_param(
return Ok(quote! { py });
}

if arg.is_cancel_handle {
return Ok(quote! { __cancel_handle });
}

let name = arg.name;
let name_str = name.to_string();

Expand Down
22 changes: 20 additions & 2 deletions pyo3-macros-backend/src/pyfunction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,20 @@ pub use self::signature::{FunctionSignature, SignatureAttribute};
#[derive(Clone, Debug)]
pub struct PyFunctionArgPyO3Attributes {
pub from_py_with: Option<FromPyWithAttribute>,
pub cancel_handle: Option<attributes::kw::cancel_handle>,
}

enum PyFunctionArgPyO3Attribute {
FromPyWith(FromPyWithAttribute),
CancelHandle(attributes::kw::cancel_handle),
}

impl Parse for PyFunctionArgPyO3Attribute {
fn parse(input: ParseStream<'_>) -> Result<Self> {
let lookahead = input.lookahead1();
if lookahead.peek(attributes::kw::from_py_with) {
if lookahead.peek(attributes::kw::cancel_handle) {
input.parse().map(PyFunctionArgPyO3Attribute::CancelHandle)
} else if lookahead.peek(attributes::kw::from_py_with) {
input.parse().map(PyFunctionArgPyO3Attribute::FromPyWith)
} else {
Err(lookahead.error())
Expand All @@ -43,7 +47,10 @@ impl Parse for PyFunctionArgPyO3Attribute {
impl PyFunctionArgPyO3Attributes {
/// Parses #[pyo3(from_python_with = "func")]
pub fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Self> {
let mut attributes = PyFunctionArgPyO3Attributes { from_py_with: None };
let mut attributes = PyFunctionArgPyO3Attributes {
from_py_with: None,
cancel_handle: None,
};
take_attributes(attrs, |attr| {
if let Some(pyo3_attrs) = get_pyo3_options(attr)? {
for attr in pyo3_attrs {
Expand All @@ -55,7 +62,18 @@ impl PyFunctionArgPyO3Attributes {
);
attributes.from_py_with = Some(from_py_with);
}
PyFunctionArgPyO3Attribute::CancelHandle(cancel_handle) => {
ensure_spanned!(
attributes.cancel_handle.is_none(),
cancel_handle.span() => "`cancel_handle` may only be specified once per argument"
);
attributes.cancel_handle = Some(cancel_handle);
}
}
ensure_spanned!(
attributes.from_py_with.is_none() || attributes.cancel_handle.is_none(),
attributes.cancel_handle.unwrap().span() => "`from_py_with` and `cancel_handle` cannot be specified together"
);
}
Ok(true)
} else {
Expand Down
14 changes: 12 additions & 2 deletions pyo3-macros-backend/src/pyfunction/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,16 @@ impl<'a> FunctionSignature<'a> {
// Otherwise try next argument.
continue;
}
if fn_arg.is_cancel_handle {
// If the user incorrectly tried to include cancel: CoroutineCancel in the
// signature, give a useful error as a hint.
ensure_spanned!(
name != fn_arg.name,
name.span() => "`cancel_handle` argument must not be part of the signature"
);
// Otherwise try next argument.
continue;
}

ensure_spanned!(
name == fn_arg.name,
Expand Down Expand Up @@ -411,7 +421,7 @@ impl<'a> FunctionSignature<'a> {
}

// Ensure no non-py arguments remain
if let Some(arg) = args_iter.find(|arg| !arg.py) {
if let Some(arg) = args_iter.find(|arg| !arg.py && !arg.is_cancel_handle) {
bail_spanned!(
attribute.kw.span() => format!("missing signature entry for argument `{}`", arg.name)
);
Expand All @@ -429,7 +439,7 @@ impl<'a> FunctionSignature<'a> {
let mut python_signature = PythonSignature::default();
for arg in &arguments {
// Python<'_> arguments don't show in Python signature
if arg.py {
if arg.py || arg.is_cancel_handle {
continue;
}

Expand Down
17 changes: 14 additions & 3 deletions src/coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@ use crate::{
IntoPy, Py, PyAny, PyErr, PyObject, PyResult, Python,
};

pub(crate) mod cancel;
mod waker;

use crate::coroutine::cancel::ThrowCallback;
pub use cancel::CancelHandle;

const COROUTINE_REUSED_ERROR: &str = "cannot reuse already awaited coroutine";

type FutureOutput = Result<PyResult<PyObject>, Box<dyn Any + Send>>;
Expand All @@ -32,6 +36,7 @@ type FutureOutput = Result<PyResult<PyObject>, Box<dyn Any + Send>>;
pub struct Coroutine {
name: Option<Py<PyString>>,
qualname_prefix: Option<&'static str>,
throw_callback: Option<ThrowCallback>,
future: Option<Pin<Box<dyn Future<Output = FutureOutput> + Send>>>,
waker: Option<Arc<AsyncioWaker>>,
}
Expand All @@ -46,6 +51,7 @@ impl Coroutine {
pub(crate) fn new<F, T, E>(
name: Option<Py<PyString>>,
qualname_prefix: Option<&'static str>,
throw_callback: Option<ThrowCallback>,
future: F,
) -> Self
where
Expand All @@ -61,6 +67,7 @@ impl Coroutine {
Self {
name,
qualname_prefix,
throw_callback,
future: Some(Box::pin(panic::AssertUnwindSafe(wrap).catch_unwind())),
waker: None,
}
Expand All @@ -77,9 +84,13 @@ impl Coroutine {
None => return Err(PyRuntimeError::new_err(COROUTINE_REUSED_ERROR)),
};
// reraise thrown exception it
if let Some(exc) = throw {
self.close();
return Err(PyErr::from_value(exc.as_ref(py)));
match (throw, &self.throw_callback) {
(Some(exc), Some(cb)) => cb.throw(exc.as_ref(py)),
(Some(exc), None) => {
self.close();
return Err(PyErr::from_value(exc.as_ref(py)));
}
_ => {}
}
// create a new waker, or try to reset it in place
if let Some(waker) = self.waker.as_mut().and_then(Arc::get_mut) {
Expand Down
74 changes: 74 additions & 0 deletions src/coroutine/cancel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
use crate::{ffi, Py, PyAny, PyObject};
use futures_util::future::poll_fn;
use futures_util::task::AtomicWaker;
use std::ptr;
use std::ptr::NonNull;
use std::sync::atomic::{AtomicPtr, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};

#[derive(Debug, Default)]
struct Inner {
exception: AtomicPtr<ffi::PyObject>,
waker: AtomicWaker,
}

/// Helper used to wait and retrieve exception thrown in [`Coroutine`](super::Coroutine).
///
/// Only the last exception thrown can be retrieved.
#[derive(Debug, Default)]
pub struct CancelHandle(Arc<Inner>);

impl CancelHandle {
/// Create a new `CoroutineCancel`.
pub fn new() -> Self {
Default::default()
}

/// Returns whether the associated coroutine has been cancelled.
pub fn is_cancelled(&self) -> bool {
!self.0.exception.load(Ordering::Relaxed).is_null()
}

/// Poll to retrieve the exception thrown in the associated coroutine.
pub fn poll_cancelled(&mut self, cx: &mut Context<'_>) -> Poll<PyObject> {
// SAFETY: only valid owned pointer are set in `ThrowCallback::throw`
let take = || unsafe {
// pointer cannot be null because it is checked the line before,
// and the swap is protected by `&mut self`
Py::from_non_null(
NonNull::new(self.0.exception.swap(ptr::null_mut(), Ordering::Relaxed)).unwrap(),
)
};
if self.is_cancelled() {
return Poll::Ready(take());
}
self.0.waker.register(cx.waker());
if self.is_cancelled() {
return Poll::Ready(take());
}
Poll::Pending
}

/// Retrieve the exception thrown in the associated coroutine.
pub async fn cancelled(&mut self) -> PyObject {
poll_fn(|cx| self.poll_cancelled(cx)).await
}

#[doc(hidden)]
pub fn throw_callback(&self) -> ThrowCallback {
ThrowCallback(self.0.clone())
}
}

#[doc(hidden)]
pub struct ThrowCallback(Arc<Inner>);

impl ThrowCallback {
pub(super) fn throw(&self, exc: &PyAny) {
let ptr = self.0.exception.swap(exc.into_ptr(), Ordering::Relaxed);
// SAFETY: non-null pointers set in `self.0.exceptions` are valid owned pointers
drop(unsafe { PyObject::from_owned_ptr_or_opt(exc.py(), ptr) });
self.0.waker.wake();
}
}
4 changes: 3 additions & 1 deletion src/impl_/coroutine.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
use std::future::Future;

use crate::coroutine::cancel::ThrowCallback;
use crate::{coroutine::Coroutine, types::PyString, IntoPy, PyErr, PyObject};

pub fn new_coroutine<F, T, E>(
name: &PyString,
qualname_prefix: Option<&'static str>,
throw_callback: Option<ThrowCallback>,
future: F,
) -> Coroutine
where
F: Future<Output = Result<T, E>> + Send + 'static,
T: IntoPy<PyObject>,
E: Into<PyErr>,
{
Coroutine::new(Some(name.into()), qualname_prefix, future)
Coroutine::new(Some(name.into()), qualname_prefix, throw_callback, future)
}
2 changes: 1 addition & 1 deletion src/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1038,7 +1038,7 @@ impl<T> Py<T> {
/// # Safety
/// `ptr` must point to a Python object of type T.
#[inline]
unsafe fn from_non_null(ptr: NonNull<ffi::PyObject>) -> Self {
pub(crate) unsafe fn from_non_null(ptr: NonNull<ffi::PyObject>) -> Self {
Self(ptr, PhantomData)
}

Expand Down
Loading

0 comments on commit d66be8b

Please sign in to comment.