-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Paddle Tensor No.6][BUPT] 新增 Tensor.__ror__
#69463
Changes from 17 commits
0b7d37c
9c656f6
da50f62
fbf7bd8
d8753a3
d7cf30c
fe7d090
90aa056
94871b1
61d58ec
a94b756
ea4125e
f458b03
b926b32
4d6ad06
1e367e2
3e22e18
7f6ab52
f877fb6
71ecd47
e738d94
9321efa
1935338
d88d66c
0e6b587
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1290,6 +1290,21 @@ def bitwise_or( | |
) | ||
|
||
|
||
def __ror__( | ||
x: Tensor, | ||
y: Tensor | int | bool, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. y不会是Tensor |
||
out: Tensor | None = None, | ||
name: str | None = None, | ||
) -> Tensor: | ||
if isinstance(y, (int, bool)): | ||
y = paddle.to_tensor(y, dtype=x.dtype) | ||
return bitwise_or(y, x, out=out, name=name) | ||
else: | ||
raise TypeError( | ||
f"unsupported operand type(s) for |: '{type(y).__name__}' and 'Tensor'" | ||
) | ||
|
||
|
||
@inplace_apis_in_dygraph_only | ||
def bitwise_or_(x: Tensor, y: Tensor, name: str | None = None) -> Tensor: | ||
r""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -333,6 +333,42 @@ def test_bitwise_or(self): | |
) | ||
np.testing.assert_array_equal(out[0], out_np) | ||
|
||
@prog_scope() | ||
def test_ror(self): | ||
place = ( | ||
paddle.CUDAPlace(0) | ||
if paddle.is_compiled_with_cuda() | ||
else paddle.CPUPlace() | ||
) | ||
x_int = 5 | ||
y_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32") | ||
y = paddle.static.data("y", y_np.shape, dtype=y_np.dtype) | ||
z = x_int | y | ||
exe = paddle.static.Executor(place) | ||
out = exe.run( | ||
feed={'x': x_int, 'y': y_np}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. feed里的 |
||
fetch_list=[z], | ||
) | ||
out_ref = x_int | y_np | ||
np.testing.assert_array_equal(out[0], out_ref) | ||
x_bool = True | ||
res_ror_bool = x_bool | y | ||
out_bool = exe.run( | ||
feed={'x': x_bool, 'y': y_np}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上, |
||
fetch_list=[res_ror_bool], | ||
) | ||
res_py_bool = x_bool | y_np | ||
np.testing.assert_array_equal(out_bool[0], res_py_bool) | ||
|
||
for x_invalid in ( | ||
np.float32(5.0), | ||
np.float64(5.0), | ||
np.complex64(5), | ||
np.complex128(5.0 + 2j), | ||
): | ||
with self.assertRaises(TypeError): | ||
x_invalid | y | ||
|
||
@prog_scope() | ||
def test_bitwise_xor(self): | ||
x_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32") | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个文件能加一个动态图的测试吗?期望是 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 单测修改一下,感觉没测试到预期的代码上 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 还有一个:test/legacy_test/test_math_op_patch.py 也需要添加单测吧? |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -248,6 +248,60 @@ def test_bitwise_or(self): | |
np.testing.assert_array_equal(res_np_c, c_np) | ||
np.testing.assert_array_equal(res_np_d, d_np) | ||
|
||
def test_dygraph_ror(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 参考同文件的其它代码,添加pir的测试 |
||
paddle.disable_static() | ||
x_int32 = 5 | ||
x_bool = True | ||
y_np = np.random.randint(0, 2, [2, 3, 5]).astype("int32") | ||
y_tensor = paddle.to_tensor(y_np) | ||
res_ror_int32 = x_int32 | y_tensor | ||
res_py_int32 = x_int32 | y_tensor.numpy() | ||
np.testing.assert_array_equal(res_py_int32, res_ror_int32.numpy()) | ||
res_ror_bool = x_bool | y_tensor | ||
res_py_bool = x_bool | y_tensor.numpy() | ||
np.testing.assert_array_equal(res_py_bool, res_ror_bool.numpy()) | ||
for x_np in ( | ||
np.float32(5.0), | ||
np.float64(5.0), | ||
np.complex64(5), | ||
np.complex128(5.0 + 2j), | ||
): | ||
with self.assertRaises(TypeError): | ||
x_np | y_tensor | ||
paddle.enable_static() | ||
with paddle.pir_utils.IrGuard(): | ||
main_program, exe, program_guard = new_program() | ||
with program_guard: | ||
x_int = 5 | ||
y_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32") | ||
y = paddle.static.data("y", y_np.shape, dtype=y_np.dtype) | ||
z = x_int | y | ||
out = exe.run( | ||
main_program, | ||
feed={'x': x_int, 'y': y_np}, | ||
fetch_list=[z], | ||
) | ||
out_ref = x_int | y_np | ||
np.testing.assert_array_equal(out[0], out_ref) | ||
x_bool = True | ||
res_ror_bool = x_bool | y | ||
out_bool = exe.run( | ||
main_program, | ||
feed={'x': x_bool, 'y': y_np}, | ||
fetch_list=[res_ror_bool], | ||
) | ||
res_py_bool = x_bool | y_np | ||
np.testing.assert_array_equal(out_bool[0], res_py_bool) | ||
|
||
for x_invalid in ( | ||
np.float32(5.0), | ||
np.float64(5.0), | ||
np.complex64(5), | ||
np.complex128(5.0 + 2j), | ||
): | ||
with self.assertRaises(TypeError): | ||
x_invalid | y | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pir的单测不要写到 |
||
|
||
def test_bitwise_and(self): | ||
paddle.disable_static() | ||
x_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个文件应该不用改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感谢指正,那这个__ror__函数我应该在哪个文件里实现呢?