Skip to content

Commit

Permalink
Replace IterNextOutput by autoref-based specialization to allow retur…
Browse files Browse the repository at this point in the history
…ning arbitrary values
  • Loading branch information
adamreichold committed Dec 19, 2023
1 parent ff50285 commit 5121b65
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 132 deletions.
31 changes: 24 additions & 7 deletions pyo3-macros-backend/src/pymethod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -787,9 +787,11 @@ pub const __RICHCMP__: SlotDef = SlotDef::new("Py_tp_richcompare", "richcmpfunc"
const __GET__: SlotDef = SlotDef::new("Py_tp_descr_get", "descrgetfunc")
.arguments(&[Ty::MaybeNullObject, Ty::MaybeNullObject]);
const __ITER__: SlotDef = SlotDef::new("Py_tp_iter", "getiterfunc");
const __NEXT__: SlotDef = SlotDef::new("Py_tp_iternext", "iternextfunc").return_conversion(
TokenGenerator(|| quote! { _pyo3::class::iter::IterNextOutput::<_, _> }),
);
const __NEXT__: SlotDef = SlotDef::new("Py_tp_iternext", "iternextfunc")
.return_specialized_conversion(
TokenGenerator(|| quote! { IterBaseKind, IterOptionKind, IterResultOptionKind }),
TokenGenerator(|| quote! { iter_tag }),
);
const __AWAIT__: SlotDef = SlotDef::new("Py_am_await", "unaryfunc");
const __AITER__: SlotDef = SlotDef::new("Py_am_aiter", "unaryfunc");
const __ANEXT__: SlotDef = SlotDef::new("Py_am_anext", "unaryfunc").return_conversion(
Expand Down Expand Up @@ -987,17 +989,23 @@ fn extract_object(
enum ReturnMode {
ReturnSelf,
Conversion(TokenGenerator),
SpecializedConversion(TokenGenerator, TokenGenerator),
}

impl ReturnMode {
fn return_call_output(&self, call: TokenStream) -> TokenStream {
match self {
ReturnMode::Conversion(conversion) => quote! {
let _result: _pyo3::PyResult<#conversion> = #call;
let _result: _pyo3::PyResult<#conversion> = _pyo3::callback::convert(py, #call);
_pyo3::callback::convert(py, _result)
},
ReturnMode::SpecializedConversion(traits, tag) => quote! {
let _result = #call;
use _pyo3::callback::{#traits};
(&_result).#tag().convert(py, _result)
},
ReturnMode::ReturnSelf => quote! {
let _result: _pyo3::PyResult<()> = #call;
let _result: _pyo3::PyResult<()> = _pyo3::callback::convert(py, #call);
_result?;
_pyo3::ffi::Py_XINCREF(_raw_slf);
::std::result::Result::Ok(_raw_slf)
Expand Down Expand Up @@ -1046,6 +1054,15 @@ impl SlotDef {
self
}

const fn return_specialized_conversion(
mut self,
traits: TokenGenerator,
tag: TokenGenerator,
) -> Self {
self.return_mode = Some(ReturnMode::SpecializedConversion(traits, tag));
self
}

const fn extract_error_mode(mut self, extract_error_mode: ExtractErrorMode) -> Self {
self.extract_error_mode = extract_error_mode;
self
Expand Down Expand Up @@ -1142,11 +1159,11 @@ fn generate_method_body(
let self_arg = spec.tp.self_arg(Some(cls), extract_error_mode);
let rust_name = spec.name;
let args = extract_proto_arguments(spec, arguments, extract_error_mode)?;
let call = quote! { _pyo3::callback::convert(py, #cls::#rust_name(#self_arg #(#args),*)) };
let call = quote! { #cls::#rust_name(#self_arg #(#args),*) };
Ok(if let Some(return_mode) = return_mode {
return_mode.return_call_output(call)
} else {
call
quote! { _pyo3::callback::convert(py, #call) }
})
}

Expand Down
17 changes: 8 additions & 9 deletions pytests/src/awaitable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
//! when awaited, see guide examples related to pyo3-asyncio for ways
//! to suspend tasks and await results.
use pyo3::{prelude::*, pyclass::IterNextOutput};
use pyo3::exceptions::PyStopIteration;
use pyo3::prelude::*;

#[pyclass]
#[derive(Debug)]
Expand All @@ -30,13 +31,13 @@ impl IterAwaitable {
pyself
}

fn __next__(&mut self, py: Python<'_>) -> PyResult<IterNextOutput<PyObject, PyObject>> {
fn __next__(&mut self, py: Python<'_>) -> PyResult<PyObject> {
match self.result.take() {
Some(res) => match res {
Ok(v) => Ok(IterNextOutput::Return(v)),
Ok(v) => Err(PyStopIteration::new_err(v)),
Err(err) => Err(err),
},
_ => Ok(IterNextOutput::Yield(py.None().into())),
_ => Ok(py.None().into()),
}
}
}
Expand Down Expand Up @@ -66,15 +67,13 @@ impl FutureAwaitable {
pyself
}

fn __next__(
mut pyself: PyRefMut<'_, Self>,
) -> PyResult<IterNextOutput<PyRefMut<'_, Self>, PyObject>> {
fn __next__(mut pyself: PyRefMut<'_, Self>) -> PyResult<PyRefMut<'_, Self>> {
match pyself.result {
Some(_) => match pyself.result.take().unwrap() {
Ok(v) => Ok(IterNextOutput::Return(v)),
Ok(v) => Err(PyStopIteration::new_err(v)),
Err(err) => Err(err),
},
_ => Ok(IterNextOutput::Yield(pyself)),
_ => Ok(pyself),
}
}
}
Expand Down
9 changes: 4 additions & 5 deletions pytests/src/pyclasses.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use pyo3::exceptions::PyValueError;
use pyo3::iter::IterNextOutput;
use pyo3::exceptions::{PyStopIteration, PyValueError};
use pyo3::prelude::*;
use pyo3::types::PyType;

Expand Down Expand Up @@ -28,12 +27,12 @@ impl PyClassIter {
Default::default()
}

fn __next__(&mut self) -> IterNextOutput<usize, &'static str> {
fn __next__(&mut self) -> PyResult<usize> {
if self.count < 5 {
self.count += 1;
IterNextOutput::Yield(self.count)
Ok(self.count)
} else {
IterNextOutput::Return("Ended")
Err(PyStopIteration::new_err("Ended"))
}
}
}
Expand Down
84 changes: 84 additions & 0 deletions src/callback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::ffi::{self, Py_hash_t};
use crate::{IntoPy, PyObject, Python};
use std::isize;
use std::os::raw::c_int;
use std::ptr::null_mut;

/// A type which can be the return type of a python C-API callback
pub trait PyCallbackOutput: Copy {
Expand Down Expand Up @@ -176,3 +177,86 @@ where
{
value.convert(py)
}

// Autoref-based specialization to allow deprecation of __next__ returning `Option`

#[doc(hidden)]
pub struct IterBaseTag;

impl IterBaseTag {
#[inline]
pub fn convert<Value, Target>(self, py: Python<'_>, value: Value) -> PyResult<Target>
where
Value: IntoPyCallbackOutput<Target>,
{
value.convert(py)
}
}

#[doc(hidden)]
pub trait IterBaseKind {
fn iter_tag(&self) -> IterBaseTag {
IterBaseTag
}
}

impl<Value> IterBaseKind for &Value {}

#[doc(hidden)]
pub struct IterOptionTag;

impl IterOptionTag {
#[inline]
pub fn convert<Value>(
self,
py: Python<'_>,
value: Option<Value>,
) -> PyResult<*mut ffi::PyObject>
where
Value: IntoPyCallbackOutput<*mut ffi::PyObject>,
{
match value {
Some(value) => value.convert(py),
None => Ok(null_mut()),
}
}
}

#[doc(hidden)]
pub trait IterOptionKind {
fn iter_tag(&self) -> IterOptionTag {
IterOptionTag
}
}

impl<Value> IterOptionKind for Option<Value> {}

#[doc(hidden)]
pub struct IterResultOptionTag;

impl IterResultOptionTag {
#[inline]
pub fn convert<Value>(
self,
py: Python<'_>,
value: PyResult<Option<Value>>,
) -> PyResult<*mut ffi::PyObject>
where
Value: IntoPyCallbackOutput<*mut ffi::PyObject>,
{
match value {
Ok(Some(value)) => value.convert(py),
Ok(None) => Ok(null_mut()),
Err(err) => Err(err),
}
}
}

#[doc(hidden)]
pub trait IterResultOptionKind {
fn iter_tag(&self) -> IterResultOptionTag {
IterResultOptionTag
}
}

impl<Value> IterResultOptionKind for PyResult<Option<Value>> {}
26 changes: 7 additions & 19 deletions src/coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ use crate::{
coroutine::{cancel::ThrowCallback, waker::AsyncioWaker},
exceptions::{PyAttributeError, PyRuntimeError, PyStopIteration},
panic::PanicException,
pyclass::IterNextOutput,
types::{PyIterator, PyString},
IntoPy, Py, PyAny, PyErr, PyObject, PyResult, Python,
};
Expand Down Expand Up @@ -68,11 +67,7 @@ impl Coroutine {
}
}

fn poll(
&mut self,
py: Python<'_>,
throw: Option<PyObject>,
) -> PyResult<IterNextOutput<PyObject, PyObject>> {
fn poll(&mut self, py: Python<'_>, throw: Option<PyObject>) -> PyResult<PyObject> {
// raise if the coroutine has already been run to completion
let future_rs = match self.future {
Some(ref mut fut) => fut,
Expand Down Expand Up @@ -100,7 +95,7 @@ impl Coroutine {
match panic::catch_unwind(panic::AssertUnwindSafe(poll)) {
Ok(Poll::Ready(res)) => {
self.close();
return Ok(IterNextOutput::Return(res?));
return Err(PyStopIteration::new_err(res?));
}
Err(err) => {
self.close();
Expand All @@ -115,19 +110,12 @@ impl Coroutine {
if let Some(future) = PyIterator::from_object(future).unwrap().next() {
// future has not been leaked into Python for now, and Rust code can only call
// `set_result(None)` in `Wake` implementation, so it's safe to unwrap
return Ok(IterNextOutput::Yield(future.unwrap().into()));
return Ok(future.unwrap().into());
}
}
// if waker has been waken during future polling, this is roughly equivalent to
// `await asyncio.sleep(0)`, so just yield `None`.
Ok(IterNextOutput::Yield(py.None().into()))
}
}

pub(crate) fn iter_result(result: IterNextOutput<PyObject, PyObject>) -> PyResult<PyObject> {
match result {
IterNextOutput::Yield(ob) => Ok(ob),
IterNextOutput::Return(ob) => Err(PyStopIteration::new_err(ob)),
Ok(py.None().into())
}
}

Expand All @@ -153,11 +141,11 @@ impl Coroutine {
}

fn send(&mut self, py: Python<'_>, _value: &PyAny) -> PyResult<PyObject> {
iter_result(self.poll(py, None)?)
self.poll(py, None)
}

fn throw(&mut self, py: Python<'_>, exc: PyObject) -> PyResult<PyObject> {
iter_result(self.poll(py, Some(exc))?)
self.poll(py, Some(exc))
}

fn close(&mut self) {
Expand All @@ -170,7 +158,7 @@ impl Coroutine {
self_
}

fn __next__(&mut self, py: Python<'_>) -> PyResult<IterNextOutput<PyObject, PyObject>> {
fn __next__(&mut self, py: Python<'_>) -> PyResult<PyObject> {
self.poll(py, None)
}
}
11 changes: 0 additions & 11 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,17 +354,6 @@ pub mod class {
pub use crate::pyclass::{IterANextOutput, PyIterANextOutput};
}

/// Old module which contained some implementation details of the `#[pyproto]` module.
///
/// Prefer using the same content from `pyo3::pyclass`, e.g. `use pyo3::pyclass::IterNextOutput` instead
/// of `use pyo3::class::pyasync::IterNextOutput`.
///
/// For compatibility reasons this has not yet been removed, however will be done so
/// once <https://github.com/rust-lang/rust/issues/30827> is resolved.
pub mod iter {
pub use crate::pyclass::{IterNextOutput, PyIterNextOutput};
}

/// Old module which contained some implementation details of the `#[pyproto]` module.
///
/// Prefer using the same content from `pyo3::pyclass`, e.g. `use pyo3::pyclass::PyTraverseError` instead
Expand Down
Loading

0 comments on commit 5121b65

Please sign in to comment.