Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enum WIP #1045

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions pyo3-derive-backend/src/common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use proc_macro2::{Span, TokenStream};
use quote::quote;

/// To allow multiple #[pymethods]/#[pyproto] block, we define inventory types.
pub fn impl_methods_inventory(cls: &syn::Ident) -> TokenStream {
// Try to build a unique type for better error messages
let name = format!("Pyo3MethodsInventoryFor{}", cls);
let inventory_cls = syn::Ident::new(&name, Span::call_site());

quote! {
#[doc(hidden)]
pub struct #inventory_cls {
methods: &'static [pyo3::class::PyMethodDefType],
}
impl pyo3::class::methods::PyMethodsInventory for #inventory_cls {
fn new(methods: &'static [pyo3::class::PyMethodDefType]) -> Self {
Self { methods }
}
fn get(&self) -> &'static [pyo3::class::PyMethodDefType] {
self.methods
}
}

impl pyo3::class::methods::HasMethodsInventory for #cls {
type Methods = #inventory_cls;
}

pyo3::inventory::collect!(#inventory_cls);
}
}

pub fn impl_extractext(cls: &syn::Ident) -> TokenStream {
quote! {
impl<'a> pyo3::derive_utils::ExtractExt<'a> for &'a #cls
{
type Target = pyo3::PyRef<'a, #cls>;
}

impl<'a> pyo3::derive_utils::ExtractExt<'a> for &'a mut #cls
{
type Target = pyo3::PyRefMut<'a, #cls>;
}
}
}

/// Implement `HasProtoRegistry` for the class for lazy protocol initialization.
pub fn impl_proto_registry(cls: &syn::Ident) -> TokenStream {
quote! {
impl pyo3::class::proto_methods::HasProtoRegistry for #cls {
fn registry() -> &'static pyo3::class::proto_methods::PyProtoRegistry {
static REGISTRY: pyo3::class::proto_methods::PyProtoRegistry
= pyo3::class::proto_methods::PyProtoRegistry::new();
&REGISTRY
}
}
}
}
3 changes: 3 additions & 0 deletions pyo3-derive-backend/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@

#![recursion_limit = "1024"]

mod common;
mod defs;
mod func;
mod konst;
mod method;
mod module;
mod pyclass;
mod pyenum;
mod pyfunction;
mod pyimpl;
mod pymethod;
Expand All @@ -17,6 +19,7 @@ mod utils;

pub use module::{add_fn_to_module, process_functions_in_module, py_init};
pub use pyclass::{build_py_class, PyClassArgs};
pub use pyenum::build_py_enum;
pub use pyfunction::{build_py_function, PyFunctionAttr};
pub use pyimpl::{build_py_methods, impl_methods};
pub use pyproto::build_py_proto;
Expand Down
55 changes: 5 additions & 50 deletions pyo3-derive-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
// Copyright (c) 2017-present PyO3 Project and Contributors

use crate::common::{impl_extractext, impl_methods_inventory, impl_proto_registry};
use crate::method::{FnType, SelfType};
use crate::pymethod::{
impl_py_getter_def, impl_py_setter_def, impl_wrap_getter, impl_wrap_setter, PropertyType,
};

use crate::utils;
use proc_macro2::{Span, TokenStream};
use quote::quote;
Expand Down Expand Up @@ -206,47 +208,6 @@ fn parse_descriptors(item: &mut syn::Field) -> syn::Result<Vec<FnType>> {
Ok(descs)
}

/// To allow multiple #[pymethods]/#[pyproto] block, we define inventory types.
fn impl_methods_inventory(cls: &syn::Ident) -> TokenStream {
// Try to build a unique type for better error messages
let name = format!("Pyo3MethodsInventoryFor{}", cls);
let inventory_cls = syn::Ident::new(&name, Span::call_site());

quote! {
#[doc(hidden)]
pub struct #inventory_cls {
methods: &'static [pyo3::class::PyMethodDefType],
}
impl pyo3::class::methods::PyMethodsInventory for #inventory_cls {
fn new(methods: &'static [pyo3::class::PyMethodDefType]) -> Self {
Self { methods }
}
fn get(&self) -> &'static [pyo3::class::PyMethodDefType] {
self.methods
}
}

impl pyo3::class::methods::HasMethodsInventory for #cls {
type Methods = #inventory_cls;
}

