Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow constructor customization of complex enum variants #4158

Merged
merged 5 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions guide/pyclass-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

| Parameter | Description |
| :- | :- |
| `constructor` | This is currently only allowed on [variants of complex enums][params-constructor]. It allows customization of the generated class constructor for each variant. It uses the same syntax and supports the same options as the `signature` attribute of functions and methods. |
| <span style="white-space: pre">`crate = "some::path"`</span> | Path to import the `pyo3` crate, if it's not accessible at `::pyo3`. |
| `dict` | Gives instances of this class an empty `__dict__` to store custom attributes. |
| <span style="white-space: pre">`extends = BaseType`</span> | Use a custom baseclass. Defaults to [`PyAny`][params-1] |
Expand Down Expand Up @@ -39,5 +40,6 @@ struct MyClass {}
[params-4]: https://doc.rust-lang.org/std/rc/struct.Rc.html
[params-5]: https://doc.rust-lang.org/std/sync/struct.Arc.html
[params-6]: https://docs.python.org/3/library/weakref.html
[params-constructor]: https://pyo3.rs/latest/class.html#complex-enums
[params-mapping]: https://pyo3.rs/latest/class/protocols.html#mapping--sequence-types
[params-sequence]: https://pyo3.rs/latest/class/protocols.html#mapping--sequence-types
40 changes: 40 additions & 0 deletions guide/src/class.md
Original file line number Diff line number Diff line change
Expand Up @@ -1243,6 +1243,46 @@ Python::with_gil(|py| {
})
```

The constructor of each generated class can be customized using the `#[pyo3(constructor = (...))]` attribute. This uses the same syntax as the [`#[pyo3(signature = (...))]`](function/signature.md)
attribute on function and methods and supports the same options. To apply this attribute simply place it on top of a variant in a `#[pyclass]` complex enum as shown below:

```rust
# use pyo3::prelude::*;
#[pyclass]
enum Shape {
#[pyo3(constructor = (radius=1.0))]
Circle { radius: f64 },
#[pyo3(constructor = (*, width, height))]
Rectangle { width: f64, height: f64 },
#[pyo3(constructor = (side_count, radius=1.0))]
RegularPolygon { side_count: u32, radius: f64 },
Nothing { },
}

# #[cfg(Py_3_10)]
Python::with_gil(|py| {
let cls = py.get_type_bound::<Shape>();
pyo3::py_run!(py, cls, r#"
circle = cls.Circle()
assert isinstance(circle, cls)
assert isinstance(circle, cls.Circle)
assert circle.radius == 1.0

square = cls.Rectangle(width = 1, height = 1)
assert isinstance(square, cls)
assert isinstance(square, cls.Rectangle)
assert square.width == 1
assert square.height == 1

hexagon = cls.RegularPolygon(6)
assert isinstance(hexagon, cls)
assert isinstance(hexagon, cls.RegularPolygon)
assert hexagon.side_count == 6
assert hexagon.radius == 1
"#)
})
```

## Implementation details

