Skip to content

Commit

Permalink
Implement default slot methods and default __repr__ for enums.
Browse files Browse the repository at this point in the history
  • Loading branch information
jovenlin0527 committed Nov 22, 2021
1 parent 245617a commit 44d05dd
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 92 deletions.
22 changes: 21 additions & 1 deletion pyo3-macros-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use crate::attributes::{self, take_pyo3_options, NameAttribute, TextSignatureAttribute};
use crate::deprecations::Deprecations;
use crate::konst::{ConstAttributes, ConstSpec};
use crate::pyimpl::{gen_py_const, PyClassMethodsType};
use crate::pyimpl::{gen_default_slot_impls, gen_py_const, PyClassMethodsType};
use crate::pymethod::{impl_py_getter_def, impl_py_setter_def, PropertyType};
use crate::utils::{self, unwrap_group, PythonDoc};
use proc_macro2::{Span, TokenStream};
Expand Down Expand Up @@ -422,6 +422,24 @@ fn impl_enum_class(
.impl_all();
let descriptors = unit_variants_as_descriptors(cls, variants.iter().map(|v| v.ident));

let variants_repr = variants.iter().map(|variant| {
let variant_name = variant.ident;
// Assuming all variants are unit variants because they are the only type we support.
let repr = format!("{}.{}", cls, variant_name);
quote! { #cls::#variant_name => #repr, }
});

let default_repr_impl = quote! {
#[allow(non_snake_case)]
#[pyo3(name = "__repr__")]
fn __pyo3__repr__(&self) -> &'static str {
match self {
#(#variants_repr)*
_ => unreachable!("Unsupported variant type."),
}
}
};
let default_impls = gen_default_slot_impls(cls, vec![default_repr_impl]);
Ok(quote! {

#pytypeinfo
Expand All @@ -430,6 +448,8 @@ fn impl_enum_class(

#descriptors

#default_impls

})
}

Expand Down
41 changes: 41 additions & 0 deletions pyo3-macros-backend/src/pyimpl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,47 @@ pub fn gen_py_const(cls: &syn::Type, spec: &ConstSpec) -> TokenStream {
}
}