pyo3::inventory::collect!(#inventory_cls);
}
}

/// Implement `HasProtoRegistry` for the class for lazy protocol initialization.
fn impl_proto_registry(cls: &syn::Ident) -> TokenStream {
quote! {
impl pyo3::class::proto_methods::HasProtoRegistry for #cls {
fn registry() -> &'static pyo3::class::proto_methods::PyProtoRegistry {
static REGISTRY: pyo3::class::proto_methods::PyProtoRegistry
= pyo3::class::proto_methods::PyProtoRegistry::new();
&REGISTRY
}
}
}
}

fn get_class_python_name(cls: &syn::Ident, attr: &PyClassArgs) -> TokenStream {
match &attr.name {
Some(name) => quote! { #name },
Expand Down Expand Up @@ -394,6 +355,8 @@ fn impl_class(
quote! { pyo3::pyclass::ThreadCheckerStub<#cls> }
};

let extractext = impl_extractext(cls);

Ok(quote! {
unsafe impl pyo3::type_object::PyTypeInfo for #cls {
type Type = #cls;
Expand Down Expand Up @@ -422,15 +385,7 @@ fn impl_class(
type BaseNativeType = #base_nativetype;
}

impl<'a> pyo3::derive_utils::ExtractExt<'a> for &'a #cls
{
type Target = pyo3::PyRef<'a, #cls>;
}

impl<'a> pyo3::derive_utils::ExtractExt<'a> for &'a mut #cls
{
type Target = pyo3::PyRefMut<'a, #cls>;
}
#extractext

impl pyo3::pyclass::PyClassSend for #cls {
type ThreadChecker = #thread_checker;
Expand Down
214 changes: 214 additions & 0 deletions pyo3-derive-backend/src/pyenum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
// Copyright (c) 2017-present PyO3 Project and Contributors

use crate::common::{impl_extractext, impl_methods_inventory, impl_proto_registry};
use proc_macro2::TokenStream;
use quote::quote;

pub fn build_py_enum(enum_: &syn::ItemEnum) -> syn::Result<TokenStream> {
let mut variants = Vec::new();

for variant in enum_.variants.iter() {
if !variant.fields.is_empty() {
return Err(syn::Error::new_spanned(
variant,
"#[pyenum] only supports unit enums",
));
}
if let Some((_, syn::Expr::Lit(lit))) = &variant.discriminant {
variants.push((variant.ident.clone(), lit.clone()))
} else {
return Err(syn::Error::new_spanned(
variant,
"#[pyenum] requires explicit discriminant (MyVal = 4)",
));
}
}

impl_enum(&enum_.ident, variants)
}

fn impl_enum(
enum_: &syn::Ident,
variants: Vec<(syn::Ident, syn::ExprLit)>,
) -> syn::Result<TokenStream> {
let enum_cls = impl_class(enum_, None)?;
let variant_names: Vec<syn::Ident> = variants
.iter()
.map(|(ident, _)| variant_enumname(enum_, ident))
.collect::<syn::Result<Vec<_>>>()?;

let variant_cls = variants
.iter()
.map(|(ident, _)| impl_class(ident, Some(enum_)))
.collect::<syn::Result<Vec<_>>>()?;
let variant_consts = variants
.iter()
.map(|(ident, _)| impl_const(enum_, ident))
.collect::<syn::Result<Vec<_>>>()?;

let to_py = impl_to_py(enum_, &variants)?;
let from_py = impl_from_py(enum_, &variants)?;

Ok(quote! {

#enum_cls

#(
struct #variant_names;
)*

#(
#variant_cls
)*
Comment on lines +56 to +62
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this bit is needed - instead you should just be making the variants be instances of the the enum class. (Like how it is with the enum module.)


#to_py
#from_py

#[pymethods]
impl #enum_ {
#(
#[allow(non_upper_case_globals)]
#variant_consts
)*
}

})
}

fn impl_from_py(
enum_: &syn::Ident,
variants: &Vec<(syn::Ident, syn::ExprLit)>,
) -> syn::Result<TokenStream> {
let matches = variants
.iter()
.map(|(ident, _)| {
variant_enumname(enum_, ident).map(|cls| {
let name = cls.to_string();
// TODO: this should be
// #cls as pyo3::type_object::PyTypeObject>::type_object(py).compare()
// but I can't figure out how to get py inside extract()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use pyo3::PyNativeType::py(ob)

quote! {
if typ_name == #name {
return Ok(#enum_::#ident)
}
}
})
})
.collect::<syn::Result<Vec<_>>>()?;

let errormsg = format!("Could not convert to {} enum", enum_);

Ok(quote! {
impl pyo3::conversion::FromPyObject<'_> for #enum_ {
fn extract(ob: &pyo3::PyAny) -> pyo3::PyResult<Self> {
let typ: &pyo3::types::PyType = ob.extract();
let typ_name = typ.name();
#(
#matches
)*

Err(pyo3::exceptions::PyValueError::into(#errormsg))
}
}
})
}

fn impl_to_py(
enum_: &syn::Ident,
variants: &Vec<(syn::Ident, syn::ExprLit)>,
) -> syn::Result<TokenStream> {
let matches = variants
.iter()
.map(|(ident, _)| {
variant_enumname(enum_, ident).map(|cls| {
quote! {
#enum_::#ident => <#cls as pyo3::type_object::PyTypeObject>::type_object(py).to_object(py),
}
})
})
.collect::<syn::Result<Vec<_>>>()?;

Ok(quote! {
impl pyo3::FromPy<#enum_> for pyo3::PyObject {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the rebase on master, you'll need to change this to impl IntoPy<PyObject> for #enum

fn from_py(v: #enum_, py: pyo3::Python) -> Self {
match v {
#(
#matches
)*
}
}
}
})
}

