From cae3bbb8e1ed929dae642b33cc07fd24f4647a29 Mon Sep 17 00:00:00 2001 From: Daniil Konovalenko Date: Sun, 31 Jan 2021 23:30:02 +0300 Subject: [PATCH] add #[pyo3(from_py_with="...")] attribute * allow from_py_with inside #[derive(FromPyObject)] * split up FnSpec::parse --- CHANGELOG.md | 1 + pyo3-macros-backend/src/attrs.rs | 22 ++ pyo3-macros-backend/src/from_pyobject.rs | 110 ++++++---- pyo3-macros-backend/src/lib.rs | 1 + pyo3-macros-backend/src/method.rs | 227 ++++++++++++-------- pyo3-macros-backend/src/module.rs | 9 +- pyo3-macros-backend/src/pyfunction.rs | 50 +++++ pyo3-macros-backend/src/pymethod.rs | 11 +- pyo3-macros-backend/src/pyproto.rs | 4 +- tests/test_class_basics.rs | 44 ++++ tests/test_compile_error.rs | 1 + tests/test_frompyobject.rs | 27 +++ tests/test_pyfunction.rs | 55 ++++- tests/ui/invalid_argument_attributes.rs | 15 ++ tests/ui/invalid_argument_attributes.stderr | 23 ++ tests/ui/invalid_frompy_derive.rs | 12 ++ tests/ui/invalid_frompy_derive.stderr | 18 +- 17 files changed, 485 insertions(+), 145 deletions(-) create mode 100644 pyo3-macros-backend/src/attrs.rs create mode 100644 tests/ui/invalid_argument_attributes.rs create mode 100644 tests/ui/invalid_argument_attributes.stderr diff --git a/CHANGELOG.md b/CHANGELOG.md index 0765c3ad45d..712115f2e21 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased] ### Added - Add conversions between `OsStr`/`OsString`/`Path`/`PathBuf` and Python strings. [#1379](https://github.com/PyO3/pyo3/pull/1379) +- Add #[pyo3(from_py_with = "...")]` attribute for function arguments and struct fields to override the default from-Python conversion. [#1411](https://github.com/PyO3/pyo3/pull/1411) - Add FFI definition `PyCFunction_CheckExact` for Python 3.9 and later. [#1425](https://github.com/PyO3/pyo3/pull/1425) ### Changed diff --git a/pyo3-macros-backend/src/attrs.rs b/pyo3-macros-backend/src/attrs.rs new file mode 100644 index 00000000000..486eff393bb --- /dev/null +++ b/pyo3-macros-backend/src/attrs.rs @@ -0,0 +1,22 @@ +use syn::spanned::Spanned; +use syn::{ExprPath, Lit, Meta, MetaNameValue, Result}; + +#[derive(Clone, Debug, PartialEq)] +pub struct FromPyWithAttribute(pub ExprPath); + +impl FromPyWithAttribute { + pub fn from_meta(meta: Meta) -> Result { + let string_literal = match meta { + Meta::NameValue(MetaNameValue { + lit: Lit::Str(string_literal), + .. + }) => string_literal, + meta => { + bail_spanned!(meta.span() => "expected a name-value: `pyo3(from_py_with = \"func\")`") + } + }; + + let expr_path = string_literal.parse::()?; + Ok(FromPyWithAttribute(expr_path)) + } +} diff --git a/pyo3-macros-backend/src/from_pyobject.rs b/pyo3-macros-backend/src/from_pyobject.rs index 23ac5755999..c143ca93428 100644 --- a/pyo3-macros-backend/src/from_pyobject.rs +++ b/pyo3-macros-backend/src/from_pyobject.rs @@ -1,3 +1,4 @@ +use crate::attrs::FromPyWithAttribute; use proc_macro2::TokenStream; use quote::quote; use syn::punctuated::Punctuated; @@ -85,7 +86,7 @@ enum ContainerType<'a> { /// Struct Container, e.g. `struct Foo { a: String }` /// /// Variant contains the list of field identifiers and the corresponding extraction call. - Struct(Vec<(&'a Ident, FieldAttribute)>), + Struct(Vec<(&'a Ident, FieldAttributes)>), /// Newtype struct container, e.g. `#[transparent] struct Foo { a: String }` /// /// The field specified by the identifier is extracted directly from the object. @@ -156,9 +157,8 @@ impl<'a> Container<'a> { .ident .as_ref() .expect("Named fields should have identifiers"); - let attr = FieldAttribute::parse_attrs(&field.attrs)? - .unwrap_or(FieldAttribute::GetAttr(None)); - fields.push((ident, attr)) + let attrs = FieldAttributes::parse_attrs(&field.attrs)?; + fields.push((ident, attrs)) } ContainerType::Struct(fields) } @@ -235,17 +235,24 @@ impl<'a> Container<'a> { ) } - fn build_struct(&self, tups: &[(&Ident, FieldAttribute)]) -> TokenStream { + fn build_struct(&self, tups: &[(&Ident, FieldAttributes)]) -> TokenStream { let self_ty = &self.path; let mut fields: Punctuated = Punctuated::new(); - for (ident, attr) in tups { - let ext_fn = match attr { - FieldAttribute::GetAttr(Some(name)) => quote!(getattr(#name)), - FieldAttribute::GetAttr(None) => quote!(getattr(stringify!(#ident))), - FieldAttribute::GetItem(Some(key)) => quote!(get_item(#key)), - FieldAttribute::GetItem(None) => quote!(get_item(stringify!(#ident))), + for (ident, attrs) in tups { + let getter = match &attrs.getter { + FieldGetter::GetAttr(Some(name)) => quote!(getattr(#name)), + FieldGetter::GetAttr(None) => quote!(getattr(stringify!(#ident))), + FieldGetter::GetItem(Some(key)) => quote!(get_item(#key)), + FieldGetter::GetItem(None) => quote!(get_item(stringify!(#ident))), }; - fields.push(quote!(#ident: obj.#ext_fn?.extract()?)); + + let get_field = quote!(obj.#getter?); + let extractor = match &attrs.from_py_with { + None => quote!(#get_field.extract()?), + Some(FromPyWithAttribute(expr_path)) => quote! (#expr_path(#get_field)?), + }; + + fields.push(quote!(#ident: #extractor)); } quote!(Ok(#self_ty{#fields})) } @@ -309,40 +316,59 @@ impl ContainerAttribute { /// Attributes for deriving FromPyObject scoped on fields. #[derive(Clone, Debug)] -enum FieldAttribute { +struct FieldAttributes { + getter: FieldGetter, + from_py_with: Option, +} + +#[derive(Clone, Debug)] +enum FieldGetter { GetItem(Option), GetAttr(Option), } -impl FieldAttribute { - /// Extract the field attribute. +impl FieldAttributes { + /// Extract the field attributes. /// - /// Currently fails if more than 1 attribute is passed in `pyo3` - fn parse_attrs(attrs: &[Attribute]) -> Result> { + fn parse_attrs(attrs: &[Attribute]) -> Result { + let mut getter = None; + let mut from_py_with = None; + let list = get_pyo3_meta_list(attrs)?; - let metaitem = match list.nested.len() { - 0 => return Ok(None), - 1 => list.nested.into_iter().next().unwrap(), - _ => bail_spanned!( - list.nested.span() => - "only one of `attribute` or `item` can be provided" - ), - }; - let meta = match metaitem { - syn::NestedMeta::Meta(meta) => meta, - syn::NestedMeta::Lit(lit) => bail_spanned!( - lit.span() => - "expected `attribute` or `item`, got a literal" - ), - }; - let path = meta.path(); - if path.is_ident("attribute") { - Ok(Some(FieldAttribute::GetAttr(Self::attribute_arg(meta)?))) - } else if path.is_ident("item") { - Ok(Some(FieldAttribute::GetItem(Self::item_arg(meta)?))) - } else { - bail_spanned!(meta.span() => "expected `attribute` or `item`"); + + for meta_item in list.nested { + let meta = match meta_item { + syn::NestedMeta::Meta(meta) => meta, + syn::NestedMeta::Lit(lit) => bail_spanned!( + lit.span() => + "expected `attribute`, `item` or `from_py_with`, got a literal" + ), + }; + let path = meta.path(); + + if path.is_ident("attribute") { + ensure_spanned!( + getter.is_none(), + meta.span() => "only one of `attribute` or `item` can be provided" + ); + getter = Some(FieldGetter::GetAttr(Self::attribute_arg(meta)?)) + } else if path.is_ident("item") { + ensure_spanned!( + getter.is_none(), + meta.span() => "only one of `attribute` or `item` can be provided" + ); + getter = Some(FieldGetter::GetItem(Self::item_arg(meta)?)) + } else if path.is_ident("from_py_with") { + from_py_with = Some(Self::from_py_with_arg(meta)?) + } else { + bail_spanned!(meta.span() => "expected `attribute`, `item` or `from_py_with`") + }; } + + Ok(FieldAttributes { + getter: getter.unwrap_or(FieldGetter::GetAttr(None)), + from_py_with, + }) } fn attribute_arg(meta: Meta) -> syn::Result> { @@ -389,6 +415,10 @@ impl FieldAttribute { bail_spanned!(arg_list.span() => "expected a single literal argument"); } + + fn from_py_with_arg(meta: Meta) -> syn::Result { + FromPyWithAttribute::from_meta(meta) + } } /// Extract pyo3 metalist, flattens multiple lists into a single one. @@ -426,7 +456,7 @@ fn verify_and_get_lifetime(generics: &syn::Generics) -> Result Foo(T)` /// adds `T: FromPyObject` on the derived implementation. pub fn build_derive_from_pyobject(tokens: &DeriveInput) -> Result { diff --git a/pyo3-macros-backend/src/lib.rs b/pyo3-macros-backend/src/lib.rs index 6b8af63d91a..9738e01f46a 100644 --- a/pyo3-macros-backend/src/lib.rs +++ b/pyo3-macros-backend/src/lib.rs @@ -7,6 +7,7 @@ #[macro_use] mod utils; +mod attrs; mod defs; mod from_pyobject; mod konst; diff --git a/pyo3-macros-backend/src/method.rs b/pyo3-macros-backend/src/method.rs index f7c252c4036..44bde1ee6a4 100644 --- a/pyo3-macros-backend/src/method.rs +++ b/pyo3-macros-backend/src/method.rs @@ -1,11 +1,12 @@ // Copyright (c) 2017-present PyO3 Project and Contributors use crate::pyfunction::Argument; -use crate::pyfunction::{parse_name_attribute, PyFunctionAttr}; +use crate::pyfunction::{parse_name_attribute, PyFunctionArgAttrs, PyFunctionAttr}; use crate::utils; use proc_macro2::TokenStream; use quote::ToTokens; use quote::{quote, quote_spanned}; +use std::ops::Deref; use syn::ext::IdentExt; use syn::spanned::Spanned; @@ -17,6 +18,7 @@ pub struct FnArg<'a> { pub ty: &'a syn::Type, pub optional: Option<&'a syn::Type>, pub py: bool, + pub attrs: PyFunctionArgAttrs, } #[derive(Clone, PartialEq, Debug, Copy, Eq)] @@ -115,146 +117,183 @@ pub fn parse_method_receiver(arg: &syn::FnArg) -> syn::Result { impl<'a> FnSpec<'a> { /// Parser function signature and function attributes - #[allow(clippy::manual_strip)] // for strip_prefix replacement supporting rust < 1.45 pub fn parse( - sig: &'a syn::Signature, + sig: &'a mut syn::Signature, meth_attrs: &mut Vec, allow_custom_name: bool, ) -> syn::Result> { - let name = &sig.ident; let MethodAttributes { ty: fn_type_attr, args: fn_attrs, mut python_name, } = parse_method_attributes(meth_attrs, allow_custom_name)?; - let mut arguments = Vec::new(); - let mut inputs_iter = sig.inputs.iter(); + let (fn_type, skip_first_arg) = Self::parse_fn_type(sig, fn_type_attr, &mut python_name)?; + + let name = &sig.ident; + let ty = get_return_info(&sig.output); + let python_name = python_name.as_ref().unwrap_or(name).unraw(); + + let text_signature = Self::parse_text_signature(meth_attrs, &fn_type, &python_name)?; + let doc = utils::get_doc(&meth_attrs, text_signature, true)?; - let mut parse_receiver = |msg: &'static str| { - inputs_iter - .next() - .ok_or_else(|| err_spanned!(sig.span() => msg)) - .and_then(parse_method_receiver) + let arguments = if skip_first_arg { + Self::parse_arguments(&mut sig.inputs.iter_mut().skip(1))? + } else { + Self::parse_arguments(&mut sig.inputs.iter_mut())? }; - // strip get_ or set_ - let strip_fn_name = |prefix: &'static str| { - let ident = name.unraw().to_string(); - if ident.starts_with(prefix) { - Some(syn::Ident::new(&ident[prefix.len()..], ident.span())) + Ok(FnSpec { + tp: fn_type, + name, + python_name, + attrs: fn_attrs, + args: arguments, + output: ty, + doc, + }) + } + + fn parse_text_signature( + meth_attrs: &mut Vec, + fn_type: &FnType, + python_name: &syn::Ident, + ) -> syn::Result> { + let mut parse_erroneous_text_signature = |error_msg: &str| { + // try to parse anyway to give better error messages + if let Some(text_signature) = + utils::parse_text_signature_attrs(meth_attrs, &python_name)? + { + bail_spanned!(text_signature.span() => error_msg) } else { - None + Ok(None) } }; - // Parse receiver & function type for various method types - let fn_type = match fn_type_attr { - Some(MethodTypeAttribute::StaticMethod) => FnType::FnStatic, - Some(MethodTypeAttribute::ClassAttribute) => { - ensure_spanned!( - sig.inputs.is_empty(), - sig.inputs.span() => "class attribute methods cannot take arguments" - ); - FnType::ClassAttribute - } - Some(MethodTypeAttribute::New) => FnType::FnNew, - Some(MethodTypeAttribute::ClassMethod) => { - // Skip first argument for classmethod - always &PyType - let _ = inputs_iter.next(); - FnType::FnClass - } - Some(MethodTypeAttribute::Call) => { - FnType::FnCall(parse_receiver("expected receiver for #[call]")?) - } - Some(MethodTypeAttribute::Getter) => { - // Strip off "get_" prefix if needed - if python_name.is_none() { - python_name = strip_fn_name("get_"); - } - - FnType::Getter(parse_receiver("expected receiver for #[getter]")?) + let text_signature = match &fn_type { + FnType::Fn(_) | FnType::FnClass | FnType::FnStatic => { + utils::parse_text_signature_attrs(&mut *meth_attrs, &python_name)? } - Some(MethodTypeAttribute::Setter) => { - // Strip off "set_" prefix if needed - if python_name.is_none() { - python_name = strip_fn_name("set_"); - } - - FnType::Setter(parse_receiver("expected receiver for #[setter]")?) + FnType::FnNew => parse_erroneous_text_signature( + "text_signature not allowed on __new__; if you want to add a signature on \ + __new__, put it on the struct definition instead", + )?, + FnType::FnCall(_) | FnType::Getter(_) | FnType::Setter(_) | FnType::ClassAttribute => { + parse_erroneous_text_signature("text_signature not allowed with this method type")? } - None => FnType::Fn(parse_receiver( - "static method needs #[staticmethod] attribute", - )?), }; - // parse rest of arguments + Ok(text_signature) + } + + fn parse_arguments( + // inputs: &'a mut [syn::FnArg], + inputs_iter: impl Iterator, + ) -> syn::Result>> { + let mut arguments = vec![]; for input in inputs_iter { match input { syn::FnArg::Receiver(recv) => { bail_spanned!(recv.span() => "unexpected receiver for method") - } - syn::FnArg::Typed(syn::PatType { pat, ty, .. }) => { - let (ident, by_ref, mutability) = match &**pat { + } // checked in parse_fn_type + syn::FnArg::Typed(cap) => { + let arg_attrs = PyFunctionArgAttrs::from_attrs(&mut cap.attrs)?; + let (ident, by_ref, mutability) = match *cap.pat { syn::Pat::Ident(syn::PatIdent { - ident, - by_ref, - mutability, + ref ident, + ref by_ref, + ref mutability, .. }) => (ident, by_ref, mutability), - _ => bail_spanned!(pat.span() => "unsupported argument"), + _ => bail_spanned!(cap.pat.span() => "unsupported argument"), }; arguments.push(FnArg { name: ident, by_ref, mutability, - ty, - optional: utils::option_type_argument(ty), - py: utils::is_python(ty), + ty: cap.ty.deref(), + optional: utils::option_type_argument(cap.ty.deref()), + py: utils::is_python(cap.ty.deref()), + attrs: arg_attrs, }); } } } - let ty = get_return_info(&sig.output); - let python_name = python_name.as_ref().unwrap_or(name).unraw(); + Ok(arguments) + } - let mut parse_erroneous_text_signature = |error_msg: &str| { - // try to parse anyway to give better error messages - if let Some(text_signature) = - utils::parse_text_signature_attrs(meth_attrs, &python_name)? - { - bail_spanned!(text_signature.span() => error_msg) + fn parse_fn_type( + sig: &syn::Signature, + fn_type_attr: Option, + python_name: &mut Option, + ) -> syn::Result<(FnType, bool)> { + let name = &sig.ident; + let parse_receiver = |msg: &'static str| { + let first_arg = sig + .inputs + .first() + .ok_or_else(|| err_spanned!(sig.span() => msg))?; + parse_method_receiver(first_arg) + }; + + #[allow(clippy::manual_strip)] // for strip_prefix replacement supporting rust < 1.45 + // strip get_ or set_ + let strip_fn_name = |prefix: &'static str| { + let ident = name.unraw().to_string(); + if ident.starts_with(prefix) { + Some(syn::Ident::new(&ident[prefix.len()..], ident.span())) } else { - Ok(None) + None } }; - let text_signature = match &fn_type { - FnType::Fn(_) | FnType::FnClass | FnType::FnStatic => { - utils::parse_text_signature_attrs(&mut *meth_attrs, &python_name)? - } - FnType::FnNew => parse_erroneous_text_signature( - "text_signature not allowed on __new__; if you want to add a signature on \ - __new__, put it on the struct definition instead", - )?, - FnType::FnCall(_) | FnType::Getter(_) | FnType::Setter(_) | FnType::ClassAttribute => { - parse_erroneous_text_signature("text_signature not allowed with this method type")? + let (fn_type, skip_first_arg) = match fn_type_attr { + Some(MethodTypeAttribute::StaticMethod) => (FnType::FnStatic, false), + Some(MethodTypeAttribute::ClassAttribute) => { + ensure_spanned!( + sig.inputs.is_empty(), + sig.inputs.span() => "class attribute methods cannot take arguments" + ); + (FnType::ClassAttribute, false) } - }; + Some(MethodTypeAttribute::New) => (FnType::FnNew, false), + Some(MethodTypeAttribute::ClassMethod) => (FnType::FnClass, true), + Some(MethodTypeAttribute::Call) => ( + FnType::FnCall(parse_receiver("expected receiver for #[call]")?), + true, + ), + Some(MethodTypeAttribute::Getter) => { + // Strip off "get_" prefix if needed + if python_name.is_none() { + *python_name = strip_fn_name("get_"); + } - let doc = utils::get_doc(&meth_attrs, text_signature, true)?; + ( + FnType::Getter(parse_receiver("expected receiver for #[getter]")?), + true, + ) + } + Some(MethodTypeAttribute::Setter) => { + // Strip off "set_" prefix if needed + if python_name.is_none() { + *python_name = strip_fn_name("set_"); + } - Ok(FnSpec { - tp: fn_type, - name, - python_name, - attrs: fn_attrs, - args: arguments, - output: ty, - doc, - }) + ( + FnType::Setter(parse_receiver("expected receiver for #[setter]")?), + true, + ) + } + None => ( + FnType::Fn(parse_receiver( + "static method needs #[staticmethod] attribute", + )?), + true, + ), + }; + Ok((fn_type, skip_first_arg)) } pub fn is_args(&self, name: &syn::Ident) -> bool { diff --git a/pyo3-macros-backend/src/module.rs b/pyo3-macros-backend/src/module.rs index 087ce2307b5..baec0f20c44 100644 --- a/pyo3-macros-backend/src/module.rs +++ b/pyo3-macros-backend/src/module.rs @@ -2,7 +2,7 @@ //! Code generation for the function that initializes a python module and adds classes and function. use crate::method; -use crate::pyfunction::PyFunctionAttr; +use crate::pyfunction::{PyFunctionArgAttrs, PyFunctionAttr}; use crate::pymethod; use crate::pymethod::get_arg_names; use crate::utils; @@ -58,7 +58,9 @@ pub fn process_functions_in_module(func: &mut syn::ItemFn) -> syn::Result<()> { } /// Transforms a rust fn arg parsed with syn into a method::FnArg -fn wrap_fn_argument(cap: &syn::PatType) -> syn::Result { +fn wrap_fn_argument(cap: &mut syn::PatType) -> syn::Result { + let arg_attrs = PyFunctionArgAttrs::from_attrs(&mut cap.attrs)?; + let (mutability, by_ref, ident) = match &*cap.pat { syn::Pat::Ident(patid) => (&patid.mutability, &patid.by_ref, &patid.ident), _ => bail_spanned!(cap.pat.span() => "unsupported argument"), @@ -71,6 +73,7 @@ fn wrap_fn_argument(cap: &syn::PatType) -> syn::Result { ty: &cap.ty, optional: utils::option_type_argument(&cap.ty), py: utils::is_python(&cap.ty), + attrs: arg_attrs, }) } @@ -142,7 +145,7 @@ pub fn add_fn_to_module( ) -> syn::Result { let mut arguments = Vec::new(); - for (i, input) in func.sig.inputs.iter().enumerate() { + for (i, input) in func.sig.inputs.iter_mut().enumerate() { match input { syn::FnArg::Receiver(_) => { bail_spanned!(input.span() => "unexpected receiver for #[pyfn]"); diff --git a/pyo3-macros-backend/src/pyfunction.rs b/pyo3-macros-backend/src/pyfunction.rs index 4a316f8ec62..15aeedd3b84 100644 --- a/pyo3-macros-backend/src/pyfunction.rs +++ b/pyo3-macros-backend/src/pyfunction.rs @@ -1,5 +1,6 @@ // Copyright (c) 2017-present PyO3 Project and Contributors +use crate::attrs::FromPyWithAttribute; use crate::module::add_fn_to_module; use proc_macro2::TokenStream; use syn::ext::IdentExt; @@ -27,6 +28,11 @@ pub struct PyFunctionAttr { pub pass_module: bool, } +#[derive(Clone, PartialEq, Debug)] +pub struct PyFunctionArgAttrs { + pub from_py_with: Option, +} + impl syn::parse::Parse for PyFunctionAttr { fn parse(input: &ParseBuffer) -> syn::Result { let attr = Punctuated::::parse_terminated(input)?; @@ -186,6 +192,50 @@ pub fn build_py_function(ast: &mut syn::ItemFn, args: PyFunctionAttr) -> syn::Re add_fn_to_module(ast, python_name, args) } +fn extract_pyo3_metas(attrs: &mut Vec) -> syn::Result> { + let mut new_attrs = Vec::new(); + let mut metas = Vec::new(); + + for attr in attrs.drain(..) { + if let syn::Meta::List(meta_list) = attr.parse_meta()? { + if meta_list.path.is_ident("pyo3") { + for meta in meta_list.nested { + metas.push(meta); + } + } else { + new_attrs.push(attr) + } + } + } + *attrs = new_attrs; + + Ok(metas) +} + +impl PyFunctionArgAttrs { + /// Parses #[pyo3(from_python_with = "func")] + pub fn from_attrs(attrs: &mut Vec) -> syn::Result { + let mut from_py_with = None; + + for meta in extract_pyo3_metas(attrs)? { + let meta = match meta { + NestedMeta::Meta(meta) => meta, + NestedMeta::Lit(lit) => { + bail_spanned!(lit.span() => "expected `from_py_with`, got a literal") + } + }; + + if meta.path().is_ident("from_py_with") { + from_py_with = Some(FromPyWithAttribute::from_meta(meta)?); + } else { + bail_spanned!(meta.span() => "only `from_py_with` is supported") + } + } + + Ok(PyFunctionArgAttrs { from_py_with }) + } +} + #[cfg(test)] mod test { use super::{Argument, PyFunctionAttr}; diff --git a/pyo3-macros-backend/src/pymethod.rs b/pyo3-macros-backend/src/pymethod.rs index f78d3607718..6c275b28f5f 100644 --- a/pyo3-macros-backend/src/pymethod.rs +++ b/pyo3-macros-backend/src/pymethod.rs @@ -1,4 +1,5 @@ // Copyright (c) 2017-present PyO3 Project and Contributors +use crate::attrs::FromPyWithAttribute; use crate::konst::ConstSpec; use crate::method::{FnArg, FnSpec, FnType, SelfType}; use crate::utils; @@ -491,6 +492,12 @@ fn impl_arg_param( (None, false) => quote! { panic!("Failed to extract required method argument") }, }; + let extract = if let Some(FromPyWithAttribute(expr_path)) = &arg.attrs.from_py_with { + quote! {#expr_path(_obj).map_err(#transform_error)?} + } else { + quote! {_obj.extract().map_err(#transform_error)?} + }; + return if let syn::Type::Reference(tref) = arg.optional.as_ref().unwrap_or(&ty) { let (tref, mut_) = preprocess_tref(tref, self_); let (target_ty, borrow_tmp) = if arg.optional.is_some() { @@ -513,7 +520,7 @@ fn impl_arg_param( quote! { let #mut_ _tmp: #target_ty = match #arg_value { - Some(_obj) => _obj.extract().map_err(#transform_error)?, + Some(_obj) => #extract, None => #default, }; let #arg_name = #borrow_tmp; @@ -521,7 +528,7 @@ fn impl_arg_param( } else { quote! { let #arg_name = match #arg_value { - Some(_obj) => _obj.extract().map_err(#transform_error)?, + Some(_obj) => #extract, None => #default, }; } diff --git a/pyo3-macros-backend/src/pyproto.rs b/pyo3-macros-backend/src/pyproto.rs index d804bd0b7af..a56afbc6e87 100644 --- a/pyo3-macros-backend/src/pyproto.rs +++ b/pyo3-macros-backend/src/pyproto.rs @@ -62,8 +62,7 @@ fn impl_proto_impl( } // Add non-slot methods to inventory like `#[pymethods]` if let Some(m) = proto.get_method(&met.sig.ident) { - let name = &met.sig.ident; - let fn_spec = FnSpec::parse(&met.sig, &mut met.attrs, false)?; + let fn_spec = FnSpec::parse(&mut met.sig, &mut met.attrs, false)?; let method = if let FnType::Fn(self_ty) = &fn_spec.tp { pymethod::impl_proto_wrap(ty, &fn_spec, &self_ty) @@ -79,6 +78,7 @@ fn impl_proto_impl( } else { quote!(0) }; + let name = &met.sig.ident; // TODO(kngwyu): Set ml_doc py_methods.push(quote! { pyo3::class::PyMethodDefType::Method({ diff --git a/tests/test_class_basics.rs b/tests/test_class_basics.rs index fcb49214040..cbb2d59c91c 100644 --- a/tests/test_class_basics.rs +++ b/tests/test_class_basics.rs @@ -244,3 +244,47 @@ fn panic_unsendable_base() { fn panic_unsendable_child() { test_unsendable::().unwrap(); } + +fn get_length(obj: &PyAny) -> PyResult { + let length = obj.len()?; + + Ok(length) +} + +#[pyclass] +struct ClassWithFromPyWithMethods {} + +#[pymethods] +impl ClassWithFromPyWithMethods { + fn instance_method(&self, #[pyo3(from_py_with = "get_length")] argument: usize) -> usize { + argument + } + #[classmethod] + fn classmethod(_cls: &PyType, #[pyo3(from_py_with = "PyAny::len")] argument: usize) -> usize { + argument + } + + #[staticmethod] + fn staticmethod(#[pyo3(from_py_with = "get_length")] argument: usize) -> usize { + argument + } +} + +#[test] +fn test_pymethods_from_py_with() { + Python::with_gil(|py| { + let instance = Py::new(py, ClassWithFromPyWithMethods {}).unwrap(); + + py_run!( + py, + instance, + r#" + arg = {1: 1, 2: 3} + + assert instance.instance_method(arg) == 2 + assert instance.classmethod(arg) == 2 + assert instance.staticmethod(arg) == 2 + "# + ); + }) +} diff --git a/tests/test_compile_error.rs b/tests/test_compile_error.rs index c50f4be830a..b816150da03 100644 --- a/tests/test_compile_error.rs +++ b/tests/test_compile_error.rs @@ -8,6 +8,7 @@ fn test_compile_errors() { t.compile_fail("tests/ui/invalid_pyclass_args.rs"); t.compile_fail("tests/ui/invalid_pymethods.rs"); t.compile_fail("tests/ui/invalid_pymethod_names.rs"); + t.compile_fail("tests/ui/invalid_argument_attributes.rs"); t.compile_fail("tests/ui/reject_generics.rs"); tests_rust_1_45(&t); diff --git a/tests/test_frompyobject.rs b/tests/test_frompyobject.rs index 24a7322b01b..bb8112204f2 100644 --- a/tests/test_frompyobject.rs +++ b/tests/test_frompyobject.rs @@ -299,3 +299,30 @@ fn test_err_rename() { "TypeError: 'dict' object cannot be converted to 'Union[str, uint, int]'" ); } + +#[derive(Debug, FromPyObject)] +pub struct Zap { + #[pyo3(item)] + name: String, + + #[pyo3(from_py_with = "PyAny::len", item("my_object"))] + some_object_length: usize, +} + +#[test] +fn test_from_py_with() { + Python::with_gil(|py| { + let py_zap = py + .eval( + r#"{"name": "whatever", "my_object": [1, 2, 3]}"#, + None, + None, + ) + .expect("failed to create dict"); + + let zap = Zap::extract(py_zap).unwrap(); + + assert_eq!(zap.name, "whatever"); + assert_eq!(zap.some_object_length, 3usize); + }); +} diff --git a/tests/test_pyfunction.rs b/tests/test_pyfunction.rs index 099e8cbf5ad..b59e6e932c4 100644 --- a/tests/test_pyfunction.rs +++ b/tests/test_pyfunction.rs @@ -3,7 +3,7 @@ use pyo3::buffer::PyBuffer; use pyo3::prelude::*; use pyo3::types::PyCFunction; #[cfg(not(Py_LIMITED_API))] -use pyo3::types::PyFunction; +use pyo3::types::{PyDateTime, PyFunction}; use pyo3::{raw_pycfunction, wrap_pyfunction}; mod common; @@ -111,6 +111,59 @@ fn test_functions_with_function_args() { } } +#[cfg(not(Py_LIMITED_API))] +fn datetime_to_timestamp(dt: &PyAny) -> PyResult { + let dt: &PyDateTime = dt.extract()?; + let ts: f64 = dt.call_method0("timestamp")?.extract()?; + + Ok(ts as i64) +} + +#[cfg(not(Py_LIMITED_API))] +#[pyfunction] +fn function_with_custom_conversion( + #[pyo3(from_py_with = "datetime_to_timestamp")] timestamp: i64, +) -> i64 { + timestamp +} + +#[cfg(not(Py_LIMITED_API))] +#[test] +fn test_function_with_custom_conversion() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let custom_conv_func = wrap_pyfunction!(function_with_custom_conversion)(py).unwrap(); + + pyo3::py_run!( + py, + custom_conv_func, + r#" + import datetime + + dt = datetime.datetime.fromtimestamp(1612040400) + assert custom_conv_func(dt) == 1612040400 + "# + ) +} + +#[cfg(not(Py_LIMITED_API))] +#[test] +fn test_function_with_custom_conversion_error() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let custom_conv_func = wrap_pyfunction!(function_with_custom_conversion)(py).unwrap(); + + py_expect_exception!( + py, + custom_conv_func, + "custom_conv_func(['a'])", + PyTypeError, + "argument 'timestamp': 'list' object cannot be converted to 'PyDateTime'" + ); +} + #[test] fn test_raw_function() { let gil = Python::acquire_gil(); diff --git a/tests/ui/invalid_argument_attributes.rs b/tests/ui/invalid_argument_attributes.rs new file mode 100644 index 00000000000..798d428ac86 --- /dev/null +++ b/tests/ui/invalid_argument_attributes.rs @@ -0,0 +1,15 @@ +use pyo3::prelude::*; + +#[pyfunction] +fn invalid_attribute(#[pyo3(get)] param: String) {} + +#[pyfunction] +fn from_py_with_no_value(#[pyo3(from_py_with)] param: String) {} + +#[pyfunction] +fn from_py_with_string(#[pyo3("from_py_with")] param: String) {} + +#[pyfunction] +fn from_py_with_value_not_a_string(#[pyo3(from_py_with = func)] param: String) {} + +fn main() {} diff --git a/tests/ui/invalid_argument_attributes.stderr b/tests/ui/invalid_argument_attributes.stderr new file mode 100644 index 00000000000..f1e3751873d --- /dev/null +++ b/tests/ui/invalid_argument_attributes.stderr @@ -0,0 +1,23 @@ +error: only `from_py_with` is supported + --> $DIR/invalid_argument_attributes.rs:4:29 + | +4 | fn invalid_attribute(#[pyo3(get)] param: String) {} + | ^^^ + +error: expected a name-value: `pyo3(from_py_with = "func")` + --> $DIR/invalid_argument_attributes.rs:7:33 + | +7 | fn from_py_with_no_value(#[pyo3(from_py_with)] param: String) {} + | ^^^^^^^^^^^^ + +error: expected `from_py_with`, got a literal + --> $DIR/invalid_argument_attributes.rs:10:31 + | +10 | fn from_py_with_string(#[pyo3("from_py_with")] param: String) {} + | ^^^^^^^^^^^^^^ + +error: expected literal + --> $DIR/invalid_argument_attributes.rs:13:58 + | +13 | fn from_py_with_value_not_a_string(#[pyo3(from_py_with = func)] param: String) {} + | ^^^^ diff --git a/tests/ui/invalid_frompy_derive.rs b/tests/ui/invalid_frompy_derive.rs index 0a54aeca0fb..ad21001f12f 100644 --- a/tests/ui/invalid_frompy_derive.rs +++ b/tests/ui/invalid_frompy_derive.rs @@ -153,4 +153,16 @@ enum UnitEnum { Unit, } +#[derive(FromPyObject)] +struct InvalidFromPyWith { + #[pyo3(from_py_with)] + field: String, +} + +#[derive(FromPyObject)] +struct InvalidFromPyWithLiteral { + #[pyo3(from_py_with = func)] + field: String, +} + fn main() {} diff --git a/tests/ui/invalid_frompy_derive.stderr b/tests/ui/invalid_frompy_derive.stderr index b3c153b9cb1..ea8a96df0de 100644 --- a/tests/ui/invalid_frompy_derive.stderr +++ b/tests/ui/invalid_frompy_derive.stderr @@ -84,7 +84,7 @@ error: transparent structs and variants can only have 1 field 70 | | }, | |_____^ -error: expected `attribute` or `item` +error: expected `attribute`, `item` or `from_py_with` --> $DIR/invalid_frompy_derive.rs:76:12 | 76 | #[pyo3(attr)] @@ -127,10 +127,10 @@ error: expected a single literal argument | ^^^^ error: only one of `attribute` or `item` can be provided - --> $DIR/invalid_frompy_derive.rs:118:12 + --> $DIR/invalid_frompy_derive.rs:118:18 | 118 | #[pyo3(item, attribute)] - | ^^^^ + | ^^^^^^^^^ error: unknown `pyo3` container attribute --> $DIR/invalid_frompy_derive.rs:123:8 @@ -169,3 +169,15 @@ error: cannot derive FromPyObject for empty structs and variants | ^^^^^^^^^^^^ | = note: this error originates in a derive macro (in Nightly builds, run with -Z macro-backtrace for more info) + +error: expected a name-value: `pyo3(from_py_with = "func")` + --> $DIR/invalid_frompy_derive.rs:158:12 + | +158 | #[pyo3(from_py_with)] + | ^^^^^^^^^^^^ + +error: expected literal + --> $DIR/invalid_frompy_derive.rs:164:27 + | +164 | #[pyo3(from_py_with = func)] + | ^^^^