From 8d48a7fad3409412aa31deb5e3482e1370042486 Mon Sep 17 00:00:00 2001 From: WANG Xuerui Date: Wed, 19 Jun 2024 20:01:52 +0800 Subject: [PATCH] Add several missing wrappers to PyAnyMethods --- newsfragments/4264.changed.md | 1 + src/types/any.rs | 52 +++++++++++++++++++++++++++++++++++ tests/test_arithmetics.rs | 30 ++++++++++++++++++++ 3 files changed, 83 insertions(+) create mode 100644 newsfragments/4264.changed.md diff --git a/newsfragments/4264.changed.md b/newsfragments/4264.changed.md new file mode 100644 index 00000000000..a3e89c6f98c --- /dev/null +++ b/newsfragments/4264.changed.md @@ -0,0 +1 @@ +Added `PyAnyMethods::{bitnot, matmul, floor_div, rem, divmod}` for completeness. diff --git a/src/types/any.rs b/src/types/any.rs index f2a86ff528d..c8c6d67e534 100644 --- a/src/types/any.rs +++ b/src/types/any.rs @@ -1143,6 +1143,9 @@ pub trait PyAnyMethods<'py>: crate::sealed::Sealed { /// Equivalent to the Python expression `abs(self)`. fn abs(&self) -> PyResult>; + /// Computes `~self`. + fn bitnot(&self) -> PyResult>; + /// Tests whether this object is less than another. /// /// This is equivalent to the Python expression `self < other`. @@ -1200,11 +1203,31 @@ pub trait PyAnyMethods<'py>: crate::sealed::Sealed { where O: ToPyObject; + /// Computes `self @ other`. + fn matmul(&self, other: O) -> PyResult> + where + O: ToPyObject; + /// Computes `self / other`. fn div(&self, other: O) -> PyResult> where O: ToPyObject; + /// Computes `self // other`. + fn floor_div(&self, other: O) -> PyResult> + where + O: ToPyObject; + + /// Computes `self % other`. + fn rem(&self, other: O) -> PyResult> + where + O: ToPyObject; + + /// Computes `divmod(self, other)`. + fn divmod(&self, other: O) -> PyResult> + where + O: ToPyObject; + /// Computes `self << other`. fn lshift(&self, other: O) -> PyResult> where @@ -1898,6 +1921,14 @@ impl<'py> PyAnyMethods<'py> for Bound<'py, PyAny> { inner(self) } + fn bitnot(&self) -> PyResult> { + fn inner<'py>(any: &Bound<'py, PyAny>) -> PyResult> { + unsafe { ffi::PyNumber_Invert(any.as_ptr()).assume_owned_or_err(any.py()) } + } + + inner(self) + } + fn lt(&self, other: O) -> PyResult where O: ToPyObject, @@ -1949,13 +1980,34 @@ impl<'py> PyAnyMethods<'py> for Bound<'py, PyAny> { implement_binop!(add, PyNumber_Add, "+"); implement_binop!(sub, PyNumber_Subtract, "-"); implement_binop!(mul, PyNumber_Multiply, "*"); + implement_binop!(matmul, PyNumber_MatrixMultiply, "@"); implement_binop!(div, PyNumber_TrueDivide, "/"); + implement_binop!(floor_div, PyNumber_FloorDivide, "//"); + implement_binop!(rem, PyNumber_Remainder, "%"); implement_binop!(lshift, PyNumber_Lshift, "<<"); implement_binop!(rshift, PyNumber_Rshift, ">>"); implement_binop!(bitand, PyNumber_And, "&"); implement_binop!(bitor, PyNumber_Or, "|"); implement_binop!(bitxor, PyNumber_Xor, "^"); + /// Computes `divmod(self, other)`. + fn divmod(&self, other: O) -> PyResult> + where + O: ToPyObject, + { + fn inner<'py>( + any: &Bound<'py, PyAny>, + other: Bound<'_, PyAny>, + ) -> PyResult> { + unsafe { + ffi::PyNumber_Divmod(any.as_ptr(), other.as_ptr()).assume_owned_or_err(any.py()) + } + } + + let py = self.py(); + inner(self, other.to_object(py).into_bound(py)) + } + /// Computes `self ** other % modulus` (`pow(self, other, modulus)`). /// `py.None()` may be passed for the `modulus`. fn pow(&self, other: O1, modulus: O2) -> PyResult> diff --git a/tests/test_arithmetics.rs b/tests/test_arithmetics.rs index 007f42a79e8..ade41be406e 100644 --- a/tests/test_arithmetics.rs +++ b/tests/test_arithmetics.rs @@ -179,10 +179,26 @@ impl BinaryArithmetic { format!("BA * {:?}", rhs) } + fn __matmul__(&self, rhs: &Bound<'_, PyAny>) -> String { + format!("BA @ {:?}", rhs) + } + fn __truediv__(&self, rhs: &Bound<'_, PyAny>) -> String { format!("BA / {:?}", rhs) } + fn __floordiv__(&self, rhs: &Bound<'_, PyAny>) -> String { + format!("BA // {:?}", rhs) + } + + fn __mod__(&self, rhs: &Bound<'_, PyAny>) -> String { + format!("BA % {:?}", rhs) + } + + fn __divmod__(&self, rhs: &Bound<'_, PyAny>) -> String { + format!("divmod(BA, {:?})", rhs) + } + fn __lshift__(&self, rhs: &Bound<'_, PyAny>) -> String { format!("BA << {:?}", rhs) } @@ -217,6 +233,11 @@ fn binary_arithmetic() { py_run!(py, c, "assert c + 1 == 'BA + 1'"); py_run!(py, c, "assert c - 1 == 'BA - 1'"); py_run!(py, c, "assert c * 1 == 'BA * 1'"); + py_run!(py, c, "assert c @ 1 == 'BA @ 1'"); + py_run!(py, c, "assert c / 1 == 'BA / 1'"); + py_run!(py, c, "assert c // 1 == 'BA // 1'"); + py_run!(py, c, "assert c % 1 == 'BA % 1'"); + py_run!(py, c, "assert divmod(c, 1) == 'divmod(BA, 1)'"); py_run!(py, c, "assert c << 1 == 'BA << 1'"); py_run!(py, c, "assert c >> 1 == 'BA >> 1'"); py_run!(py, c, "assert c & 1 == 'BA & 1'"); @@ -230,6 +251,11 @@ fn binary_arithmetic() { py_expect_exception!(py, c, "1 + c", PyTypeError); py_expect_exception!(py, c, "1 - c", PyTypeError); py_expect_exception!(py, c, "1 * c", PyTypeError); + py_expect_exception!(py, c, "1 @ c", PyTypeError); + py_expect_exception!(py, c, "1 / c", PyTypeError); + py_expect_exception!(py, c, "1 // c", PyTypeError); + py_expect_exception!(py, c, "1 % c", PyTypeError); + py_expect_exception!(py, c, "divmod(1, c)", PyTypeError); py_expect_exception!(py, c, "1 << c", PyTypeError); py_expect_exception!(py, c, "1 >> c", PyTypeError); py_expect_exception!(py, c, "1 & c", PyTypeError); @@ -243,7 +269,11 @@ fn binary_arithmetic() { assert_py_eq!(c.add(&c).unwrap(), "BA + BA"); assert_py_eq!(c.sub(&c).unwrap(), "BA - BA"); assert_py_eq!(c.mul(&c).unwrap(), "BA * BA"); + assert_py_eq!(c.matmul(&c).unwrap(), "BA @ BA"); assert_py_eq!(c.div(&c).unwrap(), "BA / BA"); + assert_py_eq!(c.floor_div(&c).unwrap(), "BA // BA"); + assert_py_eq!(c.rem(&c).unwrap(), "BA % BA"); + assert_py_eq!(c.divmod(&c).unwrap(), "divmod(BA, BA)"); assert_py_eq!(c.lshift(&c).unwrap(), "BA << BA"); assert_py_eq!(c.rshift(&c).unwrap(), "BA >> BA"); assert_py_eq!(c.bitand(&c).unwrap(), "BA & BA");