Skip to content

Commit

Permalink
Index into an array instead of hashmap
Browse files Browse the repository at this point in the history
  • Loading branch information
ricohageman committed May 6, 2023
1 parent c3f628b commit 63ae240
Showing 1 changed file with 23 additions and 42 deletions.
65 changes: 23 additions & 42 deletions pyo3-macros-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -517,51 +517,32 @@ impl EnumVariantPyO3Options {
}

fn impl_into_py_enum(cls: &syn::Ident, variants: &[PyClassEnumVariant<'_>]) -> TokenStream {
// Assuming all variants are unit variants because they are the only type we support.
match variants.len() {
0 => quote! {},
1 => {
let variant_name = variants.first().unwrap().ident;
quote! {
impl _pyo3::IntoPy<_pyo3::PyObject> for #cls {
fn into_py(self, py: _pyo3::Python) -> _pyo3::PyObject {
static SINGLETON: _pyo3::once_cell::GILOnceCell<_pyo3::Py<#cls>> = _pyo3::once_cell::GILOnceCell::new();
let number_of_variants = variants.len();

let singleton = SINGLETON.get_or_init(py, || {
_pyo3::Py::new(py, #cls::#variant_name).unwrap()
});
let variants_to_initial_value_index = variants
.iter()
.enumerate()
.map(|(index, variant)| {
let variant_name = variant.ident;
quote! { #cls::#variant_name => #index, }
});

_pyo3::IntoPy::into_py(::std::clone::Clone::clone(singleton), py)
}
}
}
}
_ => {
let initialize_values: Vec<TokenStream> = variants
.iter()
.map(|variant| {
let variant_name = variant.ident;
quote!((#cls::#variant_name as usize, _pyo3::Py::new(py, #cls::#variant_name).unwrap()))
})
.collect();
let initial_values = variants
.iter()
.map(|variant| {
let variant_name = variant.ident;
quote!(_pyo3::Py::new(py, #cls::#variant_name).unwrap())
});

quote! {
impl _pyo3::IntoPy<_pyo3::PyObject> for #cls {
fn into_py(self, py: _pyo3::Python) -> _pyo3::PyObject {
static SINGLETON_PER_VARIANT: _pyo3::once_cell::GILOnceCell<::std::collections::HashMap<usize, _pyo3::Py<#cls>>> = _pyo3::once_cell::GILOnceCell::new();

let singleton = SINGLETON_PER_VARIANT.get_or_init(py, || {
[#(#initialize_values),*]
.iter()
.cloned()
.collect::<::std::collections::HashMap<_, _>>()
})
.get(&(self as usize))
.unwrap();

_pyo3::IntoPy::into_py(::std::clone::Clone::clone(singleton), py)
}
}
quote! {
impl _pyo3::IntoPy<_pyo3::PyObject> for #cls {
fn into_py(self, py: _pyo3::Python) -> _pyo3::PyObject {
static SINGLETON_PER_VARIANT: _pyo3::once_cell::GILOnceCell<[_pyo3::Py<#cls>; #number_of_variants]> = _pyo3::once_cell::GILOnceCell::new();
let index = match self {
#(#variants_to_initial_value_index)*
};
let singleton = &SINGLETON_PER_VARIANT.get_or_init(py, || [#(#initial_values),*])[index];
_pyo3::IntoPy::into_py(::std::clone::Clone::clone(singleton), py)
}
}
}
Expand Down

0 comments on commit 63ae240

Please sign in to comment.