Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable comparison of C-like enums by identity #3061

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions newsfragments/3061.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Enable comparison of C-like enums by identity
91 changes: 78 additions & 13 deletions pyo3-macros-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,24 @@ fn get_class_python_name<'a>(cls: &'a syn::Ident, args: &'a PyClassArgs) -> Cow<
.unwrap_or_else(|| Cow::Owned(cls.unraw()))
}

fn impl_into_py_class(cls: &syn::Ident, attr: &PyClassArgs) -> Option<TokenStream> {
let cls = cls;
let attr = attr;

// If #cls is not extended type, we allow Self->PyObject conversion
if attr.options.extends.is_some() {
return None;
}

Some(quote! {
impl _pyo3::IntoPy<_pyo3::PyObject> for #cls {
fn into_py(self, py: _pyo3::Python) -> _pyo3::PyObject {
_pyo3::IntoPy::into_py(_pyo3::Py::new(py, self).unwrap(), py)
}
}
})
}

fn impl_class(
cls: &syn::Ident,
args: &PyClassArgs,
Expand All @@ -359,6 +377,7 @@ fn impl_class(
methods_type,
descriptors_to_items(cls, field_options)?,
vec![],
impl_into_py_class(cls, args),
)
.doc(doc)
.impl_all()?;
Expand Down Expand Up @@ -512,6 +531,57 @@ fn impl_enum(
impl_enum_class(enum_, args, doc, methods_type, krate)
}

fn impl_into_py_enum(cls: &syn::Ident, variants: &[PyClassEnumVariant<'_>]) -> TokenStream {
// Assuming all variants are unit variants because they are the only type we support.
match variants.len() {
0 => quote! {},
1 => {
let variant_name = variants.first().unwrap().ident;
quote! {
impl _pyo3::IntoPy<_pyo3::PyObject> for #cls {
fn into_py(self, py: _pyo3::Python) -> _pyo3::PyObject {
static SINGLETON: _pyo3::once_cell::GILOnceCell<_pyo3::Py<#cls>> = _pyo3::once_cell::GILOnceCell::new();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use of GILOnceCell seems reasonable, however at the same time if we ever want to have a stab at #2274 (which is admittedly a long way off) then it'd be better to not introduce new statics containing Python objects.

I think an alternative implementation could look up variants as attributes from the enum's Python class object, however I think there's potentially a chicken-or-egg problem of how we create those attributes (as I think that currently uses IntoPy).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think given that we're still no closer to solving the statics problem, and it'll require coming up with new patterns which will probably apply to all statics equally, let's just stick with a static GILOnceCell for now and leave migration off that for the future 😆


let singleton = SINGLETON.get_or_init(py, || {
_pyo3::Py::new(py, #cls::#variant_name).unwrap()
});

_pyo3::IntoPy::into_py(::std::clone::Clone::clone(singleton), py)
}
}
}
}
_ => {
let initialize_values: Vec<TokenStream> = variants
.iter()
.map(|variant| {
let variant_name = variant.ident;
quote!((#cls::#variant_name as usize, _pyo3::Py::new(py, #cls::#variant_name).unwrap()))
})
.collect();

quote! {
impl _pyo3::IntoPy<_pyo3::PyObject> for #cls {
fn into_py(self, py: _pyo3::Python) -> _pyo3::PyObject {
static SINGLETON_PER_VARIANT: _pyo3::once_cell::GILOnceCell<::std::collections::HashMap<usize, _pyo3::Py<#cls>>> = _pyo3::once_cell::GILOnceCell::new();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to use a HashMap here instead of a fixed-size array?

Copy link
Member

@adamreichold adamreichold Apr 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think #cls::#variant_name as usize could be sparse, e.g.

enum Foo {
  Bar = 1,
  Baz = 999999,
}

so while we could store those as sorted pairs in a [(usize, Py<#cls>); 2], we would need to do a binary search for the variants instead of a hash table look-up as direct access into a [(usize, Option<Py<#cls>>); 1000000] seems wasteful.

Personally, I see the simplicity of using the hash table but would probably opt for the sorted array in a follow-up. (But let's see whether this lands using the statics at all.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The discriminant can indeed be sparse which is also tested by test_custom_discriminant_comparison_by_identity . But using binary search in a sorted array is definitely a possible replacement for the hashmap.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can know from parsing the enum whether it's dense or sparse?

For the dense case we can just index into an array, no binary search needed, and for the sparse case we can generate a match in front to convert into a dense set of indices first.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and for the sparse case we can generate a match in front to convert into a dense set of indices first.

This would really be nice as the compiler is pretty clever in how to handle match AFAIK, e.g. producing linear or binary searches depending on cardinality estimates.


let singleton = SINGLETON_PER_VARIANT.get_or_init(py, || {
[#(#initialize_values),*]
.iter()
.cloned()
.collect::<::std::collections::HashMap<_, _>>()
})
.get(&(self as usize))
.unwrap();

_pyo3::IntoPy::into_py(::std::clone::Clone::clone(singleton), py)
}
}
}
}
}
}

fn impl_enum_class(
enum_: PyClassEnum<'_>,
args: &PyClassArgs,
Expand Down Expand Up @@ -615,6 +685,7 @@ fn impl_enum_class(
methods_type,
enum_default_methods(cls, variants.iter().map(|v| (v.ident, v.python_name()))),
default_slots,
Some(impl_into_py_enum(cls, &variants)),
)
.doc(doc)
.impl_all()?;
Expand Down Expand Up @@ -775,6 +846,7 @@ struct PyClassImplsBuilder<'a> {
methods_type: PyClassMethodsType,
default_methods: Vec<MethodAndMethodDef>,
default_slots: Vec<MethodAndSlotDef>,
into_py: Option<TokenStream>,
doc: Option<PythonDoc>,
}

Expand All @@ -785,13 +857,15 @@ impl<'a> PyClassImplsBuilder<'a> {
methods_type: PyClassMethodsType,
default_methods: Vec<MethodAndMethodDef>,
default_slots: Vec<MethodAndSlotDef>,
into_py: Option<TokenStream>,
) -> Self {
Self {
cls,
attr,
methods_type,
default_methods,
default_slots,
into_py,
doc: None,
}
}
Expand Down Expand Up @@ -871,21 +945,12 @@ impl<'a> PyClassImplsBuilder<'a> {
}

fn impl_into_py(&self) -> TokenStream {
let cls = self.cls;
let attr = self.attr;
// If #cls is not extended type, we allow Self->PyObject conversion
if attr.options.extends.is_none() {
quote! {
impl _pyo3::IntoPy<_pyo3::PyObject> for #cls {
fn into_py(self, py: _pyo3::Python) -> _pyo3::PyObject {
_pyo3::IntoPy::into_py(_pyo3::Py::new(py, self).unwrap(), py)
}
}
}
} else {
quote! {}
match &self.into_py {
None => quote! {},
Some(implementation) => implementation.clone(),
}
}

fn impl_pyclassimpl(&self) -> Result<TokenStream> {
let cls = self.cls;
let doc = self.doc.as_ref().map_or(quote! {"\0"}, |doc| quote! {#doc});
Expand Down
37 changes: 37 additions & 0 deletions tests/test_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,28 @@ fn test_enum_eq_incomparable() {
})
}

#[test]
fn test_enum_compare_by_identity() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I noticed a test above test_enum_class_attr with the following line:

Py::new(py, MyEnum::Variant).unwrap();

This is potentially a problem; that is a new instance of MyEnum type and so it won't compare successfully with identity.

Now, the question is, is that a good thing? Pro could be that it gives users who need it an escape hatch. However, it is probably also a footgun.

Unfortunately, changing how Py::new works would require a refactoring of our initialization machinery. I'm sure there's plenty of scope for improvement in that area (e.g. #2384), however it would likely be a much bigger patch than this PR...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that this possibility is a footgun which we should prevent. Maybe we should put this a bit on hold until the scope for changes to Py::new is defined?

Copy link
Member

@davidhewitt davidhewitt Apr 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By "this on hold" do you mean solving this particular problem or the whole PR?

Copy link
Contributor Author

@ricohageman ricohageman Apr 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant the whole pull request but I don't know what the best approach would be. It seems that comparing enums by identity is a desired feature but there are also two concerns #3061 (comment) and #3061 (comment) that would both require quite a refactor. Also happy to help in that regard but I don't know a lot about the existing problems and where to start. But it seems that it might not be the right time to make this change as I can imagine you don't want to complicate that work further by including some changes now?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we've just had #3287 merged which might be enough of a tweak to make this possible. Given that Py::new relies on Into<PyClassInitializer<T>> we might be able to adjust that trait implementation for enums to somehow use PyClassInitializerImpl::Existing (by fetching the value from a GILOnceCell which is itself initialized by a PyClassInitializerImpl::New invocation to Py::new.

Python::with_gil(|py| {
#[allow(non_snake_case)]
let MyEnum = py.get_type::<MyEnum>();
py_run!(
py,
MyEnum,
r#"
assert MyEnum.Variant is MyEnum.Variant
assert MyEnum.Variant is not MyEnum.OtherVariant
assert (MyEnum.Variant is not MyEnum.Variant) == False
"#
);
})
}

#[pyclass]
enum CustomDiscriminant {
One = 1,
Two = 2,
Four = 4,
}

#[test]
Expand All @@ -93,6 +111,25 @@ fn test_custom_discriminant() {
})
}

#[test]
fn test_custom_discriminant_comparison_by_identity() {
Python::with_gil(|py| {
#[allow(non_snake_case)]
let CustomDiscriminant = py.get_type::<CustomDiscriminant>();
py_run!(
py,
CustomDiscriminant,
r#"
assert CustomDiscriminant.One is CustomDiscriminant.One
assert CustomDiscriminant.Two is CustomDiscriminant.Two
assert CustomDiscriminant.Four is CustomDiscriminant.Four
assert CustomDiscriminant.One is not CustomDiscriminant.Two
assert CustomDiscriminant.Two is not CustomDiscriminant.Four
"#
);
})
}

#[test]
fn test_enum_to_int() {
Python::with_gil(|py| {
Expand Down