diff --git a/test/dygraph_to_static/test_partial_program_hook.py b/test/dygraph_to_static/test_partial_program_hook.py index 950fb570e635a..1b50b5b4add91 100644 --- a/test/dygraph_to_static/test_partial_program_hook.py +++ b/test/dygraph_to_static/test_partial_program_hook.py @@ -14,7 +14,7 @@ import unittest -from dygraph_to_static_util import dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase import paddle from paddle.base import core @@ -22,8 +22,7 @@ from paddle.jit.dy2static import partial_program, program_translator -@dy2static_unittest -class TestPartiaProgramLayerHook(unittest.TestCase): +class TestPartiaProgramLayerHook(Dy2StTestBase): def setUp(self): ENV_ENABLE_SOT.set(False) self._hook = partial_program.PartialProgramLayerHook() @@ -38,8 +37,7 @@ def test_after_infer(self): self.assertIsNone(self._hook.after_infer(None)) -@dy2static_unittest -class TestPrimHook(unittest.TestCase): +class TestPrimHook(Dy2StTestBase): def setUp(self): ENV_ENABLE_SOT.set(False) core._set_prim_all_enabled(False) diff --git a/test/dygraph_to_static/test_place.py b/test/dygraph_to_static/test_place.py index f1cb7e80589a3..f9aaca6932906 100644 --- a/test/dygraph_to_static/test_place.py +++ b/test/dygraph_to_static/test_place.py @@ -14,13 +14,12 @@ import unittest -from dygraph_to_static_util import dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase import paddle -@dy2static_unittest -class TestPlace(unittest.TestCase): +class TestPlace(Dy2StTestBase): def test_place(self): paddle.enable_static() x = paddle.to_tensor([1, 2, 3, 4]) diff --git a/test/dygraph_to_static/test_print.py b/test/dygraph_to_static/test_print.py index 251bca776e700..35022512ce7f6 100644 --- a/test/dygraph_to_static/test_print.py +++ b/test/dygraph_to_static/test_print.py @@ -15,10 +15,7 @@ import unittest import numpy -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, compare_legacy_with_pir import paddle from paddle import base @@ -87,8 +84,7 @@ def dyfunc_print_with_kwargs(x): print("Tensor", x_t, end='\n\n', sep=': ') -@dy2static_unittest -class TestPrintBase(unittest.TestCase): +class TestPrintBase(Dy2StTestBase): def setUp(self): self.input = numpy.ones(5).astype("int32") self.place = ( @@ -110,7 +106,7 @@ def _run(self, to_static): def get_dygraph_output(self): self._run(to_static=False) - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def get_static_output(self): self._run(to_static=True) diff --git a/test/dygraph_to_static/test_program_translator.py b/test/dygraph_to_static/test_program_translator.py index 9447bbbf4f608..253a1a9b7d67e 100644 --- a/test/dygraph_to_static/test_program_translator.py +++ b/test/dygraph_to_static/test_program_translator.py @@ -18,7 +18,7 @@ import astor import numpy as np -from dygraph_to_static_util import ast_only_test, dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase, test_ast_only from ifelse_simple_func import ( dyfunc_with_if_else_early_return1, dyfunc_with_if_else_early_return2, @@ -212,13 +212,12 @@ def forward(self, x): return y -@dy2static_unittest -class TestEnableDeclarative(unittest.TestCase): +class TestEnableDeclarative(Dy2StTestBase): def setUp(self): self.x = np.random.randn(30, 10, 32).astype('float32') self.weight = np.random.randn(32, 64).astype('float32') - @ast_only_test + @test_ast_only def test_raise_error(self): with base.dygraph.guard(): paddle.jit.enable_to_static(True) @@ -268,9 +267,8 @@ def switch_mode_function(): return True -@dy2static_unittest -class TestFunctionTrainEvalMode(unittest.TestCase): - @ast_only_test +class TestFunctionTrainEvalMode(Dy2StTestBase): + @test_ast_only def test_switch_mode(self): paddle.disable_static() switch_mode_function.eval() @@ -299,8 +297,7 @@ def test_raise_error(self): net.foo.train() -@dy2static_unittest -class TestIfElseEarlyReturn(unittest.TestCase): +class TestIfElseEarlyReturn(Dy2StTestBase): def test_ifelse_early_return1(self): answer = np.zeros([2, 2]) + 1 static_func = paddle.jit.to_static(dyfunc_with_if_else_early_return1) @@ -314,8 +311,7 @@ def test_ifelse_early_return2(self): np.testing.assert_allclose(answer, out[0].numpy(), rtol=1e-05) -@dy2static_unittest -class TestRemoveCommentInDy2St(unittest.TestCase): +class TestRemoveCommentInDy2St(Dy2StTestBase): def func_with_comment(self): # Comment1 x = paddle.to_tensor([1, 2, 3]) @@ -356,8 +352,7 @@ def func1(x): return func1(data) -@dy2static_unittest -class TestParameterRecorder(unittest.TestCase): +class TestParameterRecorder(Dy2StTestBase): def test_recorder(self): """function calls nn.Layer case.""" net = Net() diff --git a/test/dygraph_to_static/test_ptb_lm.py b/test/dygraph_to_static/test_ptb_lm.py index 76a35d57ac9ba..87a6cbd5a8fe1 100644 --- a/test/dygraph_to_static/test_ptb_lm.py +++ b/test/dygraph_to_static/test_ptb_lm.py @@ -17,10 +17,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, compare_legacy_with_pir import paddle from paddle import base @@ -318,14 +315,13 @@ def train_dygraph(place): return train(place) -@test_and_compare_with_new_ir(True) +@compare_legacy_with_pir def train_static(place): paddle.jit.enable_to_static(True) return train(place) -@dy2static_unittest -class TestPtb(unittest.TestCase): +class TestPtb(Dy2StTestBase): def setUp(self): self.place = ( base.CUDAPlace(0) diff --git a/test/dygraph_to_static/test_ptb_lm_v2.py b/test/dygraph_to_static/test_ptb_lm_v2.py index 92d4d43d9d4ea..abc351d17f1ec 100644 --- a/test/dygraph_to_static/test_ptb_lm_v2.py +++ b/test/dygraph_to_static/test_ptb_lm_v2.py @@ -17,7 +17,7 @@ import unittest import numpy as np -from dygraph_to_static_util import dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase import paddle @@ -323,8 +323,7 @@ def train_static(place): return train(place) -@dy2static_unittest -class TestPtb(unittest.TestCase): +class TestPtb(Dy2StTestBase): def setUp(self): self.place = ( paddle.CUDAPlace(0) diff --git a/test/dygraph_to_static/test_pylayer.py b/test/dygraph_to_static/test_pylayer.py index d047b6d5cd1cb..5b78575f7fdb0 100644 --- a/test/dygraph_to_static/test_pylayer.py +++ b/test/dygraph_to_static/test_pylayer.py @@ -26,7 +26,7 @@ import unittest import numpy as np -from dygraph_to_static_util import dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase from test_jit_save_load import train import paddle @@ -263,8 +263,7 @@ def forward(self, x): return out -@dy2static_unittest -class TestPyLayerBase(unittest.TestCase): +class TestPyLayerBase(Dy2StTestBase): def setUp(self): self.place = "gpu" if paddle.is_compiled_with_cuda() else "cpu" self.to_static = False @@ -514,8 +513,7 @@ def test_pylayer_net_with_no_grad(self): self._run_and_compare(input1, input2) -@dy2static_unittest -class PyLayerTrainHelper(unittest.TestCase): +class PyLayerTrainHelper(Dy2StTestBase): def setUp(self): self.place = "gpu" if paddle.is_compiled_with_cuda() else "cpu" @@ -588,8 +586,7 @@ def test_pylayer_net_no_grad(self): ) -@dy2static_unittest -class TestPyLayerJitSaveLoad(unittest.TestCase): +class TestPyLayerJitSaveLoad(Dy2StTestBase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() self.model_path = os.path.join( diff --git a/test/dygraph_to_static/test_reinforcement_learning.py b/test/dygraph_to_static/test_reinforcement_learning.py index ffbd0e315229d..a47607b561f8d 100644 --- a/test/dygraph_to_static/test_reinforcement_learning.py +++ b/test/dygraph_to_static/test_reinforcement_learning.py @@ -18,10 +18,7 @@ import gym import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir import paddle import paddle.nn.functional as F @@ -206,8 +203,7 @@ def finish_episode(): return np.array(loss_data) -@dy2static_unittest -class TestDeclarative(unittest.TestCase): +class TestDeclarative(Dy2StTestBase): def setUp(self): self.place = ( base.CUDAPlace(0) @@ -216,7 +212,7 @@ def setUp(self): ) self.args = Args() - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_train(self): st_out = train(self.args, self.place, to_static=True) dy_out = train(self.args, self.place, to_static=False) diff --git a/test/dygraph_to_static/test_resnet.py b/test/dygraph_to_static/test_resnet.py index cb57ce234b263..dc4c36fc4ca5f 100644 --- a/test/dygraph_to_static/test_resnet.py +++ b/test/dygraph_to_static/test_resnet.py @@ -19,7 +19,7 @@ import unittest import numpy as np -from dygraph_to_static_util import dy2static_unittest, test_with_new_ir +from dygraph_to_static_utils_new import Dy2StTestBase, test_pir_only from predictor_utils import PredictorTools import paddle @@ -386,8 +386,7 @@ def predict_analysis_inference(self, data): return out -@dy2static_unittest -class TestResnet(unittest.TestCase): +class TestResnet(Dy2StTestBase): def setUp(self): self.resnet_helper = ResNetHelper() @@ -420,7 +419,7 @@ def verify_predict(self): err_msg=f'predictor_pre:\n {predictor_pre}\n, st_pre: \n{st_pre}.', ) - @test_with_new_ir + @test_pir_only def test_resnet_new_ir(self): static_loss = self.train(to_static=True) dygraph_loss = self.train(to_static=False) diff --git a/test/dygraph_to_static/test_resnet_amp.py b/test/dygraph_to_static/test_resnet_amp.py index 0255c0c00db3b..d8e3b6963fbec 100644 --- a/test/dygraph_to_static/test_resnet_amp.py +++ b/test/dygraph_to_static/test_resnet_amp.py @@ -16,10 +16,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir from test_resnet import SEED, ResNet, optimizer_setting import paddle @@ -114,13 +111,12 @@ def train(to_static, build_strategy=None): return total_loss.numpy() -@dy2static_unittest -class TestResnet(unittest.TestCase): +class TestResnet(Dy2StTestBase): def train(self, to_static): paddle.jit.enable_to_static(to_static) return train(to_static) - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_resnet(self): static_loss = self.train(to_static=True) dygraph_loss = self.train(to_static=False) diff --git a/test/dygraph_to_static/test_resnet_pure_fp16.py b/test/dygraph_to_static/test_resnet_pure_fp16.py index 771f9033f99d7..b5c132ce43df0 100644 --- a/test/dygraph_to_static/test_resnet_pure_fp16.py +++ b/test/dygraph_to_static/test_resnet_pure_fp16.py @@ -16,10 +16,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir from test_resnet import SEED, ResNet, optimizer_setting import paddle @@ -115,8 +112,7 @@ def train(to_static, build_strategy=None): return loss_data -@dy2static_unittest -class TestResnet(unittest.TestCase): +class TestResnet(Dy2StTestBase): def train(self, to_static): paddle.jit.enable_to_static(to_static) build_strategy = paddle.static.BuildStrategy() @@ -125,7 +121,7 @@ def train(self, to_static): build_strategy.enable_inplace = False return train(to_static, build_strategy) - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_resnet(self): if base.is_compiled_with_cuda(): static_loss = self.train(to_static=True) diff --git a/test/dygraph_to_static/test_resnet_v2.py b/test/dygraph_to_static/test_resnet_v2.py index 0f5d804427ca6..856ee246e202f 100644 --- a/test/dygraph_to_static/test_resnet_v2.py +++ b/test/dygraph_to_static/test_resnet_v2.py @@ -19,7 +19,7 @@ import unittest import numpy as np -from dygraph_to_static_util import dy2static_unittest, test_with_new_ir +from dygraph_to_static_utils_new import Dy2StTestBase, test_pir_only from predictor_utils import PredictorTools import paddle @@ -242,8 +242,7 @@ def __len__(self): return len(self.img) -@dy2static_unittest -class TestResnet(unittest.TestCase): +class TestResnet(Dy2StTestBase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() @@ -427,7 +426,7 @@ def verify_predict(self): err_msg=f'predictor_pre:\n {predictor_pre}\n, st_pre: \n{st_pre}.', ) - @test_with_new_ir + @test_pir_only def test_resnet_new_ir(self): static_loss = self.train(to_static=True) dygraph_loss = self.train(to_static=False) diff --git a/test/dygraph_to_static/test_return.py b/test/dygraph_to_static/test_return.py index dc79b8456ed3b..3c1e1136d7364 100644 --- a/test/dygraph_to_static/test_return.py +++ b/test/dygraph_to_static/test_return.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test, dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase, test_ast_only from ifelse_simple_func import dyfunc_with_if_else import paddle @@ -264,8 +264,7 @@ def func(): return func() -@dy2static_unittest -class TestReturnBase(unittest.TestCase): +class TestReturnBase(Dy2StTestBase): def setUp(self): self.input = np.ones(1).astype('int32') self.place = ( @@ -303,6 +302,7 @@ def _test_value_impl(self): else: self.assertEqual(dygraph_res, static_res) + @test_ast_only def test_transformed_static_result(self): if hasattr(self, "error"): with self.assertRaisesRegex(Dygraph2StaticException, self.error): @@ -351,20 +351,12 @@ def init_dygraph_func(self): self.dygraph_func = test_return_in_while_2 self.error = "Found return statement in While or For body and loop" - @ast_only_test - def test_transformed_static_result(self): - super().test_transformed_static_result() - class TestReturnInFor2(TestReturnBase): def init_dygraph_func(self): self.dygraph_func = test_return_in_for_2 self.error = "Found return statement in While or For body and loop" - @ast_only_test - def test_transformed_static_result(self): - super().test_transformed_static_result() - class TestRecursiveReturn(TestReturnBase): def init_dygraph_func(self): @@ -377,20 +369,12 @@ def init_dygraph_func(self): self.dygraph_func = test_return_different_length_if_body self.error = "Your if/else have different number of return value." - @ast_only_test - def test_transformed_static_result(self): - super().test_transformed_static_result() - class TestReturnDifferentLengthElse(TestReturnBase): def init_dygraph_func(self): self.dygraph_func = test_return_different_length_else self.error = "Your if/else have different number of return value." - @ast_only_test - def test_transformed_static_result(self): - super().test_transformed_static_result() - class TestNoReturn(TestReturnBase): def init_dygraph_func(self): @@ -402,20 +386,12 @@ def init_dygraph_func(self): self.dygraph_func = test_return_none self.error = "Your if/else have different number of return value." - @ast_only_test - def test_transformed_static_result(self): - super().test_transformed_static_result() - class TestReturnNoVariable(TestReturnBase): def init_dygraph_func(self): self.dygraph_func = test_return_no_variable self.error = "Your if/else have different number of return value." - @ast_only_test - def test_transformed_static_result(self): - super().test_transformed_static_result() - class TestReturnListOneValue(TestReturnBase): def init_dygraph_func(self): diff --git a/test/dygraph_to_static/test_rollback.py b/test/dygraph_to_static/test_rollback.py index 7ee3456747b51..2cba4d9ed7d85 100644 --- a/test/dygraph_to_static/test_rollback.py +++ b/test/dygraph_to_static/test_rollback.py @@ -15,10 +15,10 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - ast_only_test, - dy2static_unittest, - test_and_compare_with_new_ir, +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + test_ast_only, + test_legacy_and_pir, ) import paddle @@ -71,12 +71,11 @@ def foo(x, flag=False): return out -@dy2static_unittest -class TestRollBackPlainFunction(unittest.TestCase): +class TestRollBackPlainFunction(Dy2StTestBase): def setUp(self): paddle.set_device("cpu") - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_plain_func(self): st_foo = paddle.jit.to_static(foo) x = paddle.randn([3, 4]) @@ -91,13 +90,12 @@ def test_plain_func(self): np.testing.assert_array_equal(st_out.numpy(), dy_out.numpy()) -@dy2static_unittest -class TestRollBackNet(unittest.TestCase): +class TestRollBackNet(Dy2StTestBase): def setUp(self): paddle.set_device("cpu") - @ast_only_test - @test_and_compare_with_new_ir(False) + @test_ast_only + @test_legacy_and_pir def test_net(self): net = paddle.jit.to_static(Net()) x = paddle.randn([3, 4]) @@ -143,10 +141,9 @@ def func(self, x): return x + 2 -@dy2static_unittest -class TestRollBackNotForward(unittest.TestCase): - @ast_only_test - @test_and_compare_with_new_ir(False) +class TestRollBackNotForward(Dy2StTestBase): + @test_ast_only + @test_legacy_and_pir def test_rollback(self): x = paddle.zeros([2, 2]) net = FuncRollback() diff --git a/test/dygraph_to_static/test_save_inference_model.py b/test/dygraph_to_static/test_save_inference_model.py index e765aec9670e1..8103ac8821d51 100644 --- a/test/dygraph_to_static/test_save_inference_model.py +++ b/test/dygraph_to_static/test_save_inference_model.py @@ -17,10 +17,11 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - ast_only_test, - dy2static_unittest, - test_and_compare_with_new_ir, +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + compare_legacy_with_pir, + test_ast_only, + test_legacy_and_pir, ) import paddle @@ -77,15 +78,14 @@ def forward(self, x): return loss, out -@dy2static_unittest -class TestDyToStaticSaveInferenceModel(unittest.TestCase): +class TestDyToStaticSaveInferenceModel(Dy2StTestBase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() def tearDown(self): self.temp_dir.cleanup() - @ast_only_test + @test_ast_only def test_save_inference_model(self): fc_size = 20 x_data = np.random.random((fc_size, fc_size)).astype('float32') @@ -127,7 +127,7 @@ def test_save_inference_model(self): layer, [x_data], dygraph_out.numpy(), feed=[x] ) - @ast_only_test + @test_ast_only def test_save_pylayer_model(self): fc_size = 20 x_data = np.random.random((fc_size, fc_size)).astype('float32') @@ -191,7 +191,7 @@ def check_save_inference_model( output_spec=fetch if fetch else None, ) if enable_new_ir: - wrapped_load_and_run_inference = test_and_compare_with_new_ir(True)( + wrapped_load_and_run_inference = compare_legacy_with_pir( self.load_and_run_inference ) infer_out = wrapped_load_and_run_inference( @@ -228,10 +228,9 @@ def load_and_run_inference( return np.array(results[0]) -@dy2static_unittest -class TestPartialProgramRaiseError(unittest.TestCase): - @ast_only_test - @test_and_compare_with_new_ir(False) +class TestPartialProgramRaiseError(Dy2StTestBase): + @test_ast_only + @test_legacy_and_pir def test_param_type(self): paddle.jit.enable_to_static(True) x_data = np.random.random((20, 20)).astype('float32')