Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add #[classattr] methods to define Python class attributes #905

Merged
merged 4 commits into from
May 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions guide/src/class.md
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,37 @@ impl MyClass {
}
```

## Class attributes

To create a class attribute (also called [class variable][classattr]), a method without
any arguments can be annotated with the `#[classattr]` attribute. The return type must be `T` for
some `T` that implements `IntoPy<PyObject>`.

```rust
# use pyo3::prelude::*;
# #[pyclass]
# struct MyClass {}
#[pymethods]
impl MyClass {
#[classattr]
fn my_attribute() -> String {
"hello".to_string()
}
}

let gil = Python::acquire_gil();
let py = gil.python();
let my_class = py.get_type::<MyClass>();
pyo3::py_run!(py, my_class, "assert my_class.my_attribute == 'hello'")
```

Note that unlike class variables defined in Python code, class attributes defined in Rust cannot
be mutated at all:
```rust,ignore
// Would raise a `TypeError: can't set attributes of built-in/extension type 'MyClass'`
pyo3::py_run!(py, my_class, "my_class.my_attribute = 'foo'")
```

## Callable objects

To specify a custom `__call__` method for a custom class, the method needs to be annotated with
Expand Down Expand Up @@ -914,3 +945,5 @@ To escape this we use [inventory](https://github.com/dtolnay/inventory), which a
[`PyClassInitializer<T>`]: https://pyo3.rs/master/doc/pyo3/pyclass_init/struct.PyClassInitializer.html

[`RefCell`]: https://doc.rust-lang.org/std/cell/struct.RefCell.html

[classattr]: https://docs.python.org/3/tutorial/classes.html#class-and-instance-variables
14 changes: 13 additions & 1 deletion pyo3-derive-backend/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pub enum FnType {
FnCall,
FnClass,
FnStatic,
ClassAttribute,
/// For methods taht have `self_: &PyCell<Self>` instead of self receiver
PySelfRef(syn::TypeReference),
/// For methods taht have `self_: PyRef<Self>` or `PyRefMut<Self>` instead of self receiver
Expand Down Expand Up @@ -139,6 +140,15 @@ impl<'a> FnSpec<'a> {
};
}

if let FnType::ClassAttribute = &fn_type {
if self_.is_some() || !arguments.is_empty() {
return Err(syn::Error::new_spanned(
name,
"Class attribute methods cannot take arguments",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be helpful to optionally allow one argument, of type Python, for those who want to create python types as class attributes. See pymethod/impl_wrap_getter as an example of how we allow this for getters.

One caveat I am not sure about though: as we're currently in the middle of creating a type object, is it safe to run arbitrary Python code?

Copy link
Member

@kngwyu kngwyu May 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it safe to run arbitrary Python code?

Maybe this code can cause SIGSEGV.
Oh I was wrong. This code works. Still investigating this can cause some odd errors, though.

#[pymethods]
impl MyClass {
    #[classattr]
    fn foo() -> MyClass { ... } 
}

I'll open a small PR to prevent this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kngwyu is your conclusion that this is safe?

));
}
}

// "Tweak" getter / setter names: strip off set_ and get_ if needed
if let FnType::Getter | FnType::Setter = &fn_type {
if python_name.is_none() {
Expand Down Expand Up @@ -178,7 +188,7 @@ impl<'a> FnSpec<'a> {
"text_signature not allowed on __new__; if you want to add a signature on \
__new__, put it on the struct definition instead",
)?,
FnType::FnCall | FnType::Getter | FnType::Setter => {
FnType::FnCall | FnType::Getter | FnType::Setter | FnType::ClassAttribute => {
parse_erroneous_text_signature("text_signature not allowed with this attribute")?
}
};
Expand Down Expand Up @@ -331,6 +341,8 @@ fn parse_method_attributes(
res = Some(FnType::FnClass)
} else if name.is_ident("staticmethod") {
res = Some(FnType::FnStatic)
} else if name.is_ident("classattr") {
res = Some(FnType::ClassAttribute)
} else if name.is_ident("setter") || name.is_ident("getter") {
if let syn::AttrStyle::Inner(_) = attr.style {
return Err(syn::Error::new_spanned(
Expand Down
30 changes: 30 additions & 0 deletions pyo3-derive-backend/src/pymethod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ pub fn gen_py_method(
FnType::FnCall => impl_py_method_def_call(&spec, &impl_wrap(cls, &spec, false)),
FnType::FnClass => impl_py_method_def_class(&spec, &impl_wrap_class(cls, &spec)),
FnType::FnStatic => impl_py_method_def_static(&spec, &impl_wrap_static(cls, &spec)),
FnType::ClassAttribute => {
impl_py_class_attribute(&spec, &impl_wrap_class_attribute(cls, &spec))
}
FnType::Getter => impl_py_getter_def(
&spec.python_name,
&spec.doc,
Expand Down Expand Up @@ -246,6 +249,19 @@ pub fn impl_wrap_static(cls: &syn::Type, spec: &FnSpec<'_>) -> TokenStream {
}
}

/// Generate a wrapper for initialization of a class attribute.
/// To be called in `pyo3::pyclass::initialize_type_object`.
pub fn impl_wrap_class_attribute(cls: &syn::Type, spec: &FnSpec<'_>) -> TokenStream {
let name = &spec.name;
let cb = quote! { #cls::#name() };

quote! {
fn __wrap(py: pyo3::Python<'_>) -> pyo3::PyObject {
pyo3::IntoPy::into_py(#cb, py)
}
}
}

fn impl_call_getter(spec: &FnSpec) -> syn::Result<TokenStream> {
let (py_arg, args) = split_off_python_arg(&spec.args);
if !args.is_empty() {
Expand Down Expand Up @@ -615,6 +631,20 @@ pub fn impl_py_method_def_static(spec: &FnSpec, wrapper: &TokenStream) -> TokenS
}
}

pub fn impl_py_class_attribute(spec: &FnSpec<'_>, wrapper: &TokenStream) -> TokenStream {
let python_name = &spec.python_name;
quote! {
pyo3::class::PyMethodDefType::ClassAttribute({
#wrapper

pyo3::class::PyClassAttributeDef {
name: stringify!(#python_name),
meth: __wrap,
}
})
}
}

pub fn impl_py_method_def_call(spec: &FnSpec, wrapper: &TokenStream) -> TokenStream {
let python_name = &spec.python_name;
let doc = &spec.doc;
Expand Down
21 changes: 20 additions & 1 deletion src/class/methods.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
// Copyright (c) 2017-present PyO3 Project and Contributors

use crate::ffi;
use crate::{ffi, PyObject, Python};
use libc::c_int;
use std::ffi::CString;
use std::fmt;

/// `PyMethodDefType` represents different types of Python callable objects.
/// It is used by the `#[pymethods]` and `#[pyproto]` annotations.
Expand All @@ -18,6 +19,8 @@ pub enum PyMethodDefType {
Static(PyMethodDef),
/// Represents normal method
Method(PyMethodDef),
/// Represents class attribute, used by `#[attribute]`
ClassAttribute(PyClassAttributeDef),
/// Represents getter descriptor, used by `#[getter]`
Getter(PyGetterDef),
/// Represents setter descriptor, used by `#[setter]`
Expand All @@ -40,6 +43,12 @@ pub struct PyMethodDef {
pub ml_doc: &'static str,
}

#[derive(Copy, Clone)]
pub struct PyClassAttributeDef {
pub name: &'static str,
pub meth: for<'p> fn(Python<'p>) -> PyObject,
}