pub fn gen_default_slot_impls(cls: &syn::Ident, method_defs: Vec<TokenStream>) -> TokenStream {
// This function uses a lot of `unwrap()`; since method_defs are provided by us, they should
// all succeed.
let ty: syn::Type = syn::parse_quote!(#cls);

let mut method_defs: Vec<_> = method_defs
.into_iter()
.map(|token| syn::parse2::<syn::ImplItemMethod>(token).unwrap())
.collect();

let mut proto_impls = Vec::new();

for meth in &mut method_defs {
let options = PyFunctionOptions::from_attrs(&mut meth.attrs).unwrap();
match pymethod::gen_py_method(&ty, &mut meth.sig, &mut meth.attrs, options).unwrap() {
GeneratedPyMethod::Proto(token_stream) => {
let attrs = get_cfg_attributes(&meth.attrs);
proto_impls.push(quote!(#(#attrs)* #token_stream))
}
GeneratedPyMethod::SlotTraitImpl(..) => {
todo!()
}
GeneratedPyMethod::Method(_) | GeneratedPyMethod::TraitImpl(_) => {
panic!("Only protocol methods can have default implementation!")
}
}
}

quote! {
impl #cls {
#(#method_defs)*
}
impl ::pyo3::class::impl_::PyClassDefaultSlots<#cls>
for ::pyo3::class::impl_::PyClassImplCollector<#cls> {
fn py_class_default_slots(self) -> &'static [::pyo3::ffi::PyType_Slot] {
&[#(#proto_impls),*]
}
}
}
}

fn impl_py_methods(ty: &syn::Type, methods: Vec<TokenStream>) -> TokenStream {
quote! {
impl ::pyo3::class::impl_::PyMethods<#ty>
Expand Down
105 changes: 14 additions & 91 deletions tests/test_default_impls.rs
Original file line number Diff line number Diff line change
@@ -1,118 +1,41 @@
#![allow(non_snake_case)]
use pyo3::class::PyMethodDefType;
use pyo3::prelude::*;
use pyo3::py_run;

mod common;

// Tests for PyClassDefaultSlots
// Test default generated __repr__.
#[pyclass]
struct TestDefaultSlots;

// generated using `Cargo expand`
// equivalent to
// ```
// impl TestDefaultSlots {{
// fn __str__(&self) -> &'static str {
// "default"
// }
// }
// ```
impl TestDefaultSlots {
fn __pyo3__str__(&self) -> &'static str {
"default"
}
}

impl ::pyo3::class::impl_::PyClassDefaultSlots<TestDefaultSlots>
for ::pyo3::class::impl_::PyClassImplCollector<TestDefaultSlots>
{
fn py_class_default_slots(self) -> &'static [::pyo3::ffi::PyType_Slot] {
&[{
unsafe extern "C" fn __wrap(
_raw_slf: *mut ::pyo3::ffi::PyObject,
) -> *mut ::pyo3::ffi::PyObject {
let _slf = _raw_slf;
::pyo3::callback::handle_panic(|_py| {
let _cell = _py
.from_borrowed_ptr::<::pyo3::PyAny>(_slf)
.downcast::<::pyo3::PyCell<TestDefaultSlots>>()?;
let _ref = _cell.try_borrow()?;
let _slf = &_ref;
::pyo3::callback::convert(_py, TestDefaultSlots::__pyo3__str__(_slf))
})
}
::pyo3::ffi::PyType_Slot {
slot: ::pyo3::ffi::Py_tp_str,
pfunc: __wrap as ::pyo3::ffi::reprfunc as _,
}
}]
}
enum TestDefaultRepr {
Var,
}

#[test]
fn test_default_slot_exists() {
Python::with_gil(|py| {
let test_object = Py::new(py, TestDefaultSlots).unwrap();
py_assert!(py, test_object, "str(test_object) == 'default'");
let test_object = Py::new(py, TestDefaultRepr::Var).unwrap();
py_assert!(
py,
test_object,
"repr(test_object) == 'TestDefaultRepr.Var'"
);
})
}

#[pyclass]
struct OverrideSlot;

// generated using `Cargo expand`
// equivalent to
// ```
// impl OverrideMagicMethod {
// fn __str__(&self) -> &'static str {
// "default"
// }
// }
// ```
impl OverrideSlot {
fn __pyo3__str__(&self) -> &'static str {
"default"
}
}

impl ::pyo3::class::impl_::PyClassDefaultSlots<OverrideSlot>
for ::pyo3::class::impl_::PyClassImplCollector<OverrideSlot>
{
fn py_class_default_slots(self) -> &'static [::pyo3::ffi::PyType_Slot] {
&[{
unsafe extern "C" fn __wrap(
_raw_slf: *mut ::pyo3::ffi::PyObject,
) -> *mut ::pyo3::ffi::PyObject {
let _slf = _raw_slf;
::pyo3::callback::handle_panic(|_py| {
let _cell = _py
.from_borrowed_ptr::<::pyo3::PyAny>(_slf)
.downcast::<::pyo3::PyCell<OverrideSlot>>()?;
let _ref = _cell.try_borrow()?;
let _slf = &_ref;
::pyo3::callback::convert(_py, OverrideSlot::__pyo3__str__(_slf))
})
}
::pyo3::ffi::PyType_Slot {
slot: ::pyo3::ffi::Py_tp_str,
pfunc: __wrap as ::pyo3::ffi::reprfunc as _,
}
}]
}
enum OverrideSlot {
Var,
}

#[pymethods]
impl OverrideSlot {
fn __str__(&self) -> &str {
fn __repr__(&self) -> &str {
"overriden"
}
}

#[test]
fn test_override_slot() {
Python::with_gil(|py| {
let test_object = Py::new(py, OverrideSlot).unwrap();
py_assert!(py, test_object, "str(test_object) == 'overriden'");
let test_object = Py::new(py, OverrideSlot::Var).unwrap();
py_assert!(py, test_object, "repr(test_object) == 'overriden'");
})
}

0 comments on commit 44d05dd

Please sign in to comment.