Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve lifetime insertions for #[pyproto] #1093

Merged
merged 2 commits into from
Aug 11, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Don't rely on the order of structmembers to compute offsets in PyCell. Related to
[#1058](https://github.com/PyO3/pyo3/pull/1058). [#1059](https://github.com/PyO3/pyo3/pull/1059)
- Allows `&Self` as a `#[pymethods]` argument again. [#1071](https://github.com/PyO3/pyo3/pull/1071)
- Improve lifetime insertion in `#[pyproto]`. [#1093](https://github.com/PyO3/pyo3/pull/1093)
kngwyu marked this conversation as resolved.
Show resolved Hide resolved

## [0.11.1] - 2020-06-30
### Added
Expand Down
10 changes: 5 additions & 5 deletions guide/src/class.md
Original file line number Diff line number Diff line change
Expand Up @@ -923,8 +923,8 @@ struct MyIterator {

#[pyproto]
impl PyIterProtocol for MyIterator {
fn __iter__(slf: PyRef<Self>) -> Py<MyIterator> {
slf.into()
fn __iter__(slf: PyRef<Self>) -> PyRef<Self> {
slf
}
fn __next__(mut slf: PyRefMut<Self>) -> Option<PyObject> {
slf.iter.next()
Expand All @@ -948,8 +948,8 @@ struct Iter {

#[pyproto]
impl PyIterProtocol for Iter {
fn __iter__(slf: PyRefMut<Self>) -> Py<Iter> {
slf.into()
fn __iter__(slf: PyRef<Self>) -> PyRef<Self> {
slf
}

fn __next__(mut slf: PyRefMut<Self>) -> Option<usize> {
Expand All @@ -964,7 +964,7 @@ struct Container {

#[pyproto]
impl PyIterProtocol for Container {
fn __iter__(slf: PyRefMut<Self>) -> PyResult<Py<Iter>> {
fn __iter__(slf: PyRef<Self>) -> PyResult<Py<Iter>> {
let iter = Iter {
inner: slf.iter.clone().into_iter(),
};
Expand Down
2 changes: 1 addition & 1 deletion pyo3-derive-backend/src/defs.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Copyright (c) 2017-present PyO3 Project and Contributors
use crate::func::MethodProto;
use crate::proto_method::MethodProto;

/// Predicates for `#[pyproto]`.
pub struct Proto {
Expand Down
2 changes: 1 addition & 1 deletion pyo3-derive-backend/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
#![recursion_limit = "1024"]

mod defs;
mod func;
mod konst;
mod method;
mod module;
mod proto_method;
mod pyclass;
mod pyfunction;
mod pyimpl;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use syn::Token;

// TODO:
// Add lifetime support for args with Rptr

#[derive(Debug)]
pub enum MethodProto {
Free {
Expand Down Expand Up @@ -77,7 +76,11 @@ pub(crate) fn impl_method_proto(
) -> TokenStream {
let ret_ty = match &sig.output {
syn::ReturnType::Default => quote! { () },
syn::ReturnType::Type(_, ty) => ty.to_token_stream(),
syn::ReturnType::Type(_, ty) => {
let mut ty = ty.clone();
insert_lifetime(&mut ty);
ty.to_token_stream()
}
};

match *meth {
Expand Down Expand Up @@ -106,22 +109,7 @@ pub(crate) fn impl_method_proto(
let p: syn::Path = syn::parse_str(proto).unwrap();

let slf_name = syn::Ident::new(arg, Span::call_site());
let mut slf_ty = get_arg_ty(sig, 0);

// update the type if no lifetime was given:
// PyRef<Self> --> PyRef<'p, Self>
if let syn::Type::Path(ref mut path) = slf_ty {
if let syn::PathArguments::AngleBracketed(ref mut args) =
path.path.segments[0].arguments
{
if let syn::GenericArgument::Lifetime(_) = args.args[0] {
} else {
let lt = syn::parse_quote! {'p};
args.args.insert(0, lt);
}
}
}

let slf_ty = get_arg_ty(sig, 0);
let tmp: syn::ItemFn = syn::parse_quote! {
fn test(&self) -> <#cls as #p<'p>>::Result {}
};
Expand Down Expand Up @@ -336,38 +324,62 @@ pub(crate) fn impl_method_proto(
}
}

// TODO: better arg ty detection
/// Some hacks for arguments: get `T` from `Option<T>` and insert lifetime
fn get_arg_ty(sig: &syn::Signature, idx: usize) -> syn::Type {
let mut ty = match sig.inputs[idx] {
syn::FnArg::Typed(ref cap) => {
match *cap.ty {
syn::Type::Path(ref ty) => {
// use only last path segment for Option<>
let seg = ty.path.segments.last().unwrap().clone();
if seg.ident == "Option" {
if let syn::PathArguments::AngleBracketed(ref data) = seg.arguments {
if let Some(pair) = data.args.last() {
match pair {
syn::GenericArgument::Type(ref ty) => return ty.clone(),
_ => panic!("Option only accepted for concrete types"),
}
};
}
}
*cap.ty.clone()
fn get_option_ty(path: &syn::Path) -> Option<syn::Type> {
let seg = path.segments.last()?;
if seg.ident == "Option" {
if let syn::PathArguments::AngleBracketed(ref data) = seg.arguments {
if let Some(syn::GenericArgument::Type(ref ty)) = data.args.last() {
return Some(ty.to_owned());
}
_ => *cap.ty.clone(),
}
}
_ => panic!("fn arg type is not supported"),
None
}

let mut ty = match &sig.inputs[idx] {
syn::FnArg::Typed(ref cap) => match &*cap.ty {
// For `Option<T>`, we use `T` as an associated type for the protocol.
syn::Type::Path(ref ty) => get_option_ty(&ty.path).unwrap_or_else(|| *cap.ty.clone()),
_ => *cap.ty.clone(),
},
ty => panic!("Unsupported argument type: {:?}", ty),
};
insert_lifetime(&mut ty);
ty
}

// Add a lifetime if there is none
if let syn::Type::Reference(ref mut r) = ty {
r.lifetime.get_or_insert(syn::parse_quote! {'p});
/// Insert lifetime `'p` to `PyRef<Self>` or references (e.g., `&PyType`).
fn insert_lifetime(ty: &mut syn::Type) {
fn insert_lifetime_for_path(path: &mut syn::TypePath) {
if let Some(seg) = path.path.segments.last_mut() {
if let syn::PathArguments::AngleBracketed(ref mut args) = seg.arguments {
let mut has_lifetime = false;
for arg in &mut args.args {
match arg {
// Insert `'p` recursively for `Option<PyRef<Self>>` or so.
syn::GenericArgument::Type(ref mut ty) => insert_lifetime(ty),
syn::GenericArgument::Lifetime(_) => has_lifetime = true,
_ => {}
}
}
// Insert lifetime to PyRef (i.e., PyRef<Self> -> PyRef<'p, Self>)
if !has_lifetime && (seg.ident == "PyRef" || seg.ident == "PyRefMut") {
args.args.insert(0, syn::parse_quote! {'p});
}
}
}
}

ty
match ty {
syn::Type::Reference(ref mut r) => {
r.lifetime.get_or_insert(syn::parse_quote! {'p});
insert_lifetime(&mut *r.elem);
}
syn::Type::Path(ref mut path) => insert_lifetime_for_path(path),
_ => {}
}
}

fn extract_decl(spec: syn::Item) -> syn::Signature {
Expand Down
2 changes: 1 addition & 1 deletion pyo3-derive-backend/src/pyproto.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// Copyright (c) 2017-present PyO3 Project and Contributors

use crate::defs;
use crate::func::impl_method_proto;
use crate::method::{FnSpec, FnType};
use crate::proto_method::impl_method_proto;
use crate::pymethod;
use proc_macro2::{Span, TokenStream};
use quote::quote;
Expand Down
32 changes: 16 additions & 16 deletions tests/test_dunder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ struct Iterator {
}

#[pyproto]
impl<'p> PyIterProtocol for Iterator {
fn __iter__(slf: PyRef<'p, Self>) -> Py<Iterator> {
slf.into()
impl PyIterProtocol for Iterator {
fn __iter__(slf: PyRef<Self>) -> PyRef<Self> {
slf
}

fn __next__(mut slf: PyRefMut<'p, Self>) -> Option<i32> {
fn __next__(mut slf: PyRefMut<Self>) -> Option<i32> {
slf.iter.next()
}
}
Expand All @@ -81,7 +81,7 @@ fn iterator() {
struct StringMethods {}

#[pyproto]
impl<'p> PyObjectProtocol<'p> for StringMethods {
impl PyObjectProtocol for StringMethods {
fn __str__(&self) -> &'static str {
"str"
}
Expand Down Expand Up @@ -236,7 +236,7 @@ struct SetItem {
}

#[pyproto]
impl PyMappingProtocol<'a> for SetItem {
impl PyMappingProtocol for SetItem {
fn __setitem__(&mut self, key: i32, val: i32) {
self.key = key;
self.val = val;
Expand Down Expand Up @@ -362,16 +362,16 @@ struct ContextManager {
}

#[pyproto]
impl<'p> PyContextProtocol<'p> for ContextManager {
impl PyContextProtocol for ContextManager {
fn __enter__(&mut self) -> i32 {
42
}

fn __exit__(
&mut self,
ty: Option<&'p PyType>,
_value: Option<&'p PyAny>,
_traceback: Option<&'p PyAny>,
ty: Option<&PyType>,
_value: Option<&PyAny>,
_traceback: Option<&PyAny>,
) -> bool {
let gil = Python::acquire_gil();
self.exit_called = true;
Expand Down Expand Up @@ -564,14 +564,14 @@ impl OnceFuture {

#[pyproto]
impl PyAsyncProtocol for OnceFuture {
fn __await__(slf: PyRef<'p, Self>) -> PyRef<'p, Self> {
fn __await__(slf: PyRef<Self>) -> PyRef<Self> {
slf
}
}

#[pyproto]
impl PyIterProtocol for OnceFuture {
fn __iter__(slf: PyRef<'p, Self>) -> PyRef<'p, Self> {
fn __iter__(slf: PyRef<Self>) -> PyRef<Self> {
slf
}
fn __next__(mut slf: PyRefMut<Self>) -> Option<PyObject> {
Expand Down Expand Up @@ -632,14 +632,14 @@ impl DescrCounter {
#[pyproto]
impl PyDescrProtocol for DescrCounter {
fn __get__(
mut slf: PyRefMut<'p, Self>,
mut slf: PyRefMut<Self>,
_instance: &PyAny,
_owner: Option<&'p PyType>,
) -> PyRefMut<'p, Self> {
_owner: Option<&PyType>,
) -> PyRefMut<Self> {
slf.count += 1;
slf
}
fn __set__(_slf: PyRef<'p, Self>, _instance: &PyAny, mut new_value: PyRefMut<'p, Self>) {
fn __set__(_slf: PyRef<Self>, _instance: &PyAny, mut new_value: PyRefMut<Self>) {
new_value.count = _slf.count;
}
}
Expand Down