Skip to content

Commit

Permalink
allow #[pyo3(signature = ...)] on complex enum variants to specify …
Browse files Browse the repository at this point in the history
…constructor signature
  • Loading branch information
Icxolu committed Apr 25, 2024
1 parent 8734b76 commit 3e2c41b
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 15 deletions.
43 changes: 32 additions & 11 deletions pyo3-macros-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::attributes::{
use crate::deprecations::Deprecations;
use crate::konst::{ConstAttributes, ConstSpec};
use crate::method::{FnArg, FnSpec, PyArg, RegularArg};
use crate::pyfunction::SignatureAttribute;
use crate::pyimpl::{gen_py_const, PyClassMethodsType};
use crate::pymethod::{
impl_py_getter_def, impl_py_setter_def, MethodAndMethodDef, MethodAndSlotDef, PropertyType,
Expand Down Expand Up @@ -622,17 +623,21 @@ struct PyClassEnumVariantNamedField<'a> {
/// `#[pyo3()]` options for pyclass enum variants
struct EnumVariantPyO3Options {
name: Option<NameAttribute>,
signature: Option<SignatureAttribute>,
}

enum EnumVariantPyO3Option {
Name(NameAttribute),
Signature(SignatureAttribute),
}

impl Parse for EnumVariantPyO3Option {
fn parse(input: ParseStream<'_>) -> Result<Self> {
let lookahead = input.lookahead1();
if lookahead.peek(attributes::kw::name) {
input.parse().map(EnumVariantPyO3Option::Name)
} else if lookahead.peek(attributes::kw::signature) {
input.parse().map(EnumVariantPyO3Option::Signature)
} else {
Err(lookahead.error())
}
Expand All @@ -641,7 +646,10 @@ impl Parse for EnumVariantPyO3Option {

impl EnumVariantPyO3Options {
fn take_pyo3_options(attrs: &mut Vec<syn::Attribute>) -> Result<Self> {
let mut options = EnumVariantPyO3Options { name: None };
let mut options = EnumVariantPyO3Options {
name: None,
signature: None,
};

for option in take_pyo3_options(attrs)? {
match option {
Expand All @@ -652,6 +660,13 @@ impl EnumVariantPyO3Options {
);
options.name = Some(name);
}
EnumVariantPyO3Option::Signature(signature) => {
ensure_spanned!(
options.signature.is_none(),
signature.span() => "`signature` may only be specified once"
);
options.signature = Some(signature);
}
}
}

Expand Down Expand Up @@ -691,19 +706,20 @@ fn impl_simple_enum(

let (default_repr, default_repr_slot) = {
let variants_repr = variants.iter().map(|variant| {
ensure_spanned!(variant.options.signature.is_none(), variant.options.signature.span() => "`signature` can't be used on a simple enum variant");
let variant_name = variant.ident;
// Assuming all variants are unit variants because they are the only type we support.
let repr = format!(
"{}.{}",
get_class_python_name(cls, args),
variant.get_python_name(args),
);
quote! { #cls::#variant_name => #repr, }
});
Ok(quote! { #cls::#variant_name => #repr, })
}).collect::<Result<TokenStream>>()?;
let mut repr_impl: syn::ImplItemFn = syn::parse_quote! {
fn __pyo3__repr__(&self) -> &'static str {
match self {
#(#variants_repr)*
#variants_repr
}
}
};
Expand Down Expand Up @@ -889,7 +905,7 @@ fn impl_complex_enum(
let mut variant_cls_pytypeinfos = vec![];
let mut variant_cls_pyclass_impls = vec![];
let mut variant_cls_impls = vec![];
for variant in &variants {
for variant in variants {
let variant_cls = gen_complex_enum_variant_class_ident(cls, variant.get_ident());

let variant_cls_zst = quote! {
Expand All @@ -908,11 +924,11 @@ fn impl_complex_enum(
let variant_cls_pytypeinfo = impl_pytypeinfo(&variant_cls, &variant_args, None, ctx);
variant_cls_pytypeinfos.push(variant_cls_pytypeinfo);

let variant_new = complex_enum_variant_new(cls, variant, ctx)?;

let (variant_cls_impl, field_getters) = impl_complex_enum_variant_cls(cls, variant, ctx)?;
let (variant_cls_impl, field_getters) = impl_complex_enum_variant_cls(cls, &variant, ctx)?;
variant_cls_impls.push(variant_cls_impl);

let variant_new = complex_enum_variant_new(cls, variant, ctx)?;

let pyclass_impl = PyClassImplsBuilder::new(
&variant_cls,
&variant_args,
Expand Down Expand Up @@ -1120,7 +1136,7 @@ pub fn gen_complex_enum_variant_attr(

fn complex_enum_variant_new<'a>(
cls: &'a syn::Ident,
variant: &'a PyClassEnumVariant<'a>,
variant: PyClassEnumVariant<'a>,
ctx: &Ctx,
) -> Result<MethodAndSlotDef> {
match variant {
Expand All @@ -1132,7 +1148,7 @@ fn complex_enum_variant_new<'a>(

fn complex_enum_struct_variant_new<'a>(
cls: &'a syn::Ident,
variant: &'a PyClassEnumStructVariant<'a>,
variant: PyClassEnumStructVariant<'a>,
ctx: &Ctx,
) -> Result<MethodAndSlotDef> {
let Ctx { pyo3_path } = ctx;
Expand Down Expand Up @@ -1162,7 +1178,12 @@ fn complex_enum_struct_variant_new<'a>(
}
args
};
let signature = crate::pyfunction::FunctionSignature::from_arguments(args)?;

let signature = if let Some(signature) = variant.options.signature {
crate::pyfunction::FunctionSignature::from_arguments_and_attribute(args, signature)?
} else {
crate::pyfunction::FunctionSignature::from_arguments(args)?
};

let spec = FnSpec {
tp: crate::method::FnType::FnNew,
Expand Down
27 changes: 23 additions & 4 deletions pytests/src/enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,26 @@ pub fn do_simple_stuff(thing: &SimpleEnum) -> SimpleEnum {

#[pyclass]
pub enum ComplexEnum {
Int { i: i32 },
Float { f: f64 },
Str { s: String },
Int {
i: i32,
},
Float {
f: f64,
},
Str {
s: String,
},
EmptyStruct {},
MultiFieldStruct { a: i32, b: f64, c: bool },
MultiFieldStruct {
a: i32,
b: f64,
c: bool,
},
#[pyo3(signature = (a = 42, b = None))]
VariantWithDefault {
a: i32,
b: Option<String>,
},
}

#[pyfunction]
Expand All @@ -58,5 +73,9 @@ pub fn do_complex_stuff(thing: &ComplexEnum) -> ComplexEnum {
b: *b,
c: *c,
},
ComplexEnum::VariantWithDefault { a, b } => ComplexEnum::VariantWithDefault {
a: 2 * a,
b: b.as_ref().map(|s| s.to_uppercase()),
},
}
}
21 changes: 21 additions & 0 deletions pytests/tests/test_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ def test_complex_enum_variant_constructors():
multi_field_struct_variant = enums.ComplexEnum.MultiFieldStruct(42, 3.14, True)
assert isinstance(multi_field_struct_variant, enums.ComplexEnum.MultiFieldStruct)

variant_with_default_1 = enums.ComplexEnum.VariantWithDefault()
assert isinstance(variant_with_default_1, enums.ComplexEnum.VariantWithDefault)

variant_with_default_2 = enums.ComplexEnum.VariantWithDefault(25, "Hello")
assert isinstance(variant_with_default_2, enums.ComplexEnum.VariantWithDefault)

@pytest.mark.parametrize(
"variant",
Expand All @@ -27,6 +32,7 @@ def test_complex_enum_variant_constructors():
enums.ComplexEnum.Str("hello"),
enums.ComplexEnum.EmptyStruct(),
enums.ComplexEnum.MultiFieldStruct(42, 3.14, True),
enums.ComplexEnum.VariantWithDefault(),
],
)
def test_complex_enum_variant_subclasses(variant: enums.ComplexEnum):
Expand All @@ -48,6 +54,9 @@ def test_complex_enum_field_getters():
assert multi_field_struct_variant.b == 3.14
assert multi_field_struct_variant.c is True

variant_with_default = enums.ComplexEnum.VariantWithDefault()
assert variant_with_default.a == 42
assert variant_with_default.b is None

@pytest.mark.parametrize(
"variant",
Expand All @@ -57,6 +66,7 @@ def test_complex_enum_field_getters():
enums.ComplexEnum.Str("hello"),
enums.ComplexEnum.EmptyStruct(),
enums.ComplexEnum.MultiFieldStruct(42, 3.14, True),
enums.ComplexEnum.VariantWithDefault(),
],
)
def test_complex_enum_desugared_match(variant: enums.ComplexEnum):
Expand All @@ -78,6 +88,11 @@ def test_complex_enum_desugared_match(variant: enums.ComplexEnum):
assert x == 42
assert y == 3.14
assert z is True
elif isinstance(variant, enums.ComplexEnum.VariantWithDefault):
x = variant.a
y = variant.b
assert x == 42
assert y is None
else:
assert False

Expand All @@ -90,6 +105,7 @@ def test_complex_enum_desugared_match(variant: enums.ComplexEnum):
enums.ComplexEnum.Str("hello"),
enums.ComplexEnum.EmptyStruct(),
enums.ComplexEnum.MultiFieldStruct(42, 3.14, True),
enums.ComplexEnum.VariantWithDefault(b = "hello"),
],
)
def test_complex_enum_pyfunction_in_out_desugared_match(variant: enums.ComplexEnum):
Expand All @@ -112,5 +128,10 @@ def test_complex_enum_pyfunction_in_out_desugared_match(variant: enums.ComplexEn
assert x == 42
assert y == 3.14
assert z is True
elif isinstance(variant, enums.ComplexEnum.VariantWithDefault):
x = variant.a
y = variant.b
assert x == 84
assert y == "HELLO"
else:
assert False
7 changes: 7 additions & 0 deletions tests/ui/invalid_pyclass_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,11 @@ enum NoTupleVariants {
TupleVariant(i32),
}

#[pyclass]
enum SimpleNoSignature {
#[pyo3(signature = (a, b))]
A,
B,
}

fn main() {}
6 changes: 6 additions & 0 deletions tests/ui/invalid_pyclass_enum.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,9 @@ error: Tuple variant `TupleVariant` is not yet supported in a complex enum
|
27 | TupleVariant(i32),
| ^^^^^^^^^^^^

error: `signature` can't be used on a simple enum variant
--> tests/ui/invalid_pyclass_enum.rs:32:12
|
32 | #[pyo3(signature = (a, b))]
| ^^^^^^^^^

0 comments on commit 3e2c41b

Please sign in to comment.