fn variant_enumname(enum_: &syn::Ident, cls: &syn::Ident) -> syn::Result<syn::Ident> {
let name = format!("{}_EnumVariant_{}", enum_, cls);
syn::parse_str(&name)
}

fn impl_const(enum_: &syn::Ident, cls: &syn::Ident) -> syn::Result<TokenStream> {
Ok(quote! {
#[classattr]
const #cls: #enum_ = #enum_::#cls;
})
}

fn impl_class(cls: &syn::Ident, parent: Option<&syn::Ident>) -> syn::Result<TokenStream> {
let (typ, desc) = if let Some(p) = parent {
(
variant_enumname(p, cls)?,
format!("variant {} of enum {}", cls, p),
)
} else {
// Clone is just to align the types
(cls.clone(), format!("enum {}", cls))
};

let clsname = cls.to_string();
let inventory = impl_methods_inventory(&typ);
let extractext = impl_extractext(&typ);
let protoregistry = impl_proto_registry(&typ);

Ok(quote! {
unsafe impl pyo3::type_object::PyTypeInfo for #typ {
type Type = #typ;
type BaseType = pyo3::PyAny;
type Layout = pyo3::PyCell<Self>;
type BaseLayout = pyo3::pycell::PyCellBase<pyo3::PyAny>;

type Initializer = pyo3::pyclass_init::PyClassInitializer<Self>;
type AsRefTarget = pyo3::PyCell<Self>;

const NAME: &'static str = #clsname;
const MODULE: Option<&'static str> = None;
const DESCRIPTION: &'static str = #desc;
const FLAGS: usize = 0;

#[inline]
fn type_object_raw(py: pyo3::Python) -> *mut pyo3::ffi::PyTypeObject {
use pyo3::type_object::LazyStaticType;
static TYPE_OBJECT: LazyStaticType = LazyStaticType::new();
TYPE_OBJECT.get_or_init::<Self>(py)
}
}

impl pyo3::PyClass for #typ {
type Dict = pyo3::pyclass_slots::PyClassDummySlot ;
type WeakRef = pyo3::pyclass_slots::PyClassDummySlot;
type BaseNativeType = pyo3::PyAny;
}

#protoregistry

#extractext

impl pyo3::pyclass::PyClassAlloc for #typ {}

// TODO: handle not in send
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's reasonable to expect #[pyenum] to always be Send? It wouldn't ever carry nontrivial data afaik.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe go further and have pyenum always be 'Copy' as we are doing small data here rather than ADTs?

Copy link
Member

@davidhewitt davidhewitt Aug 11, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps, though I'm not sure it's our place to add #[derive(Copy, Clone)] automatically to the struct. If we don't need it for the #[pyenum] implementation, I'd rather give users the freedom to choose this themselves.

impl pyo3::pyclass::PyClassSend for #typ {
type ThreadChecker = pyo3::pyclass::ThreadCheckerStub<#typ>;
}

#inventory
})
}
Loading