Skip to content

Commit

Permalink
Merge pull request #905 from scalexm/master
Browse files Browse the repository at this point in the history
Add `#[classattr]` methods to define Python class attributes
  • Loading branch information
kngwyu authored May 8, 2020
2 parents 8d28291 + e3d9544 commit 8aeae6c
Show file tree
Hide file tree
Showing 7 changed files with 194 additions and 6 deletions.
33 changes: 33 additions & 0 deletions guide/src/class.md
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,37 @@ impl MyClass {
}
```

## Class attributes

To create a class attribute (also called [class variable][classattr]), a method without
any arguments can be annotated with the `#[classattr]` attribute. The return type must be `T` for
some `T` that implements `IntoPy<PyObject>`.

```rust
# use pyo3::prelude::*;
# #[pyclass]
# struct MyClass {}
#[pymethods]
impl MyClass {
#[classattr]
fn my_attribute() -> String {
"hello".to_string()
}
}

let gil = Python::acquire_gil();
let py = gil.python();
let my_class = py.get_type::<MyClass>();
pyo3::py_run!(py, my_class, "assert my_class.my_attribute == 'hello'")
```

Note that unlike class variables defined in Python code, class attributes defined in Rust cannot
be mutated at all:
```rust,ignore
// Would raise a `TypeError: can't set attributes of built-in/extension type 'MyClass'`
pyo3::py_run!(py, my_class, "my_class.my_attribute = 'foo'")
```

## Callable objects

To specify a custom `__call__` method for a custom class, the method needs to be annotated with
Expand Down Expand Up @@ -914,3 +945,5 @@ To escape this we use [inventory](https://github.com/dtolnay/inventory), which a
[`PyClassInitializer<T>`]: https://pyo3.rs/master/doc/pyo3/pyclass_init/struct.PyClassInitializer.html

[`RefCell`]: https://doc.rust-lang.org/std/cell/struct.RefCell.html

[classattr]: https://docs.python.org/3/tutorial/classes.html#class-and-instance-variables
14 changes: 13 additions & 1 deletion pyo3-derive-backend/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pub enum FnType {
FnCall,
FnClass,
FnStatic,
ClassAttribute,
/// For methods taht have `self_: &PyCell<Self>` instead of self receiver
PySelfRef(syn::TypeReference),
/// For methods taht have `self_: PyRef<Self>` or `PyRefMut<Self>` instead of self receiver
Expand Down Expand Up @@ -139,6 +140,15 @@ impl<'a> FnSpec<'a> {
};
}

if let FnType::ClassAttribute = &fn_type {
if self_.is_some() || !arguments.is_empty() {
return Err(syn::Error::new_spanned(
name,
"Class attribute methods cannot take arguments",
));
}
}

// "Tweak" getter / setter names: strip off set_ and get_ if needed
if let FnType::Getter | FnType::Setter = &fn_type {
if python_name.is_none() {
Expand Down Expand Up @@ -178,7 +188,7 @@ impl<'a> FnSpec<'a> {
"text_signature not allowed on __new__; if you want to add a signature on \
__new__, put it on the struct definition instead",
)?,
FnType::FnCall | FnType::Getter | FnType::Setter => {
FnType::FnCall | FnType::Getter | FnType::Setter | FnType::ClassAttribute => {
parse_erroneous_text_signature("text_signature not allowed with this attribute")?
}
};
Expand Down Expand Up @@ -331,6 +341,8 @@ fn parse_method_attributes(
res = Some(FnType::FnClass)
} else if name.is_ident("staticmethod") {
res = Some(FnType::FnStatic)
} else if name.is_ident("classattr") {
res = Some(FnType::ClassAttribute)
} else if name.is_ident("setter") || name.is_ident("getter") {
if let syn::AttrStyle::Inner(_) = attr.style {
return Err(syn::Error::new_spanned(
Expand Down
30 changes: 30 additions & 0 deletions pyo3-derive-backend/src/pymethod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ pub fn gen_py_method(
FnType::FnCall => impl_py_method_def_call(&spec, &impl_wrap(cls, &spec, false)),
FnType::FnClass => impl_py_method_def_class(&spec, &impl_wrap_class(cls, &spec)),
FnType::FnStatic => impl_py_method_def_static(&spec, &impl_wrap_static(cls, &spec)),
FnType::ClassAttribute => {
impl_py_class_attribute(&spec, &impl_wrap_class_attribute(cls, &spec))
}
FnType::Getter => impl_py_getter_def(
&spec.python_name,
&spec.doc,
Expand Down Expand Up @@ -246,6 +249,19 @@ pub fn impl_wrap_static(cls: &syn::Type, spec: &FnSpec<'_>) -> TokenStream {
}
}

/// Generate a wrapper for initialization of a class attribute.
/// To be called in `pyo3::pyclass::initialize_type_object`.
pub fn impl_wrap_class_attribute(cls: &syn::Type, spec: &FnSpec<'_>) -> TokenStream {
let name = &spec.name;
let cb = quote! { #cls::#name() };

quote! {
fn __wrap(py: pyo3::Python<'_>) -> pyo3::PyObject {
pyo3::IntoPy::into_py(#cb, py)
}
}
}

fn impl_call_getter(spec: &FnSpec) -> syn::Result<TokenStream> {
let (py_arg, args) = split_off_python_arg(&spec.args);
if !args.is_empty() {
Expand Down Expand Up @@ -615,6 +631,20 @@ pub fn impl_py_method_def_static(spec: &FnSpec, wrapper: &TokenStream) -> TokenS
}
}

pub fn impl_py_class_attribute(spec: &FnSpec<'_>, wrapper: &TokenStream) -> TokenStream {
let python_name = &spec.python_name;
quote! {
pyo3::class::PyMethodDefType::ClassAttribute({
#wrapper

pyo3::class::PyClassAttributeDef {
name: stringify!(#python_name),
meth: __wrap,
}
})
}
}

pub fn impl_py_method_def_call(spec: &FnSpec, wrapper: &TokenStream) -> TokenStream {
let python_name = &spec.python_name;
let doc = &spec.doc;
Expand Down
21 changes: 20 additions & 1 deletion src/class/methods.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
// Copyright (c) 2017-present PyO3 Project and Contributors

use crate::ffi;
use crate::{ffi, PyObject, Python};
use libc::c_int;
use std::ffi::CString;
use std::fmt;

/// `PyMethodDefType` represents different types of Python callable objects.
/// It is used by the `#[pymethods]` and `#[pyproto]` annotations.
Expand All @@ -18,6 +19,8 @@ pub enum PyMethodDefType {
Static(PyMethodDef),
/// Represents normal method
Method(PyMethodDef),
/// Represents class attribute, used by `#[attribute]`
ClassAttribute(PyClassAttributeDef),
/// Represents getter descriptor, used by `#[getter]`
Getter(PyGetterDef),
/// Represents setter descriptor, used by `#[setter]`
Expand All @@ -40,6 +43,12 @@ pub struct PyMethodDef {
pub ml_doc: &'static str,
}

#[derive(Copy, Clone)]
pub struct PyClassAttributeDef {
pub name: &'static str,
pub meth: for<'p> fn(Python<'p>) -> PyObject,
}

