Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable comparison of C-like enums by identity #3061

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
22 changes: 21 additions & 1 deletion guide/src/class.md
Original file line number Diff line number Diff line change
Expand Up @@ -940,7 +940,9 @@ Note that `text_signature` on `#[new]` is not compatible with compilation in

## #[pyclass] enums

Currently PyO3 only supports fieldless enums. PyO3 adds a class attribute for each variant, so you can access them in Python without defining `#[new]`. PyO3 also provides default implementations of `__richcmp__` and `__int__`, so they can be compared using `==`:
Currently PyO3 only supports fieldless enums.
PyO3 adds a class attribute for each variant, so you can access them in Python without defining `#[new]`.
PyO3 also provides default implementations of `__richcmp__` and `__int__`, so they can be compared using `==`.

```rust
# use pyo3::prelude::*;
Expand All @@ -961,6 +963,24 @@ Python::with_gil(|py| {
"#)
})
```
You can also compare enums by identity using `is`:
```rust
# use pyo3::prelude::*;
#[pyclass]
enum MyEnum {
Variant,
OtherVariant,
}

Python::with_gil(|py| {
let cls = py.get_type::<MyEnum>();
pyo3::py_run!(py, cls, r#"
assert cls.Variant is cls.Variant
assert cls.OtherVariant == cls.OtherVariant
assert cls.Variant is not cls.OtherVariant
"#)
})
```

You can also convert your enums into `int`:

Expand Down
1 change: 1 addition & 0 deletions newsfragments/3061.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Enable comparison of C-like enums by identity
68 changes: 55 additions & 13 deletions pyo3-macros-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,24 @@ fn get_class_python_name<'a>(cls: &'a syn::Ident, args: &'a PyClassArgs) -> Cow<
.unwrap_or_else(|| Cow::Owned(cls.unraw()))
}

fn impl_into_py_class(cls: &syn::Ident, attr: &PyClassArgs) -> Option<TokenStream> {
let cls = cls;
let attr = attr;

// If #cls is not extended type, we allow Self->PyObject conversion
if attr.options.extends.is_some() {
return None;
}

Some(quote! {
impl _pyo3::IntoPy<_pyo3::PyObject> for #cls {
fn into_py(self, py: _pyo3::Python) -> _pyo3::PyObject {
_pyo3::IntoPy::into_py(_pyo3::Py::new(py, self).unwrap(), py)
}
}
})
}

fn impl_class(
cls: &syn::Ident,
args: &PyClassArgs,
Expand All @@ -357,6 +375,7 @@ fn impl_class(
methods_type,
descriptors_to_items(cls, field_options)?,
vec![],
impl_into_py_class(cls, args),
)
.doc(doc)
.impl_all()?;
Expand Down Expand Up @@ -495,13 +514,41 @@ impl EnumVariantPyO3Options {
}
}

