Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Paddle Tensor No.6][BUPT] 新增 Tensor.__ror__ #69463

Merged
merged 25 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
transpose_,
)
from .logic import ( # noqa: F401
__ror__,
allclose,
bitwise_and,
bitwise_and_,
Expand Down Expand Up @@ -859,6 +860,7 @@
magic_method_func = [
('__and__', 'bitwise_and'),
('__or__', 'bitwise_or'),
('__ror__', '__ror__'),
('__xor__', 'bitwise_xor'),
('__invert__', 'bitwise_not'),
('__pos__', 'positive'),
Expand Down
15 changes: 15 additions & 0 deletions python/paddle/tensor/logic.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件应该不用改

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感谢指正,那这个__ror__函数我应该在哪个文件里实现呢?

Original file line number Diff line number Diff line change
Expand Up @@ -1290,6 +1290,21 @@ def bitwise_or(
)


def __ror__(
x: Tensor,
y: Tensor | int | bool,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

y不会是Tensor

out: Tensor | None = None,
name: str | None = None,
) -> Tensor:
if isinstance(y, (int, bool)):
y = paddle.to_tensor(y, dtype=x.dtype)
return bitwise_or(y, x, out=out, name=name)
else:
raise TypeError(
f"unsupported operand type(s) for |: '{type(y).__name__}' and 'Tensor'"
)


@inplace_apis_in_dygraph_only
def bitwise_or_(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
r"""
Expand Down
36 changes: 36 additions & 0 deletions test/legacy_test/test_math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,42 @@ def test_bitwise_or(self):
)
np.testing.assert_array_equal(out[0], out_np)

@prog_scope()
def test_ror(self):
place = (
paddle.CUDAPlace(0)
if paddle.is_compiled_with_cuda()
else paddle.CPUPlace()
)
x_int = 5
y_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32")
y = paddle.static.data("y", y_np.shape, dtype=y_np.dtype)
z = x_int | y
exe = paddle.static.Executor(place)
out = exe.run(
feed={'x': x_int, 'y': y_np},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feed里的'x': x_int应该是不需要的吧?可以确认下

fetch_list=[z],
)
out_ref = x_int | y_np
np.testing.assert_array_equal(out[0], out_ref)
x_bool = True
res_ror_bool = x_bool | y
out_bool = exe.run(
feed={'x': x_bool, 'y': y_np},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,'x': x_bool应该可以删掉,因为这只是一个int,没有对应的Variable占位符

fetch_list=[res_ror_bool],
)
res_py_bool = x_bool | y_np
np.testing.assert_array_equal(out_bool[0], res_py_bool)

for x_invalid in (
np.float32(5.0),
np.float64(5.0),
np.complex64(5),
np.complex128(5.0 + 2j),
):
with self.assertRaises(TypeError):
x_invalid | y

@prog_scope()
def test_bitwise_xor(self):
x_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32")
Expand Down
54 changes: 54 additions & 0 deletions test/legacy_test/test_math_op_patch_pir.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件能加一个动态图的测试吗?期望是 np.ndarray | Tensor 或者 int | Tensor,来触发 __ror__,参考:https://github.com/PaddlePaddle/Paddle/pull/69348/files#diff-60c4d2d54c7500c4405914e9c50c03203923aa4fb48bba1e68d0362be2377ca1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单测修改一下,感觉没测试到预期的代码上

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

还有一个:test/legacy_test/test_math_op_patch.py 也需要添加单测吧?

Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,60 @@ def test_bitwise_or(self):
np.testing.assert_array_equal(res_np_c, c_np)
np.testing.assert_array_equal(res_np_d, d_np)

def test_dygraph_ror(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考同文件的其它代码,添加pir的测试

paddle.disable_static()
x_int32 = 5
x_bool = True
y_np = np.random.randint(0, 2, [2, 3, 5]).astype("int32")
y_tensor = paddle.to_tensor(y_np)
res_ror_int32 = x_int32 | y_tensor
res_py_int32 = x_int32 | y_tensor.numpy()
np.testing.assert_array_equal(res_py_int32, res_ror_int32.numpy())
res_ror_bool = x_bool | y_tensor
res_py_bool = x_bool | y_tensor.numpy()
np.testing.assert_array_equal(res_py_bool, res_ror_bool.numpy())
for x_np in (
np.float32(5.0),
np.float64(5.0),
np.complex64(5),
np.complex128(5.0 + 2j),
):
with self.assertRaises(TypeError):
x_np | y_tensor
paddle.enable_static()
with paddle.pir_utils.IrGuard():
main_program, exe, program_guard = new_program()
with program_guard:
x_int = 5
y_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32")
y = paddle.static.data("y", y_np.shape, dtype=y_np.dtype)
z = x_int | y
out = exe.run(
main_program,
feed={'x': x_int, 'y': y_np},
fetch_list=[z],
)
out_ref = x_int | y_np
np.testing.assert_array_equal(out[0], out_ref)
x_bool = True
res_ror_bool = x_bool | y
out_bool = exe.run(
main_program,
feed={'x': x_bool, 'y': y_np},
fetch_list=[res_ror_bool],
)
res_py_bool = x_bool | y_np
np.testing.assert_array_equal(out_bool[0], res_py_bool)

for x_invalid in (
np.float32(5.0),
np.float64(5.0),
np.complex64(5),
np.complex128(5.0 + 2j),
):
with self.assertRaises(TypeError):
x_invalid | y
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pir的单测不要写到test_dygraph_ror里,可以新建一个函数


def test_bitwise_and(self):
paddle.disable_static()
x_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32")
Expand Down