Skip to content

Commit

Permalink
Get rid of the different implementations depending on the number of v…
Browse files Browse the repository at this point in the history
…ariants
  • Loading branch information
ricohageman committed Mar 26, 2023
1 parent 286ee1e commit f2a89b1
Showing 1 changed file with 22 additions and 43 deletions.
65 changes: 22 additions & 43 deletions pyo3-macros-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -533,50 +533,29 @@ fn impl_enum(

fn impl_into_py_enum(cls: &syn::Ident, variants: &[PyClassEnumVariant<'_>]) -> TokenStream {
// Assuming all variants are unit variants because they are the only type we support.
match variants.len() {
0 => quote! {},
1 => {
let variant_name = variants.first().unwrap().ident;
quote! {
impl _pyo3::IntoPy<_pyo3::PyObject> for #cls {
fn into_py(self, py: _pyo3::Python) -> _pyo3::PyObject {
static SINGLETON: _pyo3::once_cell::GILOnceCell<_pyo3::Py<#cls>> = _pyo3::once_cell::GILOnceCell::new();

let singleton = SINGLETON.get_or_init(py, || {
_pyo3::Py::new(py, #cls::#variant_name).unwrap()
});

_pyo3::IntoPy::into_py(::std::clone::Clone::clone(singleton), py)
}
}
}
}
_ => {
let initialize_values: Vec<TokenStream> = variants
.iter()
.map(|variant| {
let variant_name = variant.ident;
quote!((#cls::#variant_name as usize, _pyo3::Py::new(py, #cls::#variant_name).unwrap()))
})
.collect();
let initialize_values: Vec<TokenStream> = variants
.iter()
.map(|variant| {
let variant_name = variant.ident;
quote!((#cls::#variant_name as usize, _pyo3::Py::new(py, #cls::#variant_name).unwrap()))
})
.collect();

quote! {
impl _pyo3::IntoPy<_pyo3::PyObject> for #cls {
fn into_py(self, py: _pyo3::Python) -> _pyo3::PyObject {
static SINGLETON_PER_VARIANT: _pyo3::once_cell::GILOnceCell<::std::collections::HashMap<usize, _pyo3::Py<#cls>>> = _pyo3::once_cell::GILOnceCell::new();

let singleton = SINGLETON_PER_VARIANT.get_or_init(py, || {
[#(#initialize_values),*]
.iter()
.cloned()
.collect::<::std::collections::HashMap<_, _>>()
})
.get(&(self as usize))
.unwrap();

_pyo3::IntoPy::into_py(::std::clone::Clone::clone(singleton), py)
}
}
quote! {
impl _pyo3::IntoPy<_pyo3::PyObject> for #cls {
fn into_py(self, py: _pyo3::Python) -> _pyo3::PyObject {
static SINGLETON_PER_VARIANT: _pyo3::once_cell::GILOnceCell<::std::collections::HashMap<usize, _pyo3::Py<#cls>>> = _pyo3::once_cell::GILOnceCell::new();

let singleton = SINGLETON_PER_VARIANT.get_or_init(py, || {
vec![#(#initialize_values),*]
.iter()
.cloned()
.collect::<::std::collections::HashMap<_, _>>()
})
.get(&(self as usize))
.unwrap();

_pyo3::IntoPy::into_py(::std::clone::Clone::clone(singleton), py)
}
}
}
Expand Down

0 comments on commit f2a89b1

Please sign in to comment.