From 3e3e1a5d3417a77a8f64b1e7b7e53c197665b296 Mon Sep 17 00:00:00 2001 From: juncaipeng <13006307475@163.com> Date: Sat, 18 Feb 2023 16:49:24 +0800 Subject: [PATCH] [Zero-dim] add unittest for static.nn.prelu --- .../tests/unittests/test_zero_dim_tensor.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index 0b3a6c20ec1a9..99c2bc117d74f 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -2750,6 +2750,29 @@ def test_prelu(self): self.assertEqual(res[4].shape, ()) self.assertEqual(res[5].shape, ()) + def test_static_nn_prelu(self): + x1 = paddle.full([], 1.0, 'float32') + x1.stop_gradient = False + out1 = paddle.static.nn.prelu(x1, 'all') + paddle.static.append_backward(out1.sum()) + + prog = paddle.static.default_main_program() + self.exe.run(paddle.static.default_startup_program()) + res = self.exe.run( + prog, + fetch_list=[ + out1, + x1.grad_name, + out1.grad_name, + ], + ) + + self.assertEqual(res[0].shape, ()) + self.assertEqual(res[1].shape, ()) + self.assertEqual(res[2].shape, ()) + np.testing.assert_allclose(res[0], np.array(1)) + np.testing.assert_allclose(res[1], np.array(1)) + @prog_scope() def test_while_loop(self): def cond(i, x):