Skip to content

Commit

Permalink
Allow PyRef<Self> as return type by inserting 'p
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Aug 8, 2020
1 parent c66bc54 commit be76f8c
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 21 deletions.
10 changes: 5 additions & 5 deletions guide/src/class.md
Original file line number Diff line number Diff line change
Expand Up @@ -923,8 +923,8 @@ struct MyIterator {

#[pyproto]
impl PyIterProtocol for MyIterator {
fn __iter__(slf: PyRef<Self>) -> Py<MyIterator> {
slf.into()
fn __iter__(slf: PyRef<Self>) -> PyRef<Self> {
slf
}
fn __next__(mut slf: PyRefMut<Self>) -> Option<PyObject> {
slf.iter.next()
Expand All @@ -948,8 +948,8 @@ struct Iter {

#[pyproto]
impl PyIterProtocol for Iter {
fn __iter__(slf: PyRefMut<Self>) -> Py<Iter> {
slf.into()
fn __iter__(slf: PyRef<Self>) -> PyRef<Self> {
slf
}

fn __next__(mut slf: PyRefMut<Self>) -> Option<usize> {
Expand All @@ -964,7 +964,7 @@ struct Container {

#[pyproto]
impl PyIterProtocol for Container {
fn __iter__(slf: PyRefMut<Self>) -> PyResult<Py<Iter>> {
fn __iter__(slf: PyRef<Self>) -> PyResult<Py<Iter>> {
let iter = Iter {
inner: slf.iter.clone().into_iter(),
};
Expand Down
40 changes: 26 additions & 14 deletions pyo3-derive-backend/src/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,12 @@ pub(crate) fn impl_method_proto(
) -> TokenStream {
let ret_ty = match &sig.output {
syn::ReturnType::Default => quote! { () },
syn::ReturnType::Type(_, ty) => ty.to_token_stream(),
syn::ReturnType::Type(_, ty) => {
let mut ty = ty.clone();
// Insert lifetime to PyRef (i.e., PyRef<Self> -> PyRef<'p, Self>)
insert_lifetime_to_pyref(&mut ty);
ty.to_token_stream()
}
};

match *meth {
Expand Down Expand Up @@ -108,19 +113,8 @@ pub(crate) fn impl_method_proto(
let slf_name = syn::Ident::new(arg, Span::call_site());
let mut slf_ty = get_arg_ty(sig, 0);

// update the type if no lifetime was given:
// PyRef<Self> --> PyRef<'p, Self>
if let syn::Type::Path(ref mut path) = slf_ty {
if let syn::PathArguments::AngleBracketed(ref mut args) =
path.path.segments[0].arguments
{
if let syn::GenericArgument::Lifetime(_) = args.args[0] {
} else {
let lt = syn::parse_quote! {'p};
args.args.insert(0, lt);
}
}
}
// Insert lifetime to PyRef (i.e., PyRef<Self> -> PyRef<'p, Self>)
insert_lifetime_to_pyref(&mut slf_ty);

let tmp: syn::ItemFn = syn::parse_quote! {
fn test(&self) -> <#cls as #p<'p>>::Result {}
Expand Down Expand Up @@ -370,6 +364,24 @@ fn get_arg_ty(sig: &syn::Signature, idx: usize) -> syn::Type {
ty
}

fn insert_lifetime_to_pyref(ty: &mut syn::Type) {
if let syn::Type::Path(ref mut path) = ty {
if let Some(seg) = path.path.segments.last_mut() {
if seg.ident != "PyRef" && seg.ident != "PyRefMut" {
return;
}
if let syn::PathArguments::AngleBracketed(ref mut args) = seg.arguments {
match args.args.first() {
Some(syn::GenericArgument::Lifetime(_)) => {}
_ => {
args.args.insert(0, syn::parse_quote! {'p});
}
}
}
}
}
}

fn extract_decl(spec: syn::Item) -> syn::Signature {
match spec {
syn::Item::Fn(f) => f.sig,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_dunder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ struct Iterator {

#[pyproto]
impl<'p> PyIterProtocol for Iterator {
fn __iter__(slf: PyRef<'p, Self>) -> Py<Iterator> {
slf.into()
fn __iter__(slf: PyRef<'p, Self>) -> PyRef<'p, Self> {
slf
}

fn __next__(mut slf: PyRefMut<'p, Self>) -> Option<i32> {
Expand Down

0 comments on commit be76f8c

Please sign in to comment.