From 7ee852fbd2c810ae4198a85d442202bdb1cd9452 Mon Sep 17 00:00:00 2001 From: gouzil <66515297+gouzil@users.noreply.github.com> Date: Tue, 31 Oct 2023 23:36:14 +0800 Subject: [PATCH] [Dy2St] Refactor dy2st unittest decorators name - Part 7 (#58509) --------- Co-authored-by: SigureMo --- .../dygraph_to_static_util.py | 173 ------------------ test/dygraph_to_static/test_container.py | 5 +- test/dygraph_to_static/test_ifelse_basic.py | 13 -- test/dygraph_to_static/test_lac.py | 5 +- test/dygraph_to_static/test_no_gradient.py | 5 +- test/dygraph_to_static/test_sentiment.py | 10 +- .../test_write_python_container.py | 15 +- test/legacy_test/test_cond.py | 14 +- test/legacy_test/test_while_loop_op.py | 16 +- test/legacy_test/test_while_op.py | 8 +- 10 files changed, 35 insertions(+), 229 deletions(-) delete mode 100644 test/dygraph_to_static/dygraph_to_static_util.py delete mode 100644 test/dygraph_to_static/test_ifelse_basic.py diff --git a/test/dygraph_to_static/dygraph_to_static_util.py b/test/dygraph_to_static/dygraph_to_static_util.py deleted file mode 100644 index 9a5b9bf22d92a4..00000000000000 --- a/test/dygraph_to_static/dygraph_to_static_util.py +++ /dev/null @@ -1,173 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import contextlib -import os -from functools import wraps - -import numpy as np - -from paddle import set_flags, static -from paddle.base import core - - -@contextlib.contextmanager -def enable_fallback_guard(enable): - flag = os.environ.get("ENABLE_FALL_BACK", None) - os.environ["ENABLE_FALL_BACK"] = enable - yield - if flag is not None: - os.environ["ENABLE_FALL_BACK"] = flag - else: - del os.environ["ENABLE_FALL_BACK"] - - -def to_ast(func): - """ - convert run fall_back to ast - """ - - def impl(*args, **kwargs): - with enable_fallback_guard("False"): - func(*args, **kwargs) - - return impl - - -def to_sot(func): - """ - convert run fall_back to ast - """ - # TODO(SigureMo): ENABLE_SOT should always be True, remove this - enable_sot = os.environ.get("ENABLE_SOT", "True") == "True" - - def impl(*args, **kwargs): - if enable_sot: - with enable_fallback_guard("True"): - func(*args, **kwargs) - else: - return - - return impl - - -def dy2static_unittest(cls): - """ - dy2static unittest must be decorated to each Dy2static Unittests. - run both in Fallback and Ast mode. - - Examples: - - >>> @dy2static_unittest - ... class TestA(unittest.TestCase): - ... ... - """ - for key in dir(cls): - if key.startswith("test"): - if not key.endswith("_ast"): - test_func = getattr(cls, key) - setattr(cls, key + "_ast", to_ast(test_func)) - test_func = getattr(cls, key) - setattr(cls, key, to_sot(test_func)) - return cls - - -def ast_only_test(func): - """ - run this test function in ast only mode. - - Examples: - - >>> @dy2static_unittest - ... class TestA(unittest.TestCase): - ... @ast_only_test - ... def test_ast_only(self): - ... pass - """ - - def impl(*args, **kwargs): - if os.environ.get("ENABLE_FALL_BACK", "False") == "False": - func(*args, **kwargs) - - return impl - - -def sot_only_test(func): - """ - run this test function in ast only mode. - - Examples: - - >>> @dy2static_unittest - ... class TestA(unittest.TestCase): - ... @sot_only_test - ... def test_sot_only(self): - ... pass - """ - - def impl(*args, **kwargs): - if os.environ.get("ENABLE_FALL_BACK", "False") == "True": - func(*args, **kwargs) - - return impl - - -def test_with_new_ir(func): - @wraps(func) - def impl(*args, **kwargs): - ir_outs = None - if os.environ.get('FLAGS_use_stride_kernel', False): - return - with static.scope_guard(static.Scope()): - with static.program_guard(static.Program()): - try: - new_ir_flag = 'FLAGS_enable_new_ir_in_executor' - os.environ[new_ir_flag] = 'True' - set_flags({new_ir_flag: True}) - ir_outs = func(*args, **kwargs) - finally: - del os.environ[new_ir_flag] - set_flags({new_ir_flag: False}) - return ir_outs - - return impl - - -def test_and_compare_with_new_ir(need_check_output: bool = True): - def decorator(func): - @wraps(func) - def impl(*args, **kwargs): - outs = func(*args, **kwargs) - if core._is_bwd_prim_enabled() or core._is_fwd_prim_enabled(): - return outs - ir_outs = test_with_new_ir(func)(*args, **kwargs) - if not need_check_output: - return outs - np.testing.assert_equal( - outs, - ir_outs, - err_msg='Dy2St Unittest Check (' - + func.__name__ - + ') has diff ' - + '\nExpect ' - + str(outs) - + '\n' - + 'But Got' - + str(ir_outs), - ) - return outs - - return impl - - return decorator diff --git a/test/dygraph_to_static/test_container.py b/test/dygraph_to_static/test_container.py index 412362ba725c5b..964bc270b59a43 100644 --- a/test/dygraph_to_static/test_container.py +++ b/test/dygraph_to_static/test_container.py @@ -17,7 +17,7 @@ import unittest import numpy as np -from dygraph_to_static_util import dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase import paddle @@ -70,8 +70,7 @@ def forward(self, x): return self.layers(x) -@dy2static_unittest -class TestSequential(unittest.TestCase): +class TestSequential(Dy2StTestBase): def setUp(self): paddle.set_device('cpu') self.seed = 2021 diff --git a/test/dygraph_to_static/test_ifelse_basic.py b/test/dygraph_to_static/test_ifelse_basic.py deleted file mode 100644 index 97043fd7ba6885..00000000000000 --- a/test/dygraph_to_static/test_ifelse_basic.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/test/dygraph_to_static/test_lac.py b/test/dygraph_to_static/test_lac.py index 461b03fe7a5edc..d1feacae222627 100644 --- a/test/dygraph_to_static/test_lac.py +++ b/test/dygraph_to_static/test_lac.py @@ -22,7 +22,7 @@ os.environ["CUDA_VISIBLE_DEVICES"] = "2" -from dygraph_to_static_util import dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase import paddle from paddle import _legacy_C_ops, base @@ -515,8 +515,7 @@ def create_dataloader(reader, place): return data_loader -@dy2static_unittest -class TestLACModel(unittest.TestCase): +class TestLACModel(Dy2StTestBase): def setUp(self): self.args = Args() self.place = ( diff --git a/test/dygraph_to_static/test_no_gradient.py b/test/dygraph_to_static/test_no_gradient.py index ec6443d5290179..b3bc726762ee48 100644 --- a/test/dygraph_to_static/test_no_gradient.py +++ b/test/dygraph_to_static/test_no_gradient.py @@ -15,7 +15,7 @@ import unittest import numpy -from dygraph_to_static_util import dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase import paddle @@ -32,8 +32,7 @@ def main_func(x, index): return out -@dy2static_unittest -class TestNoGradientCase(unittest.TestCase): +class TestNoGradientCase(Dy2StTestBase): def test_no_gradient(self): paddle.disable_static() x = paddle.randn([10, 3]) diff --git a/test/dygraph_to_static/test_sentiment.py b/test/dygraph_to_static/test_sentiment.py index 60d3678a5a72b0..3c6a52dd9bad0e 100644 --- a/test/dygraph_to_static/test_sentiment.py +++ b/test/dygraph_to_static/test_sentiment.py @@ -15,10 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir from test_lac import DynamicGRU import paddle @@ -372,12 +369,11 @@ def train(args, to_static): return loss_data -@dy2static_unittest -class TestSentiment(unittest.TestCase): +class TestSentiment(Dy2StTestBase): def setUp(self): self.args = Args() - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def train_model(self, model_type='cnn_net'): self.args.model_type = model_type st_out = train(self.args, True) diff --git a/test/dygraph_to_static/test_write_python_container.py b/test/dygraph_to_static/test_write_python_container.py index a175b881d86c75..c22a5c7cba0a9a 100644 --- a/test/dygraph_to_static/test_write_python_container.py +++ b/test/dygraph_to_static/test_write_python_container.py @@ -14,10 +14,10 @@ import unittest -from dygraph_to_static_util import ( - ast_only_test, - dy2static_unittest, - sot_only_test, +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + test_ast_only, + test_sot_only, ) import paddle @@ -99,8 +99,7 @@ def func_ifelse_write_nest_list_dict(x): return res -@dy2static_unittest -class TestWriteContainer(unittest.TestCase): +class TestWriteContainer(Dy2StTestBase): def setUp(self): self.set_func() self.set_getitem_path() @@ -117,7 +116,7 @@ def get_raw_value(self, container, getitem_path): out = out[path] return out - @sot_only_test + @test_sot_only def test_write_container_sot(self): func_static = paddle.jit.to_static(self.func) input = paddle.to_tensor([1, 2, 3]) @@ -125,7 +124,7 @@ def test_write_container_sot(self): out_dygraph = self.get_raw_value(self.func(input), self.getitem_path) self.assertEqual(out_static, out_dygraph) - @ast_only_test + @test_ast_only def test_write_container(self): func_static = paddle.jit.to_static(self.func) input = paddle.to_tensor([1, 2, 3]) diff --git a/test/legacy_test/test_cond.py b/test/legacy_test/test_cond.py index cec7664ae6cb63..76467328f77253 100644 --- a/test/legacy_test/test_cond.py +++ b/test/legacy_test/test_cond.py @@ -19,7 +19,7 @@ from simple_nets import batchnorm_fc_with_inputs, simple_fc_net_with_inputs sys.path.append("../dygraph_to_static") -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_utils_new import compare_legacy_with_pir import paddle from paddle import base @@ -31,7 +31,7 @@ class TestCondInputOutput(unittest.TestCase): - @test_and_compare_with_new_ir() + @compare_legacy_with_pir def test_return_single_var(self): """ pseudocode: @@ -78,7 +78,7 @@ def false_func(): np.asarray(ret), np.full((3, 2), -1, np.int32), rtol=1e-05 ) - @test_and_compare_with_new_ir() + @compare_legacy_with_pir def test_return_0d_tensor(self): """ pseudocode: @@ -116,7 +116,7 @@ def false_func(): np.testing.assert_allclose(np.asarray(ret), np.array(2), rtol=1e-05) self.assertEqual(ret.shape, ()) - @test_and_compare_with_new_ir() + @compare_legacy_with_pir def test_0d_tensor_as_cond(self): """ pseudocode: @@ -217,7 +217,7 @@ def test_0d_tensor_dygraph(self): ) self.assertEqual(a.grad.shape, []) - @test_and_compare_with_new_ir() + @compare_legacy_with_pir def test_return_var_tuple(self): """ pseudocode: @@ -265,7 +265,7 @@ def false_func(): np.asarray(ret[1]), np.full((2, 3), True, bool), rtol=1e-05 ) - @test_and_compare_with_new_ir() + @compare_legacy_with_pir def test_pass_and_modify_var(self): """ pseudocode: @@ -356,7 +356,7 @@ def false_func(): self.assertIsNone(out2) self.assertIsNone(out3) - @test_and_compare_with_new_ir() + @compare_legacy_with_pir def test_wrong_structure_exception(self): """ test returning different number of tensors cannot merge into output diff --git a/test/legacy_test/test_while_loop_op.py b/test/legacy_test/test_while_loop_op.py index f7b1973c2261f9..231fb0bed32f9e 100644 --- a/test/legacy_test/test_while_loop_op.py +++ b/test/legacy_test/test_while_loop_op.py @@ -25,13 +25,13 @@ from paddle.base.framework import Program, program_guard sys.path.append("../dygraph_to_static") -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_utils_new import compare_legacy_with_pir paddle.enable_static() class TestApiWhileLoop(unittest.TestCase): - @test_and_compare_with_new_ir() + @compare_legacy_with_pir def test_var_tuple(self): def cond(i): return paddle.less_than(i, ten) @@ -60,7 +60,7 @@ def body(i): np.asarray(res[0]), np.full(1, 10, np.int64), rtol=1e-05 ) - # @test_and_compare_with_new_ir() + # @compare_legacy_with_pir def test_var_list(self): def cond(i, mem): return paddle.less_than(i, ten) @@ -97,7 +97,7 @@ def body(i, mem): data = np.add(data, data_one) np.testing.assert_allclose(np.asarray(res[1]), data, rtol=1e-05) - @test_and_compare_with_new_ir() + @compare_legacy_with_pir def test_var_dict(self): def cond(i, ten, test_dict, test_list, test_list_dict): return paddle.less_than(i, ten) @@ -182,7 +182,7 @@ def body(i, ten, test_dict, test_list, test_list_dict): class TestApiWhileLoop_Nested(unittest.TestCase): - # @test_and_compare_with_new_ir() + # @compare_legacy_with_pir def test_nested_net(self): def external_cond(i, j, init, sums): return paddle.less_than(i, loop_len1) @@ -436,7 +436,7 @@ def internal_body(j, x, mem_array): class TestApiWhileLoopWithSwitchCase(unittest.TestCase): - # @test_and_compare_with_new_ir() + # @compare_legacy_with_pir def test_with_switch_case(self): def cond(i): return paddle.less_than(i, ten) @@ -486,7 +486,7 @@ def fn_add_one(): class TestApiWhileLoop_Error(unittest.TestCase): - @test_and_compare_with_new_ir() + @compare_legacy_with_pir def test_error(self): def cond_returns_constant(i): return 1 @@ -655,7 +655,7 @@ def value_error_body_returns_with_mutable_list(): class TestApiWhileLoopSliceInBody(unittest.TestCase): - # @test_and_compare_with_new_ir() + # @compare_legacy_with_pir def test_var_slice(self): def cond(z, i): return i + 1 <= x_shape[0] diff --git a/test/legacy_test/test_while_op.py b/test/legacy_test/test_while_op.py index bf4924815a4743..766c23dbdceb07 100644 --- a/test/legacy_test/test_while_op.py +++ b/test/legacy_test/test_while_op.py @@ -25,7 +25,7 @@ from paddle.incubate.layers.nn import shuffle_batch sys.path.append("../dygraph_to_static") -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_utils_new import compare_legacy_with_pir paddle.enable_static() @@ -121,7 +121,7 @@ def test_simple_net_forward(self): for _ in range(2): exe.run(binary, feed={'d0': d[0], 'd1': d[1], 'd2': d[2]}) - @test_and_compare_with_new_ir() + @compare_legacy_with_pir def test_exceptions(self): i = paddle.zeros(shape=[2], dtype='int64') array_len = paddle.tensor.fill_constant( @@ -136,7 +136,7 @@ def test_exceptions(self): class BadInputTest(unittest.TestCase): - @test_and_compare_with_new_ir() + @compare_legacy_with_pir def test_error(self): with base.program_guard(base.Program()): @@ -192,7 +192,7 @@ def body_func(i, ten, batch_info, origin_seq): class TestOutputsMustExistsInputs(unittest.TestCase): - @test_and_compare_with_new_ir() + @compare_legacy_with_pir def test_outputs_exists_inputs(self): """ We guarantee that the output tensor must be in the input tensor, so that the output and input can correspond to each other, but the input can be greater than the number of outputs. It's required in paddle2onnx.