#[derive(Copy, Clone, Debug)]
pub struct PyGetterDef {
pub name: &'static str,
Expand Down Expand Up @@ -85,6 +94,16 @@ impl PyMethodDef {
}
}

// Manual implementation because `Python<'_>` does not implement `Debug` and
// trait bounds on `fn` compiler-generated derive impls are too restrictive.
impl fmt::Debug for PyClassAttributeDef {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PyClassAttributeDef")
.field("name", &self.name)
.finish()
}
}

impl PyGetterDef {
/// Copy descriptor information to `ffi::PyGetSetDef`
pub fn copy_to(&self, dst: &mut ffi::PyGetSetDef) {
Expand Down
4 changes: 3 additions & 1 deletion src/class/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ pub use self::descr::PyDescrProtocol;
pub use self::gc::{PyGCProtocol, PyTraverseError, PyVisit};
pub use self::iter::PyIterProtocol;
pub use self::mapping::PyMappingProtocol;
pub use self::methods::{PyGetterDef, PyMethodDef, PyMethodDefType, PyMethodType, PySetterDef};
pub use self::methods::{
PyClassAttributeDef, PyGetterDef, PyMethodDef, PyMethodDefType, PyMethodType, PySetterDef,
};
pub use self::number::PyNumberProtocol;
pub use self::pyasync::PyAsyncProtocol;
pub use self::sequence::PySequenceProtocol;
23 changes: 20 additions & 3 deletions src/pyclass.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
//! `PyClass` trait
use crate::class::methods::{PyMethodDefType, PyMethodsImpl};
use crate::class::methods::{PyClassAttributeDef, PyMethodDefType, PyMethodsImpl};
use crate::conversion::{IntoPyPointer, ToPyObject};
use crate::pyclass_slots::{PyClassDict, PyClassWeakRef};
use crate::type_object::{type_flags, PyLayout};
use crate::types::PyDict;
use crate::{class, ffi, PyCell, PyErr, PyNativeType, PyResult, PyTypeInfo, Python};
use std::ffi::CString;
use std::os::raw::c_void;
Expand Down Expand Up @@ -165,13 +167,23 @@ where
// buffer protocol
type_object.tp_as_buffer = to_ptr(<T as class::buffer::PyBufferProtocolImpl>::tp_as_buffer());

let (new, call, mut methods, attrs) = py_class_method_defs::<T>();

// normal methods
let (new, call, mut methods) = py_class_method_defs::<T>();
if !methods.is_empty() {
methods.push(ffi::PyMethodDef_INIT);
type_object.tp_methods = Box::into_raw(methods.into_boxed_slice()) as *mut _;
}

// class attributes
if !attrs.is_empty() {
let dict = PyDict::new(py);
for attr in attrs {
dict.set_item(attr.name, (attr.meth)(py))?;
}
type_object.tp_dict = dict.to_object(py).into_ptr();
}

// __new__ method
type_object.tp_new = new;
// __call__ method
Expand Down Expand Up @@ -219,8 +231,10 @@ fn py_class_method_defs<T: PyMethodsImpl>() -> (
Option<ffi::newfunc>,
Option<ffi::PyCFunctionWithKeywords>,
Vec<ffi::PyMethodDef>,
Vec<PyClassAttributeDef>,
) {
let mut defs = Vec::new();
let mut attrs = Vec::new();
let mut call = None;
let mut new = None;

Expand All @@ -243,6 +257,9 @@ fn py_class_method_defs<T: PyMethodsImpl>() -> (
| PyMethodDefType::Static(ref def) => {
defs.push(def.as_method_def());
}
PyMethodDefType::ClassAttribute(def) => {
attrs.push(def);
}
_ => (),
}
}
Expand All @@ -265,7 +282,7 @@ fn py_class_method_defs<T: PyMethodsImpl>() -> (

py_class_async_methods::<T>(&mut defs);

(new, call, defs)
(new, call, defs, attrs)
}

fn py_class_async_methods<T>(defs: &mut Vec<ffi::PyMethodDef>) {
Expand Down
75 changes: 75 additions & 0 deletions tests/test_class_attributes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
use pyo3::prelude::*;

mod common;

#[pyclass]
struct Foo {
#[pyo3(get)]
x: i32,
}

#[pyclass]
struct Bar {
#[pyo3(get)]
x: i32,
}

#[pymethods]
impl Foo {
#[classattr]
fn a() -> i32 {
5
}

#[classattr]
#[name = "B"]
fn b() -> String {
"bar".to_string()
}

#[classattr]
fn foo() -> Foo {
Foo { x: 1 }
}

#[classattr]
fn bar() -> Bar {
Bar { x: 2 }
}
}

#[pymethods]
impl Bar {
#[classattr]
fn foo() -> Foo {
Foo { x: 3 }
}
}

#[test]
fn class_attributes() {
let gil = Python::acquire_gil();
let py = gil.python();
let foo_obj = py.get_type::<Foo>();
py_assert!(py, foo_obj, "foo_obj.a == 5");
py_assert!(py, foo_obj, "foo_obj.B == 'bar'");
}

#[test]
fn class_attributes_are_immutable() {
let gil = Python::acquire_gil();
let py = gil.python();
let foo_obj = py.get_type::<Foo>();
py_expect_exception!(py, foo_obj, "foo_obj.a = 6", TypeError);
}

#[test]
fn recursive_class_attributes() {
let gil = Python::acquire_gil();
let py = gil.python();
let foo_obj = py.get_type::<Foo>();
let bar_obj = py.get_type::<Bar>();
py_assert!(py, foo_obj, "foo_obj.foo.x == 1");
py_assert!(py, foo_obj, "foo_obj.bar.x == 2");
py_assert!(py, bar_obj, "bar_obj.foo.x == 3");
}

0 comments on commit 8aeae6c

Please sign in to comment.