diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 926d4bcfb26cbe..55daaae0873ede 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -108,6 +108,7 @@ from .logic import ( # noqa: F401 __rand__, __ror__, + __rxor__, allclose, bitwise_and, bitwise_and_, @@ -872,6 +873,7 @@ ('__or__', 'bitwise_or'), ('__ror__', '__ror__'), ('__xor__', 'bitwise_xor'), + ('__rxor__', '__rxor__'), ('__invert__', 'bitwise_not'), ('__pos__', 'positive'), ('__lshift__', '__lshift__'), diff --git a/python/paddle/tensor/logic.py b/python/paddle/tensor/logic.py index d5714c4136b791..cb00e7b7002a8b 100755 --- a/python/paddle/tensor/logic.py +++ b/python/paddle/tensor/logic.py @@ -1416,6 +1416,21 @@ def bitwise_xor( ) +def __rxor__( + x: Tensor, + y: int | bool, + out: Tensor | None = None, + name: str | None = None, +) -> Tensor: + if isinstance(y, (int, bool)): + y = paddle.to_tensor(y, dtype=x.dtype) + return bitwise_xor(y, x, out=out, name=name) + else: + raise TypeError( + f"unsupported operand type(s) for |: '{type(y).__name__}' and 'Tensor'" + ) + + @inplace_apis_in_dygraph_only def bitwise_xor_(x: Tensor, y: Tensor, name: str | None = None) -> Tensor: r""" diff --git a/python/paddle/tensor/tensor.prototype.pyi b/python/paddle/tensor/tensor.prototype.pyi index d739991f955889..ccbc46306a7909 100644 --- a/python/paddle/tensor/tensor.prototype.pyi +++ b/python/paddle/tensor/tensor.prototype.pyi @@ -172,6 +172,7 @@ class AbstractTensor: def __pow__(self, y: _typing.TensorLike) -> Tensor: ... def __and__(self, y: _typing.TensorLike) -> Tensor: ... def __ror__(self, y: _typing.TensorLike) -> Tensor: ... + def __rxor__(self, y: _typing.TensorLike) -> Tensor: ... def __div__(self, y: _typing.TensorLike) -> Tensor: ... def __radd__(self, y: _typing.TensorLike) -> Tensor: ... # type: ignore def __rsub__(self, y: _typing.TensorLike) -> Tensor: ... # type: ignore diff --git a/test/legacy_test/test_math_op_patch.py b/test/legacy_test/test_math_op_patch.py index d61db98cb4460a..c27dae10188ee4 100644 --- a/test/legacy_test/test_math_op_patch.py +++ b/test/legacy_test/test_math_op_patch.py @@ -389,6 +389,42 @@ def test_bitwise_xor(self): ) np.testing.assert_array_equal(out[0], out_np) + @prog_scope() + def test_rxor(self): + place = ( + paddle.CUDAPlace(0) + if paddle.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + x_int = 5 + y_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32") + y = paddle.static.data("y", y_np.shape, dtype=y_np.dtype) + z = x_int ^ y + exe = paddle.static.Executor(place) + out = exe.run( + feed={'y': y_np}, + fetch_list=[z], + ) + out_ref = x_int ^ y_np + np.testing.assert_array_equal(out[0], out_ref) + x_bool = True + res_rxor_bool = x_bool ^ y + out_bool = exe.run( + feed={'y': y_np}, + fetch_list=[res_rxor_bool], + ) + res_py_bool = x_bool ^ y_np + np.testing.assert_array_equal(out_bool[0], res_py_bool) + + for x_invalid in ( + np.float32(5.0), + np.float64(5.0), + np.complex64(5), + np.complex128(5.0 + 2j), + ): + with self.assertRaises(TypeError): + x_invalid ^ y + @prog_scope() def test_bitwise_not(self): x_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32") diff --git a/test/legacy_test/test_math_op_patch_pir.py b/test/legacy_test/test_math_op_patch_pir.py index 48eb26e19bf576..3ca932e2ccf57b 100644 --- a/test/legacy_test/test_math_op_patch_pir.py +++ b/test/legacy_test/test_math_op_patch_pir.py @@ -225,6 +225,63 @@ def test_bitwise_xor(self): np.testing.assert_array_equal(res_np_c, c_np) np.testing.assert_array_equal(res_np_d, d_np) + def test_rxor(self): + with dygraph_guard(): + x_int32 = 5 + x_bool = True + y_np = np.random.randint(0, 2, [2, 3, 5]).astype("int32") + y_tensor = paddle.to_tensor(y_np) + res_ror_int32 = x_int32 ^ y_tensor + res_py_int32 = x_int32 ^ y_tensor.numpy() + np.testing.assert_array_equal(res_py_int32, res_ror_int32.numpy()) + res_ror_bool = x_bool ^ y_tensor + res_py_bool = x_bool ^ y_tensor.numpy() + np.testing.assert_array_equal(res_py_bool, res_ror_bool.numpy()) + for x_np in ( + np.float32(5.0), + np.float64(5.0), + np.complex64(5), + np.complex128(5.0 + 2j), + ): + with self.assertRaises(TypeError): + x_np ^ y_tensor + + with static_guard(): + with paddle.pir_utils.IrGuard(): + main_program, exe, program_guard = new_program() + with program_guard: + x_int = 5 + y_np = np.random.randint(-100, 100, [2, 3, 5]).astype( + "int32" + ) + y = paddle.static.data("y", y_np.shape, dtype=y_np.dtype) + z = x_int ^ y + out = exe.run( + main_program, + feed={'y': y_np}, + fetch_list=[z], + ) + out_ref = x_int ^ y_np + np.testing.assert_array_equal(out[0], out_ref) + x_bool = True + res_rxor_bool = x_bool ^ y + out_bool = exe.run( + main_program, + feed={'y': y_np}, + fetch_list=[res_rxor_bool], + ) + res_py_bool = x_bool ^ y_np + np.testing.assert_array_equal(out_bool[0], res_py_bool) + + for x_invalid in ( + np.float32(5.0), + np.float64(5.0), + np.complex64(5), + np.complex128(5.0 + 2j), + ): + with self.assertRaises(TypeError): + x_invalid ^ y + def test_bitwise_or(self): paddle.disable_static() x_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32")