From b3bb6674d66e296d6b9c536d78527ac19ac28d6e Mon Sep 17 00:00:00 2001 From: Icxolu <10486322+Icxolu@users.noreply.github.com> Date: Fri, 25 Oct 2024 11:09:18 +0200 Subject: [PATCH] fix `#[derive(FromPyObject)]` expansion with trait bounds (#4645) * fix `#[derive(FromPyObject)]` expansion with trait bounds * add newsfragment --- newsfragments/4645.fixed.md | 1 + pyo3-macros-backend/src/frompyobject.rs | 21 ++++++++++++--------- tests/test_frompyobject.rs | 17 ++++++++++++++++- 3 files changed, 29 insertions(+), 10 deletions(-) create mode 100644 newsfragments/4645.fixed.md diff --git a/newsfragments/4645.fixed.md b/newsfragments/4645.fixed.md new file mode 100644 index 00000000000..ec4352d6693 --- /dev/null +++ b/newsfragments/4645.fixed.md @@ -0,0 +1 @@ +fix `#[derive(FromPyObject)]` expansion on generic with trait bounds \ No newline at end of file diff --git a/pyo3-macros-backend/src/frompyobject.rs b/pyo3-macros-backend/src/frompyobject.rs index a20eeec9ffd..14c8755e9be 100644 --- a/pyo3-macros-backend/src/frompyobject.rs +++ b/pyo3-macros-backend/src/frompyobject.rs @@ -572,24 +572,27 @@ fn verify_and_get_lifetime(generics: &syn::Generics) -> Result Foo(T)` /// adds `T: FromPyObject` on the derived implementation. pub fn build_derive_from_pyobject(tokens: &DeriveInput) -> Result { + let options = ContainerOptions::from_attrs(&tokens.attrs)?; + let ctx = &Ctx::new(&options.krate, None); + let Ctx { pyo3_path, .. } = &ctx; + + let (_, ty_generics, _) = tokens.generics.split_for_impl(); let mut trait_generics = tokens.generics.clone(); - let generics = &tokens.generics; - let lt_param = if let Some(lt) = verify_and_get_lifetime(generics)? { + let lt_param = if let Some(lt) = verify_and_get_lifetime(&trait_generics)? { lt.clone() } else { trait_generics.params.push(parse_quote!('py)); parse_quote!('py) }; - let mut where_clause: syn::WhereClause = parse_quote!(where); - for param in generics.type_params() { + let (impl_generics, _, where_clause) = trait_generics.split_for_impl(); + + let mut where_clause = where_clause.cloned().unwrap_or_else(|| parse_quote!(where)); + for param in trait_generics.type_params() { let gen_ident = ¶m.ident; where_clause .predicates - .push(parse_quote!(#gen_ident: FromPyObject<#lt_param>)) + .push(parse_quote!(#gen_ident: #pyo3_path::FromPyObject<'py>)) } - let options = ContainerOptions::from_attrs(&tokens.attrs)?; - let ctx = &Ctx::new(&options.krate, None); - let Ctx { pyo3_path, .. } = &ctx; let derives = match &tokens.data { syn::Data::Enum(en) => { @@ -616,7 +619,7 @@ pub fn build_derive_from_pyobject(tokens: &DeriveInput) -> Result { let ident = &tokens.ident; Ok(quote!( #[automatically_derived] - impl #trait_generics #pyo3_path::FromPyObject<#lt_param> for #ident #generics #where_clause { + impl #impl_generics #pyo3_path::FromPyObject<#lt_param> for #ident #ty_generics #where_clause { fn extract_bound(obj: &#pyo3_path::Bound<#lt_param, #pyo3_path::PyAny>) -> #pyo3_path::PyResult { #derives } diff --git a/tests/test_frompyobject.rs b/tests/test_frompyobject.rs index a1b91c25128..6093b774733 100644 --- a/tests/test_frompyobject.rs +++ b/tests/test_frompyobject.rs @@ -2,7 +2,7 @@ use pyo3::exceptions::PyValueError; use pyo3::prelude::*; -use pyo3::types::{PyDict, PyList, PyString, PyTuple}; +use pyo3::types::{IntoPyDict, PyDict, PyList, PyString, PyTuple}; #[macro_use] #[path = "../src/tests/common.rs"] @@ -109,6 +109,21 @@ fn test_generic_transparent_named_field_struct() { }); } +#[derive(Debug, FromPyObject)] +pub struct GenericWithBound(std::collections::HashMap); + +#[test] +fn test_generic_with_bound() { + Python::with_gil(|py| { + let dict = [("1", 1), ("2", 2)].into_py_dict(py).unwrap(); + let map = dict.extract::>().unwrap().0; + assert_eq!(map.len(), 2); + assert_eq!(map["1"], 1); + assert_eq!(map["2"], 2); + assert!(!map.contains_key("3")); + }); +} + #[derive(Debug, FromPyObject)] pub struct E { test: T,