Skip to content

Commit

Permalink
fix comment
Browse files Browse the repository at this point in the history
  • Loading branch information
zhwesky2010 committed Nov 8, 2022
1 parent b1c1c76 commit 0b68cc5
Showing 1 changed file with 49 additions and 18 deletions.
67 changes: 49 additions & 18 deletions python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,13 +233,6 @@ def test_static(self):
{'func': paddle.subtract, 'cls_method': '__sub__'},
{'func': paddle.multiply, 'cls_method': '__mul__'},
{'func': paddle.divide, 'cls_method': '__div__'},
{'func': paddle.subtract, 'cls_method': '__sub__'},
{'func': paddle.pow, 'cls_method': '__pow__'},
{'func': paddle.add, 'cls_method': '__add__'},
{'func': paddle.subtract, 'cls_method': '__sub__'},
{'func': paddle.multiply, 'cls_method': '__mul__'},
{'func': paddle.divide, 'cls_method': '__div__'},
{'func': paddle.subtract, 'cls_method': '__sub__'},
{'func': paddle.pow, 'cls_method': '__pow__'},
]

Expand All @@ -258,6 +251,12 @@ def test_static(self):
paddle.logical_xor,
]

binary_int_api_list_without_grad = [
paddle.bitwise_and,
paddle.bitwise_or,
paddle.bitwise_xor,
]


# Use to test zero-dim of binary API
class TestBinaryAPI(unittest.TestCase):
Expand All @@ -278,7 +277,6 @@ def test_dygraph_binary(self):
out = api(x, y)

self.assertEqual(out.shape, [])

if api not in binary_api_list_without_grad:
out.backward()
self.assertEqual(x.grad.shape, [])
Expand All @@ -298,7 +296,6 @@ def test_dygraph_binary(self):
out = api(x, y)

self.assertEqual(out.shape, [2, 3, 4])

if api not in binary_api_list_without_grad:
out.backward()
self.assertEqual(x.grad.shape, [2, 3, 4])
Expand All @@ -318,7 +315,6 @@ def test_dygraph_binary(self):
out = api(x, y)

self.assertEqual(out.shape, [2, 3, 4])

if api not in binary_api_list_without_grad:
out.backward()
self.assertEqual(x.grad.shape, [])
Expand All @@ -333,11 +329,30 @@ def test_dygraph_binary(self):
out = getattr(paddle.Tensor, api['cls_method'])(x, y)
self.assertEqual(out.shape, [])

for api in binary_int_api_list_without_grad:
# 1) x/y is 0D
x = paddle.randint(-10, 10, [])
y = paddle.randint(-10, 10, [])
out = api(x, y)
self.assertEqual(out.shape, [])

# 2) x is not 0D , y is 0D
x = paddle.randint(-10, 10, [3, 5])
y = paddle.randint(-10, 10, [])
out = api(x, y)
self.assertEqual(out.shape, [3, 5])

# 3) x is 0D , y is not 0D
x = paddle.randint(-10, 10, [])
y = paddle.randint(-10, 10, [3, 5])
out = api(x, y)
self.assertEqual(out.shape, [3, 5])

paddle.enable_static()

def test_static_unary(self):
paddle.enable_static()
for api in binary_api_list:
for api in binary_api_list + binary_api_list_without_grad:
main_prog = fluid.Program()
with fluid.program_guard(main_prog, fluid.Program()):
# 1) x/y is 0D
Expand All @@ -355,17 +370,12 @@ def test_static_unary(self):
out = api(x, y)
fluid.backward.append_backward(out)

# append_backward always set grad shape to [1]
prog = paddle.static.default_main_program()
block = prog.global_block()

# Test compile shape
self.assertEqual(out.shape, ())

exe = fluid.Executor()
result = exe.run(main_prog, fetch_list=[x, y, out])
out_np = exe.run(main_prog, fetch_list=[out])[0]
# Test runtime shape
self.assertEqual(result[2].shape, ())
self.assertEqual(out_np.shape, ())

# 2) x is 0D , y is scalar
x = paddle.rand([])
Expand All @@ -377,6 +387,27 @@ def test_static_unary(self):
)
self.assertEqual(out.shape, ())

for api in binary_int_api_list_without_grad:
main_prog = fluid.Program()
with fluid.program_guard(main_prog, fluid.Program()):
# 1) x/y is 0D
x = paddle.randint(-10, 10, [])
y = paddle.randint(-10, 10, [])
out = api(x, y)
self.assertEqual(out.shape, ())

# 2) x is not 0D , y is 0D
x = paddle.randint(-10, 10, [3, 5])
y = paddle.randint(-10, 10, [])
out = api(x, y)
self.assertEqual(out.shape, (3, 5))

# 3) x is 0D , y is not 0D
x = paddle.randint(-10, 10, [])
y = paddle.randint(-10, 10, [3, 5])
out = api(x, y)
self.assertEqual(out.shape, (3, 5))

paddle.disable_static()


Expand Down

0 comments on commit 0b68cc5

Please sign in to comment.