diff --git a/CHANGELOG.md b/CHANGELOG.md index fadf9d99cde..faced376a6a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## Unreleased +* Support for `#[name = "foo"]` attribute for `#[pyfunction]` and in `#[pymethods]`. [#692](https://github.com/PyO3/pyo3/pull/692) + ## [0.8.4] ### Added diff --git a/pyo3-derive-backend/src/func.rs b/pyo3-derive-backend/src/func.rs index 2b7f69aca7b..c8e3c83b542 100644 --- a/pyo3-derive-backend/src/func.rs +++ b/pyo3-derive-backend/src/func.rs @@ -55,16 +55,16 @@ pub enum MethodProto { }, } -impl PartialEq for MethodProto { - fn eq(&self, name: &str) -> bool { +impl MethodProto { + pub fn name(&self) -> &str { match *self { - MethodProto::Free { name: n, .. } => n == name, - MethodProto::Unary { name: n, .. } => n == name, - MethodProto::Binary { name: n, .. } => n == name, - MethodProto::BinaryS { name: n, .. } => n == name, - MethodProto::Ternary { name: n, .. } => n == name, - MethodProto::TernaryS { name: n, .. } => n == name, - MethodProto::Quaternary { name: n, .. } => n == name, + MethodProto::Free { ref name, .. } => name, + MethodProto::Unary { ref name, .. } => name, + MethodProto::Binary { ref name, .. } => name, + MethodProto::BinaryS { ref name, .. } => name, + MethodProto::Ternary { ref name, .. } => name, + MethodProto::TernaryS { ref name, .. } => name, + MethodProto::Quaternary { ref name, .. } => name, } } } diff --git a/pyo3-derive-backend/src/lib.rs b/pyo3-derive-backend/src/lib.rs index 2ca6b87a9d3..9571a52cb73 100644 --- a/pyo3-derive-backend/src/lib.rs +++ b/pyo3-derive-backend/src/lib.rs @@ -16,7 +16,7 @@ mod utils; pub use module::{add_fn_to_module, process_functions_in_module, py_init}; pub use pyclass::{build_py_class, PyClassArgs}; -pub use pyfunction::PyFunctionAttr; +pub use pyfunction::{build_py_function, PyFunctionAttr}; pub use pyimpl::{build_py_methods, impl_methods}; pub use pyproto::build_py_proto; pub use utils::get_doc; diff --git a/pyo3-derive-backend/src/method.rs b/pyo3-derive-backend/src/method.rs index f95fe820f25..e8d25ca7591 100644 --- a/pyo3-derive-backend/src/method.rs +++ b/pyo3-derive-backend/src/method.rs @@ -1,10 +1,12 @@ // Copyright (c) 2017-present PyO3 Project and Contributors use crate::pyfunction::Argument; -use crate::pyfunction::PyFunctionAttr; +use crate::pyfunction::{parse_name_attribute, PyFunctionAttr}; +use crate::utils; use proc_macro2::TokenStream; use quote::quote; use quote::ToTokens; +use syn::ext::IdentExt; use syn::spanned::Spanned; #[derive(Clone, PartialEq, Debug)] @@ -20,8 +22,8 @@ pub struct FnArg<'a> { #[derive(Clone, PartialEq, Debug)] pub enum FnType { - Getter(Option), - Setter(Option), + Getter, + Setter, Fn, FnNew, FnCall, @@ -33,9 +35,15 @@ pub enum FnType { #[derive(Clone, PartialEq, Debug)] pub struct FnSpec<'a> { pub tp: FnType, + // Rust function name + pub name: &'a syn::Ident, + // Wrapped python name. This should not have any leading r#. + // r# can be removed by syn::ext::IdentExt::unraw() + pub python_name: syn::Ident, pub attrs: Vec, pub args: Vec>, pub output: syn::Type, + pub doc: syn::LitStr, } pub fn get_return_info(output: &syn::ReturnType) -> syn::Type { @@ -48,11 +56,16 @@ pub fn get_return_info(output: &syn::ReturnType) -> syn::Type { impl<'a> FnSpec<'a> { /// Parser function signature and function attributes pub fn parse( - name: &'a syn::Ident, sig: &'a syn::Signature, meth_attrs: &mut Vec, + allow_custom_name: bool, ) -> syn::Result> { - let (mut fn_type, fn_attrs) = parse_attributes(meth_attrs)?; + let name = &sig.ident; + let MethodAttributes { + ty: mut fn_type, + args: fn_attrs, + mut python_name, + } = parse_method_attributes(meth_attrs, allow_custom_name)?; let mut has_self = false; let mut arguments = Vec::new(); @@ -112,11 +125,58 @@ impl<'a> FnSpec<'a> { fn_type = FnType::PySelf(tp); } + // "Tweak" getter / setter names: strip off set_ and get_ if needed + if let FnType::Getter | FnType::Setter = &fn_type { + if python_name.is_none() { + let prefix = match &fn_type { + FnType::Getter => "get_", + FnType::Setter => "set_", + _ => unreachable!(), + }; + + let ident = sig.ident.unraw().to_string(); + if ident.starts_with(prefix) { + python_name = Some(syn::Ident::new(&ident[prefix.len()..], ident.span())) + } + } + } + + let python_name = python_name.unwrap_or_else(|| name.unraw()); + + let mut parse_erroneous_text_signature = |error_msg: &str| { + // try to parse anyway to give better error messages + if let Some(text_signature) = + utils::parse_text_signature_attrs(meth_attrs, &python_name)? + { + Err(syn::Error::new_spanned(text_signature, error_msg)) + } else { + Ok(None) + } + }; + + let text_signature = match &fn_type { + FnType::Fn | FnType::PySelf(_) | FnType::FnClass | FnType::FnStatic => { + utils::parse_text_signature_attrs(&mut *meth_attrs, name)? + } + FnType::FnNew => parse_erroneous_text_signature( + "text_signature not allowed on __new__; if you want to add a signature on \ + __new__, put it on the struct definition instead", + )?, + FnType::FnCall | FnType::Getter | FnType::Setter => { + parse_erroneous_text_signature("text_signature not allowed with this attribute")? + } + }; + + let doc = utils::get_doc(&meth_attrs, text_signature, true)?; + Ok(FnSpec { tp: fn_type, + name, + python_name, attrs: fn_attrs, args: arguments, output: ty, + doc, }) } @@ -279,10 +339,21 @@ pub fn check_arg_ty_and_optional<'a>( } } -fn parse_attributes(attrs: &mut Vec) -> syn::Result<(FnType, Vec)> { +#[derive(Clone, PartialEq, Debug)] +struct MethodAttributes { + ty: FnType, + args: Vec, + python_name: Option, +} + +fn parse_method_attributes( + attrs: &mut Vec, + allow_custom_name: bool, +) -> syn::Result { let mut new_attrs = Vec::new(); - let mut spec = Vec::new(); + let mut args = Vec::new(); let mut res: Option = None; + let mut property_name = None; for attr in attrs.iter() { match attr.parse_meta()? { @@ -302,15 +373,21 @@ fn parse_attributes(attrs: &mut Vec) -> syn::Result<(FnType, Vec res = Some(FnType::FnStatic) } else if name.is_ident("setter") || name.is_ident("getter") { if let syn::AttrStyle::Inner(_) = attr.style { - panic!("Inner style attribute is not supported for setter and getter"); + return Err(syn::Error::new_spanned( + attr, + "Inner style attribute is not supported for setter and getter", + )); } if res != None { - panic!("setter/getter attribute can not be used mutiple times"); + return Err(syn::Error::new_spanned( + attr, + "setter/getter attribute can not be used mutiple times", + )); } if name.is_ident("setter") { - res = Some(FnType::Setter(None)) + res = Some(FnType::Setter) } else { - res = Some(FnType::Getter(None)) + res = Some(FnType::Getter) } } else { new_attrs.push(attr.clone()) @@ -332,44 +409,53 @@ fn parse_attributes(attrs: &mut Vec) -> syn::Result<(FnType, Vec res = Some(FnType::FnCall) } else if path.is_ident("setter") || path.is_ident("getter") { if let syn::AttrStyle::Inner(_) = attr.style { - panic!( - "Inner style attribute is not - supported for setter and getter" - ); + return Err(syn::Error::new_spanned( + attr, + "Inner style attribute is not supported for setter and getter", + )); } if res != None { - panic!("setter/getter attribute can not be used mutiple times"); + return Err(syn::Error::new_spanned( + attr, + "setter/getter attribute can not be used mutiple times", + )); } if nested.len() != 1 { - panic!("setter/getter requires one value"); + return Err(syn::Error::new_spanned( + attr, + "setter/getter requires one value", + )); } - match nested.first().unwrap() { - syn::NestedMeta::Meta(syn::Meta::Path(ref w)) => { - if path.is_ident("setter") { - res = Some(FnType::Setter(Some(w.segments[0].ident.to_string()))) - } else { - res = Some(FnType::Getter(Some(w.segments[0].ident.to_string()))) - } + + res = if path.is_ident("setter") { + Some(FnType::Setter) + } else { + Some(FnType::Getter) + }; + + property_name = match nested.first().unwrap() { + syn::NestedMeta::Meta(syn::Meta::Path(ref w)) if w.segments.len() == 1 => { + Some(w.segments[0].ident.clone()) } syn::NestedMeta::Lit(ref lit) => match *lit { - syn::Lit::Str(ref s) => { - if path.is_ident("setter") { - res = Some(FnType::Setter(Some(s.value()))) - } else { - res = Some(FnType::Getter(Some(s.value()))) - } - } + syn::Lit::Str(ref s) => Some(s.parse()?), _ => { - panic!("setter/getter attribute requires str value"); + return Err(syn::Error::new_spanned( + lit, + "setter/getter attribute requires str value", + )) } }, _ => { - println!("cannot parse {:?} attribute: {:?}", path, nested); + return Err(syn::Error::new_spanned( + nested.first().unwrap(), + "expected ident or string literal for property name", + )) } - } + }; } else if path.is_ident("args") { let attrs = PyFunctionAttr::from_meta(nested)?; - spec.extend(attrs.arguments) + args.extend(attrs.arguments) } else { new_attrs.push(attr.clone()) } @@ -377,13 +463,51 @@ fn parse_attributes(attrs: &mut Vec) -> syn::Result<(FnType, Vec syn::Meta::NameValue(_) => new_attrs.push(attr.clone()), } } + attrs.clear(); attrs.extend(new_attrs); - match res { - Some(tp) => Ok((tp, spec)), - None => Ok((FnType::Fn, spec)), + let ty = res.unwrap_or(FnType::Fn); + let python_name = if allow_custom_name { + parse_method_name_attribute(&ty, attrs, property_name)? + } else { + property_name + }; + + Ok(MethodAttributes { + ty, + args, + python_name, + }) +} + +fn parse_method_name_attribute( + ty: &FnType, + attrs: &mut Vec, + property_name: Option, +) -> syn::Result> { + let name = parse_name_attribute(attrs)?; + + // Reject some invalid combinations + if let Some(name) = &name { + match ty { + FnType::FnNew | FnType::FnCall | FnType::Getter | FnType::Setter => { + return Err(syn::Error::new_spanned( + name, + "name not allowed with this attribute", + )) + } + _ => {} + } } + + // Thanks to check above we can be sure that this generates the right python name + Ok(match ty { + FnType::FnNew => Some(syn::Ident::new("__new__", proc_macro2::Span::call_site())), + FnType::FnCall => Some(syn::Ident::new("__call__", proc_macro2::Span::call_site())), + FnType::Getter | FnType::Setter => property_name, + _ => name, + }) } // Replace A with A<_> diff --git a/pyo3-derive-backend/src/module.rs b/pyo3-derive-backend/src/module.rs index 3ee1433b5eb..badaf37f0a3 100644 --- a/pyo3-derive-backend/src/module.rs +++ b/pyo3-derive-backend/src/module.rs @@ -13,7 +13,7 @@ use syn::Ident; /// Generates the function that is called by the python interpreter to initialize the native /// module -pub fn py_init(fnname: &Ident, name: &Ident, doc: syn::Lit) -> TokenStream { +pub fn py_init(fnname: &Ident, name: &Ident, doc: syn::LitStr) -> TokenStream { let cb_name = Ident::new(&format!("PyInit_{}", name), Span::call_site()); quote! { @@ -36,7 +36,7 @@ pub fn process_functions_in_module(func: &mut syn::ItemFn) { if let Some((module_name, python_name, pyfn_attrs)) = extract_pyfn_attrs(&mut func.attrs) { - let function_to_python = add_fn_to_module(func, &python_name, pyfn_attrs); + let function_to_python = add_fn_to_module(func, python_name, pyfn_attrs); let function_wrapper_ident = function_wrapper_ident(&func.sig.ident); let item: syn::ItemFn = syn::parse_quote! { fn block_wrapper() { @@ -134,7 +134,7 @@ fn function_wrapper_ident(name: &Ident) -> Ident { /// function pub fn add_fn_to_module( func: &mut syn::ItemFn, - python_name: &Ident, + python_name: Ident, pyfn_attrs: Vec, ) -> TokenStream { let mut arguments = Vec::new(); @@ -147,24 +147,32 @@ pub fn add_fn_to_module( let ty = method::get_return_info(&func.sig.output); + let text_signature = match utils::parse_text_signature_attrs(&mut func.attrs, &python_name) { + Ok(text_signature) => text_signature, + Err(err) => return err.to_compile_error(), + }; + let doc = match utils::get_doc(&func.attrs, text_signature, true) { + Ok(doc) => doc, + Err(err) => return err.to_compile_error(), + }; + + let function_wrapper_ident = function_wrapper_ident(&func.sig.ident); + let spec = method::FnSpec { tp: method::FnType::Fn, + name: &function_wrapper_ident, + python_name, attrs: pyfn_attrs, args: arguments, output: ty, + doc, }; - let function_wrapper_ident = function_wrapper_ident(&func.sig.ident); + let doc = &spec.doc; + + let python_name = &spec.python_name; let wrapper = function_c_wrapper(&func.sig.ident, &spec); - let text_signature = match utils::parse_text_signature_attrs(&mut func.attrs, python_name) { - Ok(text_signature) => text_signature, - Err(err) => return err.to_compile_error(), - }; - let doc = match utils::get_doc(&func.attrs, text_signature, true) { - Ok(doc) => doc, - Err(err) => return err.to_compile_error(), - }; let tokens = quote! { fn #function_wrapper_ident(py: pyo3::Python) -> pyo3::PyObject { diff --git a/pyo3-derive-backend/src/pyclass.rs b/pyo3-derive-backend/src/pyclass.rs index b4a8d6fe346..78118380b01 100644 --- a/pyo3-derive-backend/src/pyclass.rs +++ b/pyo3-derive-backend/src/pyclass.rs @@ -5,6 +5,7 @@ use crate::pymethod::{impl_py_getter_def, impl_py_setter_def, impl_wrap_getter, use crate::utils; use proc_macro2::{Span, TokenStream}; use quote::quote; +use syn::ext::IdentExt; use syn::parse::{Parse, ParseStream}; use syn::punctuated::Punctuated; use syn::{parse_quote, Expr, Token}; @@ -179,7 +180,7 @@ pub fn build_py_class(class: &mut syn::ItemStruct, attr: &PyClassArgs) -> syn::R )); } - Ok(impl_class(&class.ident, &attr, doc, descriptors)) + impl_class(&class.ident, &attr, doc, descriptors) } /// Parses `#[pyo3(get, set)]` @@ -192,9 +193,9 @@ fn parse_descriptors(item: &mut syn::Field) -> syn::Result> { for meta in list.nested.iter() { if let syn::NestedMeta::Meta(ref metaitem) = meta { if metaitem.path().is_ident("get") { - descs.push(FnType::Getter(None)); + descs.push(FnType::Getter); } else if metaitem.path().is_ident("set") { - descs.push(FnType::Setter(None)); + descs.push(FnType::Setter); } else { return Err(syn::Error::new_spanned( metaitem, @@ -259,9 +260,9 @@ fn get_class_python_name(cls: &syn::Ident, attr: &PyClassArgs) -> TokenStream { fn impl_class( cls: &syn::Ident, attr: &PyClassArgs, - doc: syn::Lit, + doc: syn::LitStr, descriptors: Vec<(syn::Field, Vec)>, -) -> TokenStream { +) -> syn::Result { let cls_name = get_class_python_name(cls, attr).to_string(); let extra = { @@ -293,7 +294,7 @@ fn impl_class( let extra = if !descriptors.is_empty() { let path = syn::Path::from(syn::PathSegment::from(cls.clone())); let ty = syn::Type::from(syn::TypePath { path, qself: None }); - let desc_impls = impl_descriptors(&ty, descriptors); + let desc_impls = impl_descriptors(&ty, descriptors)?; quote! { #desc_impls #extra @@ -354,7 +355,7 @@ fn impl_class( let base = &attr.base; let flags = &attr.flags; - quote! { + Ok(quote! { impl pyo3::type_object::PyTypeInfo for #cls { type Type = #cls; type BaseType = #base; @@ -396,10 +397,13 @@ fn impl_class( #gc_impl - } + }) } -fn impl_descriptors(cls: &syn::Type, descriptors: Vec<(syn::Field, Vec)>) -> TokenStream { +fn impl_descriptors( + cls: &syn::Type, + descriptors: Vec<(syn::Field, Vec)>, +) -> syn::Result { let methods: Vec = descriptors .iter() .flat_map(|&(ref field, ref fns)| { @@ -408,7 +412,7 @@ fn impl_descriptors(cls: &syn::Type, descriptors: Vec<(syn::Field, Vec)> let name = field.ident.clone().unwrap(); let field_ty = &field.ty; match *desc { - FnType::Getter(_) => { + FnType::Getter => { quote! { impl #cls { fn #name(&self) -> pyo3::PyResult<#field_ty> { @@ -417,7 +421,7 @@ fn impl_descriptors(cls: &syn::Type, descriptors: Vec<(syn::Field, Vec)> } } } - FnType::Setter(_) => { + FnType::Setter => { let setter_name = syn::Ident::new(&format!("set_{}", name), Span::call_site()); quote! { @@ -444,21 +448,29 @@ fn impl_descriptors(cls: &syn::Type, descriptors: Vec<(syn::Field, Vec)> let name = field.ident.clone().unwrap(); // FIXME better doc? - let doc = syn::Lit::from(syn::LitStr::new(&name.to_string(), name.span())); + let doc = syn::LitStr::new(&name.to_string(), name.span()); let field_ty = &field.ty; match *desc { - FnType::Getter(ref getter) => impl_py_getter_def( - &name, - doc, - getter, - &impl_wrap_getter(&cls, &name, false), - ), - FnType::Setter(ref setter) => { + FnType::Getter => { + let spec = FnSpec { + tp: FnType::Getter, + name: &name, + python_name: name.unraw(), + attrs: Vec::new(), + args: Vec::new(), + output: parse_quote!(PyResult<#field_ty>), + doc, + }; + Ok(impl_py_getter_def(&spec, &impl_wrap_getter(&cls, &spec)?)) + } + FnType::Setter => { let setter_name = syn::Ident::new(&format!("set_{}", name), Span::call_site()); let spec = FnSpec { - tp: FnType::Setter(None), + tp: FnType::Setter, + name: &setter_name, + python_name: name.unraw(), attrs: Vec::new(), args: vec![FnArg { name: &name, @@ -470,22 +482,18 @@ fn impl_descriptors(cls: &syn::Type, descriptors: Vec<(syn::Field, Vec)> reference: false, }], output: parse_quote!(PyResult<()>), - }; - impl_py_setter_def( - &name, doc, - setter, - &impl_wrap_setter(&cls, &setter_name, &spec), - ) + }; + Ok(impl_py_setter_def(&spec, &impl_wrap_setter(&cls, &spec)?)) } _ => unreachable!(), } }) - .collect::>() + .collect::>>() }) - .collect(); + .collect::>()?; - quote! { + Ok(quote! { #(#methods)* pyo3::inventory::submit! { @@ -494,7 +502,7 @@ fn impl_descriptors(cls: &syn::Type, descriptors: Vec<(syn::Field, Vec)> ::new(&[#(#py_methods),*]) } } - } + }) } fn check_generics(class: &mut syn::ItemStruct) -> syn::Result<()> { diff --git a/pyo3-derive-backend/src/pyfunction.rs b/pyo3-derive-backend/src/pyfunction.rs index 24a0eb26bff..588deb6a9de 100644 --- a/pyo3-derive-backend/src/pyfunction.rs +++ b/pyo3-derive-backend/src/pyfunction.rs @@ -1,7 +1,11 @@ // Copyright (c) 2017-present PyO3 Project and Contributors +use crate::module::add_fn_to_module; +use proc_macro2::TokenStream; +use syn::ext::IdentExt; use syn::parse::ParseBuffer; use syn::punctuated::Punctuated; +use syn::spanned::Spanned; use syn::{NestedMeta, Path}; #[derive(Debug, Clone, PartialEq)] @@ -193,6 +197,46 @@ impl PyFunctionAttr { } } +pub fn parse_name_attribute(attrs: &mut Vec) -> syn::Result> { + let mut name_attrs = Vec::new(); + + // Using retain will extract all name attributes from the attribute list + attrs.retain(|attr| match attr.parse_meta() { + Ok(syn::Meta::NameValue(ref nv)) if nv.path.is_ident("name") => { + name_attrs.push((nv.lit.clone(), attr.span())); + false + } + _ => true, + }); + + match &*name_attrs { + [] => Ok(None), + [(syn::Lit::Str(s), span)] => { + let mut ident: syn::Ident = s.parse()?; + // This span is the whole attribute span, which is nicer for reporting errors. + ident.set_span(*span); + Ok(Some(ident)) + } + [(_, span)] => Err(syn::Error::new( + *span, + "Expected string literal for #[name] argument", + )), + // TODO: The below pattern is unstable, so instead we match the wildcard. + // slice_patterns due to be stable soon: https://github.com/rust-lang/rust/issues/62254 + // [(_, span), _, ..] => { + _ => Err(syn::Error::new( + name_attrs[0].1, + "#[name] can not be specified multiple times", + )), + } +} + +pub fn build_py_function(ast: &mut syn::ItemFn, args: PyFunctionAttr) -> syn::Result { + let python_name = + parse_name_attribute(&mut ast.attrs)?.unwrap_or_else(|| ast.sig.ident.unraw()); + Ok(add_fn_to_module(ast, python_name, args.arguments)) +} + #[cfg(test)] mod test { use super::{Argument, PyFunctionAttr}; diff --git a/pyo3-derive-backend/src/pyimpl.rs b/pyo3-derive-backend/src/pyimpl.rs index 64abab02637..89e09c532db 100644 --- a/pyo3-derive-backend/src/pyimpl.rs +++ b/pyo3-derive-backend/src/pyimpl.rs @@ -25,13 +25,7 @@ pub fn impl_methods(ty: &syn::Type, impls: &mut Vec) -> syn::Resu let mut methods = Vec::new(); for iimpl in impls.iter_mut() { if let syn::ImplItem::Method(ref mut meth) = iimpl { - let name = meth.sig.ident.clone(); - methods.push(pymethod::gen_py_method( - ty, - &name, - &mut meth.sig, - &mut meth.attrs, - )?); + methods.push(pymethod::gen_py_method(ty, &mut meth.sig, &mut meth.attrs)?); } } diff --git a/pyo3-derive-backend/src/pymethod.rs b/pyo3-derive-backend/src/pymethod.rs index ba32667c606..ce051d29142 100644 --- a/pyo3-derive-backend/src/pymethod.rs +++ b/pyo3-derive-backend/src/pymethod.rs @@ -6,97 +6,29 @@ use quote::quote; pub fn gen_py_method( cls: &syn::Type, - name: &syn::Ident, sig: &mut syn::Signature, meth_attrs: &mut Vec, ) -> syn::Result { - check_generic(name, sig)?; + check_generic(sig)?; - let spec = FnSpec::parse(name, sig, &mut *meth_attrs)?; - - let mut parse_erroneous_text_signature = |alt_name: Option<&str>, error_msg: &str| { - let python_name; - let python_name = match alt_name { - None => name, - Some(alt_name) => { - python_name = syn::Ident::new(alt_name, name.span()); - &python_name - } - }; - // try to parse anyway to give better error messages - if let Some(text_signature) = - utils::parse_text_signature_attrs(&mut *meth_attrs, python_name)? - { - Err(syn::Error::new_spanned(text_signature, error_msg)) - } else { - Ok(None) - } - }; - - let text_signature = match &spec.tp { - FnType::Fn | FnType::PySelf(_) | FnType::FnClass | FnType::FnStatic => { - utils::parse_text_signature_attrs(&mut *meth_attrs, name)? - } - FnType::FnNew => parse_erroneous_text_signature( - Some("__new__"), - "text_signature not allowed on __new__; if you want to add a signature on \ - __new__, put it on the struct definition instead", - )?, - FnType::FnCall => parse_erroneous_text_signature( - Some("__call__"), - "text_signature not allowed on __call__", - )?, - FnType::Getter(getter_name) => parse_erroneous_text_signature( - getter_name.as_ref().map(|v| &**v), - "text_signature not allowed on getter", - )?, - FnType::Setter(setter_name) => parse_erroneous_text_signature( - setter_name.as_ref().map(|v| &**v), - "text_signature not allowed on setter", - )?, - }; - let doc = utils::get_doc(&meth_attrs, text_signature, true)?; + let spec = FnSpec::parse(sig, &mut *meth_attrs, true)?; Ok(match spec.tp { - FnType::Fn => impl_py_method_def(name, doc, &spec, &impl_wrap(cls, name, &spec, true)), - FnType::PySelf(ref self_ty) => impl_py_method_def( - name, - doc, - &spec, - &impl_wrap_pyslf(cls, name, &spec, self_ty, true), - ), - FnType::FnNew => impl_py_method_def_new(name, doc, &impl_wrap_new(cls, name, &spec)), - FnType::FnCall => impl_py_method_def_call(name, doc, &impl_wrap(cls, name, &spec, false)), - FnType::FnClass => impl_py_method_def_class(name, doc, &impl_wrap_class(cls, name, &spec)), - FnType::FnStatic => { - impl_py_method_def_static(name, doc, &impl_wrap_static(cls, name, &spec)) - } - FnType::Getter(ref getter) => { - let takes_py = match &*spec.args { - [] => false, - [arg] if utils::if_type_is_python(arg.ty) => true, - _ => { - return Err(syn::Error::new_spanned( - spec.args[0].ty, - "Getter function can only have one argument of type pyo3::Python!", - )); - } - }; - impl_py_getter_def(name, doc, getter, &impl_wrap_getter(cls, name, takes_py)) - } - FnType::Setter(ref setter) => { - impl_py_setter_def(name, doc, setter, &impl_wrap_setter(cls, name, &spec)) + FnType::Fn => impl_py_method_def(&spec, &impl_wrap(cls, &spec, true)), + FnType::PySelf(ref self_ty) => { + impl_py_method_def(&spec, &impl_wrap_pyslf(cls, &spec, self_ty, true)) } + FnType::FnNew => impl_py_method_def_new(&spec, &impl_wrap_new(cls, &spec)), + FnType::FnCall => impl_py_method_def_call(&spec, &impl_wrap(cls, &spec, false)), + FnType::FnClass => impl_py_method_def_class(&spec, &impl_wrap_class(cls, &spec)), + FnType::FnStatic => impl_py_method_def_static(&spec, &impl_wrap_static(cls, &spec)), + FnType::Getter => impl_py_getter_def(&spec, &impl_wrap_getter(cls, &spec)?), + FnType::Setter => impl_py_setter_def(&spec, &impl_wrap_setter(cls, &spec)?), }) } -fn check_generic(name: &syn::Ident, sig: &syn::Signature) -> syn::Result<()> { - let err_msg = |typ| { - format!( - "A Python method can't have a generic {} parameter: {}", - name, typ - ) - }; +fn check_generic(sig: &syn::Signature) -> syn::Result<()> { + let err_msg = |typ| format!("A Python method can't have a generic {} parameter", typ); for param in &sig.generics.params { match param { syn::GenericParam::Lifetime(_) => {} @@ -112,40 +44,35 @@ fn check_generic(name: &syn::Ident, sig: &syn::Signature) -> syn::Result<()> { } /// Generate function wrapper (PyCFunction, PyCFunctionWithKeywords) -pub fn impl_wrap( - cls: &syn::Type, - name: &syn::Ident, - spec: &FnSpec<'_>, - noargs: bool, -) -> TokenStream { - let body = impl_call(cls, name, &spec); +pub fn impl_wrap(cls: &syn::Type, spec: &FnSpec<'_>, noargs: bool) -> TokenStream { + let body = impl_call(cls, &spec); let slf = impl_self("e! { &mut #cls }); - impl_wrap_common(cls, name, spec, noargs, slf, body) + impl_wrap_common(cls, spec, noargs, slf, body) } pub fn impl_wrap_pyslf( cls: &syn::Type, - name: &syn::Ident, spec: &FnSpec<'_>, self_ty: &syn::TypePath, noargs: bool, ) -> TokenStream { let names = get_arg_names(spec); + let name = &spec.name; let body = quote! { #cls::#name(_slf, #(#names),*) }; let slf = impl_self(self_ty); - impl_wrap_common(cls, name, spec, noargs, slf, body) + impl_wrap_common(cls, spec, noargs, slf, body) } fn impl_wrap_common( cls: &syn::Type, - name: &syn::Ident, spec: &FnSpec<'_>, noargs: bool, slf: TokenStream, body: TokenStream, ) -> TokenStream { + let python_name = &spec.python_name; if spec.args.is_empty() && noargs { quote! { unsafe extern "C" fn __wrap( @@ -153,7 +80,7 @@ fn impl_wrap_common( ) -> *mut pyo3::ffi::PyObject { const _LOCATION: &'static str = concat!( - stringify!(#cls), ".", stringify!(#name), "()"); + stringify!(#cls), ".", stringify!(#python_name), "()"); let _py = pyo3::Python::assume_gil_acquired(); let _pool = pyo3::GILPool::new(_py); #slf @@ -175,7 +102,7 @@ fn impl_wrap_common( _kwargs: *mut pyo3::ffi::PyObject) -> *mut pyo3::ffi::PyObject { const _LOCATION: &'static str = concat!( - stringify!(#cls), ".", stringify!(#name), "()"); + stringify!(#cls), ".", stringify!(#python_name), "()"); let _py = pyo3::Python::assume_gil_acquired(); let _pool = pyo3::GILPool::new(_py); #slf @@ -192,8 +119,9 @@ fn impl_wrap_common( } /// Generate function wrapper for protocol method (PyCFunction, PyCFunctionWithKeywords) -pub fn impl_proto_wrap(cls: &syn::Type, name: &syn::Ident, spec: &FnSpec<'_>) -> TokenStream { - let cb = impl_call(cls, name, &spec); +pub fn impl_proto_wrap(cls: &syn::Type, spec: &FnSpec<'_>) -> TokenStream { + let python_name = &spec.python_name; + let cb = impl_call(cls, &spec); let body = impl_arg_params(&spec, cb); quote! { @@ -203,7 +131,7 @@ pub fn impl_proto_wrap(cls: &syn::Type, name: &syn::Ident, spec: &FnSpec<'_>) -> _args: *mut pyo3::ffi::PyObject, _kwargs: *mut pyo3::ffi::PyObject) -> *mut pyo3::ffi::PyObject { - const _LOCATION: &'static str = concat!(stringify!(#cls),".",stringify!(#name),"()"); + const _LOCATION: &'static str = concat!(stringify!(#cls),".",stringify!(#python_name),"()"); let _py = pyo3::Python::assume_gil_acquired(); let _pool = pyo3::GILPool::new(_py); let _slf = _py.mut_from_borrowed_ptr::<#cls>(_slf); @@ -219,7 +147,9 @@ pub fn impl_proto_wrap(cls: &syn::Type, name: &syn::Ident, spec: &FnSpec<'_>) -> } /// Generate class method wrapper (PyCFunction, PyCFunctionWithKeywords) -pub fn impl_wrap_new(cls: &syn::Type, name: &syn::Ident, spec: &FnSpec<'_>) -> TokenStream { +pub fn impl_wrap_new(cls: &syn::Type, spec: &FnSpec<'_>) -> TokenStream { + let name = &spec.name; + let python_name = &spec.python_name; let names: Vec = get_arg_names(&spec); let cb = quote! { #cls::#name(&_obj, #(#names),*) }; @@ -234,7 +164,7 @@ pub fn impl_wrap_new(cls: &syn::Type, name: &syn::Ident, spec: &FnSpec<'_>) -> T { use pyo3::type_object::PyTypeInfo; - const _LOCATION: &'static str = concat!(stringify!(#cls),".",stringify!(#name),"()"); + const _LOCATION: &'static str = concat!(stringify!(#cls),".",stringify!(#python_name),"()"); let _py = pyo3::Python::assume_gil_acquired(); let _pool = pyo3::GILPool::new(_py); match pyo3::type_object::PyRawObject::new(_py, #cls::type_object(), _cls) { @@ -262,7 +192,9 @@ pub fn impl_wrap_new(cls: &syn::Type, name: &syn::Ident, spec: &FnSpec<'_>) -> T } /// Generate class method wrapper (PyCFunction, PyCFunctionWithKeywords) -pub fn impl_wrap_class(cls: &syn::Type, name: &syn::Ident, spec: &FnSpec<'_>) -> TokenStream { +pub fn impl_wrap_class(cls: &syn::Type, spec: &FnSpec<'_>) -> TokenStream { + let name = &spec.name; + let python_name = &spec.python_name; let names: Vec = get_arg_names(&spec); let cb = quote! { #cls::#name(&_cls, #(#names),*) }; @@ -275,7 +207,7 @@ pub fn impl_wrap_class(cls: &syn::Type, name: &syn::Ident, spec: &FnSpec<'_>) -> _args: *mut pyo3::ffi::PyObject, _kwargs: *mut pyo3::ffi::PyObject) -> *mut pyo3::ffi::PyObject { - const _LOCATION: &'static str = concat!(stringify!(#cls),".",stringify!(#name),"()"); + const _LOCATION: &'static str = concat!(stringify!(#cls),".",stringify!(#python_name),"()"); let _py = pyo3::Python::assume_gil_acquired(); let _pool = pyo3::GILPool::new(_py); let _cls = pyo3::types::PyType::from_type_ptr(_py, _cls as *mut pyo3::ffi::PyTypeObject); @@ -291,7 +223,9 @@ pub fn impl_wrap_class(cls: &syn::Type, name: &syn::Ident, spec: &FnSpec<'_>) -> } /// Generate static method wrapper (PyCFunction, PyCFunctionWithKeywords) -pub fn impl_wrap_static(cls: &syn::Type, name: &syn::Ident, spec: &FnSpec<'_>) -> TokenStream { +pub fn impl_wrap_static(cls: &syn::Type, spec: &FnSpec<'_>) -> TokenStream { + let name = &spec.name; + let python_name = &spec.python_name; let names: Vec = get_arg_names(&spec); let cb = quote! { #cls::#name(#(#names),*) }; @@ -304,7 +238,7 @@ pub fn impl_wrap_static(cls: &syn::Type, name: &syn::Ident, spec: &FnSpec<'_>) - _args: *mut pyo3::ffi::PyObject, _kwargs: *mut pyo3::ffi::PyObject) -> *mut pyo3::ffi::PyObject { - const _LOCATION: &'static str = concat!(stringify!(#cls),".",stringify!(#name),"()"); + const _LOCATION: &'static str = concat!(stringify!(#cls),".",stringify!(#python_name),"()"); let _py = pyo3::Python::assume_gil_acquired(); let _pool = pyo3::GILPool::new(_py); let _args = _py.from_borrowed_ptr::(_args); @@ -319,17 +253,32 @@ pub fn impl_wrap_static(cls: &syn::Type, name: &syn::Ident, spec: &FnSpec<'_>) - } /// Generate functiona wrapper (PyCFunction, PyCFunctionWithKeywords) -pub(crate) fn impl_wrap_getter(cls: &syn::Type, name: &syn::Ident, takes_py: bool) -> TokenStream { +pub(crate) fn impl_wrap_getter(cls: &syn::Type, spec: &FnSpec) -> syn::Result { + let takes_py = match &*spec.args { + [] => false, + [arg] if utils::if_type_is_python(arg.ty) => true, + _ => { + return Err(syn::Error::new_spanned( + spec.args[0].ty, + "Getter function can only have one argument of type pyo3::Python!", + )); + } + }; + + let name = &spec.name; + let python_name = &spec.python_name; + let fncall = if takes_py { quote! { _slf.#name(_py) } } else { quote! { _slf.#name() } }; - quote! { + + Ok(quote! { unsafe extern "C" fn __wrap( _slf: *mut pyo3::ffi::PyObject, _: *mut ::std::os::raw::c_void) -> *mut pyo3::ffi::PyObject { - const _LOCATION: &'static str = concat!(stringify!(#cls),".",stringify!(#name),"()"); + const _LOCATION: &'static str = concat!(stringify!(#cls),".",stringify!(#python_name),"()"); let _py = pyo3::Python::assume_gil_acquired(); let _pool = pyo3::GILPool::new(_py); @@ -347,31 +296,37 @@ pub(crate) fn impl_wrap_getter(cls: &syn::Type, name: &syn::Ident, takes_py: boo } } } - } + }) } /// Generate functiona wrapper (PyCFunction, PyCFunctionWithKeywords) -pub(crate) fn impl_wrap_setter( - cls: &syn::Type, - name: &syn::Ident, - spec: &FnSpec<'_>, -) -> TokenStream { - if spec.args.is_empty() { - println!( - "Not enough arguments for setter {}::{}", - quote! {#cls}, - name - ); - } - let val_ty = spec.args[0].ty; +pub(crate) fn impl_wrap_setter(cls: &syn::Type, spec: &FnSpec<'_>) -> syn::Result { + let name = &spec.name; + let python_name = &spec.python_name; + + let val_ty = match &*spec.args { + [] => { + return Err(syn::Error::new_spanned( + &spec.name, + "Not enough arguments for setter {}::{}", + )) + } + [arg] => &arg.ty, + _ => { + return Err(syn::Error::new_spanned( + spec.args[0].ty, + "Setter function must have exactly one argument", + )) + } + }; - quote! { + Ok(quote! { #[allow(unused_mut)] unsafe extern "C" fn __wrap( _slf: *mut pyo3::ffi::PyObject, _value: *mut pyo3::ffi::PyObject, _: *mut ::std::os::raw::c_void) -> pyo3::libc::c_int { - const _LOCATION: &'static str = concat!(stringify!(#cls),".",stringify!(#name),"()"); + const _LOCATION: &'static str = concat!(stringify!(#cls),".",stringify!(#python_name),"()"); let _py = pyo3::Python::assume_gil_acquired(); let _pool = pyo3::GILPool::new(_py); let _slf = _py.mut_from_borrowed_ptr::<#cls>(_slf); @@ -391,7 +346,7 @@ pub(crate) fn impl_wrap_setter( } } } - } + }) } /// This function abstracts away some copied code and can propably be simplified itself @@ -401,7 +356,8 @@ pub fn get_arg_names(spec: &FnSpec) -> Vec { .collect() } -fn impl_call(_cls: &syn::Type, fname: &syn::Ident, spec: &FnSpec<'_>) -> TokenStream { +fn impl_call(_cls: &syn::Type, spec: &FnSpec<'_>) -> TokenStream { + let fname = &spec.name; let names = get_arg_names(spec); quote! { _slf.#fname(#(#names),*) } } @@ -560,19 +516,16 @@ fn impl_arg_param( } } -pub fn impl_py_method_def( - name: &syn::Ident, - doc: syn::Lit, - spec: &FnSpec<'_>, - wrapper: &TokenStream, -) -> TokenStream { +pub fn impl_py_method_def(spec: &FnSpec, wrapper: &TokenStream) -> TokenStream { + let python_name = &spec.python_name; + let doc = &spec.doc; if spec.args.is_empty() { quote! { pyo3::class::PyMethodDefType::Method({ #wrapper pyo3::class::PyMethodDef { - ml_name: stringify!(#name), + ml_name: stringify!(#python_name), ml_meth: pyo3::class::PyMethodType::PyNoArgsFunction(__wrap), ml_flags: pyo3::ffi::METH_NOARGS, ml_doc: #doc, @@ -585,7 +538,7 @@ pub fn impl_py_method_def( #wrapper pyo3::class::PyMethodDef { - ml_name: stringify!(#name), + ml_name: stringify!(#python_name), ml_meth: pyo3::class::PyMethodType::PyCFunctionWithKeywords(__wrap), ml_flags: pyo3::ffi::METH_VARARGS | pyo3::ffi::METH_KEYWORDS, ml_doc: #doc, @@ -595,17 +548,15 @@ pub fn impl_py_method_def( } } -pub fn impl_py_method_def_new( - name: &syn::Ident, - doc: syn::Lit, - wrapper: &TokenStream, -) -> TokenStream { +pub fn impl_py_method_def_new(spec: &FnSpec, wrapper: &TokenStream) -> TokenStream { + let python_name = &spec.python_name; + let doc = &spec.doc; quote! { pyo3::class::PyMethodDefType::New({ #wrapper pyo3::class::PyMethodDef { - ml_name: stringify!(#name), + ml_name: stringify!(#python_name), ml_meth: pyo3::class::PyMethodType::PyNewFunc(__wrap), ml_flags: pyo3::ffi::METH_VARARGS | pyo3::ffi::METH_KEYWORDS, ml_doc: #doc, @@ -614,17 +565,15 @@ pub fn impl_py_method_def_new( } } -pub fn impl_py_method_def_class( - name: &syn::Ident, - doc: syn::Lit, - wrapper: &TokenStream, -) -> TokenStream { +pub fn impl_py_method_def_class(spec: &FnSpec, wrapper: &TokenStream) -> TokenStream { + let python_name = &spec.python_name; + let doc = &spec.doc; quote! { pyo3::class::PyMethodDefType::Class({ #wrapper pyo3::class::PyMethodDef { - ml_name: stringify!(#name), + ml_name: stringify!(#python_name), ml_meth: pyo3::class::PyMethodType::PyCFunctionWithKeywords(__wrap), ml_flags: pyo3::ffi::METH_VARARGS | pyo3::ffi::METH_KEYWORDS | pyo3::ffi::METH_CLASS, @@ -634,17 +583,15 @@ pub fn impl_py_method_def_class( } } -pub fn impl_py_method_def_static( - name: &syn::Ident, - doc: syn::Lit, - wrapper: &TokenStream, -) -> TokenStream { +pub fn impl_py_method_def_static(spec: &FnSpec, wrapper: &TokenStream) -> TokenStream { + let python_name = &spec.python_name; + let doc = &spec.doc; quote! { pyo3::class::PyMethodDefType::Static({ #wrapper pyo3::class::PyMethodDef { - ml_name: stringify!(#name), + ml_name: stringify!(#python_name), ml_meth: pyo3::class::PyMethodType::PyCFunctionWithKeywords(__wrap), ml_flags: pyo3::ffi::METH_VARARGS | pyo3::ffi::METH_KEYWORDS | pyo3::ffi::METH_STATIC, ml_doc: #doc, @@ -653,17 +600,15 @@ pub fn impl_py_method_def_static( } } -pub fn impl_py_method_def_call( - name: &syn::Ident, - doc: syn::Lit, - wrapper: &TokenStream, -) -> TokenStream { +pub fn impl_py_method_def_call(spec: &FnSpec, wrapper: &TokenStream) -> TokenStream { + let python_name = &spec.python_name; + let doc = &spec.doc; quote! { pyo3::class::PyMethodDefType::Call({ #wrapper pyo3::class::PyMethodDef { - ml_name: stringify!(#name), + ml_name: stringify!(#python_name), ml_meth: pyo3::class::PyMethodType::PyCFunctionWithKeywords(__wrap), ml_flags: pyo3::ffi::METH_VARARGS | pyo3::ffi::METH_KEYWORDS, ml_doc: #doc, @@ -672,29 +617,16 @@ pub fn impl_py_method_def_call( } } -pub(crate) fn impl_py_setter_def( - name: &syn::Ident, - doc: syn::Lit, - setter: &Option, - wrapper: &TokenStream, -) -> TokenStream { - let n = if let Some(ref name) = setter { - name.to_string() - } else { - let n = name.to_string(); - if n.starts_with("set_") { - n[4..].to_string() - } else { - n - } - }; +pub(crate) fn impl_py_setter_def(spec: &FnSpec, wrapper: &TokenStream) -> TokenStream { + let python_name = &&spec.python_name; + let doc = &spec.doc; quote! { pyo3::class::PyMethodDefType::Setter({ #wrapper pyo3::class::PySetterDef { - name: #n, + name: stringify!(#python_name), meth: __wrap, doc: #doc, } @@ -702,29 +634,16 @@ pub(crate) fn impl_py_setter_def( } } -pub(crate) fn impl_py_getter_def( - name: &syn::Ident, - doc: syn::Lit, - getter: &Option, - wrapper: &TokenStream, -) -> TokenStream { - let n = if let Some(ref name) = getter { - name.to_string() - } else { - let n = name.to_string(); - if n.starts_with("get_") { - n[4..].to_string() - } else { - n - } - }; +pub(crate) fn impl_py_getter_def(spec: &FnSpec, wrapper: &TokenStream) -> TokenStream { + let python_name = &&spec.python_name; + let doc = &spec.doc; quote! { pyo3::class::PyMethodDefType::Getter({ #wrapper pyo3::class::PyGetterDef { - name: #n, + name: stringify!(#python_name), meth: __wrap, doc: #doc, } diff --git a/pyo3-derive-backend/src/pyproto.rs b/pyo3-derive-backend/src/pyproto.rs index c0921bcbbfa..5cd32c6e3a6 100644 --- a/pyo3-derive-backend/src/pyproto.rs +++ b/pyo3-derive-backend/src/pyproto.rs @@ -4,7 +4,6 @@ use crate::defs; use crate::func::impl_method_proto; use crate::method::FnSpec; use crate::pymethod; -use proc_macro2::Span; use proc_macro2::TokenStream; use quote::quote; use quote::ToTokens; @@ -65,21 +64,20 @@ fn impl_proto_impl( for iimpl in impls.iter_mut() { if let syn::ImplItem::Method(ref mut met) = iimpl { for m in proto.methods { - if m == met.sig.ident.to_string().as_str() { + if met.sig.ident == m.name() { impl_method_proto(ty, &mut met.sig, m).to_tokens(&mut tokens); } } for m in proto.py_methods { - let ident = met.sig.ident.clone(); - if m.name == ident.to_string().as_str() { - let name = syn::Ident::new(m.name, Span::call_site()); + if met.sig.ident == m.name { + let name = &met.sig.ident; let proto: syn::Path = syn::parse_str(m.proto).unwrap(); - let fn_spec = match FnSpec::parse(&ident, &met.sig, &mut met.attrs) { + let fn_spec = match FnSpec::parse(&met.sig, &mut met.attrs, false) { Ok(fn_spec) => fn_spec, Err(err) => return err.to_compile_error(), }; - let meth = pymethod::impl_proto_wrap(ty, &ident, &fn_spec); + let meth = pymethod::impl_proto_wrap(ty, &fn_spec); py_methods.push(quote! { impl #proto for #ty diff --git a/pyo3-derive-backend/src/utils.rs b/pyo3-derive-backend/src/utils.rs index 84b445b6f1f..a06c31ffd07 100644 --- a/pyo3-derive-backend/src/utils.rs +++ b/pyo3-derive-backend/src/utils.rs @@ -98,7 +98,7 @@ pub fn get_doc( attrs: &[syn::Attribute], text_signature: Option, null_terminated: bool, -) -> syn::Result { +) -> syn::Result { let mut doc = String::new(); let mut span = Span::call_site(); @@ -139,5 +139,5 @@ pub fn get_doc( doc.push('\0'); } - Ok(syn::Lit::Str(syn::LitStr::new(&doc, span))) + Ok(syn::LitStr::new(&doc, span)) } diff --git a/pyo3cls/src/lib.rs b/pyo3cls/src/lib.rs index e85a645adc4..0ba93c998ce 100644 --- a/pyo3cls/src/lib.rs +++ b/pyo3cls/src/lib.rs @@ -4,13 +4,11 @@ extern crate proc_macro; use proc_macro::TokenStream; -use proc_macro2::Span; use pyo3_derive_backend::{ - add_fn_to_module, build_py_class, build_py_methods, build_py_proto, get_doc, + build_py_class, build_py_function, build_py_methods, build_py_proto, get_doc, process_functions_in_module, py_init, PyClassArgs, PyFunctionAttr, }; use quote::quote; -use syn::ext::IdentExt; use syn::parse_macro_input; /// Internally, this proc macro create a new c function called `PyInit_{my_module}` @@ -83,8 +81,7 @@ pub fn pyfunction(attr: TokenStream, input: TokenStream) -> TokenStream { let mut ast = parse_macro_input!(input as syn::ItemFn); let args = parse_macro_input!(attr as PyFunctionAttr); - let python_name = syn::Ident::new(&ast.sig.ident.unraw().to_string(), Span::call_site()); - let expanded = add_fn_to_module(&mut ast, &python_name, args.arguments); + let expanded = build_py_function(&mut ast, args).unwrap_or_else(|e| e.to_compile_error()); quote!( #ast diff --git a/tests/test_class_basics.rs b/tests/test_class_basics.rs index 00e2674e64a..c7b56940858 100644 --- a/tests/test_class_basics.rs +++ b/tests/test_class_basics.rs @@ -42,12 +42,47 @@ fn class_with_docstr() { #[pyclass(name=CustomName)] struct EmptyClass2 {} +#[pymethods] +impl EmptyClass2 { + #[name = "custom_fn"] + fn bar(&self) {} + + #[staticmethod] + #[name = "custom_static"] + fn bar_static() {} +} + #[test] -fn custom_class_name() { +fn custom_names() { let gil = Python::acquire_gil(); let py = gil.python(); let typeobj = py.get_type::(); py_assert!(py, typeobj, "typeobj.__name__ == 'CustomName'"); + py_assert!(py, typeobj, "typeobj.custom_fn.__name__ == 'custom_fn'"); + py_assert!( + py, + typeobj, + "typeobj.custom_static.__name__ == 'custom_static'" + ); + py_assert!(py, typeobj, "not hasattr(typeobj, 'bar')"); + py_assert!(py, typeobj, "not hasattr(typeobj, 'bar_static')"); +} + +#[pyclass] +struct RawIdents {} + +#[pymethods] +impl RawIdents { + fn r#fn(&self) {} +} + +#[test] +fn test_raw_idents() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let typeobj = py.get_type::(); + py_assert!(py, typeobj, "not hasattr(typeobj, 'r#fn')"); + py_assert!(py, typeobj, "hasattr(typeobj, 'fn')"); } #[pyclass] diff --git a/tests/test_compile_error.rs b/tests/test_compile_error.rs index 5524ec41cb2..634b4932c40 100755 --- a/tests/test_compile_error.rs +++ b/tests/test_compile_error.rs @@ -3,4 +3,5 @@ fn test_compile_errors() { let t = trybuild::TestCases::new(); t.compile_fail("tests/ui/reject_generics.rs"); t.compile_fail("tests/ui/too_many_args_to_getter.rs"); + t.compile_fail("tests/ui/invalid_pymethod_names.rs"); } diff --git a/tests/test_module.rs b/tests/test_module.rs index cda0d8356c2..96c8f6f9cea 100644 --- a/tests/test_module.rs +++ b/tests/test_module.rs @@ -148,6 +148,32 @@ fn test_raw_idents() { py_assert!(py, module, "module.move() == 42"); } +#[pyfunction] +#[name = "foobar"] +fn custom_named_fn() -> usize { + 42 +} + +#[pymodule] +fn foobar_module(_py: Python, module: &PyModule) -> PyResult<()> { + use pyo3::wrap_pyfunction; + + module.add_wrapped(wrap_pyfunction!(custom_named_fn)) +} + +#[test] +fn test_custom_names() { + use pyo3::wrap_pymodule; + + let gil = Python::acquire_gil(); + let py = gil.python(); + + let module = wrap_pymodule!(foobar_module)(py); + + py_assert!(py, module, "not hasattr(module, 'custom_named_fn')"); + py_assert!(py, module, "module.foobar() == 42"); +} + #[pyfunction] fn subfunction() -> String { "Subfunction".to_string() diff --git a/tests/ui/invalid_pymethod_names.rs b/tests/ui/invalid_pymethod_names.rs new file mode 100644 index 00000000000..41834fa1940 --- /dev/null +++ b/tests/ui/invalid_pymethod_names.rs @@ -0,0 +1,29 @@ +use pyo3::prelude::*; + +#[pyclass] +struct TestClass { + num: u32, +} + +#[pymethods] +impl TestClass { + #[name = "num"] + #[getter(number)] + fn get_num(&self) -> u32 { self.num } +} + +#[pymethods] +impl TestClass { + #[name = "foo"] + #[name = "bar"] + fn qux(&self) -> u32 { self.num } +} + +#[pymethods] +impl TestClass { + #[name = "makenew"] + #[new] + fn new(&self) -> Self { Self { num: 0 } } +} + +fn main() {} diff --git a/tests/ui/invalid_pymethod_names.stderr b/tests/ui/invalid_pymethod_names.stderr new file mode 100644 index 00000000000..fc1911ac191 --- /dev/null +++ b/tests/ui/invalid_pymethod_names.stderr @@ -0,0 +1,17 @@ +error: name not allowed with this attribute + --> $DIR/invalid_pymethod_names.rs:10:5 + | +10 | #[name = "num"] + | ^^^^^^^^^^^^^^^ + +error: #[name] can not be specified multiple times + --> $DIR/invalid_pymethod_names.rs:17:5 + | +17 | #[name = "foo"] + | ^^^^^^^^^^^^^^^ + +error: name not allowed with this attribute + --> $DIR/invalid_pymethod_names.rs:24:5 + | +24 | #[name = "makenew"] + | ^^^^^^^^^^^^^^^^^^^