diff --git a/CHANGELOG.md b/CHANGELOG.md index 4ec080b627a..fe182a0f027 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `PyCode` and `PyFrame` high level objects. [#2408](https://github.com/PyO3/pyo3/pull/2408) - Add FFI definitions `Py_fstring_input`, `sendfunc`, and `_PyErr_StackItem`. [#2423](https://github.com/PyO3/pyo3/pull/2423) - Add `PyDateTime::new_with_fold`, `PyTime::new_with_fold`, `PyTime::get_fold`, `PyDateTime::get_fold` for PyPy. [#2428](https://github.com/PyO3/pyo3/pull/2428) +- Allow `#[classattr]` take `Python` argument. [#2383](https://github.com/PyO3/pyo3/issues/2383) ### Changed diff --git a/pyo3-macros-backend/src/method.rs b/pyo3-macros-backend/src/method.rs index c095c23cc24..877a00f0c54 100644 --- a/pyo3-macros-backend/src/method.rs +++ b/pyo3-macros-backend/src/method.rs @@ -374,13 +374,7 @@ impl<'a> FnSpec<'a> { let (fn_type, skip_first_arg, fixed_convention) = match fn_type_attr { Some(MethodTypeAttribute::StaticMethod) => (FnType::FnStatic, false, None), - Some(MethodTypeAttribute::ClassAttribute) => { - ensure_spanned!( - sig.inputs.is_empty(), - sig.inputs.span() => "class attribute methods cannot take arguments" - ); - (FnType::ClassAttribute, false, None) - } + Some(MethodTypeAttribute::ClassAttribute) => (FnType::ClassAttribute, false, None), Some(MethodTypeAttribute::New) => { if let Some(name) = &python_name { bail_spanned!(name.span() => "`name` not allowed with `#[new]`"); diff --git a/pyo3-macros-backend/src/pymethod.rs b/pyo3-macros-backend/src/pymethod.rs index 768c91639db..1a31c95ae59 100644 --- a/pyo3-macros-backend/src/pymethod.rs +++ b/pyo3-macros-backend/src/pymethod.rs @@ -178,7 +178,7 @@ pub fn gen_py_method( // Class attributes go before protos so that class attributes can be used to set proto // method to None. (_, FnType::ClassAttribute) => { - GeneratedPyMethod::Method(impl_py_class_attribute(cls, spec)) + GeneratedPyMethod::Method(impl_py_class_attribute(cls, spec)?) } (PyMethodKind::Proto(proto_kind), _) => { ensure_no_forbidden_protocol_attributes(spec, &method.method_name)?; @@ -348,12 +348,25 @@ fn impl_traverse_slot(cls: &syn::Type, spec: FnSpec<'_>) -> TokenStream { }} } -fn impl_py_class_attribute(cls: &syn::Type, spec: &FnSpec<'_>) -> TokenStream { +fn impl_py_class_attribute(cls: &syn::Type, spec: &FnSpec<'_>) -> syn::Result { + let (py_arg, args) = split_off_python_arg(&spec.args); + ensure_spanned!( + args.is_empty(), + args[0].ty.span() => "#[classattr] can only have one argument (of type pyo3::Python)" + ); + let name = &spec.name; + let fncall = if py_arg.is_some() { + quote!(#cls::#name(py)) + } else { + quote!(#cls::#name()) + }; + let wrapper_ident = format_ident!("__pymethod_{}__", name); let deprecations = &spec.deprecations; let python_name = spec.null_terminated_python_name(); - quote! { + + let classattr = quote! { _pyo3::class::PyMethodDefType::ClassAttribute({ _pyo3::class::PyClassAttributeDef::new( #python_name, @@ -363,7 +376,7 @@ fn impl_py_class_attribute(cls: &syn::Type, spec: &FnSpec<'_>) -> TokenStream { #[allow(non_snake_case)] fn #wrapper_ident(py: _pyo3::Python<'_>) -> _pyo3::PyResult<_pyo3::PyObject> { #deprecations - let mut ret = #cls::#name(); + let mut ret = #fncall; if false { use _pyo3::impl_::ghost::IntoPyResult; ret.assert_into_py_result(); @@ -375,7 +388,8 @@ fn impl_py_class_attribute(cls: &syn::Type, spec: &FnSpec<'_>) -> TokenStream { }) ) }) - } + }; + Ok(classattr) } fn impl_call_setter(cls: &syn::Type, spec: &FnSpec<'_>) -> syn::Result { diff --git a/tests/test_class_attributes.rs b/tests/test_class_attributes.rs index 55353d405a1..fa6e38af92e 100644 --- a/tests/test_class_attributes.rs +++ b/tests/test_class_attributes.rs @@ -45,6 +45,11 @@ impl Foo { fn a_foo() -> Foo { Foo { x: 1 } } + + #[classattr] + fn a_foo_with_py(py: Python<'_>) -> Py { + Py::new(py, Foo { x: 1 }).unwrap() + } } #[test] @@ -57,6 +62,7 @@ fn class_attributes() { py_assert!(py, foo_obj, "foo_obj.a == 5"); py_assert!(py, foo_obj, "foo_obj.B == 'bar'"); py_assert!(py, foo_obj, "foo_obj.a_foo.x == 1"); + py_assert!(py, foo_obj, "foo_obj.a_foo_with_py.x == 1"); } // Ignored because heap types are not immutable: diff --git a/tests/ui/invalid_pymethods.stderr b/tests/ui/invalid_pymethods.stderr index 57c98123099..590d0295e64 100644 --- a/tests/ui/invalid_pymethods.stderr +++ b/tests/ui/invalid_pymethods.stderr @@ -1,8 +1,8 @@ -error: class attribute methods cannot take arguments - --> tests/ui/invalid_pymethods.rs:9:29 +error: #[classattr] can only have one argument (of type pyo3::Python) + --> tests/ui/invalid_pymethods.rs:9:34 | 9 | fn class_attr_with_args(foo: i32) {} - | ^^^ + | ^^^ error: `#[classattr]` does not take any arguments --> tests/ui/invalid_pymethods.rs:14:5