diff --git a/guide/src/class.md b/guide/src/class.md index 2bcfe75911e..ab0c82fc88b 100644 --- a/guide/src/class.md +++ b/guide/src/class.md @@ -327,8 +327,12 @@ explicitly. To get a parent class from a child, use [`PyRef`] instead of `&self` for methods, or [`PyRefMut`] instead of `&mut self`. -Then you can access a parent class by `self_.as_ref()` as `&Self::BaseClass`, -or by `self_.into_super()` as `PyRef`. +Then you can access a parent class by `self_.as_super()` as `&PyRef`, +or by `self_.into_super()` as `PyRef` (and similar for the `PyRefMut` +case). For convenience, `self_.as_ref()` can also be used to get `&Self::BaseClass` +directly; however, this approach does not let you access base clases higher in the +inheritance hierarchy, for which you would need to chain multiple `as_super` or +`into_super` calls. ```rust # use pyo3::prelude::*; @@ -345,7 +349,7 @@ impl BaseClass { BaseClass { val1: 10 } } - pub fn method(&self) -> PyResult { + pub fn method1(&self) -> PyResult { Ok(self.val1) } } @@ -363,8 +367,8 @@ impl SubClass { } fn method2(self_: PyRef<'_, Self>) -> PyResult { - let super_ = self_.as_ref(); // Get &BaseClass - super_.method().map(|x| x * self_.val2) + let super_ = self_.as_super(); // Get &PyRef + super_.method1().map(|x| x * self_.val2) } } @@ -381,11 +385,28 @@ impl SubSubClass { } fn method3(self_: PyRef<'_, Self>) -> PyResult { + let base = self_.as_super().as_super(); // Get &PyRef<'_, BaseClass> + base.method1().map(|x| x * self_.val3) + } + + fn method4(self_: PyRef<'_, Self>) -> PyResult { let v = self_.val3; let super_ = self_.into_super(); // Get PyRef<'_, SubClass> SubClass::method2(super_).map(|x| x * v) } + fn get_values(self_: PyRef<'_, Self>) -> (usize, usize, usize) { + let val1 = self_.as_super().as_super().val1; + let val2 = self_.as_super().val2; + (val1, val2, self_.val3) + } + + fn double_values(mut self_: PyRefMut<'_, Self>) { + self_.as_super().as_super().val1 *= 2; + self_.as_super().val2 *= 2; + self_.val3 *= 2; + } + #[staticmethod] fn factory_method(py: Python<'_>, val: usize) -> PyResult { let base = PyClassInitializer::from(BaseClass::new()); @@ -400,7 +421,13 @@ impl SubSubClass { } # Python::with_gil(|py| { # let subsub = pyo3::Py::new(py, SubSubClass::new()).unwrap(); -# pyo3::py_run!(py, subsub, "assert subsub.method3() == 3000"); +# pyo3::py_run!(py, subsub, "assert subsub.method1() == 10"); +# pyo3::py_run!(py, subsub, "assert subsub.method2() == 150"); +# pyo3::py_run!(py, subsub, "assert subsub.method3() == 200"); +# pyo3::py_run!(py, subsub, "assert subsub.method4() == 3000"); +# pyo3::py_run!(py, subsub, "assert subsub.get_values() == (10, 15, 20)"); +# pyo3::py_run!(py, subsub, "assert subsub.double_values() == None"); +# pyo3::py_run!(py, subsub, "assert subsub.get_values() == (20, 30, 40)"); # let subsub = SubSubClass::factory_method(py, 2).unwrap(); # let subsubsub = SubSubClass::factory_method(py, 3).unwrap(); # let cls = py.get_type_bound::(); diff --git a/newsfragments/4219.added.md b/newsfragments/4219.added.md new file mode 100644 index 00000000000..cea8fa1c314 --- /dev/null +++ b/newsfragments/4219.added.md @@ -0,0 +1,3 @@ +- Added `as_super` methods to `PyRef` and `PyRefMut` for accesing the base class by reference +- Updated user guide to recommend `as_super` for referencing the base class instead of `as_ref` +- Added `pyo3::internal_tricks::ptr_from_mut` function for casting `&mut T` to `*mut T` \ No newline at end of file diff --git a/src/internal_tricks.rs b/src/internal_tricks.rs index 62ec0d02166..a8873dda007 100644 --- a/src/internal_tricks.rs +++ b/src/internal_tricks.rs @@ -223,3 +223,9 @@ pub(crate) fn extract_c_string( pub(crate) const fn ptr_from_ref(t: &T) -> *const T { t as *const T } + +// TODO: use ptr::from_mut on MSRV 1.76 +#[inline] +pub(crate) fn ptr_from_mut(t: &mut T) -> *mut T { + t as *mut T +} diff --git a/src/pycell.rs b/src/pycell.rs index 1d601474bda..9ed6c8aca7d 100644 --- a/src/pycell.rs +++ b/src/pycell.rs @@ -196,13 +196,13 @@ use crate::conversion::AsPyPointer; use crate::exceptions::PyRuntimeError; use crate::ffi_ptr_ext::FfiPtrExt; +use crate::internal_tricks::{ptr_from_mut, ptr_from_ref}; use crate::pyclass::{boolean_struct::False, PyClass}; use crate::types::any::PyAnyMethods; #[cfg(feature = "gil-refs")] use crate::{ conversion::ToPyObject, impl_::pyclass::PyClassImpl, - internal_tricks::ptr_from_ref, pyclass::boolean_struct::True, pyclass_init::PyClassInitializer, type_object::{PyLayout, PySizedLayout}, @@ -612,6 +612,7 @@ impl fmt::Debug for PyCell { /// ``` /// /// See the [module-level documentation](self) for more information. +#[repr(transparent)] pub struct PyRef<'p, T: PyClass> { // TODO: once the GIL Ref API is removed, consider adding a lifetime parameter to `PyRef` to // store `Borrowed` here instead, avoiding reference counting overhead. @@ -631,7 +632,7 @@ where U: PyClass, { fn as_ref(&self) -> &T::BaseType { - unsafe { &*self.inner.get_class_object().ob_base.get_ptr() } + self.as_super() } } @@ -743,6 +744,58 @@ where }, } } + + /// Borrows a shared reference to `PyRef`. + /// + /// With the help of this method, you can access attributes and call methods + /// on the superclass without consuming the `PyRef`. This method can also + /// be chained to access the super-superclass (and so on). + /// + /// # Examples + /// ``` + /// # use pyo3::prelude::*; + /// #[pyclass(subclass)] + /// struct Base { + /// base_name: &'static str, + /// } + /// #[pymethods] + /// impl Base { + /// fn base_name_len(&self) -> usize { + /// self.base_name.len() + /// } + /// } + /// + /// #[pyclass(extends=Base)] + /// struct Sub { + /// sub_name: &'static str, + /// } + /// + /// #[pymethods] + /// impl Sub { + /// #[new] + /// fn new() -> (Self, Base) { + /// (Self { sub_name: "sub_name" }, Base { base_name: "base_name" }) + /// } + /// fn sub_name_len(&self) -> usize { + /// self.sub_name.len() + /// } + /// fn format_name_lengths(slf: PyRef<'_, Self>) -> String { + /// format!("{} {}", slf.as_super().base_name_len(), slf.sub_name_len()) + /// } + /// } + /// # Python::with_gil(|py| { + /// # let sub = Py::new(py, Sub::new()).unwrap(); + /// # pyo3::py_run!(py, sub, "assert sub.format_name_lengths() == '9 8'") + /// # }); + /// ``` + pub fn as_super(&self) -> &PyRef<'p, U> { + let ptr = ptr_from_ref::>(&self.inner) + // `Bound` has the same layout as `Bound` + .cast::>() + // `Bound` has the same layout as `PyRef` + .cast::>(); + unsafe { &*ptr } + } } impl<'p, T: PyClass> Deref for PyRef<'p, T> { @@ -799,6 +852,7 @@ impl fmt::Debug for PyRef<'_, T> { /// A wrapper type for a mutably borrowed value from a [`Bound<'py, T>`]. /// /// See the [module-level documentation](self) for more information. +#[repr(transparent)] pub struct PyRefMut<'p, T: PyClass> { // TODO: once the GIL Ref API is removed, consider adding a lifetime parameter to `PyRef` to // store `Borrowed` here instead, avoiding reference counting overhead. @@ -818,7 +872,7 @@ where U: PyClass, { fn as_ref(&self) -> &T::BaseType { - unsafe { &*self.inner.get_class_object().ob_base.get_ptr() } + PyRefMut::downgrade(self).as_super() } } @@ -828,7 +882,7 @@ where U: PyClass, { fn as_mut(&mut self) -> &mut T::BaseType { - unsafe { &mut *self.inner.get_class_object().ob_base.get_ptr() } + self.as_super() } } @@ -870,6 +924,11 @@ impl<'py, T: PyClass> PyRefMut<'py, T> { .try_borrow_mut() .map(|_| Self { inner: obj.clone() }) } + + pub(crate) fn downgrade(slf: &Self) -> &PyRef<'py, T> { + // `PyRefMut` and `PyRef` have the same layout + unsafe { &*ptr_from_ref(slf).cast() } + } } impl<'p, T, U> PyRefMut<'p, T> @@ -891,6 +950,23 @@ where }, } } + + /// Borrows a mutable reference to `PyRefMut`. + /// + /// With the help of this method, you can mutate attributes and call mutating + /// methods on the superclass without consuming the `PyRefMut`. This method + /// can also be chained to access the super-superclass (and so on). + /// + /// See [`PyRef::as_super`] for more. + pub fn as_super(&mut self) -> &mut PyRefMut<'p, U> { + let ptr = ptr_from_mut::>(&mut self.inner) + // `Bound` has the same layout as `Bound` + .cast::>() + // `Bound` has the same layout as `PyRefMut`, + // and the mutable borrow on `self` prevents aliasing + .cast::>(); + unsafe { &mut *ptr } + } } impl<'p, T: PyClass> Deref for PyRefMut<'p, T> { @@ -1140,4 +1216,88 @@ mod tests { unsafe { ffi::Py_DECREF(ptr) }; }) } + + #[crate::pyclass] + #[pyo3(crate = "crate", subclass)] + struct BaseClass { + val1: usize, + } + + #[crate::pyclass] + #[pyo3(crate = "crate", extends=BaseClass, subclass)] + struct SubClass { + val2: usize, + } + + #[crate::pyclass] + #[pyo3(crate = "crate", extends=SubClass)] + struct SubSubClass { + val3: usize, + } + + #[crate::pymethods] + #[pyo3(crate = "crate")] + impl SubSubClass { + #[new] + fn new(py: Python<'_>) -> crate::Py { + let init = crate::PyClassInitializer::from(BaseClass { val1: 10 }) + .add_subclass(SubClass { val2: 15 }) + .add_subclass(SubSubClass { val3: 20 }); + crate::Py::new(py, init).expect("allocation error") + } + + fn get_values(self_: PyRef<'_, Self>) -> (usize, usize, usize) { + let val1 = self_.as_super().as_super().val1; + let val2 = self_.as_super().val2; + (val1, val2, self_.val3) + } + + fn double_values(mut self_: PyRefMut<'_, Self>) { + self_.as_super().as_super().val1 *= 2; + self_.as_super().val2 *= 2; + self_.val3 *= 2; + } + } + + #[test] + fn test_pyref_as_super() { + Python::with_gil(|py| { + let obj = SubSubClass::new(py).into_bound(py); + let pyref = obj.borrow(); + assert_eq!(pyref.as_super().as_super().val1, 10); + assert_eq!(pyref.as_super().val2, 15); + assert_eq!(pyref.as_ref().val2, 15); // `as_ref` also works + assert_eq!(pyref.val3, 20); + assert_eq!(SubSubClass::get_values(pyref), (10, 15, 20)); + }); + } + + #[test] + fn test_pyrefmut_as_super() { + Python::with_gil(|py| { + let obj = SubSubClass::new(py).into_bound(py); + assert_eq!(SubSubClass::get_values(obj.borrow()), (10, 15, 20)); + { + let mut pyrefmut = obj.borrow_mut(); + assert_eq!(pyrefmut.as_super().as_ref().val1, 10); + pyrefmut.as_super().as_super().val1 -= 5; + pyrefmut.as_super().val2 -= 3; + pyrefmut.as_mut().val2 -= 2; // `as_mut` also works + pyrefmut.val3 -= 5; + } + assert_eq!(SubSubClass::get_values(obj.borrow()), (5, 10, 15)); + SubSubClass::double_values(obj.borrow_mut()); + assert_eq!(SubSubClass::get_values(obj.borrow()), (10, 20, 30)); + }); + } + + #[test] + fn test_pyrefs_in_python() { + Python::with_gil(|py| { + let obj = SubSubClass::new(py); + crate::py_run!(py, obj, "assert obj.get_values() == (10, 15, 20)"); + crate::py_run!(py, obj, "assert obj.double_values() is None"); + crate::py_run!(py, obj, "assert obj.get_values() == (20, 30, 40)"); + }); + } }