fn impl_into_py_enum(cls: &syn::Ident, variants: &[PyClassEnumVariant<'_>]) -> TokenStream {
let number_of_variants = variants.len();

let variants_to_initial_value_index = variants.iter().enumerate().map(|(index, variant)| {
let variant_name = variant.ident;
quote! { #cls::#variant_name => #index, }
});

let initial_values = variants.iter().map(|variant| {
let variant_name = variant.ident;
quote!(_pyo3::Py::new(py, #cls::#variant_name).unwrap())
});

quote! {
impl _pyo3::IntoPy<_pyo3::PyObject> for #cls {
fn into_py(self, py: _pyo3::Python) -> _pyo3::PyObject {
static SINGLETON_PER_VARIANT: _pyo3::once_cell::GILOnceCell<[_pyo3::Py<#cls>; #number_of_variants]> = _pyo3::once_cell::GILOnceCell::new();
let index = match self {
#(#variants_to_initial_value_index)*
};
let singleton = &SINGLETON_PER_VARIANT.get_or_init(py, || [#(#initial_values),*])[index];
_pyo3::IntoPy::into_py(::std::clone::Clone::clone(singleton), py)
}
}
}
}

fn impl_enum(
enum_: PyClassEnum<'_>,
args: &PyClassArgs,
doc: PythonDoc,
methods_type: PyClassMethodsType,
) -> Result<TokenStream> {
let krate = get_pyo3_crate(&args.options.krate);

let cls = enum_.ident;
let ty: syn::Type = syn::parse_quote!(#cls);
let variants = enum_.variants;
Expand Down Expand Up @@ -598,6 +645,7 @@ fn impl_enum(
methods_type,
enum_default_methods(cls, variants.iter().map(|v| (v.ident, v.python_name()))),
default_slots,
Some(impl_into_py_enum(cls, &variants)),
)
.doc(doc)
.impl_all()?;
Expand Down Expand Up @@ -758,6 +806,7 @@ struct PyClassImplsBuilder<'a> {
methods_type: PyClassMethodsType,
default_methods: Vec<MethodAndMethodDef>,
default_slots: Vec<MethodAndSlotDef>,
into_py: Option<TokenStream>,
doc: Option<PythonDoc>,
}

Expand All @@ -768,13 +817,15 @@ impl<'a> PyClassImplsBuilder<'a> {
methods_type: PyClassMethodsType,
default_methods: Vec<MethodAndMethodDef>,
default_slots: Vec<MethodAndSlotDef>,
into_py: Option<TokenStream>,
) -> Self {
Self {
cls,
attr,
methods_type,
default_methods,
default_slots,
into_py,
doc: None,
}
}
Expand Down Expand Up @@ -854,21 +905,12 @@ impl<'a> PyClassImplsBuilder<'a> {
}

fn impl_into_py(&self) -> TokenStream {
let cls = self.cls;
let attr = self.attr;
// If #cls is not extended type, we allow Self->PyObject conversion
if attr.options.extends.is_none() {
quote! {
impl _pyo3::IntoPy<_pyo3::PyObject> for #cls {
fn into_py(self, py: _pyo3::Python) -> _pyo3::PyObject {
_pyo3::IntoPy::into_py(_pyo3::Py::new(py, self).unwrap(), py)
}
}
}
} else {
quote! {}
match &self.into_py {
None => quote! {},
Some(implementation) => implementation.clone(),
}
}

fn impl_pyclassimpl(&self) -> Result<TokenStream> {
let cls = self.cls;
let doc = self.doc.as_ref().map_or(quote! {"\0"}, |doc| quote! {#doc});
Expand Down
68 changes: 68 additions & 0 deletions tests/test_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,59 @@ fn test_enum_eq_incomparable() {
})
}

#[test]
fn test_enum_compare_by_identity() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I noticed a test above test_enum_class_attr with the following line:

Py::new(py, MyEnum::Variant).unwrap();

This is potentially a problem; that is a new instance of MyEnum type and so it won't compare successfully with identity.

Now, the question is, is that a good thing? Pro could be that it gives users who need it an escape hatch. However, it is probably also a footgun.

Unfortunately, changing how Py::new works would require a refactoring of our initialization machinery. I'm sure there's plenty of scope for improvement in that area (e.g. #2384), however it would likely be a much bigger patch than this PR...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that this possibility is a footgun which we should prevent. Maybe we should put this a bit on hold until the scope for changes to Py::new is defined?

Copy link
Member

@davidhewitt davidhewitt Apr 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By "this on hold" do you mean solving this particular problem or the whole PR?

Copy link
Contributor Author

@ricohageman ricohageman Apr 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant the whole pull request but I don't know what the best approach would be. It seems that comparing enums by identity is a desired feature but there are also two concerns #3061 (comment) and #3061 (comment) that would both require quite a refactor. Also happy to help in that regard but I don't know a lot about the existing problems and where to start. But it seems that it might not be the right time to make this change as I can imagine you don't want to complicate that work further by including some changes now?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we've just had #3287 merged which might be enough of a tweak to make this possible. Given that Py::new relies on Into<PyClassInitializer<T>> we might be able to adjust that trait implementation for enums to somehow use PyClassInitializerImpl::Existing (by fetching the value from a GILOnceCell which is itself initialized by a PyClassInitializerImpl::New invocation to Py::new.

Python::with_gil(|py| {
#[allow(non_snake_case)]
let MyEnum = py.get_type::<MyEnum>();
py_run!(
py,
MyEnum,
r#"
assert MyEnum.Variant is MyEnum.Variant
assert MyEnum.Variant is not MyEnum.OtherVariant
assert (MyEnum.Variant is not MyEnum.Variant) == False
"#
);
})
}

#[pyclass]
struct MyEnumHoldingStruct {
#[pyo3(get)]
my_enum: MyEnum,
}

#[pymethods]
impl MyEnumHoldingStruct {
#[new]
fn new() -> Self {
Self {
my_enum: MyEnum::Variant,
}
}
}

#[test]
fn test_struct_holding_enum_compare_enum_by_identity() {
Python::with_gil(|py| {
#[allow(non_snake_case)]
let MyEnum = py.get_type::<MyEnum>();
#[allow(non_snake_case)]
let MyEnumHoldingStruct = py.get_type::<MyEnumHoldingStruct>();
py_run!(py, MyEnumHoldingStruct MyEnum, r#"
my_enum_holding_struct = MyEnumHoldingStruct()
assert my_enum_holding_struct.my_enum is MyEnum.Variant
assert my_enum_holding_struct.my_enum is not MyEnum.OtherVariant
"#);
})
}

#[pyclass]
enum CustomDiscriminant {
One = 1,
Two = 2,
Four = 4,
}

#[test]
Expand All @@ -93,6 +142,25 @@ fn test_custom_discriminant() {
})
}

#[test]
fn test_custom_discriminant_comparison_by_identity() {
Python::with_gil(|py| {
#[allow(non_snake_case)]
let CustomDiscriminant = py.get_type::<CustomDiscriminant>();
py_run!(
py,
CustomDiscriminant,
r#"
assert CustomDiscriminant.One is CustomDiscriminant.One
assert CustomDiscriminant.Two is CustomDiscriminant.Two
assert CustomDiscriminant.Four is CustomDiscriminant.Four
assert CustomDiscriminant.One is not CustomDiscriminant.Two
assert CustomDiscriminant.Two is not CustomDiscriminant.Four
"#
);
})
}

#[test]
fn test_enum_to_int() {
Python::with_gil(|py| {
Expand Down