Skip to content

Commit

Permalink
Merge pull request #2014 from b05902132/default_impl
Browse files Browse the repository at this point in the history
Support default method implementation
  • Loading branch information
davidhewitt authored Nov 29, 2021
2 parents 7455518 + aac0e56 commit 8a03778
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 1 deletion.
28 changes: 27 additions & 1 deletion pyo3-macros-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use crate::attributes::{self, take_pyo3_options, NameAttribute, TextSignatureAttribute};
use crate::deprecations::Deprecations;
use crate::konst::{ConstAttributes, ConstSpec};
use crate::pyimpl::{gen_py_const, PyClassMethodsType};
use crate::pyimpl::{gen_default_slot_impls, gen_py_const, PyClassMethodsType};
use crate::pymethod::{impl_py_getter_def, impl_py_setter_def, PropertyType};
use crate::utils::{self, unwrap_group, PythonDoc};
use proc_macro2::{Span, TokenStream};
Expand Down Expand Up @@ -425,6 +425,27 @@ fn impl_enum_class(
.impl_all();
let descriptors = unit_variants_as_descriptors(cls, variants.iter().map(|v| v.ident));

let default_repr_impl = {
let variants_repr = variants.iter().map(|variant| {
let variant_name = variant.ident;
// Assuming all variants are unit variants because they are the only type we support.
let repr = format!("{}.{}", cls, variant_name);
quote! { #cls::#variant_name => #repr, }
});
quote! {
#[doc(hidden)]
#[allow(non_snake_case)]
#[pyo3(name = "__repr__")]
fn __pyo3__repr__(&self) -> &'static str {
match self {
#(#variants_repr)*
_ => unreachable!("Unsupported variant type."),
}
}
}
};

let default_impls = gen_default_slot_impls(cls, vec![default_repr_impl]);
Ok(quote! {

#pytypeinfo
Expand All @@ -433,6 +454,8 @@ fn impl_enum_class(

#descriptors

#default_impls

})
}

Expand Down Expand Up @@ -758,6 +781,9 @@ impl<'a> PyClassImplsBuilder<'a> {
// Implementation which uses dtolnay specialization to load all slots.
use ::pyo3::class::impl_::*;
let collector = PyClassImplCollector::<Self>::new();
// This depends on Python implementation detail;
// an old slot entry will be overriden by newer ones.
visitor(collector.py_class_default_slots());
visitor(collector.object_protocol_slots());
visitor(collector.number_protocol_slots());
visitor(collector.iter_protocol_slots());
Expand Down
41 changes: 41 additions & 0 deletions pyo3-macros-backend/src/pyimpl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,47 @@ pub fn gen_py_const(cls: &syn::Type, spec: &ConstSpec) -> TokenStream {
}
}

pub fn gen_default_slot_impls(cls: &syn::Ident, method_defs: Vec<TokenStream>) -> TokenStream {
// This function uses a lot of `unwrap()`; since method_defs are provided by us, they should
// all succeed.
let ty: syn::Type = syn::parse_quote!(#cls);

let mut method_defs: Vec<_> = method_defs
.into_iter()
.map(|token| syn::parse2::<syn::ImplItemMethod>(token).unwrap())
.collect();

let mut proto_impls = Vec::new();

for meth in &mut method_defs {
let options = PyFunctionOptions::from_attrs(&mut meth.attrs).unwrap();
match pymethod::gen_py_method(&ty, &mut meth.sig, &mut meth.attrs, options).unwrap() {
GeneratedPyMethod::Proto(token_stream) => {
let attrs = get_cfg_attributes(&meth.attrs);
proto_impls.push(quote!(#(#attrs)* #token_stream))
}
GeneratedPyMethod::SlotTraitImpl(..) => {
panic!("SlotFragment methods cannot have default implementation!")
}
GeneratedPyMethod::Method(_) | GeneratedPyMethod::TraitImpl(_) => {
panic!("Only protocol methods can have default implementation!")
}
}
}

quote! {
impl #cls {
#(#method_defs)*
}
impl ::pyo3::class::impl_::PyClassDefaultSlots<#cls>
for ::pyo3::class::impl_::PyClassImplCollector<#cls> {
fn py_class_default_slots(self) -> &'static [::pyo3::ffi::PyType_Slot] {
&[#(#proto_impls),*]
}
}
}
}

fn impl_py_methods(ty: &syn::Type, methods: Vec<TokenStream>) -> TokenStream {
quote! {
impl ::pyo3::class::impl_::PyMethods<#ty>
Expand Down
3 changes: 3 additions & 0 deletions src/class/impl_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,9 @@ slots_trait!(PyAsyncProtocolSlots, async_protocol_slots);
slots_trait!(PySequenceProtocolSlots, sequence_protocol_slots);
slots_trait!(PyBufferProtocolSlots, buffer_protocol_slots);

// slots that PyO3 implements by default, but can be overidden by the users.
slots_trait!(PyClassDefaultSlots, py_class_default_slots);

// Protocol slots from #[pymethods] if not using inventory.
#[cfg(not(feature = "multiple-pymethods"))]
slots_trait!(PyMethodsProtocolSlots, methods_protocol_slots);
Expand Down
41 changes: 41 additions & 0 deletions tests/test_default_impls.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use pyo3::prelude::*;

mod common;

// Test default generated __repr__.
#[pyclass]
enum TestDefaultRepr {
Var,
}

#[test]
fn test_default_slot_exists() {
Python::with_gil(|py| {
let test_object = Py::new(py, TestDefaultRepr::Var).unwrap();
py_assert!(
py,
test_object,
"repr(test_object) == 'TestDefaultRepr.Var'"
);
})
}

#[pyclass]
enum OverrideSlot {
Var,
}

#[pymethods]
impl OverrideSlot {
fn __repr__(&self) -> &str {
"overriden"
}
}

#[test]
fn test_override_slot() {
Python::with_gil(|py| {
let test_object = Py::new(py, OverrideSlot::Var).unwrap();
py_assert!(py, test_object, "repr(test_object) == 'overriden'");
})
}
10 changes: 10 additions & 0 deletions tests/test_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,13 @@ fn test_enum_arg() {

py_run!(py, f mynum, "f(mynum.Variant)")
}

#[test]
fn test_default_repr_correct() {
Python::with_gil(|py| {
let var1 = Py::new(py, MyEnum::Variant).unwrap();
let var2 = Py::new(py, MyEnum::OtherVariant).unwrap();
py_assert!(py, var1, "repr(var1) == 'MyEnum.Variant'");
py_assert!(py, var2, "repr(var2) == 'MyEnum.OtherVariant'");
})
}

0 comments on commit 8a03778

Please sign in to comment.