Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing non-mutable references to self in PyIterProtocol #856

Merged
merged 7 commits into from
Apr 19, 2020
7 changes: 5 additions & 2 deletions guide/src/class.md
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,10 @@ It includes two methods `__iter__` and `__next__`:
* `fn __iter__(slf: PyRefMut<Self>) -> PyResult<impl IntoPy<PyObject>>`
* `fn __next__(slf: PyRefMut<Self>) -> PyResult<Option<impl IntoPy<PyObject>>>`

Returning `Ok(None)` from `__next__` indicates that that there are no further items.
Returning `Ok(None)` from `__next__` indicates that that there are no further items.
These two methods can be take either `PyRef<Self>` or `PyRefMut<Self>` as their
first argument, so that mutable borrow can be avoided if needed.


Example:

Expand All @@ -823,7 +826,7 @@ struct MyIterator {

#[pyproto]
impl PyIterProtocol for MyIterator {
fn __iter__(mut slf: PyRefMut<Self>) -> PyResult<Py<MyIterator>> {
fn __iter__(slf: PyRef<Self>) -> PyResult<Py<MyIterator>> {
Ok(slf.into())
}
fn __next__(mut slf: PyRefMut<Self>) -> PyResult<Option<PyObject>> {
Expand Down
6 changes: 4 additions & 2 deletions pyo3-derive-backend/src/defs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,13 +260,15 @@ pub const ITER: Proto = Proto {
name: "Iter",
py_methods: &[],
methods: &[
MethodProto::Unary {
MethodProto::UnaryS {
name: "__iter__",
arg: "Receiver",
pyres: true,
proto: "pyo3::class::iter::PyIterIterProtocol",
},
MethodProto::Unary {
MethodProto::UnaryS {
name: "__next__",
arg: "Receiver",
pyres: true,
proto: "pyo3::class::iter::PyIterNextProtocol",
},
Expand Down
59 changes: 59 additions & 0 deletions pyo3-derive-backend/src/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ pub enum MethodProto {
pyres: bool,
proto: &'static str,
},
UnaryS {
name: &'static str,
arg: &'static str,
pyres: bool,
proto: &'static str,
},
Binary {
name: &'static str,
arg: &'static str,
Expand Down Expand Up @@ -60,6 +66,7 @@ impl MethodProto {
match *self {
MethodProto::Free { ref name, .. } => name,
MethodProto::Unary { ref name, .. } => name,
MethodProto::UnaryS { ref name, .. } => name,
MethodProto::Binary { ref name, .. } => name,
MethodProto::BinaryS { ref name, .. } => name,
MethodProto::Ternary { ref name, .. } => name,
Expand Down Expand Up @@ -114,6 +121,58 @@ pub(crate) fn impl_method_proto(
}
}
}
MethodProto::UnaryS {
pyres, proto, arg, ..
} => {
let p: syn::Path = syn::parse_str(proto).unwrap();
let (ty, succ) = get_res_success(ty);

let slf_name = syn::Ident::new(arg, Span::call_site());
let mut slf_ty = get_arg_ty(sig, 0);

// update the type if no lifetime was given:
// PyRef<Self> --> PyRef<'p, Self>
if let syn::Type::Path(ref mut path) = slf_ty {
if let syn::PathArguments::AngleBracketed(ref mut args) =
path.path.segments[0].arguments
{
if let syn::GenericArgument::Lifetime(_) = args.args[0] {
} else {
let lt = syn::parse_quote! {'p};
args.args.insert(0, lt);
}
}
}

let tmp: syn::ItemFn = syn::parse_quote! {
fn test(&self) -> <#cls as #p<'p>>::Result {}
};
sig.output = tmp.sig.output;
modify_self_ty(sig);

if let syn::FnArg::Typed(ref mut arg) = sig.inputs[0] {
arg.ty = Box::new(syn::parse_quote! {
<#cls as #p<'p>>::#slf_name
});
}

if pyres {
quote! {
impl<'p> #p<'p> for #cls {
type #slf_name = #slf_ty;
type Success = #succ;
type Result = #ty;
}
}
} else {
quote! {
impl<'p> #p<'p> for #cls {
type #slf_name = #slf_ty;
type Result = #ty;
}
}
}
}
MethodProto::Binary {
name,
arg,
Expand Down
13 changes: 8 additions & 5 deletions src/class/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,24 @@
//! Trait and support implementation for implementing iterators

use crate::callback::IntoPyCallbackOutput;
use crate::derive_utils::TryFromPyCell;
use crate::err::PyResult;
use crate::{ffi, IntoPy, IntoPyPointer, PyClass, PyObject, PyRefMut, Python};
use crate::{ffi, IntoPy, IntoPyPointer, PyClass, PyObject, Python};

/// Python Iterator Interface.
///
/// Check [CPython doc](https://docs.python.org/3/c-api/typeobj.html#c.PyTypeObject.tp_iter)
/// for more.
#[allow(unused_variables)]
pub trait PyIterProtocol<'p>: PyClass {
fn __iter__(slf: PyRefMut<Self>) -> Self::Result
fn __iter__(slf: Self::Receiver) -> Self::Result
where
Self: PyIterIterProtocol<'p>,
{
unimplemented!()
}

fn __next__(slf: PyRefMut<Self>) -> Self::Result
fn __next__(slf: Self::Receiver) -> Self::Result
where
Self: PyIterNextProtocol<'p>,
{
Expand All @@ -28,11 +29,13 @@ pub trait PyIterProtocol<'p>: PyClass {
}

pub trait PyIterIterProtocol<'p>: PyIterProtocol<'p> {
type Receiver: TryFromPyCell<'p, Self>;
type Success: crate::IntoPy<PyObject>;
type Result: Into<PyResult<Self::Success>>;
}

pub trait PyIterNextProtocol<'p>: PyIterProtocol<'p> {
type Receiver: TryFromPyCell<'p, Self>;
type Success: crate::IntoPy<PyObject>;
type Result: Into<PyResult<Option<Self::Success>>>;
}
Expand Down Expand Up @@ -76,7 +79,7 @@ where
{
#[inline]
fn tp_iter() -> Option<ffi::getiterfunc> {
py_unary_refmut_func!(PyIterIterProtocol, T::__iter__)
py_unarys_func!(PyIterIterProtocol, T::__iter__)
}
}

Expand All @@ -99,7 +102,7 @@ where
{
#[inline]
fn tp_iternext() -> Option<ffi::iternextfunc> {
py_unary_refmut_func!(PyIterNextProtocol, T::__next__, IterNextConverter)
py_unarys_func!(PyIterNextProtocol, T::__next__, IterNextConverter)
}
}

Expand Down
6 changes: 4 additions & 2 deletions src/class/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ macro_rules! py_unary_func {

#[macro_export]
#[doc(hidden)]
macro_rules! py_unary_refmut_func {
macro_rules! py_unarys_func {
($trait:ident, $class:ident :: $f:ident $(, $conv:expr)?) => {{
unsafe extern "C" fn wrap<T>(slf: *mut $crate::ffi::PyObject) -> *mut $crate::ffi::PyObject
where
Expand All @@ -38,7 +38,9 @@ macro_rules! py_unary_refmut_func {
let py = pool.python();
$crate::run_callback(py, || {
let slf = py.from_borrowed_ptr::<$crate::PyCell<T>>(slf);
let res = $class::$f(slf.borrow_mut()).into();
let borrow = <T::Receiver>::try_from_pycell(slf)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess just try_from_pycell(slf)? work?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without the associated <T::Receiver>, the compiler complains about the next line and can't guess the type of e in the map_err method, so I prefer to leave it like that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, OK 👍

.map_err(|e| e.into())?;
let res = $class::$f(borrow).into();
$crate::callback::convert(py, res $(.map($conv))?)
})
}
Expand Down
26 changes: 24 additions & 2 deletions src/derive_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

//! Functionality for the code generated by the derive backend

use crate::err::PyResult;
use crate::err::{PyErr, PyResult};
use crate::exceptions::TypeError;
use crate::instance::PyNativeType;
use crate::pyclass::PyClass;
use crate::pyclass_init::PyClassInitializer;
use crate::types::{PyAny, PyDict, PyModule, PyTuple};
use crate::{ffi, GILPool, IntoPy, PyObject, Python};
use crate::{ffi, GILPool, IntoPy, PyCell, PyObject, Python};
use std::cell::UnsafeCell;

/// Description of a python parameter; used for `parse_args()`.
Expand Down Expand Up @@ -243,3 +243,25 @@ where
{
type Target = T;
}

/// A trait for types that can be borrowed from a cell.
///
/// This serves to unify the use of `PyRef` and `PyRefMut` in automatically
/// derived code, since both types can be obtained from a `PyCell`.
#[doc(hidden)]
pub trait TryFromPyCell<'a, T: PyClass>: Sized {
type Error: Into<PyErr>;
fn try_from_pycell(cell: &'a crate::PyCell<T>) -> Result<Self, Self::Error>;
}

impl<'a, T, R> TryFromPyCell<'a, T> for R
where
T: 'a + PyClass,
R: std::convert::TryFrom<&'a PyCell<T>>,
R::Error: Into<PyErr>,
{
type Error = R::Error;
fn try_from_pycell(cell: &'a crate::PyCell<T>) -> Result<Self, Self::Error> {
<R as std::convert::TryFrom<&'a PyCell<T>>>::try_from(cell)
}
}
4 changes: 2 additions & 2 deletions tests/test_dunder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ struct Iterator {

#[pyproto]
impl<'p> PyIterProtocol for Iterator {
fn __iter__(slf: PyRefMut<Self>) -> PyResult<Py<Iterator>> {
fn __iter__(slf: PyRef<'p, Self>) -> PyResult<Py<Iterator>> {
Ok(slf.into())
}

fn __next__(mut slf: PyRefMut<Self>) -> PyResult<Option<i32>> {
fn __next__(mut slf: PyRefMut<'p, Self>) -> PyResult<Option<i32>> {
Ok(slf.iter.next())
}
}
Expand Down
2 changes: 1 addition & 1 deletion tests/test_pyself.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ struct Iter {

#[pyproto]
impl PyIterProtocol for Iter {
fn __iter__(slf: PyRefMut<Self>) -> PyResult<PyObject> {
fn __iter__(slf: PyRef<Self>) -> PyResult<PyObject> {
let py = unsafe { Python::assume_gil_acquired() };
Ok(slf.into_py(py))
}
Expand Down