Skip to content

Commit

Permalink
support Bound for classmethod and pass_module (#3831)
Browse files Browse the repository at this point in the history
* support `Bound` for `classmethod` and `pass_module`

* `from_ref_to_ptr` -> `ref_from_ptr`

* add detailed docs to `ref_from_ptr`
  • Loading branch information
davidhewitt authored Feb 16, 2024
1 parent 05aedc9 commit ec6d587
Show file tree
Hide file tree
Showing 16 changed files with 205 additions and 51 deletions.
10 changes: 5 additions & 5 deletions guide/src/class.md
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ This is the equivalent of the Python decorator `@classmethod`.
#[pymethods]
impl MyClass {
#[classmethod]
fn cls_method(cls: &PyType) -> PyResult<i32> {
fn cls_method(cls: &Bound<'_, PyType>) -> PyResult<i32> {
Ok(10)
}
}
Expand Down Expand Up @@ -719,10 +719,10 @@ To create a constructor which takes a positional class argument, you can combine
impl BaseClass {
#[new]
#[classmethod]
fn py_new<'p>(cls: &'p PyType, py: Python<'p>) -> PyResult<Self> {
fn py_new(cls: &Bound<'_, PyType>) -> PyResult<Self> {
// Get an abstract attribute (presumably) declared on a subclass of this class.
let subclass_attr = cls.getattr("a_class_attr")?;
Ok(Self(subclass_attr.to_object(py)))
let subclass_attr: Bound<'_, PyAny> = cls.getattr("a_class_attr")?;
Ok(Self(subclass_attr.unbind()))
}
}
```
Expand Down Expand Up @@ -928,7 +928,7 @@ impl MyClass {
// similarly for classmethod arguments, use $cls
#[classmethod]
#[pyo3(text_signature = "($cls, e, f)")]
fn my_class_method(cls: &PyType, e: i32, f: i32) -> i32 {
fn my_class_method(cls: &Bound<'_, PyType>, e: i32, f: i32) -> i32 {
e + f
}
#[staticmethod]
Expand Down
3 changes: 2 additions & 1 deletion guide/src/function.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,11 @@ The `#[pyo3]` attribute can be used to modify properties of the generated Python

```rust
use pyo3::prelude::*;
use pyo3::types::PyString;

#[pyfunction]
#[pyo3(pass_module)]
fn pyfunction_with_module(module: &PyModule) -> PyResult<&str> {
fn pyfunction_with_module<'py>(module: &Bound<'py, PyModule>) -> PyResult<Bound<'py, PyString>> {
module.name()
}

Expand Down
14 changes: 11 additions & 3 deletions pyo3-macros-backend/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,21 @@ impl FnType {
let slf: Ident = syn::Ident::new("_slf", Span::call_site());
quote_spanned! { *span =>
#[allow(clippy::useless_conversion)]
::std::convert::Into::into(_pyo3::types::PyType::from_type_ptr(#py, #slf.cast())),
::std::convert::Into::into(
_pyo3::impl_::pymethods::BoundRef::ref_from_ptr(#py, &#slf.cast())
.downcast_unchecked::<_pyo3::types::PyType>()
),
}
}
FnType::FnModule(span) => {
let py = syn::Ident::new("py", Span::call_site());
let slf: Ident = syn::Ident::new("_slf", Span::call_site());
quote_spanned! { *span =>
#[allow(clippy::useless_conversion)]
::std::convert::Into::into(py.from_borrowed_ptr::<_pyo3::types::PyModule>(_slf)),
::std::convert::Into::into(
_pyo3::impl_::pymethods::BoundRef::ref_from_ptr(#py, &#slf.cast())
.downcast_unchecked::<_pyo3::types::PyModule>()
),
}
}
}
Expand Down Expand Up @@ -409,7 +417,7 @@ impl<'a> FnSpec<'a> {
// will error on incorrect type.
Some(syn::FnArg::Typed(first_arg)) => first_arg.ty.span(),
Some(syn::FnArg::Receiver(_)) | None => bail_spanned!(
sig.paren_token.span.join() => "Expected `&PyType` or `Py<PyType>` as the first argument to `#[classmethod]`"
sig.paren_token.span.join() => "Expected `&Bound<PyType>` or `Py<PyType>` as the first argument to `#[classmethod]`"
),
};
FnType::FnClass(span)
Expand Down
20 changes: 20 additions & 0 deletions pytests/src/pyclasses.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,25 @@ struct AssertingBaseClass;

#[pymethods]
impl AssertingBaseClass {
#[new]
#[classmethod]
fn new(cls: &Bound<'_, PyType>, expected_type: Bound<'_, PyType>) -> PyResult<Self> {
if !cls.is(&expected_type) {
return Err(PyValueError::new_err(format!(
"{:?} != {:?}",
cls, expected_type
)));
}
Ok(Self)
}
}