#[derive(Copy, Clone, Debug)]
pub struct PyGetterDef {
pub name: &'static str,
Expand Down Expand Up @@ -85,6 +94,16 @@ impl PyMethodDef {
}
}

// Manual implementation because `Python<'_>` does not implement `Debug` and
// trait bounds on `fn` compiler-generated derive impls are too restrictive.
impl fmt::Debug for PyClassAttributeDef {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PyClassAttributeDef")
.field("name", &self.name)
.finish()
}
}

impl PyGetterDef {
/// Copy descriptor information to `ffi::PyGetSetDef`
pub fn copy_to(&self, dst: &mut ffi::PyGetSetDef) {
Expand Down
4 changes: 3 additions & 1 deletion src/class/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ pub use self::descr::PyDescrProtocol;
pub use self::gc::{PyGCProtocol, PyTraverseError, PyVisit};
pub use self::iter::PyIterProtocol;
pub use self::mapping::PyMappingProtocol;
pub use self::methods::{PyGetterDef, PyMethodDef, PyMethodDefType, PyMethodType, PySetterDef};
pub use self::methods::{
PyClassAttributeDef, PyGetterDef, PyMethodDef, PyMethodDefType, PyMethodType, PySetterDef,
};
pub use self::number::PyNumberProtocol;
pub use self::pyasync::PyAsyncProtocol;
pub use self::sequence::PySequenceProtocol;
23 changes: 20 additions & 3 deletions src/pyclass.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
//! `PyClass` trait
use crate::class::methods::{PyMethodDefType, PyMethodsImpl};
use crate::class::methods::{PyClassAttributeDef, PyMethodDefType, PyMethodsImpl};
use crate::conversion::{IntoPyPointer, ToPyObject};
use crate::pyclass_slots::{PyClassDict, PyClassWeakRef};
use crate::type_object::{type_flags, PyLayout};
use crate::types::PyDict;
use crate::{class, ffi, PyCell, PyErr, PyNativeType, PyResult, PyTypeInfo, Python};
use std::ffi::CString;
use std::os::raw::c_void;
Expand Down Expand Up @@ -165,13 +167,23 @@ where
// buffer protocol
type_object.tp_as_buffer = to_ptr(<T as class::buffer::PyBufferProtocolImpl>::tp_as_buffer());

let (new, call, mut methods, attrs) = py_class_method_defs::<T>();

// normal methods
let (new, call, mut methods) = py_class_method_defs::<T>();
if !methods.is_empty() {
methods.push(ffi::PyMethodDef_INIT);
type_object.tp_methods = Box::into_raw(methods.into_boxed_slice()) as *mut _;
}

// class attributes
if !attrs.is_empty() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Considering a recursive case(e.g.,

#[pymethods]
impl MyClass {
    #[classattr]
    fn foo() -> MyClass { ... } 
}

), could you please move this code after PyType_Ready?
Then an incomplete type object is never used.

Copy link
Contributor Author

@scalexm scalexm May 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was precisely one of my use cases indeed. But I think it’s still possible to return MyOtherClass where MyOtherClass has not yet been initialized?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I think it’s still possible to return MyOtherClass where MyOtherClass has not yet been initialized?

Yes, but I'm not sure there's no corner case 🤔

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it safe to modify type_object after PyType_Ready ?

Copy link
Member

@kngwyu kngwyu May 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realized that it cannot be a problem since PyCell::new doesn't use the type object at all.
Thanks @scalexm

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ll add a few « recursive » test cases to be sure that we don’t break that assumption in the future.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it safe to modify type_object after PyType_Ready ?

no, it's not safe, according to the C API docs

let dict = PyDict::new(py);
for attr in attrs {
dict.set_item(attr.name, (attr.meth)(py))?;
}
type_object.tp_dict = dict.to_object(py).into_ptr();
}

// __new__ method
type_object.tp_new = new;
// __call__ method
Expand Down Expand Up @@ -219,8 +231,10 @@ fn py_class_method_defs<T: PyMethodsImpl>() -> (
Option<ffi::newfunc>,
Option<ffi::PyCFunctionWithKeywords>,
Vec<ffi::PyMethodDef>,
Vec<PyClassAttributeDef>,
) {
let mut defs = Vec::new();
let mut attrs = Vec::new();
let mut call = None;
let mut new = None;

Expand All @@ -243,6 +257,9 @@ fn py_class_method_defs<T: PyMethodsImpl>() -> (
| PyMethodDefType::Static(ref def) => {
defs.push(def.as_method_def());
}
PyMethodDefType::ClassAttribute(def) => {
attrs.push(def);
}
_ => (),
}
}
Expand All @@ -265,7 +282,7 @@ fn py_class_method_defs<T: PyMethodsImpl>() -> (

py_class_async_methods::<T>(&mut defs);

(new, call, defs)
(new, call, defs, attrs)
}

fn py_class_async_methods<T>(defs: &mut Vec<ffi::PyMethodDef>) {
Expand Down
75 changes: 75 additions & 0 deletions tests/test_class_attributes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
use pyo3::prelude::*;

mod common;

#[pyclass]
struct Foo {
#[pyo3(get)]
x: i32,
}

#[pyclass]
struct Bar {
#[pyo3(get)]
x: i32,
}

#[pymethods]
impl Foo {
#[classattr]
fn a() -> i32 {
5
}

#[classattr]
#[name = "B"]
fn b() -> String {
"bar".to_string()
}

#[classattr]
fn foo() -> Foo {
Foo { x: 1 }
}

#[classattr]
fn bar() -> Bar {
Bar { x: 2 }
}
}

#[pymethods]
impl Bar {
#[classattr]
fn foo() -> Foo {
Foo { x: 3 }
}
}

#[test]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mentioned this in my comment, but adding an in-line note so it doesn't get lost: I think it would be a good idea to add tests for the error cases, to ensure that they raise the correct exception.

Could be done here or it could be done in the examples/rustapi_module tests, which are written in Python and where you can use pytest.raises.

fn class_attributes() {
let gil = Python::acquire_gil();
let py = gil.python();
let foo_obj = py.get_type::<Foo>();
py_assert!(py, foo_obj, "foo_obj.a == 5");
py_assert!(py, foo_obj, "foo_obj.B == 'bar'");
}

#[test]
fn class_attributes_are_immutable() {
let gil = Python::acquire_gil();
let py = gil.python();
let foo_obj = py.get_type::<Foo>();
py_expect_exception!(py, foo_obj, "foo_obj.a = 6", TypeError);
}

#[test]
fn recursive_class_attributes() {
let gil = Python::acquire_gil();
let py = gil.python();
let foo_obj = py.get_type::<Foo>();
let bar_obj = py.get_type::<Bar>();
py_assert!(py, foo_obj, "foo_obj.foo.x == 1");
py_assert!(py, foo_obj, "foo_obj.bar.x == 2");
py_assert!(py, bar_obj, "bar_obj.foo.x == 3");
}