From ad3177533f2b1bfe3043d0b9c1b8ed1ee3864382 Mon Sep 17 00:00:00 2001 From: WANG Xuerui Date: Wed, 19 Jun 2024 20:01:52 +0800 Subject: [PATCH] Add several missing wrappers to PyAnyMethods --- newsfragments/4264.changed.md | 1 + src/types/any.rs | 52 +++++++++++++++++++++++++++++++++++ tests/test_arithmetics.rs | 44 +++++++++++++++++++++++++++++ 3 files changed, 97 insertions(+) create mode 100644 newsfragments/4264.changed.md diff --git a/newsfragments/4264.changed.md b/newsfragments/4264.changed.md new file mode 100644 index 00000000000..a3e89c6f98c --- /dev/null +++ b/newsfragments/4264.changed.md @@ -0,0 +1 @@ +Added `PyAnyMethods::{bitnot, matmul, floor_div, rem, divmod}` for completeness. diff --git a/src/types/any.rs b/src/types/any.rs index f2a86ff528d..c8c6d67e534 100644 --- a/src/types/any.rs +++ b/src/types/any.rs @@ -1143,6 +1143,9 @@ pub trait PyAnyMethods<'py>: crate::sealed::Sealed { /// Equivalent to the Python expression `abs(self)`. fn abs(&self) -> PyResult>; + /// Computes `~self`. + fn bitnot(&self) -> PyResult>; + /// Tests whether this object is less than another. /// /// This is equivalent to the Python expression `self < other`. @@ -1200,11 +1203,31 @@ pub trait PyAnyMethods<'py>: crate::sealed::Sealed { where O: ToPyObject; + /// Computes `self @ other`. + fn matmul(&self, other: O) -> PyResult> + where + O: ToPyObject; + /// Computes `self / other`. fn div(&self, other: O) -> PyResult> where O: ToPyObject; + /// Computes `self // other`. + fn floor_div(&self, other: O) -> PyResult> + where + O: ToPyObject; + + /// Computes `self % other`. + fn rem(&self, other: O) -> PyResult> + where + O: ToPyObject; + + /// Computes `divmod(self, other)`. + fn divmod(&self, other: O) -> PyResult> + where + O: ToPyObject; + /// Computes `self << other`. fn lshift(&self, other: O) -> PyResult> where @@ -1898,6 +1921,14 @@ impl<'py> PyAnyMethods<'py> for Bound<'py, PyAny> { inner(self) } + fn bitnot(&self) -> PyResult> { + fn inner<'py>(any: &Bound<'py, PyAny>) -> PyResult> { + unsafe { ffi::PyNumber_Invert(any.as_ptr()).assume_owned_or_err(any.py()) } + } + + inner(self) + } + fn lt(&self, other: O) -> PyResult where O: ToPyObject, @@ -1949,13 +1980,34 @@ impl<'py> PyAnyMethods<'py> for Bound<'py, PyAny> { implement_binop!(add, PyNumber_Add, "+"); implement_binop!(sub, PyNumber_Subtract, "-"); implement_binop!(mul, PyNumber_Multiply, "*"); + implement_binop!(matmul, PyNumber_MatrixMultiply, "@"); implement_binop!(div, PyNumber_TrueDivide, "/"); + implement_binop!(floor_div, PyNumber_FloorDivide, "//"); + implement_binop!(rem, PyNumber_Remainder, "%"); implement_binop!(lshift, PyNumber_Lshift, "<<"); implement_binop!(rshift, PyNumber_Rshift, ">>"); implement_binop!(bitand, PyNumber_And, "&"); implement_binop!(bitor, PyNumber_Or, "|"); implement_binop!(bitxor, PyNumber_Xor, "^"); + /// Computes `divmod(self, other)`. + fn divmod(&self, other: O) -> PyResult> + where + O: ToPyObject, + { + fn inner<'py>( + any: &Bound<'py, PyAny>, + other: Bound<'_, PyAny>, + ) -> PyResult> { + unsafe { + ffi::PyNumber_Divmod(any.as_ptr(), other.as_ptr()).assume_owned_or_err(any.py()) + } + } + + let py = self.py(); + inner(self, other.to_object(py).into_bound(py)) + } + /// Computes `self ** other % modulus` (`pow(self, other, modulus)`). /// `py.None()` may be passed for the `modulus`. fn pow(&self, other: O1, modulus: O2) -> PyResult> diff --git a/tests/test_arithmetics.rs b/tests/test_arithmetics.rs index 007f42a79e8..0cee2f9cf84 100644 --- a/tests/test_arithmetics.rs +++ b/tests/test_arithmetics.rs @@ -35,6 +35,10 @@ impl UnaryArithmetic { Self::new(self.inner.abs()) } + fn __invert__(&self) -> Self { + Self::new(self.inner.recip()) + } + #[pyo3(signature=(_ndigits=None))] fn __round__(&self, _ndigits: Option) -> Self { Self::new(self.inner.round()) @@ -48,8 +52,18 @@ fn unary_arithmetic() { py_run!(py, c, "assert repr(-c) == 'UA(-2.7)'"); py_run!(py, c, "assert repr(+c) == 'UA(2.7)'"); py_run!(py, c, "assert repr(abs(c)) == 'UA(2.7)'"); + py_run!(py, c, "assert repr(~c) == 'UA(0.37037037037037035)'"); py_run!(py, c, "assert repr(round(c)) == 'UA(3)'"); py_run!(py, c, "assert repr(round(c, 1)) == 'UA(3)'"); + + let c: Bound<'_, PyAny> = c.extract(py).unwrap(); + assert_py_eq!(c.neg().unwrap().repr().unwrap().as_any(), "UA(-2.7)"); + assert_py_eq!(c.pos().unwrap().repr().unwrap().as_any(), "UA(2.7)"); + assert_py_eq!(c.abs().unwrap().repr().unwrap().as_any(), "UA(2.7)"); + assert_py_eq!( + c.bitnot().unwrap().repr().unwrap().as_any(), + "UA(0.37037037037037035)" + ); }); } @@ -179,10 +193,26 @@ impl BinaryArithmetic { format!("BA * {:?}", rhs) } + fn __matmul__(&self, rhs: &Bound<'_, PyAny>) -> String { + format!("BA @ {:?}", rhs) + } + fn __truediv__(&self, rhs: &Bound<'_, PyAny>) -> String { format!("BA / {:?}", rhs) } + fn __floordiv__(&self, rhs: &Bound<'_, PyAny>) -> String { + format!("BA // {:?}", rhs) + } + + fn __mod__(&self, rhs: &Bound<'_, PyAny>) -> String { + format!("BA % {:?}", rhs) + } + + fn __divmod__(&self, rhs: &Bound<'_, PyAny>) -> String { + format!("divmod(BA, {:?})", rhs) + } + fn __lshift__(&self, rhs: &Bound<'_, PyAny>) -> String { format!("BA << {:?}", rhs) } @@ -217,6 +247,11 @@ fn binary_arithmetic() { py_run!(py, c, "assert c + 1 == 'BA + 1'"); py_run!(py, c, "assert c - 1 == 'BA - 1'"); py_run!(py, c, "assert c * 1 == 'BA * 1'"); + py_run!(py, c, "assert c @ 1 == 'BA @ 1'"); + py_run!(py, c, "assert c / 1 == 'BA / 1'"); + py_run!(py, c, "assert c // 1 == 'BA // 1'"); + py_run!(py, c, "assert c % 1 == 'BA % 1'"); + py_run!(py, c, "assert divmod(c, 1) == 'divmod(BA, 1)'"); py_run!(py, c, "assert c << 1 == 'BA << 1'"); py_run!(py, c, "assert c >> 1 == 'BA >> 1'"); py_run!(py, c, "assert c & 1 == 'BA & 1'"); @@ -230,6 +265,11 @@ fn binary_arithmetic() { py_expect_exception!(py, c, "1 + c", PyTypeError); py_expect_exception!(py, c, "1 - c", PyTypeError); py_expect_exception!(py, c, "1 * c", PyTypeError); + py_expect_exception!(py, c, "1 @ c", PyTypeError); + py_expect_exception!(py, c, "1 / c", PyTypeError); + py_expect_exception!(py, c, "1 // c", PyTypeError); + py_expect_exception!(py, c, "1 % c", PyTypeError); + py_expect_exception!(py, c, "divmod(1, c)", PyTypeError); py_expect_exception!(py, c, "1 << c", PyTypeError); py_expect_exception!(py, c, "1 >> c", PyTypeError); py_expect_exception!(py, c, "1 & c", PyTypeError); @@ -243,7 +283,11 @@ fn binary_arithmetic() { assert_py_eq!(c.add(&c).unwrap(), "BA + BA"); assert_py_eq!(c.sub(&c).unwrap(), "BA - BA"); assert_py_eq!(c.mul(&c).unwrap(), "BA * BA"); + assert_py_eq!(c.matmul(&c).unwrap(), "BA @ BA"); assert_py_eq!(c.div(&c).unwrap(), "BA / BA"); + assert_py_eq!(c.floor_div(&c).unwrap(), "BA // BA"); + assert_py_eq!(c.rem(&c).unwrap(), "BA % BA"); + assert_py_eq!(c.divmod(&c).unwrap(), "divmod(BA, BA)"); assert_py_eq!(c.lshift(&c).unwrap(), "BA << BA"); assert_py_eq!(c.rshift(&c).unwrap(), "BA >> BA"); assert_py_eq!(c.bitand(&c).unwrap(), "BA & BA");