Skip to content

Commit

Permalink
Implement unit variants as class attributes.
Browse files Browse the repository at this point in the history
  • Loading branch information
jovenlin0527 committed Nov 18, 2021
1 parent b9f8a6f commit 0ffc436
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 28 deletions.
59 changes: 36 additions & 23 deletions pyo3-macros-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ use crate::attributes::{
TextSignatureAttribute,
};
use crate::deprecations::Deprecations;
use crate::pyimpl::PyClassMethodsType;
use crate::konst::{ConstAttributes, ConstSpec};
use crate::pyimpl::{gen_py_const, PyClassMethodsType};
use crate::pymethod::{impl_py_getter_def, impl_py_setter_def, PropertyType};
use crate::utils::{self, unwrap_group, PythonDoc};
use proc_macro2::{Span, TokenStream};
Expand Down Expand Up @@ -413,54 +414,66 @@ fn impl_enum(
) -> syn::Result<TokenStream> {
let enum_name = &enum_.ident;
let doc = utils::get_doc(&enum_.attrs, None);
let enum_cls = impl_enum_class(enum_name, &attrs, doc, methods_type)?;

let variant_consts = variants
.iter()
.map(|v| impl_const(enum_name, v.ident))
.collect::<syn::Result<Vec<_>>>()?;
let enum_cls = impl_enum_class(enum_name, &attrs, variants, doc, methods_type)?;

Ok(quote! {

#enum_cls

#[pymethods]
impl #enum_name {
#(
#[allow(non_upper_case_globals)]
#variant_consts
)*
}
})
}

fn impl_const(enum_: &syn::Ident, cls: &syn::Ident) -> syn::Result<TokenStream> {
Ok(quote! {
#[classattr]
const #cls: #enum_ = #enum_::#cls;
})
}

fn impl_enum_class(
cls: &syn::Ident,
attr: &PyClassArgs,
variants: Vec<VariantPyO3>,
doc: PythonDoc,
methods_type: PyClassMethodsType,
) -> syn::Result<TokenStream> {
let pytypeinfo = impl_pytypeinfo(cls, attr, None);
let pyclass_impls = PyClassImplsBuilder::new(cls, attr, methods_type)
.doc(doc)
.impl_all();
let descriptors = unit_variants_as_descriptors(cls, variants.iter().map(|v| v.ident));

Ok(quote! {

#pytypeinfo

#pyclass_impls

#descriptors

})
}

fn unit_variants_as_descriptors<'a>(
cls: &'a syn::Ident,
variant_names: impl IntoIterator<Item = &'a syn::Ident>,
) -> TokenStream {
let cls_type = syn::parse_quote!(#cls);
let variant_to_attribute = |ident: &syn::Ident| ConstSpec {
rust_ident: ident.clone(),
attributes: ConstAttributes {
is_class_attr: true,
name: Some(NameAttribute(ident.clone())),
deprecations: Default::default(),
},
};
let py_methods = variant_names
.into_iter()
.map(|var| gen_py_const(&cls_type, &variant_to_attribute(var)));

quote! {
impl ::pyo3::class::impl_::PyClassDescriptors<#cls>
for ::pyo3::class::impl_::PyClassImplCollector<#cls>
{
fn py_class_descriptors(self) -> &'static [::pyo3::class::methods::PyMethodDefType] {
static METHODS: &[::pyo3::class::methods::PyMethodDefType] = &[#(#py_methods),*];
METHODS
}
}
}
}

fn extract_variant_data(variant: &syn::Variant) -> syn::Result<VariantPyO3> {
use syn::Fields;
let ident = match variant.fields {
Expand Down
2 changes: 1 addition & 1 deletion pyo3-macros-backend/src/pyimpl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ pub fn impl_methods(
})
}

fn gen_py_const(cls: &syn::Type, spec: &ConstSpec) -> TokenStream {
pub fn gen_py_const(cls: &syn::Type, spec: &ConstSpec) -> TokenStream {
let member = &spec.rust_ident;
let deprecations = &spec.attributes.deprecations;
let python_name = &spec.null_terminated_python_name();
Expand Down
9 changes: 5 additions & 4 deletions tests/test_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@ pub enum MyEnum {
}

#[test]
fn test_reflexive() {
fn test_enum_class_attr() {
let gil = Python::acquire_gil();
let py = gil.python();
let mynum = py.get_type::<MyEnum>();
py_assert!(py, mynum, "mynum.Variant == mynum.Variant");
py_assert!(py, mynum, "mynum.OtherVariant == mynum.OtherVariant");
let my_enum = py.get_type::<MyEnum>();
py_assert!(py, my_enum, "getattr(my_enum, 'Variant', None) is not None");
py_assert!(py, my_enum, "getattr(my_enum, 'foobar', None) is None");
py_run!(py, my_enum, "my_enum.Variant = None");
}

#[pyfunction]
Expand Down

0 comments on commit 0ffc436

Please sign in to comment.