diff --git a/test/legacy_test/test_case.py b/test/legacy_test/test_case.py index d1f67f1cd70cd..1a5cf3e459e6b 100644 --- a/test/legacy_test/test_case.py +++ b/test/legacy_test/test_case.py @@ -169,14 +169,14 @@ def fn_3(): np.testing.assert_allclose(res[4], 2, rtol=1e-05) self.assertEqual(res[4].shape, ()) - # Todo(zhangbo): grad_list can not find dx in oir mode - # @test_with_pir_api + @test_with_pir_api def test_0d_tensor_backward(self): main_program = paddle.static.Program() startup_program = paddle.static.Program() with paddle.static.program_guard(main_program, startup_program): x = paddle.full(shape=[], dtype='float32', fill_value=-2.0) x.stop_gradient = False + x.persistable = True pred = paddle.full(shape=[], dtype='bool', fill_value=0) # pred is False, so out = -x out = paddle.static.nn.case(