From 62c7fd0db345654bff0e8ddfe25511f30d065122 Mon Sep 17 00:00:00 2001 From: Icxolu <10486322+Icxolu@users.noreply.github.com> Date: Thu, 25 Apr 2024 22:25:09 +0200 Subject: [PATCH] allow `#[pyo3(signature = ...)]` on complex enum variants to specify constructor signature --- pyo3-macros-backend/src/pyclass.rs | 43 +++++++++++++++++++++------- pytests/src/enums.rs | 27 ++++++++++++++--- pytests/tests/test_enums.py | 23 +++++++++++++++ tests/ui/invalid_pyclass_enum.rs | 7 +++++ tests/ui/invalid_pyclass_enum.stderr | 6 ++++ 5 files changed, 91 insertions(+), 15 deletions(-) diff --git a/pyo3-macros-backend/src/pyclass.rs b/pyo3-macros-backend/src/pyclass.rs index d9c84655b42..ff87075b376 100644 --- a/pyo3-macros-backend/src/pyclass.rs +++ b/pyo3-macros-backend/src/pyclass.rs @@ -8,6 +8,7 @@ use crate::attributes::{ use crate::deprecations::Deprecations; use crate::konst::{ConstAttributes, ConstSpec}; use crate::method::{FnArg, FnSpec, PyArg, RegularArg}; +use crate::pyfunction::SignatureAttribute; use crate::pyimpl::{gen_py_const, PyClassMethodsType}; use crate::pymethod::{ impl_py_getter_def, impl_py_setter_def, MethodAndMethodDef, MethodAndSlotDef, PropertyType, @@ -622,10 +623,12 @@ struct PyClassEnumVariantNamedField<'a> { /// `#[pyo3()]` options for pyclass enum variants struct EnumVariantPyO3Options { name: Option, + signature: Option, } enum EnumVariantPyO3Option { Name(NameAttribute), + Signature(SignatureAttribute), } impl Parse for EnumVariantPyO3Option { @@ -633,6 +636,8 @@ impl Parse for EnumVariantPyO3Option { let lookahead = input.lookahead1(); if lookahead.peek(attributes::kw::name) { input.parse().map(EnumVariantPyO3Option::Name) + } else if lookahead.peek(attributes::kw::signature) { + input.parse().map(EnumVariantPyO3Option::Signature) } else { Err(lookahead.error()) } @@ -641,7 +646,10 @@ impl Parse for EnumVariantPyO3Option { impl EnumVariantPyO3Options { fn take_pyo3_options(attrs: &mut Vec) -> Result { - let mut options = EnumVariantPyO3Options { name: None }; + let mut options = EnumVariantPyO3Options { + name: None, + signature: None, + }; for option in take_pyo3_options(attrs)? { match option { @@ -652,6 +660,13 @@ impl EnumVariantPyO3Options { ); options.name = Some(name); } + EnumVariantPyO3Option::Signature(signature) => { + ensure_spanned!( + options.signature.is_none(), + signature.span() => "`signature` may only be specified once" + ); + options.signature = Some(signature); + } } } @@ -691,6 +706,7 @@ fn impl_simple_enum( let (default_repr, default_repr_slot) = { let variants_repr = variants.iter().map(|variant| { + ensure_spanned!(variant.options.signature.is_none(), variant.options.signature.span() => "`signature` can't be used on a simple enum variant"); let variant_name = variant.ident; // Assuming all variants are unit variants because they are the only type we support. let repr = format!( @@ -698,12 +714,12 @@ fn impl_simple_enum( get_class_python_name(cls, args), variant.get_python_name(args), ); - quote! { #cls::#variant_name => #repr, } - }); + Ok(quote! { #cls::#variant_name => #repr, }) + }).collect::>()?; let mut repr_impl: syn::ImplItemFn = syn::parse_quote! { fn __pyo3__repr__(&self) -> &'static str { match self { - #(#variants_repr)* + #variants_repr } } }; @@ -889,7 +905,7 @@ fn impl_complex_enum( let mut variant_cls_pytypeinfos = vec![]; let mut variant_cls_pyclass_impls = vec![]; let mut variant_cls_impls = vec![]; - for variant in &variants { + for variant in variants { let variant_cls = gen_complex_enum_variant_class_ident(cls, variant.get_ident()); let variant_cls_zst = quote! { @@ -908,11 +924,11 @@ fn impl_complex_enum( let variant_cls_pytypeinfo = impl_pytypeinfo(&variant_cls, &variant_args, None, ctx); variant_cls_pytypeinfos.push(variant_cls_pytypeinfo); - let variant_new = complex_enum_variant_new(cls, variant, ctx)?; - - let (variant_cls_impl, field_getters) = impl_complex_enum_variant_cls(cls, variant, ctx)?; + let (variant_cls_impl, field_getters) = impl_complex_enum_variant_cls(cls, &variant, ctx)?; variant_cls_impls.push(variant_cls_impl); + let variant_new = complex_enum_variant_new(cls, variant, ctx)?; + let pyclass_impl = PyClassImplsBuilder::new( &variant_cls, &variant_args, @@ -1120,7 +1136,7 @@ pub fn gen_complex_enum_variant_attr( fn complex_enum_variant_new<'a>( cls: &'a syn::Ident, - variant: &'a PyClassEnumVariant<'a>, + variant: PyClassEnumVariant<'a>, ctx: &Ctx, ) -> Result { match variant { @@ -1132,7 +1148,7 @@ fn complex_enum_variant_new<'a>( fn complex_enum_struct_variant_new<'a>( cls: &'a syn::Ident, - variant: &'a PyClassEnumStructVariant<'a>, + variant: PyClassEnumStructVariant<'a>, ctx: &Ctx, ) -> Result { let Ctx { pyo3_path } = ctx; @@ -1162,7 +1178,12 @@ fn complex_enum_struct_variant_new<'a>( } args }; - let signature = crate::pyfunction::FunctionSignature::from_arguments(args)?; + + let signature = if let Some(signature) = variant.options.signature { + crate::pyfunction::FunctionSignature::from_arguments_and_attribute(args, signature)? + } else { + crate::pyfunction::FunctionSignature::from_arguments(args)? + }; let spec = FnSpec { tp: crate::method::FnType::FnNew, diff --git a/pytests/src/enums.rs b/pytests/src/enums.rs index 0a1bc49bb63..bef0e0bc7f4 100644 --- a/pytests/src/enums.rs +++ b/pytests/src/enums.rs @@ -39,11 +39,26 @@ pub fn do_simple_stuff(thing: &SimpleEnum) -> SimpleEnum { #[pyclass] pub enum ComplexEnum { - Int { i: i32 }, - Float { f: f64 }, - Str { s: String }, + Int { + i: i32, + }, + Float { + f: f64, + }, + Str { + s: String, + }, EmptyStruct {}, - MultiFieldStruct { a: i32, b: f64, c: bool }, + MultiFieldStruct { + a: i32, + b: f64, + c: bool, + }, + #[pyo3(signature = (a = 42, b = None))] + VariantWithDefault { + a: i32, + b: Option, + }, } #[pyfunction] @@ -58,5 +73,9 @@ pub fn do_complex_stuff(thing: &ComplexEnum) -> ComplexEnum { b: *b, c: *c, }, + ComplexEnum::VariantWithDefault { a, b } => ComplexEnum::VariantWithDefault { + a: 2 * a, + b: b.as_ref().map(|s| s.to_uppercase()), + }, } } diff --git a/pytests/tests/test_enums.py b/pytests/tests/test_enums.py index 04b0cdca431..cd1d7aedaf8 100644 --- a/pytests/tests/test_enums.py +++ b/pytests/tests/test_enums.py @@ -18,6 +18,12 @@ def test_complex_enum_variant_constructors(): multi_field_struct_variant = enums.ComplexEnum.MultiFieldStruct(42, 3.14, True) assert isinstance(multi_field_struct_variant, enums.ComplexEnum.MultiFieldStruct) + variant_with_default_1 = enums.ComplexEnum.VariantWithDefault() + assert isinstance(variant_with_default_1, enums.ComplexEnum.VariantWithDefault) + + variant_with_default_2 = enums.ComplexEnum.VariantWithDefault(25, "Hello") + assert isinstance(variant_with_default_2, enums.ComplexEnum.VariantWithDefault) + @pytest.mark.parametrize( "variant", @@ -27,6 +33,7 @@ def test_complex_enum_variant_constructors(): enums.ComplexEnum.Str("hello"), enums.ComplexEnum.EmptyStruct(), enums.ComplexEnum.MultiFieldStruct(42, 3.14, True), + enums.ComplexEnum.VariantWithDefault(), ], ) def test_complex_enum_variant_subclasses(variant: enums.ComplexEnum): @@ -48,6 +55,10 @@ def test_complex_enum_field_getters(): assert multi_field_struct_variant.b == 3.14 assert multi_field_struct_variant.c is True + variant_with_default = enums.ComplexEnum.VariantWithDefault() + assert variant_with_default.a == 42 + assert variant_with_default.b is None + @pytest.mark.parametrize( "variant", @@ -57,6 +68,7 @@ def test_complex_enum_field_getters(): enums.ComplexEnum.Str("hello"), enums.ComplexEnum.EmptyStruct(), enums.ComplexEnum.MultiFieldStruct(42, 3.14, True), + enums.ComplexEnum.VariantWithDefault(), ], ) def test_complex_enum_desugared_match(variant: enums.ComplexEnum): @@ -78,6 +90,11 @@ def test_complex_enum_desugared_match(variant: enums.ComplexEnum): assert x == 42 assert y == 3.14 assert z is True + elif isinstance(variant, enums.ComplexEnum.VariantWithDefault): + x = variant.a + y = variant.b + assert x == 42 + assert y is None else: assert False @@ -90,6 +107,7 @@ def test_complex_enum_desugared_match(variant: enums.ComplexEnum): enums.ComplexEnum.Str("hello"), enums.ComplexEnum.EmptyStruct(), enums.ComplexEnum.MultiFieldStruct(42, 3.14, True), + enums.ComplexEnum.VariantWithDefault(b="hello"), ], ) def test_complex_enum_pyfunction_in_out_desugared_match(variant: enums.ComplexEnum): @@ -112,5 +130,10 @@ def test_complex_enum_pyfunction_in_out_desugared_match(variant: enums.ComplexEn assert x == 42 assert y == 3.14 assert z is True + elif isinstance(variant, enums.ComplexEnum.VariantWithDefault): + x = variant.a + y = variant.b + assert x == 84 + assert y == "HELLO" else: assert False diff --git a/tests/ui/invalid_pyclass_enum.rs b/tests/ui/invalid_pyclass_enum.rs index 95879c2fbd1..f057bbf5873 100644 --- a/tests/ui/invalid_pyclass_enum.rs +++ b/tests/ui/invalid_pyclass_enum.rs @@ -27,4 +27,11 @@ enum NoTupleVariants { TupleVariant(i32), } +#[pyclass] +enum SimpleNoSignature { + #[pyo3(signature = (a, b))] + A, + B, +} + fn main() {} diff --git a/tests/ui/invalid_pyclass_enum.stderr b/tests/ui/invalid_pyclass_enum.stderr index a03a0ae2814..75aafa7e0f7 100644 --- a/tests/ui/invalid_pyclass_enum.stderr +++ b/tests/ui/invalid_pyclass_enum.stderr @@ -31,3 +31,9 @@ error: Tuple variant `TupleVariant` is not yet supported in a complex enum | 27 | TupleVariant(i32), | ^^^^^^^^^^^^ + +error: `signature` can't be used on a simple enum variant + --> tests/ui/invalid_pyclass_enum.rs:32:12 + | +32 | #[pyo3(signature = (a, b))] + | ^^^^^^^^^