The `#[pyclass]` macros rely on a lot of conditional code generation: each `#[pyclass]` can optionally have a `#[pymethods]` block.
Expand Down
1 change: 1 addition & 0 deletions newsfragments/4158.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added `#[pyo3(constructor = (...))]` to customize the generated constructors for complex enum variants
1 change: 1 addition & 0 deletions pyo3-macros-backend/src/attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub mod kw {
syn::custom_keyword!(annotation);
syn::custom_keyword!(attribute);
syn::custom_keyword!(cancel_handle);
syn::custom_keyword!(constructor);
syn::custom_keyword!(dict);
syn::custom_keyword!(extends);
syn::custom_keyword!(freelist);
Expand Down
62 changes: 46 additions & 16 deletions pyo3-macros-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::attributes::{
use crate::deprecations::Deprecations;
use crate::konst::{ConstAttributes, ConstSpec};
use crate::method::{FnArg, FnSpec, PyArg, RegularArg};
use crate::pyfunction::ConstructorAttribute;
use crate::pyimpl::{gen_py_const, PyClassMethodsType};
use crate::pymethod::{
impl_py_getter_def, impl_py_setter_def, MethodAndMethodDef, MethodAndSlotDef, PropertyType,
Expand Down Expand Up @@ -620,19 +621,24 @@ struct PyClassEnumVariantNamedField<'a> {
}

/// `#[pyo3()]` options for pyclass enum variants
#[derive(Default)]
struct EnumVariantPyO3Options {
name: Option<NameAttribute>,
constructor: Option<ConstructorAttribute>,
}

enum EnumVariantPyO3Option {
Name(NameAttribute),
Constructor(ConstructorAttribute),
}

impl Parse for EnumVariantPyO3Option {
fn parse(input: ParseStream<'_>) -> Result<Self> {
let lookahead = input.lookahead1();
if lookahead.peek(attributes::kw::name) {
input.parse().map(EnumVariantPyO3Option::Name)
} else if lookahead.peek(attributes::kw::constructor) {
input.parse().map(EnumVariantPyO3Option::Constructor)
} else {
Err(lookahead.error())
}
Expand All @@ -641,21 +647,33 @@ impl Parse for EnumVariantPyO3Option {

impl EnumVariantPyO3Options {
fn take_pyo3_options(attrs: &mut Vec<syn::Attribute>) -> Result<Self> {
let mut options = EnumVariantPyO3Options { name: None };
let mut options = EnumVariantPyO3Options::default();

for option in take_pyo3_options(attrs)? {
match option {
EnumVariantPyO3Option::Name(name) => {
take_pyo3_options(attrs)?
.into_iter()
.try_for_each(|option| options.set_option(option))?;

Ok(options)
}

fn set_option(&mut self, option: EnumVariantPyO3Option) -> syn::Result<()> {
macro_rules! set_option {
($key:ident) => {
{
ensure_spanned!(
options.name.is_none(),
name.span() => "`name` may only be specified once"
self.$key.is_none(),
$key.span() => concat!("`", stringify!($key), "` may only be specified once")
);
options.name = Some(name);
self.$key = Some($key);
}
}
};
}

Ok(options)
match option {
EnumVariantPyO3Option::Constructor(constructor) => set_option!(constructor),
EnumVariantPyO3Option::Name(name) => set_option!(name),
}
Ok(())
}
}

Expand Down Expand Up @@ -689,6 +707,10 @@ fn impl_simple_enum(
let variants = simple_enum.variants;
let pytypeinfo = impl_pytypeinfo(cls, args, None, ctx);

for variant in &variants {
ensure_spanned!(variant.options.constructor.is_none(), variant.options.constructor.span() => "`constructor` can't be used on a simple enum variant");
}

let (default_repr, default_repr_slot) = {
let variants_repr = variants.iter().map(|variant| {
let variant_name = variant.ident;
Expand Down Expand Up @@ -889,7 +911,7 @@ fn impl_complex_enum(
let mut variant_cls_pytypeinfos = vec![];
let mut variant_cls_pyclass_impls = vec![];
let mut variant_cls_impls = vec![];
for variant in &variants {
for variant in variants {
let variant_cls = gen_complex_enum_variant_class_ident(cls, variant.get_ident());

let variant_cls_zst = quote! {
Expand All @@ -908,11 +930,11 @@ fn impl_complex_enum(
let variant_cls_pytypeinfo = impl_pytypeinfo(&variant_cls, &variant_args, None, ctx);
variant_cls_pytypeinfos.push(variant_cls_pytypeinfo);

let variant_new = complex_enum_variant_new(cls, variant, ctx)?;

let (variant_cls_impl, field_getters) = impl_complex_enum_variant_cls(cls, variant, ctx)?;
let (variant_cls_impl, field_getters) = impl_complex_enum_variant_cls(cls, &variant, ctx)?;
variant_cls_impls.push(variant_cls_impl);

let variant_new = complex_enum_variant_new(cls, variant, ctx)?;

let pyclass_impl = PyClassImplsBuilder::new(
&variant_cls,
&variant_args,
Expand Down Expand Up @@ -1120,7 +1142,7 @@ pub fn gen_complex_enum_variant_attr(

fn complex_enum_variant_new<'a>(
cls: &'a syn::Ident,
variant: &'a PyClassEnumVariant<'a>,
variant: PyClassEnumVariant<'a>,
ctx: &Ctx,
) -> Result<MethodAndSlotDef> {
match variant {
Expand All @@ -1132,7 +1154,7 @@ fn complex_enum_variant_new<'a>(

fn complex_enum_struct_variant_new<'a>(
cls: &'a syn::Ident,
variant: &'a PyClassEnumStructVariant<'a>,
variant: PyClassEnumStructVariant<'a>,
ctx: &Ctx,
) -> Result<MethodAndSlotDef> {
let Ctx { pyo3_path } = ctx;
Expand Down Expand Up @@ -1162,7 +1184,15 @@ fn complex_enum_struct_variant_new<'a>(
}
args
};
let signature = crate::pyfunction::FunctionSignature::from_arguments(args)?;

let signature = if let Some(constructor) = variant.options.constructor {
crate::pyfunction::FunctionSignature::from_arguments_and_attribute(
args,
constructor.into_signature(),
)?
} else {
crate::pyfunction::FunctionSignature::from_arguments(args)?
};
Icxolu marked this conversation as resolved.
Show resolved Hide resolved

let spec = FnSpec {
tp: crate::method::FnType::FnNew,
Expand Down
2 changes: 1 addition & 1 deletion pyo3-macros-backend/src/pyfunction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use syn::{

mod signature;

pub use self::signature::{FunctionSignature, SignatureAttribute};
pub use self::signature::{ConstructorAttribute, FunctionSignature, SignatureAttribute};

#[derive(Clone, Debug)]
pub struct PyFunctionArgPyO3Attributes {
Expand Down
10 changes: 10 additions & 0 deletions pyo3-macros-backend/src/pyfunction/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,16 @@ impl ToTokens for SignatureItemPosargsSep {
}

pub type SignatureAttribute = KeywordAttribute<kw::signature, Signature>;
pub type ConstructorAttribute = KeywordAttribute<kw::constructor, Signature>;

impl ConstructorAttribute {
pub fn into_signature(self) -> SignatureAttribute {
SignatureAttribute {
kw: kw::signature(self.kw.span),
value: self.value,
}
}
}

#[derive(Default)]
pub struct PythonSignature {
Expand Down
27 changes: 23 additions & 4 deletions pytests/src/enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,26 @@ pub fn do_simple_stuff(thing: &SimpleEnum) -> SimpleEnum {

#[pyclass]
pub enum ComplexEnum {
Int { i: i32 },
Float { f: f64 },
Str { s: String },
Int {
i: i32,
},
Float {
f: f64,
},
Str {
s: String,
},
EmptyStruct {},
MultiFieldStruct { a: i32, b: f64, c: bool },
MultiFieldStruct {
a: i32,
b: f64,
c: bool,
},
#[pyo3(constructor = (a = 42, b = None))]
Icxolu marked this conversation as resolved.
Show resolved Hide resolved
VariantWithDefault {
a: i32,
b: Option<String>,
},
}

#[pyfunction]
Expand All @@ -58,5 +73,9 @@ pub fn do_complex_stuff(thing: &ComplexEnum) -> ComplexEnum {
b: *b,
c: *c,
},
ComplexEnum::VariantWithDefault { a, b } => ComplexEnum::VariantWithDefault {
a: 2 * a,
b: b.as_ref().map(|s| s.to_uppercase()),
},
}
}
23 changes: 23 additions & 0 deletions pytests/tests/test_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ def test_complex_enum_variant_constructors():
multi_field_struct_variant = enums.ComplexEnum.MultiFieldStruct(42, 3.14, True)
assert isinstance(multi_field_struct_variant, enums.ComplexEnum.MultiFieldStruct)

variant_with_default_1 = enums.ComplexEnum.VariantWithDefault()
assert isinstance(variant_with_default_1, enums.ComplexEnum.VariantWithDefault)

variant_with_default_2 = enums.ComplexEnum.VariantWithDefault(25, "Hello")
assert isinstance(variant_with_default_2, enums.ComplexEnum.VariantWithDefault)


@pytest.mark.parametrize(
"variant",
Expand All @@ -27,6 +33,7 @@ def test_complex_enum_variant_constructors():
enums.ComplexEnum.Str("hello"),
enums.ComplexEnum.EmptyStruct(),
enums.ComplexEnum.MultiFieldStruct(42, 3.14, True),
enums.ComplexEnum.VariantWithDefault(),
],
)
def test_complex_enum_variant_subclasses(variant: enums.ComplexEnum):
Expand All @@ -48,6 +55,10 @@ def test_complex_enum_field_getters():
assert multi_field_struct_variant.b == 3.14
assert multi_field_struct_variant.c is True

variant_with_default = enums.ComplexEnum.VariantWithDefault()
assert variant_with_default.a == 42
assert variant_with_default.b is None


@pytest.mark.parametrize(
"variant",
Expand All @@ -57,6 +68,7 @@ def test_complex_enum_field_getters():
enums.ComplexEnum.Str("hello"),
enums.ComplexEnum.EmptyStruct(),
enums.ComplexEnum.MultiFieldStruct(42, 3.14, True),
enums.ComplexEnum.VariantWithDefault(),
],
)
def test_complex_enum_desugared_match(variant: enums.ComplexEnum):
Expand All @@ -78,6 +90,11 @@ def test_complex_enum_desugared_match(variant: enums.ComplexEnum):
assert x == 42
assert y == 3.14
assert z is True
elif isinstance(variant, enums.ComplexEnum.VariantWithDefault):
x = variant.a
y = variant.b
assert x == 42
assert y is None
else:
assert False

Expand All @@ -90,6 +107,7 @@ def test_complex_enum_desugared_match(variant: enums.ComplexEnum):
enums.ComplexEnum.Str("hello"),
enums.ComplexEnum.EmptyStruct(),
enums.ComplexEnum.MultiFieldStruct(42, 3.14, True),
enums.ComplexEnum.VariantWithDefault(b="hello"),
],
)
def test_complex_enum_pyfunction_in_out_desugared_match(variant: enums.ComplexEnum):
Expand All @@ -112,5 +130,10 @@ def test_complex_enum_pyfunction_in_out_desugared_match(variant: enums.ComplexEn
assert x == 42
assert y == 3.14
assert z is True
elif isinstance(variant, enums.ComplexEnum.VariantWithDefault):
x = variant.a
y = variant.b
assert x == 84
assert y == "HELLO"
else:
assert False
7 changes: 7 additions & 0 deletions tests/ui/invalid_pyclass_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,11 @@ enum NoTupleVariants {
TupleVariant(i32),
}

#[pyclass]
enum SimpleNoSignature {
#[pyo3(constructor = (a, b))]
A,
B,
}

fn main() {}
6 changes: 6 additions & 0 deletions tests/ui/invalid_pyclass_enum.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,9 @@ error: Tuple variant `TupleVariant` is not yet supported in a complex enum
|
27 | TupleVariant(i32),
| ^^^^^^^^^^^^

error: `constructor` can't be used on a simple enum variant
--> tests/ui/invalid_pyclass_enum.rs:32:12
|
32 | #[pyo3(constructor = (a, b))]
| ^^^^^^^^^^^
Loading