diff --git a/python/paddle/static/nn/control_flow.py b/python/paddle/static/nn/control_flow.py index a6a2027ac9a3d..7f6f607e0ed4e 100644 --- a/python/paddle/static/nn/control_flow.py +++ b/python/paddle/static/nn/control_flow.py @@ -868,7 +868,6 @@ def case(pred_fn_pairs, default=None, name=None): ... print(res_1, res_2) [[1. 1.]] [3 3 3] ''' - helper = LayerHelper('case', **locals()) def _case_check_args(pred_fn_pairs, default): ''' @@ -899,16 +898,9 @@ def _case_check_args(pred_fn_pairs, default): ) pred, fn = pred_fn - if not isinstance(pred, Variable): - raise TypeError( - _error_message( - "The pred's type", - "pred_fn_pairs", - "case", - "boolean Variable", - type(pred), - ) - ) + check_variable_and_dtype( + pred, 'pred', ['bool'], 'paddle.static.nn.case' + ) if not callable(fn): raise TypeError( diff --git a/test/legacy_test/test_case.py b/test/legacy_test/test_case.py index 294f43542bfe6..d1f67f1cd70cd 100644 --- a/test/legacy_test/test_case.py +++ b/test/legacy_test/test_case.py @@ -22,11 +22,13 @@ from paddle.base import core from paddle.base.backward import append_backward from paddle.base.framework import Program, program_guard +from paddle.pir_utils import test_with_pir_api paddle.enable_static() class TestAPICase(unittest.TestCase): + @test_with_pir_api def test_return_single_var(self): def fn_1(): return paddle.tensor.fill_constant( @@ -43,9 +45,9 @@ def fn_3(): shape=[4, 3], dtype='int32', value=3 ) - main_program = Program() - startup_program = Program() - with program_guard(main_program, startup_program): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): x = paddle.tensor.fill_constant( shape=[1], dtype='float32', value=0.3 ) @@ -100,6 +102,7 @@ def fn_3(): np.testing.assert_allclose(res[3], 2, rtol=1e-05) np.testing.assert_allclose(res[4], 2, rtol=1e-05) + @test_with_pir_api def test_0d_tensor(self): def fn_1(): return paddle.full(shape=[], dtype='int32', fill_value=1) @@ -110,9 +113,9 @@ def fn_2(): def fn_3(): return paddle.full(shape=[], dtype='int32', fill_value=3) - main_program = Program() - startup_program = Program() - with program_guard(main_program, startup_program): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): x = paddle.full(shape=[], dtype='float32', fill_value=0.3) y = paddle.full(shape=[], dtype='float32', fill_value=0.1) z = paddle.full(shape=[], dtype='float32', fill_value=0.2) @@ -166,10 +169,12 @@ def fn_3(): np.testing.assert_allclose(res[4], 2, rtol=1e-05) self.assertEqual(res[4].shape, ()) + # Todo(zhangbo): grad_list can not find dx in oir mode + # @test_with_pir_api def test_0d_tensor_backward(self): - main_program = Program() - startup_program = Program() - with program_guard(main_program, startup_program): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): x = paddle.full(shape=[], dtype='float32', fill_value=-2.0) x.stop_gradient = False pred = paddle.full(shape=[], dtype='bool', fill_value=0) @@ -177,7 +182,7 @@ def test_0d_tensor_backward(self): out = paddle.static.nn.case( pred_fn_pairs=[(pred, lambda: x)], default=lambda: -x ) - append_backward(out) + grad_list = append_backward(out) place = ( base.CUDAPlace(0) @@ -186,7 +191,14 @@ def test_0d_tensor_backward(self): ) exe = base.Executor(place) - res = exe.run(main_program, fetch_list=[out.name, x.grad_name]) + if paddle.framework.in_pir_mode(): + for p, g in grad_list: + if p.is_same(x): + dx = g + res = exe.run(main_program, fetch_list=[out, dx]) + else: + res = exe.run(main_program, fetch_list=[out.name, x.grad_name]) + np.testing.assert_allclose( np.asarray(res[0]), np.array(2.0), rtol=1e-05 ) @@ -252,6 +264,7 @@ def fn_3(): paddle.enable_static() + @test_with_pir_api def test_return_var_tuple(self): def fn_1(): return paddle.tensor.fill_constant( @@ -269,14 +282,14 @@ def fn_2(): def fn_3(): return paddle.tensor.fill_constant( - shape=[5], dtype='int32', value=5 + shape=[5, 6], dtype='int32', value=5 ), paddle.tensor.fill_constant( shape=[5, 6], dtype='float32', value=6 ) - main_program = Program() - startup_program = Program() - with program_guard(main_program, startup_program): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): x = paddle.tensor.fill_constant(shape=[1], dtype='float32', value=1) y = paddle.tensor.fill_constant(shape=[1], dtype='float32', value=1) z = paddle.tensor.fill_constant(shape=[1], dtype='float32', value=3) @@ -305,6 +318,7 @@ def fn_3(): class TestAPICase_Nested(unittest.TestCase): + @test_with_pir_api def test_nested_case(self): def fn_1(x=1): var_5 = paddle.tensor.fill_constant( @@ -383,9 +397,9 @@ def fn_3(): ) return out - main_program = Program() - startup_program = Program() - with program_guard(main_program, startup_program): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): x = paddle.tensor.fill_constant( shape=[1], dtype='float32', value=0.3 ) @@ -423,6 +437,7 @@ def fn_3(): np.testing.assert_allclose(res[1], 2, rtol=1e-05) np.testing.assert_allclose(res[2], 3, rtol=1e-05) + @test_with_pir_api def test_nested_0d_tensor(self): def fn_1(x=1): var_5 = paddle.full(shape=[], dtype='int32', fill_value=5) @@ -489,9 +504,9 @@ def fn_3(): ) return out - main_program = Program() - startup_program = Program() - with program_guard(main_program, startup_program): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): x = paddle.full(shape=[], dtype='float32', fill_value=0.3) y = paddle.full(shape=[], dtype='float32', fill_value=0.1) z = paddle.full(shape=[], dtype='float32', fill_value=0.2)