#[pyclass(subclass)]
#[derive(Clone, Debug)]
struct AssertingBaseClassGilRef;

#[pymethods]
impl AssertingBaseClassGilRef {
#[new]
#[classmethod]
fn new(cls: &PyType, expected_type: &PyType) -> PyResult<Self> {
Expand All @@ -65,6 +84,7 @@ pub fn pyclasses(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<EmptyClass>()?;
m.add_class::<PyClassIter>()?;
m.add_class::<AssertingBaseClass>()?;
m.add_class::<AssertingBaseClassGilRef>()?;
m.add_class::<ClassWithoutConstructor>()?;
Ok(())
}
11 changes: 11 additions & 0 deletions pytests/tests/test_pyclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ def test_new_classmethod():
_ = AssertingSubClass(expected_type=str)


def test_new_classmethod_gil_ref():
class AssertingSubClass(pyclasses.AssertingBaseClassGilRef):
pass

# The `AssertingBaseClass` constructor errors if it is not passed the
# relevant subclass.
_ = AssertingSubClass(expected_type=AssertingSubClass)
with pytest.raises(ValueError):
_ = AssertingSubClass(expected_type=str)


class ClassWithoutConstructorPy:
def __new__(cls):
raise TypeError("No constructor defined")
Expand Down
53 changes: 52 additions & 1 deletion src/impl_/pymethods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ use crate::exceptions::PyStopAsyncIteration;
use crate::gil::LockGIL;
use crate::impl_::panic::PanicTrap;
use crate::internal_tricks::extract_c_string;
use crate::types::{any::PyAnyMethods, PyModule, PyType};
use crate::{
ffi, PyAny, PyCell, PyClass, PyErr, PyObject, PyResult, PyTraverseError, PyVisit, Python,
ffi, Bound, Py, PyAny, PyCell, PyClass, PyErr, PyObject, PyResult, PyTraverseError, PyVisit,
Python,
};
use std::borrow::Cow;
use std::ffi::CStr;
Expand Down Expand Up @@ -466,3 +468,52 @@ pub trait AsyncIterResultOptionKind {
}

impl<Value, Error> AsyncIterResultOptionKind for Result<Option<Value>, Error> {}

/// Used in `#[classmethod]` to pass the class object to the method
/// and also in `#[pyfunction(pass_module)]`.
///
/// This is a wrapper to avoid implementing `From<Bound>` for GIL Refs.
///
/// Once the GIL Ref API is fully removed, it should be possible to simplify
/// this to just `&'a Bound<'py, T>` and `From` implementations.
pub struct BoundRef<'a, 'py, T>(pub &'a Bound<'py, T>);

impl<'a, 'py> BoundRef<'a, 'py, PyAny> {
pub unsafe fn ref_from_ptr(py: Python<'py>, ptr: &'a *mut ffi::PyObject) -> Self {
BoundRef(Bound::ref_from_ptr(py, ptr))
}

pub unsafe fn downcast_unchecked<T>(self) -> BoundRef<'a, 'py, T> {
BoundRef(self.0.downcast_unchecked::<T>())
}
}

// GIL Ref implementations for &'a T ran into trouble with orphan rules,
// so explicit implementations are used instead for the two relevant types.
impl<'a> From<BoundRef<'a, 'a, PyType>> for &'a PyType {
#[inline]
fn from(bound: BoundRef<'a, 'a, PyType>) -> Self {
bound.0.as_gil_ref()
}
}

impl<'a> From<BoundRef<'a, 'a, PyModule>> for &'a PyModule {
#[inline]
fn from(bound: BoundRef<'a, 'a, PyModule>) -> Self {
bound.0.as_gil_ref()
}
}

impl<'a, 'py, T> From<BoundRef<'a, 'py, T>> for &'a Bound<'py, T> {
#[inline]
fn from(bound: BoundRef<'a, 'py, T>) -> Self {
bound.0
}
}

impl<T> From<BoundRef<'_, '_, T>> for Py<T> {
#[inline]
fn from(bound: BoundRef<'_, '_, T>) -> Self {
bound.0.clone().unbind()
}
}
18 changes: 18 additions & 0 deletions src/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,24 @@ impl<'py> Bound<'py, PyAny> {
) -> PyResult<Self> {
Py::from_owned_ptr_or_err(py, ptr).map(|obj| Self(py, ManuallyDrop::new(obj)))
}

/// This slightly strange method is used to obtain `&Bound<PyAny>` from a pointer in macro code
/// where we need to constrain the lifetime `'a` safely.
///
/// Note that `'py` is required to outlive `'a` implicitly by the nature of the fact that
/// `&'a Bound<'py>` means that `Bound<'py>` exists for at least the lifetime `'a`.
///
/// # Safety
/// - `ptr` must be a valid pointer to a Python object for the lifetime `'a`. The `ptr` can
/// be either a borrowed reference or an owned reference, it does not matter, as this is
/// just `&Bound` there will never be any ownership transfer.
#[inline]
pub(crate) unsafe fn ref_from_ptr<'a>(
_py: Python<'py>,
ptr: &'a *mut ffi::PyObject,
) -> &'a Self {
&*(ptr as *const *mut ffi::PyObject).cast::<Bound<'py, PyAny>>()
}
}

