diff --git a/guide/src/class.md b/guide/src/class.md index 8d567315239..78a9cee1cd3 100644 --- a/guide/src/class.md +++ b/guide/src/class.md @@ -808,7 +808,10 @@ It includes two methods `__iter__` and `__next__`: * `fn __iter__(slf: PyRefMut) -> PyResult>` * `fn __next__(slf: PyRefMut) -> PyResult>>` - Returning `Ok(None)` from `__next__` indicates that that there are no further items. +Returning `Ok(None)` from `__next__` indicates that that there are no further items. +These two methods can be take either `PyRef` or `PyRefMut` as their +first argument, so that mutable borrow can be avoided if needed. + Example: @@ -823,7 +826,7 @@ struct MyIterator { #[pyproto] impl PyIterProtocol for MyIterator { - fn __iter__(mut slf: PyRefMut) -> PyResult> { + fn __iter__(slf: PyRef) -> PyResult> { Ok(slf.into()) } fn __next__(mut slf: PyRefMut) -> PyResult> { diff --git a/pyo3-derive-backend/src/defs.rs b/pyo3-derive-backend/src/defs.rs index 02a5ca74df7..01240d44546 100644 --- a/pyo3-derive-backend/src/defs.rs +++ b/pyo3-derive-backend/src/defs.rs @@ -260,13 +260,15 @@ pub const ITER: Proto = Proto { name: "Iter", py_methods: &[], methods: &[ - MethodProto::Unary { + MethodProto::UnaryS { name: "__iter__", + arg: "Receiver", pyres: true, proto: "pyo3::class::iter::PyIterIterProtocol", }, - MethodProto::Unary { + MethodProto::UnaryS { name: "__next__", + arg: "Receiver", pyres: true, proto: "pyo3::class::iter::PyIterNextProtocol", }, diff --git a/pyo3-derive-backend/src/func.rs b/pyo3-derive-backend/src/func.rs index fbc48cb7b36..95539d8499f 100644 --- a/pyo3-derive-backend/src/func.rs +++ b/pyo3-derive-backend/src/func.rs @@ -18,6 +18,12 @@ pub enum MethodProto { pyres: bool, proto: &'static str, }, + UnaryS { + name: &'static str, + arg: &'static str, + pyres: bool, + proto: &'static str, + }, Binary { name: &'static str, arg: &'static str, @@ -60,6 +66,7 @@ impl MethodProto { match *self { MethodProto::Free { ref name, .. } => name, MethodProto::Unary { ref name, .. } => name, + MethodProto::UnaryS { ref name, .. } => name, MethodProto::Binary { ref name, .. } => name, MethodProto::BinaryS { ref name, .. } => name, MethodProto::Ternary { ref name, .. } => name, @@ -114,6 +121,58 @@ pub(crate) fn impl_method_proto( } } } + MethodProto::UnaryS { + pyres, proto, arg, .. + } => { + let p: syn::Path = syn::parse_str(proto).unwrap(); + let (ty, succ) = get_res_success(ty); + + let slf_name = syn::Ident::new(arg, Span::call_site()); + let mut slf_ty = get_arg_ty(sig, 0); + + // update the type if no lifetime was given: + // PyRef --> PyRef<'p, Self> + if let syn::Type::Path(ref mut path) = slf_ty { + if let syn::PathArguments::AngleBracketed(ref mut args) = + path.path.segments[0].arguments + { + if let syn::GenericArgument::Lifetime(_) = args.args[0] { + } else { + let lt = syn::parse_quote! {'p}; + args.args.insert(0, lt); + } + } + } + + let tmp: syn::ItemFn = syn::parse_quote! { + fn test(&self) -> <#cls as #p<'p>>::Result {} + }; + sig.output = tmp.sig.output; + modify_self_ty(sig); + + if let syn::FnArg::Typed(ref mut arg) = sig.inputs[0] { + arg.ty = Box::new(syn::parse_quote! { + <#cls as #p<'p>>::#slf_name + }); + } + + if pyres { + quote! { + impl<'p> #p<'p> for #cls { + type #slf_name = #slf_ty; + type Success = #succ; + type Result = #ty; + } + } + } else { + quote! { + impl<'p> #p<'p> for #cls { + type #slf_name = #slf_ty; + type Result = #ty; + } + } + } + } MethodProto::Binary { name, arg, diff --git a/src/class/iter.rs b/src/class/iter.rs index 0d011d2e529..08c12963df3 100644 --- a/src/class/iter.rs +++ b/src/class/iter.rs @@ -3,8 +3,9 @@ //! Trait and support implementation for implementing iterators use crate::callback::IntoPyCallbackOutput; +use crate::derive_utils::TryFromPyCell; use crate::err::PyResult; -use crate::{ffi, IntoPy, IntoPyPointer, PyClass, PyObject, PyRefMut, Python}; +use crate::{ffi, IntoPy, IntoPyPointer, PyClass, PyObject, Python}; /// Python Iterator Interface. /// @@ -12,14 +13,14 @@ use crate::{ffi, IntoPy, IntoPyPointer, PyClass, PyObject, PyRefMut, Python}; /// for more. #[allow(unused_variables)] pub trait PyIterProtocol<'p>: PyClass { - fn __iter__(slf: PyRefMut) -> Self::Result + fn __iter__(slf: Self::Receiver) -> Self::Result where Self: PyIterIterProtocol<'p>, { unimplemented!() } - fn __next__(slf: PyRefMut) -> Self::Result + fn __next__(slf: Self::Receiver) -> Self::Result where Self: PyIterNextProtocol<'p>, { @@ -28,11 +29,13 @@ pub trait PyIterProtocol<'p>: PyClass { } pub trait PyIterIterProtocol<'p>: PyIterProtocol<'p> { + type Receiver: TryFromPyCell<'p, Self>; type Success: crate::IntoPy; type Result: Into>; } pub trait PyIterNextProtocol<'p>: PyIterProtocol<'p> { + type Receiver: TryFromPyCell<'p, Self>; type Success: crate::IntoPy; type Result: Into>>; } @@ -76,7 +79,7 @@ where { #[inline] fn tp_iter() -> Option { - py_unary_refmut_func!(PyIterIterProtocol, T::__iter__) + py_unarys_func!(PyIterIterProtocol, T::__iter__) } } @@ -99,7 +102,7 @@ where { #[inline] fn tp_iternext() -> Option { - py_unary_refmut_func!(PyIterNextProtocol, T::__next__, IterNextConverter) + py_unarys_func!(PyIterNextProtocol, T::__next__, IterNextConverter) } } diff --git a/src/class/macros.rs b/src/class/macros.rs index 0d570e79112..f96611b7fdb 100644 --- a/src/class/macros.rs +++ b/src/class/macros.rs @@ -28,7 +28,7 @@ macro_rules! py_unary_func { #[macro_export] #[doc(hidden)] -macro_rules! py_unary_refmut_func { +macro_rules! py_unarys_func { ($trait:ident, $class:ident :: $f:ident $(, $conv:expr)?) => {{ unsafe extern "C" fn wrap(slf: *mut $crate::ffi::PyObject) -> *mut $crate::ffi::PyObject where @@ -38,7 +38,9 @@ macro_rules! py_unary_refmut_func { let py = pool.python(); $crate::run_callback(py, || { let slf = py.from_borrowed_ptr::<$crate::PyCell>(slf); - let res = $class::$f(slf.borrow_mut()).into(); + let borrow = ::try_from_pycell(slf) + .map_err(|e| e.into())?; + let res = $class::$f(borrow).into(); $crate::callback::convert(py, res $(.map($conv))?) }) } diff --git a/src/derive_utils.rs b/src/derive_utils.rs index ab180e8776f..e70d5842dea 100644 --- a/src/derive_utils.rs +++ b/src/derive_utils.rs @@ -4,13 +4,13 @@ //! Functionality for the code generated by the derive backend -use crate::err::PyResult; +use crate::err::{PyErr, PyResult}; use crate::exceptions::TypeError; use crate::instance::PyNativeType; use crate::pyclass::PyClass; use crate::pyclass_init::PyClassInitializer; use crate::types::{PyAny, PyDict, PyModule, PyTuple}; -use crate::{ffi, GILPool, IntoPy, PyObject, Python}; +use crate::{ffi, GILPool, IntoPy, PyCell, PyObject, Python}; use std::cell::UnsafeCell; /// Description of a python parameter; used for `parse_args()`. @@ -243,3 +243,25 @@ where { type Target = T; } + +/// A trait for types that can be borrowed from a cell. +/// +/// This serves to unify the use of `PyRef` and `PyRefMut` in automatically +/// derived code, since both types can be obtained from a `PyCell`. +#[doc(hidden)] +pub trait TryFromPyCell<'a, T: PyClass>: Sized { + type Error: Into; + fn try_from_pycell(cell: &'a crate::PyCell) -> Result; +} + +impl<'a, T, R> TryFromPyCell<'a, T> for R +where + T: 'a + PyClass, + R: std::convert::TryFrom<&'a PyCell>, + R::Error: Into, +{ + type Error = R::Error; + fn try_from_pycell(cell: &'a crate::PyCell) -> Result { + >>::try_from(cell) + } +} diff --git a/tests/test_dunder.rs b/tests/test_dunder.rs index 2ca49b303ea..3ae14908b52 100644 --- a/tests/test_dunder.rs +++ b/tests/test_dunder.rs @@ -53,11 +53,11 @@ struct Iterator { #[pyproto] impl<'p> PyIterProtocol for Iterator { - fn __iter__(slf: PyRefMut) -> PyResult> { + fn __iter__(slf: PyRef<'p, Self>) -> PyResult> { Ok(slf.into()) } - fn __next__(mut slf: PyRefMut) -> PyResult> { + fn __next__(mut slf: PyRefMut<'p, Self>) -> PyResult> { Ok(slf.iter.next()) } } diff --git a/tests/test_pyself.rs b/tests/test_pyself.rs index 66fa1f4e222..aabc5d9a4b5 100644 --- a/tests/test_pyself.rs +++ b/tests/test_pyself.rs @@ -55,7 +55,7 @@ struct Iter { #[pyproto] impl PyIterProtocol for Iter { - fn __iter__(slf: PyRefMut) -> PyResult { + fn __iter__(slf: PyRef) -> PyResult { let py = unsafe { Python::assume_gil_acquired() }; Ok(slf.into_py(py)) }