Skip to content

Commit

Permalink
【SCU】【Paddle Tensor No.7】新增 Tensor.__rxor__ 复用已有接口Tensor.__xor__ (#69779
Browse files Browse the repository at this point in the history
)

* rxor实现

* 提交信息
  • Loading branch information
yangrongxinuser authored Nov 29, 2024
1 parent 97bacd1 commit 1a58803
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
from .logic import ( # noqa: F401
__rand__,
__ror__,
__rxor__,
allclose,
bitwise_and,
bitwise_and_,
Expand Down Expand Up @@ -872,6 +873,7 @@
('__or__', 'bitwise_or'),
('__ror__', '__ror__'),
('__xor__', 'bitwise_xor'),
('__rxor__', '__rxor__'),
('__invert__', 'bitwise_not'),
('__pos__', 'positive'),
('__lshift__', '__lshift__'),
Expand Down
15 changes: 15 additions & 0 deletions python/paddle/tensor/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,6 +1416,21 @@ def bitwise_xor(
)


def __rxor__(
x: Tensor,
y: int | bool,
out: Tensor | None = None,
name: str | None = None,
) -> Tensor:
if isinstance(y, (int, bool)):
y = paddle.to_tensor(y, dtype=x.dtype)
return bitwise_xor(y, x, out=out, name=name)
else:
raise TypeError(
f"unsupported operand type(s) for |: '{type(y).__name__}' and 'Tensor'"
)


@inplace_apis_in_dygraph_only
def bitwise_xor_(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
r"""
Expand Down
1 change: 1 addition & 0 deletions python/paddle/tensor/tensor.prototype.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ class AbstractTensor:
def __pow__(self, y: _typing.TensorLike) -> Tensor: ...
def __and__(self, y: _typing.TensorLike) -> Tensor: ...
def __ror__(self, y: _typing.TensorLike) -> Tensor: ...
def __rxor__(self, y: _typing.TensorLike) -> Tensor: ...
def __div__(self, y: _typing.TensorLike) -> Tensor: ...
def __radd__(self, y: _typing.TensorLike) -> Tensor: ... # type: ignore
def __rsub__(self, y: _typing.TensorLike) -> Tensor: ... # type: ignore
Expand Down
36 changes: 36 additions & 0 deletions test/legacy_test/test_math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,42 @@ def test_bitwise_xor(self):
)
np.testing.assert_array_equal(out[0], out_np)

@prog_scope()
def test_rxor(self):
place = (
paddle.CUDAPlace(0)
if paddle.is_compiled_with_cuda()
else paddle.CPUPlace()
)
x_int = 5
y_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32")
y = paddle.static.data("y", y_np.shape, dtype=y_np.dtype)
z = x_int ^ y
exe = paddle.static.Executor(place)
out = exe.run(
feed={'y': y_np},
fetch_list=[z],
)
out_ref = x_int ^ y_np
np.testing.assert_array_equal(out[0], out_ref)
x_bool = True
res_rxor_bool = x_bool ^ y
out_bool = exe.run(
feed={'y': y_np},
fetch_list=[res_rxor_bool],
)
res_py_bool = x_bool ^ y_np
np.testing.assert_array_equal(out_bool[0], res_py_bool)

for x_invalid in (
np.float32(5.0),
np.float64(5.0),
np.complex64(5),
np.complex128(5.0 + 2j),
):
with self.assertRaises(TypeError):
x_invalid ^ y

@prog_scope()
def test_bitwise_not(self):
x_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32")
Expand Down
57 changes: 57 additions & 0 deletions test/legacy_test/test_math_op_patch_pir.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,63 @@ def test_bitwise_xor(self):
np.testing.assert_array_equal(res_np_c, c_np)
np.testing.assert_array_equal(res_np_d, d_np)

def test_rxor(self):
with dygraph_guard():
x_int32 = 5
x_bool = True
y_np = np.random.randint(0, 2, [2, 3, 5]).astype("int32")
y_tensor = paddle.to_tensor(y_np)
res_ror_int32 = x_int32 ^ y_tensor
res_py_int32 = x_int32 ^ y_tensor.numpy()
np.testing.assert_array_equal(res_py_int32, res_ror_int32.numpy())
res_ror_bool = x_bool ^ y_tensor
res_py_bool = x_bool ^ y_tensor.numpy()
np.testing.assert_array_equal(res_py_bool, res_ror_bool.numpy())
for x_np in (
np.float32(5.0),
np.float64(5.0),
np.complex64(5),
np.complex128(5.0 + 2j),
):
with self.assertRaises(TypeError):
x_np ^ y_tensor

with static_guard():
with paddle.pir_utils.IrGuard():
main_program, exe, program_guard = new_program()
with program_guard:
x_int = 5
y_np = np.random.randint(-100, 100, [2, 3, 5]).astype(
"int32"
)
y = paddle.static.data("y", y_np.shape, dtype=y_np.dtype)
z = x_int ^ y
out = exe.run(
main_program,
feed={'y': y_np},
fetch_list=[z],
)
out_ref = x_int ^ y_np
np.testing.assert_array_equal(out[0], out_ref)
x_bool = True
res_rxor_bool = x_bool ^ y
out_bool = exe.run(
main_program,
feed={'y': y_np},
fetch_list=[res_rxor_bool],
)
res_py_bool = x_bool ^ y_np
np.testing.assert_array_equal(out_bool[0], res_py_bool)

for x_invalid in (
np.float32(5.0),
np.float64(5.0),
np.complex64(5),
np.complex128(5.0 + 2j),
):
with self.assertRaises(TypeError):
x_invalid ^ y

def test_bitwise_or(self):
paddle.disable_static()
x_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32")
Expand Down

0 comments on commit 1a58803

Please sign in to comment.