diff --git a/pyo3-macros-backend/src/pyclass.rs b/pyo3-macros-backend/src/pyclass.rs index 08f1939b216..0154ad9e397 100644 --- a/pyo3-macros-backend/src/pyclass.rs +++ b/pyo3-macros-backend/src/pyclass.rs @@ -3,7 +3,7 @@ use crate::attributes::{self, take_pyo3_options, NameAttribute, TextSignatureAttribute}; use crate::deprecations::Deprecations; use crate::konst::{ConstAttributes, ConstSpec}; -use crate::pyimpl::{gen_py_const, PyClassMethodsType}; +use crate::pyimpl::{gen_default_slot_impls, gen_py_const, PyClassMethodsType}; use crate::pymethod::{impl_py_getter_def, impl_py_setter_def, PropertyType}; use crate::utils::{self, unwrap_group, PythonDoc}; use proc_macro2::{Span, TokenStream}; @@ -422,6 +422,24 @@ fn impl_enum_class( .impl_all(); let descriptors = unit_variants_as_descriptors(cls, variants.iter().map(|v| v.ident)); + let variants_repr = variants.iter().map(|variant| { + let variant_name = variant.ident; + // Assuming all variants are unit variants because they are the only type we support. + let repr = format!("{}.{}", cls, variant_name); + quote! { #cls::#variant_name => #repr, } + }); + + let default_repr_impl = quote! { + #[allow(non_snake_case)] + #[pyo3(name = "__repr__")] + fn __pyo3__repr__(&self) -> &'static str { + match self { + #(#variants_repr)* + _ => unreachable!("Unsupported variant type."), + } + } + }; + let default_impls = gen_default_slot_impls(cls, vec![default_repr_impl]); Ok(quote! { #pytypeinfo @@ -430,6 +448,8 @@ fn impl_enum_class( #descriptors + #default_impls + }) } diff --git a/pyo3-macros-backend/src/pyimpl.rs b/pyo3-macros-backend/src/pyimpl.rs index e2fbf8d802f..6bc90d44601 100644 --- a/pyo3-macros-backend/src/pyimpl.rs +++ b/pyo3-macros-backend/src/pyimpl.rs @@ -139,6 +139,47 @@ pub fn gen_py_const(cls: &syn::Type, spec: &ConstSpec) -> TokenStream { } } +pub fn gen_default_slot_impls(cls: &syn::Ident, method_defs: Vec) -> TokenStream { + // This function uses a lot of `unwrap()`; since method_defs are provided by us, they should + // all succeed. + let ty: syn::Type = syn::parse_quote!(#cls); + + let mut method_defs: Vec<_> = method_defs + .into_iter() + .map(|token| syn::parse2::(token).unwrap()) + .collect(); + + let mut proto_impls = Vec::new(); + + for meth in &mut method_defs { + let options = PyFunctionOptions::from_attrs(&mut meth.attrs).unwrap(); + match pymethod::gen_py_method(&ty, &mut meth.sig, &mut meth.attrs, options).unwrap() { + GeneratedPyMethod::Proto(token_stream) => { + let attrs = get_cfg_attributes(&meth.attrs); + proto_impls.push(quote!(#(#attrs)* #token_stream)) + } + GeneratedPyMethod::SlotTraitImpl(..) => { + todo!() + } + GeneratedPyMethod::Method(_) | GeneratedPyMethod::TraitImpl(_) => { + panic!("Only protocol methods can have default implementation!") + } + } + } + + quote! { + impl #cls { + #(#method_defs)* + } + impl ::pyo3::class::impl_::PyClassDefaultSlots<#cls> + for ::pyo3::class::impl_::PyClassImplCollector<#cls> { + fn py_class_default_slots(self) -> &'static [::pyo3::ffi::PyType_Slot] { + &[#(#proto_impls),*] + } + } + } +} + fn impl_py_methods(ty: &syn::Type, methods: Vec) -> TokenStream { quote! { impl ::pyo3::class::impl_::PyMethods<#ty> diff --git a/tests/test_default_impls.rs b/tests/test_default_impls.rs index effaa1d1d7e..fa3b49d51a5 100644 --- a/tests/test_default_impls.rs +++ b/tests/test_default_impls.rs @@ -1,110 +1,33 @@ -#![allow(non_snake_case)] -use pyo3::class::PyMethodDefType; use pyo3::prelude::*; -use pyo3::py_run; mod common; -// Tests for PyClassDefaultSlots +// Test default generated __repr__. #[pyclass] -struct TestDefaultSlots; - -// generated using `Cargo expand` -// equivalent to -// ``` -// impl TestDefaultSlots {{ -// fn __str__(&self) -> &'static str { -// "default" -// } -// } -// ``` -impl TestDefaultSlots { - fn __pyo3__str__(&self) -> &'static str { - "default" - } -} - -impl ::pyo3::class::impl_::PyClassDefaultSlots - for ::pyo3::class::impl_::PyClassImplCollector -{ - fn py_class_default_slots(self) -> &'static [::pyo3::ffi::PyType_Slot] { - &[{ - unsafe extern "C" fn __wrap( - _raw_slf: *mut ::pyo3::ffi::PyObject, - ) -> *mut ::pyo3::ffi::PyObject { - let _slf = _raw_slf; - ::pyo3::callback::handle_panic(|_py| { - let _cell = _py - .from_borrowed_ptr::<::pyo3::PyAny>(_slf) - .downcast::<::pyo3::PyCell>()?; - let _ref = _cell.try_borrow()?; - let _slf = &_ref; - ::pyo3::callback::convert(_py, TestDefaultSlots::__pyo3__str__(_slf)) - }) - } - ::pyo3::ffi::PyType_Slot { - slot: ::pyo3::ffi::Py_tp_str, - pfunc: __wrap as ::pyo3::ffi::reprfunc as _, - } - }] - } +enum TestDefaultRepr { + Var, } #[test] fn test_default_slot_exists() { Python::with_gil(|py| { - let test_object = Py::new(py, TestDefaultSlots).unwrap(); - py_assert!(py, test_object, "str(test_object) == 'default'"); + let test_object = Py::new(py, TestDefaultRepr::Var).unwrap(); + py_assert!( + py, + test_object, + "repr(test_object) == 'TestDefaultRepr.Var'" + ); }) } #[pyclass] -struct OverrideSlot; - -// generated using `Cargo expand` -// equivalent to -// ``` -// impl OverrideMagicMethod { -// fn __str__(&self) -> &'static str { -// "default" -// } -// } -// ``` -impl OverrideSlot { - fn __pyo3__str__(&self) -> &'static str { - "default" - } -} - -impl ::pyo3::class::impl_::PyClassDefaultSlots - for ::pyo3::class::impl_::PyClassImplCollector -{ - fn py_class_default_slots(self) -> &'static [::pyo3::ffi::PyType_Slot] { - &[{ - unsafe extern "C" fn __wrap( - _raw_slf: *mut ::pyo3::ffi::PyObject, - ) -> *mut ::pyo3::ffi::PyObject { - let _slf = _raw_slf; - ::pyo3::callback::handle_panic(|_py| { - let _cell = _py - .from_borrowed_ptr::<::pyo3::PyAny>(_slf) - .downcast::<::pyo3::PyCell>()?; - let _ref = _cell.try_borrow()?; - let _slf = &_ref; - ::pyo3::callback::convert(_py, OverrideSlot::__pyo3__str__(_slf)) - }) - } - ::pyo3::ffi::PyType_Slot { - slot: ::pyo3::ffi::Py_tp_str, - pfunc: __wrap as ::pyo3::ffi::reprfunc as _, - } - }] - } +enum OverrideSlot { + Var, } #[pymethods] impl OverrideSlot { - fn __str__(&self) -> &str { + fn __repr__(&self) -> &str { "overriden" } } @@ -112,7 +35,7 @@ impl OverrideSlot { #[test] fn test_override_slot() { Python::with_gil(|py| { - let test_object = Py::new(py, OverrideSlot).unwrap(); - py_assert!(py, test_object, "str(test_object) == 'overriden'"); + let test_object = Py::new(py, OverrideSlot::Var).unwrap(); + py_assert!(py, test_object, "repr(test_object) == 'overriden'"); }) }