impl<'py, T> Bound<'py, T>
Expand Down
4 changes: 2 additions & 2 deletions src/tests/hygiene/pymethods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ impl Dummy {
#[staticmethod]
fn staticmethod() {}
#[classmethod]
fn clsmethod(_: &crate::types::PyType) {}
fn clsmethod(_: &crate::Bound<'_, crate::types::PyType>) {}
#[pyo3(signature = (*_args, **_kwds))]
fn __call__(
&self,
Expand Down Expand Up @@ -770,7 +770,7 @@ impl Dummy {
#[staticmethod]
fn staticmethod() {}
#[classmethod]
fn clsmethod(_: &crate::types::PyType) {}
fn clsmethod(_: &crate::Bound<'_, crate::types::PyType>) {}
#[pyo3(signature = (*_args, **_kwds))]
fn __call__(
&self,
Expand Down
16 changes: 14 additions & 2 deletions tests/test_class_basics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ fn panic_unsendable_child() {
test_unsendable::<UnsendableChild>().unwrap();
}

fn get_length(obj: &PyAny) -> PyResult<usize> {
fn get_length(obj: &Bound<'_, PyAny>) -> PyResult<usize> {
let length = obj.len()?;

Ok(length)
Expand All @@ -299,7 +299,18 @@ impl ClassWithFromPyWithMethods {
argument
}
#[classmethod]
fn classmethod(_cls: &PyType, #[pyo3(from_py_with = "PyAny::len")] argument: usize) -> usize {
fn classmethod(
_cls: &Bound<'_, PyType>,
#[pyo3(from_py_with = "Bound::<'_, PyAny>::len")] argument: usize,
) -> usize {
argument
}

#[classmethod]
fn classmethod_gil_ref(
_cls: &PyType,
#[pyo3(from_py_with = "PyAny::len")] argument: usize,
) -> usize {
argument
}

Expand All @@ -322,6 +333,7 @@ fn test_pymethods_from_py_with() {
assert instance.instance_method(arg) == 2
assert instance.classmethod(arg) == 2
assert instance.classmethod_gil_ref(arg) == 2
assert instance.staticmethod(arg) == 2
"#
);
Expand Down
20 changes: 15 additions & 5 deletions tests/test_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,13 @@ impl ClassMethod {

#[classmethod]
/// Test class method.
fn method(cls: &PyType) -> PyResult<String> {
fn method(cls: &Bound<'_, PyType>) -> PyResult<String> {
Ok(format!("{}.method()!", cls.as_gil_ref().qualname()?))
}

#[classmethod]
/// Test class method.
fn method_gil_ref(cls: &PyType) -> PyResult<String> {
Ok(format!("{}.method()!", cls.qualname()?))
}

Expand Down Expand Up @@ -108,8 +114,12 @@ struct ClassMethodWithArgs {}
#[pymethods]
impl ClassMethodWithArgs {
#[classmethod]
fn method(cls: &PyType, input: &PyString) -> PyResult<String> {
Ok(format!("{}.method({})", cls.qualname()?, input))
fn method(cls: &Bound<'_, PyType>, input: &PyString) -> PyResult<String> {
Ok(format!(
"{}.method({})",
cls.as_gil_ref().qualname()?,
input
))
}
}

Expand Down Expand Up @@ -915,7 +925,7 @@ impl r#RawIdents {
}

#[classmethod]
pub fn r#class_method(_: &PyType, r#type: PyObject) -> PyObject {
pub fn r#class_method(_: &Bound<'_, PyType>, r#type: PyObject) -> PyObject {
r#type
}

Expand Down Expand Up @@ -1082,7 +1092,7 @@ issue_1506!(

#[classmethod]
fn issue_1506_class(
_cls: &PyType,
_cls: &Bound<'_, PyType>,
_py: Python<'_>,
_arg: &PyAny,
_args: &PyTuple,
Expand Down
Loading

0 comments on commit ec6d587

Please sign in to comment.