From cda5dd173cf221d6ed367d201ee8ae6a170515b3 Mon Sep 17 00:00:00 2001 From: Ellie Frost Date: Fri, 17 Jul 2020 16:33:12 -0400 Subject: [PATCH 1/6] enum WIP --- pyo3-derive-backend/src/lib.rs | 2 + pyo3-derive-backend/src/pyclass.rs | 2 +- pyo3-derive-backend/src/pyenum.rs | 90 ++++++++++++++++++++++++++++++ pyo3cls/src/lib.rs | 14 ++++- src/prelude.rs | 2 +- src/types/enum.rs | 3 + tests/test_enum.rs | 7 +++ 7 files changed, 117 insertions(+), 3 deletions(-) create mode 100644 pyo3-derive-backend/src/pyenum.rs create mode 100644 src/types/enum.rs create mode 100644 tests/test_enum.rs diff --git a/pyo3-derive-backend/src/lib.rs b/pyo3-derive-backend/src/lib.rs index cd1b4c3ba54..178a5fcf52d 100644 --- a/pyo3-derive-backend/src/lib.rs +++ b/pyo3-derive-backend/src/lib.rs @@ -9,6 +9,7 @@ mod konst; mod method; mod module; mod pyclass; +mod pyenum; mod pyfunction; mod pyimpl; mod pymethod; @@ -17,6 +18,7 @@ mod utils; pub use module::{add_fn_to_module, process_functions_in_module, py_init}; pub use pyclass::{build_py_class, PyClassArgs}; +pub use pyenum::build_py_enum; pub use pyfunction::{build_py_function, PyFunctionAttr}; pub use pyimpl::{build_py_methods, impl_methods}; pub use pyproto::build_py_proto; diff --git a/pyo3-derive-backend/src/pyclass.rs b/pyo3-derive-backend/src/pyclass.rs index dffb346af05..41b31448675 100644 --- a/pyo3-derive-backend/src/pyclass.rs +++ b/pyo3-derive-backend/src/pyclass.rs @@ -207,7 +207,7 @@ fn parse_descriptors(item: &mut syn::Field) -> syn::Result> { } /// To allow multiple #[pymethods]/#[pyproto] block, we define inventory types. -fn impl_methods_inventory(cls: &syn::Ident) -> TokenStream { +pub fn impl_methods_inventory(cls: &syn::Ident) -> TokenStream { // Try to build a unique type for better error messages let name = format!("Pyo3MethodsInventoryFor{}", cls); let inventory_cls = syn::Ident::new(&name, Span::call_site()); diff --git a/pyo3-derive-backend/src/pyenum.rs b/pyo3-derive-backend/src/pyenum.rs new file mode 100644 index 00000000000..68a5bc5981b --- /dev/null +++ b/pyo3-derive-backend/src/pyenum.rs @@ -0,0 +1,90 @@ +// Copyright (c) 2017-present PyO3 Project and Contributors + +use crate::pyclass::impl_methods_inventory; +use proc_macro2::TokenStream; +use quote::quote; + +pub fn build_py_enum(enum_: &syn::ItemEnum) -> syn::Result { + let mut variants = Vec::new(); + + for variant in enum_.variants.iter() { + if !variant.fields.is_empty() { + return Err(syn::Error::new_spanned( + variant, + "#[pyenum] only supports unit enums", + )); + } + if let Some((_, syn::Expr::Lit(lit))) = &variant.discriminant { + variants.push((variant.ident.clone(), lit.clone())) + } else { + return Err(syn::Error::new_spanned( + variant, + "#[pyenum] requires explicit discriminant (MyVal = 4)", + )); + } + } + + impl_enum(&enum_.ident, variants) +} + +fn impl_enum( + enum_: &syn::Ident, + _variants: Vec<(syn::Ident, syn::ExprLit)>, +) -> syn::Result { + let inventory = impl_methods_inventory(enum_); + + let enum_name = enum_.to_string(); + + Ok(quote! { + unsafe impl pyo3::type_object::PyTypeInfo for #enum_ { + type Type = #enum_; + type BaseType = pyo3::PyAny; + type Layout = pyo3::PyCell; + type BaseLayout = pyo3::pycell::PyCellBase; + + type Initializer = pyo3::pyclass_init::PyClassInitializer; + type AsRefTarget = pyo3::PyCell; + + const NAME: &'static str = #enum_name; + const MODULE: Option<&'static str> = None; + const DESCRIPTION: &'static str = "y'know, an enum\0"; // TODO + const FLAGS: usize = 0; + + #[inline] + fn type_object_raw(py: pyo3::Python) -> *mut pyo3::ffi::PyTypeObject { + use pyo3::type_object::LazyStaticType; + static TYPE_OBJECT: LazyStaticType = LazyStaticType::new(); + TYPE_OBJECT.get_or_init::(py) + } + + } + + impl pyo3::PyClass for #enum_ { + type Dict = pyo3::pyclass_slots::PyClassDummySlot ; + type WeakRef = pyo3::pyclass_slots::PyClassDummySlot; + type BaseNativeType = pyo3::PyAny; + } + + impl<'a> pyo3::derive_utils::ExtractExt<'a> for &'a #enum_ + { + type Target = pyo3::PyRef<'a, #enum_>; + } + + impl<'a> pyo3::derive_utils::ExtractExt<'a> for &'a mut #enum_ + { + type Target = pyo3::PyRefMut<'a, #enum_>; + } + + impl pyo3::class::proto_methods::HasProtoRegistry for #enum_ { + fn registry() -> &'static pyo3::class::proto_methods::PyProtoRegistry { + static REGISTRY: pyo3::class::proto_methods::PyProtoRegistry + = pyo3::class::proto_methods::PyProtoRegistry::new(); + ®ISTRY + } + } + + impl pyo3::pyclass::PyClassAlloc for #enum_ {} + + #inventory + }) +} diff --git a/pyo3cls/src/lib.rs b/pyo3cls/src/lib.rs index 795423cbff3..fca110cfe53 100644 --- a/pyo3cls/src/lib.rs +++ b/pyo3cls/src/lib.rs @@ -5,7 +5,7 @@ extern crate proc_macro; use proc_macro::TokenStream; use pyo3_derive_backend::{ - build_py_class, build_py_function, build_py_methods, build_py_proto, get_doc, + build_py_class, build_py_enum, build_py_function, build_py_methods, build_py_proto, get_doc, process_functions_in_module, py_init, PyClassArgs, PyFunctionAttr, }; use quote::quote; @@ -66,6 +66,18 @@ pub fn pyclass(attr: TokenStream, input: TokenStream) -> TokenStream { .into() } +#[proc_macro_attribute] +pub fn pyenum(_: TokenStream, input: TokenStream) -> TokenStream { + let ast = parse_macro_input!(input as syn::ItemEnum); + let expanded = build_py_enum(&ast).unwrap_or_else(|e| e.to_compile_error()); + + quote!( + #ast + #expanded + ) + .into() +} + #[proc_macro_attribute] pub fn pymethods(_: TokenStream, input: TokenStream) -> TokenStream { let mut ast = parse_macro_input!(input as syn::ItemImpl); diff --git a/src/prelude.rs b/src/prelude.rs index b41b4b9cae2..8eb45e6b5d2 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -21,4 +21,4 @@ pub use crate::{FromPy, FromPyObject, IntoPy, IntoPyPointer, PyTryFrom, PyTryInt // PyModule is only part of the prelude because we need it for the pymodule function pub use crate::types::{PyAny, PyModule}; #[cfg(feature = "macros")] -pub use pyo3cls::{pyclass, pyfunction, pymethods, pymodule, pyproto}; +pub use pyo3cls::{pyclass, pyenum, pyfunction, pymethods, pymodule, pyproto}; diff --git a/src/types/enum.rs b/src/types/enum.rs new file mode 100644 index 00000000000..15f62a86819 --- /dev/null +++ b/src/types/enum.rs @@ -0,0 +1,3 @@ +pub struct Enum; + +unsafe impl pyo3::type_object::PyTypeInfo for Enum {} diff --git a/tests/test_enum.rs b/tests/test_enum.rs new file mode 100644 index 00000000000..625bfd15527 --- /dev/null +++ b/tests/test_enum.rs @@ -0,0 +1,7 @@ +use pyo3::prelude::*; + +#[pyenum] +pub enum MyEnum { + Variant = 1, + OtherVariant = 2, +} From dd6741b82e2aa665fbeee932127ea6e45c1f5366 Mon Sep 17 00:00:00 2001 From: Ellie Frost Date: Sun, 19 Jul 2020 16:47:01 -0400 Subject: [PATCH 2/6] Add PyClassSend implementation --- pyo3-derive-backend/src/pyenum.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pyo3-derive-backend/src/pyenum.rs b/pyo3-derive-backend/src/pyenum.rs index 68a5bc5981b..3219e75a87e 100644 --- a/pyo3-derive-backend/src/pyenum.rs +++ b/pyo3-derive-backend/src/pyenum.rs @@ -85,6 +85,11 @@ fn impl_enum( impl pyo3::pyclass::PyClassAlloc for #enum_ {} + // TODO: handle not in send + impl pyo3::pyclass::PyClassSend for #enum_ { + type ThreadChecker = pyo3::pyclass::ThreadCheckerStub<#enum_>; + } + #inventory }) } From 9d0d578350150361b0ead227e4d82a30ad5894b0 Mon Sep 17 00:00:00 2001 From: Ellie Frost Date: Sat, 25 Jul 2020 22:39:21 -0400 Subject: [PATCH 3/6] Derive classes for enum types as well --- pyo3-derive-backend/src/common.rs | 57 ++++++++++++++++++++++++++++++ pyo3-derive-backend/src/lib.rs | 1 + pyo3-derive-backend/src/pyclass.rs | 56 ++++------------------------- pyo3-derive-backend/src/pyenum.rs | 55 ++++++++++++++-------------- tests/test_enum.rs | 12 +++++++ 5 files changed, 104 insertions(+), 77 deletions(-) create mode 100644 pyo3-derive-backend/src/common.rs diff --git a/pyo3-derive-backend/src/common.rs b/pyo3-derive-backend/src/common.rs new file mode 100644 index 00000000000..1559c563da4 --- /dev/null +++ b/pyo3-derive-backend/src/common.rs @@ -0,0 +1,57 @@ +use proc_macro2::{Span, TokenStream}; +use quote::quote; + +/// To allow multiple #[pymethods]/#[pyproto] block, we define inventory types. +pub fn impl_methods_inventory(cls: &syn::Ident) -> TokenStream { + // Try to build a unique type for better error messages + let name = format!("Pyo3MethodsInventoryFor{}", cls); + let inventory_cls = syn::Ident::new(&name, Span::call_site()); + + quote! { + #[doc(hidden)] + pub struct #inventory_cls { + methods: &'static [pyo3::class::PyMethodDefType], + } + impl pyo3::class::methods::PyMethodsInventory for #inventory_cls { + fn new(methods: &'static [pyo3::class::PyMethodDefType]) -> Self { + Self { methods } + } + fn get(&self) -> &'static [pyo3::class::PyMethodDefType] { + self.methods + } + } + + impl pyo3::class::methods::HasMethodsInventory for #cls { + type Methods = #inventory_cls; + } + + pyo3::inventory::collect!(#inventory_cls); + } +} + +pub fn impl_extractext(cls: &syn::Ident) -> TokenStream { + quote! { + impl<'a> pyo3::derive_utils::ExtractExt<'a> for &'a #cls + { + type Target = pyo3::PyRef<'a, #cls>; + } + + impl<'a> pyo3::derive_utils::ExtractExt<'a> for &'a mut #cls + { + type Target = pyo3::PyRefMut<'a, #cls>; + } + } +} + +/// Implement `HasProtoRegistry` for the class for lazy protocol initialization. +pub fn impl_proto_registry(cls: &syn::Ident) -> TokenStream { + quote! { + impl pyo3::class::proto_methods::HasProtoRegistry for #cls { + fn registry() -> &'static pyo3::class::proto_methods::PyProtoRegistry { + static REGISTRY: pyo3::class::proto_methods::PyProtoRegistry + = pyo3::class::proto_methods::PyProtoRegistry::new(); + ®ISTRY + } + } + } +} diff --git a/pyo3-derive-backend/src/lib.rs b/pyo3-derive-backend/src/lib.rs index 178a5fcf52d..44a98ac6b23 100644 --- a/pyo3-derive-backend/src/lib.rs +++ b/pyo3-derive-backend/src/lib.rs @@ -3,6 +3,7 @@ #![recursion_limit = "1024"] +mod common; mod defs; mod func; mod konst; diff --git a/pyo3-derive-backend/src/pyclass.rs b/pyo3-derive-backend/src/pyclass.rs index 41b31448675..fbaf2e65b54 100644 --- a/pyo3-derive-backend/src/pyclass.rs +++ b/pyo3-derive-backend/src/pyclass.rs @@ -1,9 +1,11 @@ // Copyright (c) 2017-present PyO3 Project and Contributors +use crate::common::{impl_extractext, impl_methods_inventory, impl_proto_registry}; use crate::method::{FnType, SelfType}; use crate::pymethod::{ impl_py_getter_def, impl_py_setter_def, impl_wrap_getter, impl_wrap_setter, PropertyType, }; + use crate::utils; use proc_macro2::{Span, TokenStream}; use quote::quote; @@ -206,47 +208,6 @@ fn parse_descriptors(item: &mut syn::Field) -> syn::Result> { Ok(descs) } -/// To allow multiple #[pymethods]/#[pyproto] block, we define inventory types. -pub fn impl_methods_inventory(cls: &syn::Ident) -> TokenStream { - // Try to build a unique type for better error messages - let name = format!("Pyo3MethodsInventoryFor{}", cls); - let inventory_cls = syn::Ident::new(&name, Span::call_site()); - - quote! { - #[doc(hidden)] - pub struct #inventory_cls { - methods: &'static [pyo3::class::PyMethodDefType], - } - impl pyo3::class::methods::PyMethodsInventory for #inventory_cls { - fn new(methods: &'static [pyo3::class::PyMethodDefType]) -> Self { - Self { methods } - } - fn get(&self) -> &'static [pyo3::class::PyMethodDefType] { - self.methods - } - } - - impl pyo3::class::methods::HasMethodsInventory for #cls { - type Methods = #inventory_cls; - } - - pyo3::inventory::collect!(#inventory_cls); - } -} - -/// Implement `HasProtoRegistry` for the class for lazy protocol initialization. -fn impl_proto_registry(cls: &syn::Ident) -> TokenStream { - quote! { - impl pyo3::class::proto_methods::HasProtoRegistry for #cls { - fn registry() -> &'static pyo3::class::proto_methods::PyProtoRegistry { - static REGISTRY: pyo3::class::proto_methods::PyProtoRegistry - = pyo3::class::proto_methods::PyProtoRegistry::new(); - ®ISTRY - } - } - } -} - fn get_class_python_name(cls: &syn::Ident, attr: &PyClassArgs) -> TokenStream { match &attr.name { Some(name) => quote! { #name }, @@ -290,6 +251,7 @@ fn impl_class( let path = syn::Path::from(syn::PathSegment::from(cls.clone())); let ty = syn::Type::from(syn::TypePath { path, qself: None }); let desc_impls = impl_descriptors(&ty, descriptors)?; + use crate::common::{impl_extractext, impl_methods_inventory, impl_proto_registry}; quote! { #desc_impls #extra @@ -394,6 +356,8 @@ fn impl_class( quote! { pyo3::pyclass::ThreadCheckerStub<#cls> } }; + let extractext = impl_extractext(cls); + Ok(quote! { unsafe impl pyo3::type_object::PyTypeInfo for #cls { type Type = #cls; @@ -422,15 +386,7 @@ fn impl_class( type BaseNativeType = #base_nativetype; } - impl<'a> pyo3::derive_utils::ExtractExt<'a> for &'a #cls - { - type Target = pyo3::PyRef<'a, #cls>; - } - - impl<'a> pyo3::derive_utils::ExtractExt<'a> for &'a mut #cls - { - type Target = pyo3::PyRefMut<'a, #cls>; - } + #extractext impl pyo3::pyclass::PyClassSend for #cls { type ThreadChecker = #thread_checker; diff --git a/pyo3-derive-backend/src/pyenum.rs b/pyo3-derive-backend/src/pyenum.rs index 3219e75a87e..9c0d738c6d9 100644 --- a/pyo3-derive-backend/src/pyenum.rs +++ b/pyo3-derive-backend/src/pyenum.rs @@ -1,6 +1,6 @@ // Copyright (c) 2017-present PyO3 Project and Contributors -use crate::pyclass::impl_methods_inventory; +use crate::common::{impl_extractext, impl_methods_inventory, impl_proto_registry}; use proc_macro2::TokenStream; use quote::quote; @@ -29,15 +29,30 @@ pub fn build_py_enum(enum_: &syn::ItemEnum) -> syn::Result { fn impl_enum( enum_: &syn::Ident, - _variants: Vec<(syn::Ident, syn::ExprLit)>, + variants: Vec<(syn::Ident, syn::ExprLit)>, ) -> syn::Result { - let inventory = impl_methods_inventory(enum_); + let enum_cls = impl_class(enum_)?; + let variant_cls = variants + .iter() + .map(|(ident, _)| impl_class(ident)) + .collect::>>()?; - let enum_name = enum_.to_string(); + Ok(quote! { + #enum_cls + #(#variant_cls)* + }) +} + +fn impl_class(cls: &syn::Ident) -> syn::Result { + let inventory = impl_methods_inventory(cls); + let extractext = impl_extractext(cls); + let protoregistry = impl_proto_registry(cls); + + let clsname = cls.to_string(); Ok(quote! { - unsafe impl pyo3::type_object::PyTypeInfo for #enum_ { - type Type = #enum_; + unsafe impl pyo3::type_object::PyTypeInfo for #cls { + type Type = #cls; type BaseType = pyo3::PyAny; type Layout = pyo3::PyCell; type BaseLayout = pyo3::pycell::PyCellBase; @@ -45,7 +60,7 @@ fn impl_enum( type Initializer = pyo3::pyclass_init::PyClassInitializer; type AsRefTarget = pyo3::PyCell; - const NAME: &'static str = #enum_name; + const NAME: &'static str = #clsname; const MODULE: Option<&'static str> = None; const DESCRIPTION: &'static str = "y'know, an enum\0"; // TODO const FLAGS: usize = 0; @@ -59,35 +74,21 @@ fn impl_enum( } - impl pyo3::PyClass for #enum_ { + impl pyo3::PyClass for #cls { type Dict = pyo3::pyclass_slots::PyClassDummySlot ; type WeakRef = pyo3::pyclass_slots::PyClassDummySlot; type BaseNativeType = pyo3::PyAny; } - impl<'a> pyo3::derive_utils::ExtractExt<'a> for &'a #enum_ - { - type Target = pyo3::PyRef<'a, #enum_>; - } - - impl<'a> pyo3::derive_utils::ExtractExt<'a> for &'a mut #enum_ - { - type Target = pyo3::PyRefMut<'a, #enum_>; - } + #protoregistry - impl pyo3::class::proto_methods::HasProtoRegistry for #enum_ { - fn registry() -> &'static pyo3::class::proto_methods::PyProtoRegistry { - static REGISTRY: pyo3::class::proto_methods::PyProtoRegistry - = pyo3::class::proto_methods::PyProtoRegistry::new(); - ®ISTRY - } - } + #extractext - impl pyo3::pyclass::PyClassAlloc for #enum_ {} + impl pyo3::pyclass::PyClassAlloc for #cls {} // TODO: handle not in send - impl pyo3::pyclass::PyClassSend for #enum_ { - type ThreadChecker = pyo3::pyclass::ThreadCheckerStub<#enum_>; + impl pyo3::pyclass::PyClassSend for #cls { + type ThreadChecker = pyo3::pyclass::ThreadCheckerStub<#cls>; } #inventory diff --git a/tests/test_enum.rs b/tests/test_enum.rs index 625bfd15527..cf75285829b 100644 --- a/tests/test_enum.rs +++ b/tests/test_enum.rs @@ -1,7 +1,19 @@ use pyo3::prelude::*; +mod common; + #[pyenum] pub enum MyEnum { Variant = 1, OtherVariant = 2, } + +#[test] + +fn test_reflexive() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let mynum = py.get_type::(); + py_assert!(py, mynum, "mynum.Variant == mynum.Variant"); + py_assert!(py, mynum, "mynum.OtherVariant == mynum.OtherVariant"); +} From 3cc9d7fa80ccc26d371efda99e957a92924f1a52 Mon Sep 17 00:00:00 2001 From: Ellie Frost Date: Sun, 26 Jul 2020 16:23:29 -0400 Subject: [PATCH 4/6] Add FromPy implementation for enums --- pyo3-derive-backend/src/pyclass.rs | 1 - pyo3-derive-backend/src/pyenum.rs | 75 ++++++++++++++++++++++++++++-- tests/test_enum.rs | 1 - 3 files changed, 72 insertions(+), 5 deletions(-) diff --git a/pyo3-derive-backend/src/pyclass.rs b/pyo3-derive-backend/src/pyclass.rs index fbaf2e65b54..01087ef75c5 100644 --- a/pyo3-derive-backend/src/pyclass.rs +++ b/pyo3-derive-backend/src/pyclass.rs @@ -251,7 +251,6 @@ fn impl_class( let path = syn::Path::from(syn::PathSegment::from(cls.clone())); let ty = syn::Type::from(syn::TypePath { path, qself: None }); let desc_impls = impl_descriptors(&ty, descriptors)?; - use crate::common::{impl_extractext, impl_methods_inventory, impl_proto_registry}; quote! { #desc_impls #extra diff --git a/pyo3-derive-backend/src/pyenum.rs b/pyo3-derive-backend/src/pyenum.rs index 9c0d738c6d9..3c74c48440b 100644 --- a/pyo3-derive-backend/src/pyenum.rs +++ b/pyo3-derive-backend/src/pyenum.rs @@ -32,14 +32,83 @@ fn impl_enum( variants: Vec<(syn::Ident, syn::ExprLit)>, ) -> syn::Result { let enum_cls = impl_class(enum_)?; - let variant_cls = variants + let variant_names: Vec = variants .iter() - .map(|(ident, _)| impl_class(ident)) + .map(|(ident, _)| variant_enumname(enum_, ident)) .collect::>>()?; + let variant_cls = variant_names + .iter() + .map(impl_class) + .collect::>>()?; + let variant_consts = variants + .iter() + .map(|(ident, _)| impl_const(enum_, ident)) + .collect::>>()?; + + let to_py = impl_to_py(enum_, variants)?; + Ok(quote! { + #enum_cls - #(#variant_cls)* + + #( + struct #variant_names; + )* + + #( + #variant_cls + )* + + #to_py + + #[pymethods] + impl #enum_ { + #( + #variant_consts + )* + } + + }) +} + +fn impl_to_py( + enum_: &syn::Ident, + variants: Vec<(syn::Ident, syn::ExprLit)>, +) -> syn::Result { + let matches = variants + .iter() + .map(|(ident, _)| { + variant_enumname(enum_, ident).map(|cls| { + quote! { + #enum_::#ident => <#cls as pyo3::type_object::PyTypeObject>::type_object(py).to_object(py), + } + }) + }) + .collect::>>()?; + + Ok(quote! { + impl pyo3::FromPy<#enum_> for pyo3::PyObject { + fn from_py(v: #enum_, py: Python) -> Self { + match v { + #( + #matches + )* + } + } + } + }) +} + +fn variant_enumname(enum_: &syn::Ident, cls: &syn::Ident) -> syn::Result { + let name = format!("{}_EnumVariant_{}", enum_, cls); + syn::parse_str(&name) +} + +fn impl_const(enum_: &syn::Ident, cls: &syn::Ident) -> syn::Result { + Ok(quote! { + #[classattr] + const #cls: #enum_ = #enum_::#cls; }) } diff --git a/tests/test_enum.rs b/tests/test_enum.rs index cf75285829b..7c6aa0ff18f 100644 --- a/tests/test_enum.rs +++ b/tests/test_enum.rs @@ -9,7 +9,6 @@ pub enum MyEnum { } #[test] - fn test_reflexive() { let gil = Python::acquire_gil(); let py = gil.python(); From ddb7ca580e21ea4ab7057a62423f5310a08e51e6 Mon Sep 17 00:00:00 2001 From: Ellie Frost Date: Sun, 26 Jul 2020 16:36:34 -0400 Subject: [PATCH 5/6] Test returning an enum --- pyo3-derive-backend/src/pyenum.rs | 1 + tests/test_enum.rs | 16 ++++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/pyo3-derive-backend/src/pyenum.rs b/pyo3-derive-backend/src/pyenum.rs index 3c74c48440b..de7283e86ed 100644 --- a/pyo3-derive-backend/src/pyenum.rs +++ b/pyo3-derive-backend/src/pyenum.rs @@ -65,6 +65,7 @@ fn impl_enum( #[pymethods] impl #enum_ { #( + #[allow(non_upper_case_globals)] #variant_consts )* } diff --git a/tests/test_enum.rs b/tests/test_enum.rs index 7c6aa0ff18f..03f97554a0d 100644 --- a/tests/test_enum.rs +++ b/tests/test_enum.rs @@ -1,4 +1,5 @@ use pyo3::prelude::*; +use pyo3::{py_run, wrap_pyfunction}; mod common; @@ -16,3 +17,18 @@ fn test_reflexive() { py_assert!(py, mynum, "mynum.Variant == mynum.Variant"); py_assert!(py, mynum, "mynum.OtherVariant == mynum.OtherVariant"); } + +#[pyfunction] +fn return_enum() -> MyEnum { + MyEnum::Variant +} + +#[test] +fn test_return_enum() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let f = wrap_pyfunction!(return_enum)(py); + let mynum = py.get_type::(); + + py_run!(py, f mynum, "assert f() == mynum.Variant") +} From 170021ef7ba285d26ae0d54e79fd119e8a6979fe Mon Sep 17 00:00:00 2001 From: Ellie Frost Date: Sun, 26 Jul 2020 18:55:50 -0400 Subject: [PATCH 6/6] (WIP) Passing an enum as an argument --- pyo3-derive-backend/src/pyenum.rs | 84 ++++++++++++++++++++++++------- tests/test_enum.rs | 16 ++++++ 2 files changed, 82 insertions(+), 18 deletions(-) diff --git a/pyo3-derive-backend/src/pyenum.rs b/pyo3-derive-backend/src/pyenum.rs index de7283e86ed..6a0ed062153 100644 --- a/pyo3-derive-backend/src/pyenum.rs +++ b/pyo3-derive-backend/src/pyenum.rs @@ -31,22 +31,23 @@ fn impl_enum( enum_: &syn::Ident, variants: Vec<(syn::Ident, syn::ExprLit)>, ) -> syn::Result { - let enum_cls = impl_class(enum_)?; + let enum_cls = impl_class(enum_, None)?; let variant_names: Vec = variants .iter() .map(|(ident, _)| variant_enumname(enum_, ident)) .collect::>>()?; - let variant_cls = variant_names + let variant_cls = variants .iter() - .map(impl_class) + .map(|(ident, _)| impl_class(ident, Some(enum_))) .collect::>>()?; let variant_consts = variants .iter() .map(|(ident, _)| impl_const(enum_, ident)) .collect::>>()?; - let to_py = impl_to_py(enum_, variants)?; + let to_py = impl_to_py(enum_, &variants)?; + let from_py = impl_from_py(enum_, &variants)?; Ok(quote! { @@ -61,6 +62,7 @@ fn impl_enum( )* #to_py + #from_py #[pymethods] impl #enum_ { @@ -73,9 +75,47 @@ fn impl_enum( }) } +fn impl_from_py( + enum_: &syn::Ident, + variants: &Vec<(syn::Ident, syn::ExprLit)>, +) -> syn::Result { + let matches = variants + .iter() + .map(|(ident, _)| { + variant_enumname(enum_, ident).map(|cls| { + let name = cls.to_string(); + // TODO: this should be + // #cls as pyo3::type_object::PyTypeObject>::type_object(py).compare() + // but I can't figure out how to get py inside extract() + quote! { + if typ_name == #name { + return Ok(#enum_::#ident) + } + } + }) + }) + .collect::>>()?; + + let errormsg = format!("Could not convert to {} enum", enum_); + + Ok(quote! { + impl pyo3::conversion::FromPyObject<'_> for #enum_ { + fn extract(ob: &pyo3::PyAny) -> pyo3::PyResult { + let typ: &pyo3::types::PyType = ob.extract(); + let typ_name = typ.name(); + #( + #matches + )* + + Err(pyo3::exceptions::PyValueError::into(#errormsg)) + } + } + }) +} + fn impl_to_py( enum_: &syn::Ident, - variants: Vec<(syn::Ident, syn::ExprLit)>, + variants: &Vec<(syn::Ident, syn::ExprLit)>, ) -> syn::Result { let matches = variants .iter() @@ -90,7 +130,7 @@ fn impl_to_py( Ok(quote! { impl pyo3::FromPy<#enum_> for pyo3::PyObject { - fn from_py(v: #enum_, py: Python) -> Self { + fn from_py(v: #enum_, py: pyo3::Python) -> Self { match v { #( #matches @@ -113,16 +153,25 @@ fn impl_const(enum_: &syn::Ident, cls: &syn::Ident) -> syn::Result }) } -fn impl_class(cls: &syn::Ident) -> syn::Result { - let inventory = impl_methods_inventory(cls); - let extractext = impl_extractext(cls); - let protoregistry = impl_proto_registry(cls); +fn impl_class(cls: &syn::Ident, parent: Option<&syn::Ident>) -> syn::Result { + let (typ, desc) = if let Some(p) = parent { + ( + variant_enumname(p, cls)?, + format!("variant {} of enum {}", cls, p), + ) + } else { + // Clone is just to align the types + (cls.clone(), format!("enum {}", cls)) + }; let clsname = cls.to_string(); + let inventory = impl_methods_inventory(&typ); + let extractext = impl_extractext(&typ); + let protoregistry = impl_proto_registry(&typ); Ok(quote! { - unsafe impl pyo3::type_object::PyTypeInfo for #cls { - type Type = #cls; + unsafe impl pyo3::type_object::PyTypeInfo for #typ { + type Type = #typ; type BaseType = pyo3::PyAny; type Layout = pyo3::PyCell; type BaseLayout = pyo3::pycell::PyCellBase; @@ -132,7 +181,7 @@ fn impl_class(cls: &syn::Ident) -> syn::Result { const NAME: &'static str = #clsname; const MODULE: Option<&'static str> = None; - const DESCRIPTION: &'static str = "y'know, an enum\0"; // TODO + const DESCRIPTION: &'static str = #desc; const FLAGS: usize = 0; #[inline] @@ -141,10 +190,9 @@ fn impl_class(cls: &syn::Ident) -> syn::Result { static TYPE_OBJECT: LazyStaticType = LazyStaticType::new(); TYPE_OBJECT.get_or_init::(py) } - } - impl pyo3::PyClass for #cls { + impl pyo3::PyClass for #typ { type Dict = pyo3::pyclass_slots::PyClassDummySlot ; type WeakRef = pyo3::pyclass_slots::PyClassDummySlot; type BaseNativeType = pyo3::PyAny; @@ -154,11 +202,11 @@ fn impl_class(cls: &syn::Ident) -> syn::Result { #extractext - impl pyo3::pyclass::PyClassAlloc for #cls {} + impl pyo3::pyclass::PyClassAlloc for #typ {} // TODO: handle not in send - impl pyo3::pyclass::PyClassSend for #cls { - type ThreadChecker = pyo3::pyclass::ThreadCheckerStub<#cls>; + impl pyo3::pyclass::PyClassSend for #typ { + type ThreadChecker = pyo3::pyclass::ThreadCheckerStub<#typ>; } #inventory diff --git a/tests/test_enum.rs b/tests/test_enum.rs index 03f97554a0d..8500d62ee01 100644 --- a/tests/test_enum.rs +++ b/tests/test_enum.rs @@ -4,6 +4,7 @@ use pyo3::{py_run, wrap_pyfunction}; mod common; #[pyenum] +#[derive(Debug, PartialEq, Clone)] pub enum MyEnum { Variant = 1, OtherVariant = 2, @@ -32,3 +33,18 @@ fn test_return_enum() { py_run!(py, f mynum, "assert f() == mynum.Variant") } + +#[pyfunction] +fn enum_arg(e: MyEnum) { + assert_eq!(MyEnum::OtherVariant, e) +} + +#[test] +fn test_enum_arg() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let f = wrap_pyfunction!(enum_arg)(py); + let mynum = py.get_type::(); + + py_run!(py, f mynum, "f(mynum.Variant)") +}