Skip to content

Commit

Permalink
add test for zero dimensional tensor for real, imag, angle, conj, as_…
Browse files Browse the repository at this point in the history
…real and sequence_pad (PaddlePaddle#49921)
  • Loading branch information
Feiyu Chan authored and pangengzheng committed Feb 2, 2023
1 parent 52a814e commit fddca4c
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 0 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/operators/sequence_ops/sequence_pad_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class SequencePadOp : public framework::OperatorWithKernel {
auto pad_value_dims = ctx->GetInputDim("PadValue");
PADDLE_ENFORCE_EQ(
pad_value_dims == phi::make_ddim({1}) ||
pad_value_dims == phi::make_ddim({}) ||
pad_value_dims == time_step_dims,
true,
platform::errors::InvalidArgument(
Expand Down
187 changes: 187 additions & 0 deletions python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2282,6 +2282,21 @@ def test_t(self):
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ())

@prog_scope()
def test_sequence_pad(self):
x = paddle.static.data("x", [-1, 2], dtype=paddle.int64, lod_level=1)
value = paddle.to_tensor(1000, dtype=paddle.int64).squeeze()
out = paddle.static.nn.sequence_pad(x, value)

x_tensor = paddle.fluid.create_lod_tensor(
np.arange(20).astype(np.int64).reshape(-1, 2),
[[3, 3, 4]],
place=self.exe.place,
)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, feed={"x": x_tensor}, fetch_list=[out])
self.assertEqual(res[0].shape, (3, 4, 2))


# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase):
Expand Down Expand Up @@ -2633,5 +2648,177 @@ def test_one_hot_label(self):
self.assertEqual(res[0][2], 1)


unary_apis_with_complex_input = [
paddle.real,
paddle.imag,
paddle.angle,
paddle.conj,
]


class TestUnaryElementwiseAPIWithComplexInput(unittest.TestCase):
def test_dygraph_unary(self):
paddle.disable_static()
for api in unary_apis_with_complex_input:
x = paddle.to_tensor(2.0 + 3.0j).squeeze()
x.stop_gradient = False
x.retain_grads()
out = api(x)
out.retain_grads()
out.backward()

self.assertEqual(x.shape, [])
self.assertEqual(out.shape, [])
if x.grad is not None:
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.grad.shape, [])

paddle.enable_static()

def test_static_unary(self):
paddle.enable_static()

for api in unary_apis_with_complex_input:
main_prog = paddle.static.Program()
block = main_prog.global_block()
exe = paddle.static.Executor()
with paddle.static.program_guard(
main_prog, paddle.static.Program()
):
# before full support for complex, we cannot create complex tensor with the same code as in dynamic graph
x = paddle.complex(
paddle.to_tensor(2.0), paddle.to_tensor(2.0)
).squeeze()
x.stop_gradient = False
out = api(x)
# TODO(zhouwei):
# ScaleLossGradOp / append_backward set grad shape to [1]
# after output 0D, may change it to []
# use out.sum() to avoid this two problem now
loss = out.sum()
paddle.static.append_backward(loss)

fetch_list = [x, out]
if block.has_var(x.grad_name):
fetch_list.extend([x.grad_name, out.grad_name])

# 1) Test Program
res = exe.run(main_prog, fetch_list=fetch_list)
for item in res:
self.assertEqual(item.shape, ())

# 2) Test CompiledProgram Program
if paddle.device.is_compiled_with_cuda():
places = [paddle.CUDAPlace(0)]
expect_shape = ()
else:
places = [paddle.CPUPlace()] * 4
expect_shape = (4,)
compile_prog = paddle.static.CompiledProgram(
main_prog
).with_data_parallel(loss.name, places=places)

# return_merged=False #
res = exe.run(
compile_prog, fetch_list=fetch_list, return_merged=False
)
for item1 in res:
for item2 in item1:
self.assertEqual(item2.shape, ())

# return_merged=True #
res = exe.run(
compile_prog, fetch_list=fetch_list, return_merged=True
)
for item in res:
self.assertEqual(item.shape, expect_shape)

paddle.disable_static()


class TestAsReal(unittest.TestCase):
def test_dygraph(self):
paddle.disable_static()
for api in unary_apis_with_complex_input:
x = paddle.to_tensor(2.0 + 3.0j).squeeze()
x.stop_gradient = False
x.retain_grads()
out = paddle.as_real(x)
out.retain_grads()
out.backward()

self.assertEqual(x.shape, [])
self.assertEqual(out.shape, [2])
if x.grad is not None:
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.grad.shape, [2])

paddle.enable_static()

def test_static(self):
paddle.enable_static()

for api in unary_apis_with_complex_input:
main_prog = paddle.static.Program()
block = main_prog.global_block()
exe = paddle.static.Executor()
with paddle.static.program_guard(
main_prog, paddle.static.Program()
):
# before full support for complex, we cannot create complex tensor with the same code as in dynamic graph
x = paddle.complex(
paddle.to_tensor(2.0), paddle.to_tensor(2.0)
).squeeze()
x.stop_gradient = False
out = paddle.as_real(x)
self.assertEqual(x.shape, ())
self.assertEqual(out.shape, (2,))
# TODO(zhouwei):
# ScaleLossGradOp / append_backward set grad shape to [1]
# after output 0D, may change it to []
# use out.sum() to avoid this two problem now
loss = out.abs().sum()
paddle.static.append_backward(loss)

fetch_list = [x, out]
if block.has_var(x.grad_name):
fetch_list.extend([x.grad_name, out.grad_name])

# 1) Test Program
res = exe.run(main_prog, fetch_list=fetch_list)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (2,))
self.assertEqual(res[2].shape, ())
self.assertEqual(res[3].shape, (2,))

# 2) Test CompiledProgram Program
if paddle.device.is_compiled_with_cuda():
places = [paddle.CUDAPlace(0)]
expect_shapes = (), (2,), (), (2,)
else:
places = [paddle.CPUPlace()] * 4
expect_shapes = (4,), (8,), (4,), (8,)
compile_prog = paddle.static.CompiledProgram(
main_prog
).with_data_parallel(loss.name, places=places)

# return_merged=False #
res = exe.run(
compile_prog, fetch_list=fetch_list, return_merged=False
)
for out_i, expect in zip(res, [(), (2,), (), (2,)]):
for replica in out_i:
self.assertEqual(replica.shape, expect)

# return_merged=True #
res = exe.run(
compile_prog, fetch_list=fetch_list, return_merged=True
)
for actual, expect in zip(res, expect_shapes):
self.assertEqual(actual.shape, expect)

paddle.disable_static()


if __name__ == "__main__":
unittest.main()

0 comments on commit fddca4c

Please sign in to comment.