Skip to content

Commit

Permalink
[Dy2St] Refactor dy2st unittest decorators name - Part 4 (#58465)
Browse files Browse the repository at this point in the history
  • Loading branch information
gouzil authored Oct 29, 2023
1 parent 1abb961 commit 431403f
Show file tree
Hide file tree
Showing 15 changed files with 67 additions and 129 deletions.
8 changes: 3 additions & 5 deletions test/dygraph_to_static/test_partial_program_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,15 @@

import unittest

from dygraph_to_static_util import dy2static_unittest
from dygraph_to_static_utils_new import Dy2StTestBase

import paddle
from paddle.base import core
from paddle.jit.api import ENV_ENABLE_SOT
from paddle.jit.dy2static import partial_program, program_translator


@dy2static_unittest
class TestPartiaProgramLayerHook(unittest.TestCase):
class TestPartiaProgramLayerHook(Dy2StTestBase):
def setUp(self):
ENV_ENABLE_SOT.set(False)
self._hook = partial_program.PartialProgramLayerHook()
Expand All @@ -38,8 +37,7 @@ def test_after_infer(self):
self.assertIsNone(self._hook.after_infer(None))


@dy2static_unittest
class TestPrimHook(unittest.TestCase):
class TestPrimHook(Dy2StTestBase):
def setUp(self):
ENV_ENABLE_SOT.set(False)
core._set_prim_all_enabled(False)
Expand Down
5 changes: 2 additions & 3 deletions test/dygraph_to_static/test_place.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@

import unittest

from dygraph_to_static_util import dy2static_unittest
from dygraph_to_static_utils_new import Dy2StTestBase

import paddle


@dy2static_unittest
class TestPlace(unittest.TestCase):
class TestPlace(Dy2StTestBase):
def test_place(self):
paddle.enable_static()
x = paddle.to_tensor([1, 2, 3, 4])
Expand Down
10 changes: 3 additions & 7 deletions test/dygraph_to_static/test_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
import unittest

import numpy
from dygraph_to_static_util import (
dy2static_unittest,
test_and_compare_with_new_ir,
)
from dygraph_to_static_utils_new import Dy2StTestBase, compare_legacy_with_pir

import paddle
from paddle import base
Expand Down Expand Up @@ -87,8 +84,7 @@ def dyfunc_print_with_kwargs(x):
print("Tensor", x_t, end='\n\n', sep=': ')


@dy2static_unittest
class TestPrintBase(unittest.TestCase):
class TestPrintBase(Dy2StTestBase):
def setUp(self):
self.input = numpy.ones(5).astype("int32")
self.place = (
Expand All @@ -110,7 +106,7 @@ def _run(self, to_static):
def get_dygraph_output(self):
self._run(to_static=False)

@test_and_compare_with_new_ir(True)
@compare_legacy_with_pir
def get_static_output(self):
self._run(to_static=True)

Expand Down
21 changes: 8 additions & 13 deletions test/dygraph_to_static/test_program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import astor
import numpy as np
from dygraph_to_static_util import ast_only_test, dy2static_unittest
from dygraph_to_static_utils_new import Dy2StTestBase, test_ast_only
from ifelse_simple_func import (
dyfunc_with_if_else_early_return1,
dyfunc_with_if_else_early_return2,
Expand Down Expand Up @@ -212,13 +212,12 @@ def forward(self, x):
return y


@dy2static_unittest
class TestEnableDeclarative(unittest.TestCase):
class TestEnableDeclarative(Dy2StTestBase):
def setUp(self):
self.x = np.random.randn(30, 10, 32).astype('float32')
self.weight = np.random.randn(32, 64).astype('float32')

@ast_only_test
@test_ast_only
def test_raise_error(self):
with base.dygraph.guard():
paddle.jit.enable_to_static(True)
Expand Down Expand Up @@ -268,9 +267,8 @@ def switch_mode_function():
return True


@dy2static_unittest
class TestFunctionTrainEvalMode(unittest.TestCase):
@ast_only_test
class TestFunctionTrainEvalMode(Dy2StTestBase):
@test_ast_only
def test_switch_mode(self):
paddle.disable_static()
switch_mode_function.eval()
Expand Down Expand Up @@ -299,8 +297,7 @@ def test_raise_error(self):
net.foo.train()


@dy2static_unittest
class TestIfElseEarlyReturn(unittest.TestCase):
class TestIfElseEarlyReturn(Dy2StTestBase):
def test_ifelse_early_return1(self):
answer = np.zeros([2, 2]) + 1
static_func = paddle.jit.to_static(dyfunc_with_if_else_early_return1)
Expand All @@ -314,8 +311,7 @@ def test_ifelse_early_return2(self):
np.testing.assert_allclose(answer, out[0].numpy(), rtol=1e-05)


@dy2static_unittest
class TestRemoveCommentInDy2St(unittest.TestCase):
class TestRemoveCommentInDy2St(Dy2StTestBase):
def func_with_comment(self):
# Comment1
x = paddle.to_tensor([1, 2, 3])
Expand Down Expand Up @@ -356,8 +352,7 @@ def func1(x):
return func1(data)


@dy2static_unittest
class TestParameterRecorder(unittest.TestCase):
class TestParameterRecorder(Dy2StTestBase):
def test_recorder(self):
"""function calls nn.Layer case."""
net = Net()
Expand Down
10 changes: 3 additions & 7 deletions test/dygraph_to_static/test_ptb_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@
import unittest

import numpy as np
from dygraph_to_static_util import (
dy2static_unittest,
test_and_compare_with_new_ir,
)
from dygraph_to_static_utils_new import Dy2StTestBase, compare_legacy_with_pir

import paddle
from paddle import base
Expand Down Expand Up @@ -318,14 +315,13 @@ def train_dygraph(place):
return train(place)


@test_and_compare_with_new_ir(True)
@compare_legacy_with_pir
def train_static(place):
paddle.jit.enable_to_static(True)
return train(place)


@dy2static_unittest
class TestPtb(unittest.TestCase):
class TestPtb(Dy2StTestBase):
def setUp(self):
self.place = (
base.CUDAPlace(0)
Expand Down
5 changes: 2 additions & 3 deletions test/dygraph_to_static/test_ptb_lm_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import unittest

import numpy as np
from dygraph_to_static_util import dy2static_unittest
from dygraph_to_static_utils_new import Dy2StTestBase

import paddle

Expand Down Expand Up @@ -323,8 +323,7 @@ def train_static(place):
return train(place)


@dy2static_unittest
class TestPtb(unittest.TestCase):
class TestPtb(Dy2StTestBase):
def setUp(self):
self.place = (
paddle.CUDAPlace(0)
Expand Down
11 changes: 4 additions & 7 deletions test/dygraph_to_static/test_pylayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import unittest

import numpy as np
from dygraph_to_static_util import dy2static_unittest
from dygraph_to_static_utils_new import Dy2StTestBase
from test_jit_save_load import train

import paddle
Expand Down Expand Up @@ -263,8 +263,7 @@ def forward(self, x):
return out


@dy2static_unittest
class TestPyLayerBase(unittest.TestCase):
class TestPyLayerBase(Dy2StTestBase):
def setUp(self):
self.place = "gpu" if paddle.is_compiled_with_cuda() else "cpu"
self.to_static = False
Expand Down Expand Up @@ -514,8 +513,7 @@ def test_pylayer_net_with_no_grad(self):
self._run_and_compare(input1, input2)


@dy2static_unittest
class PyLayerTrainHelper(unittest.TestCase):
class PyLayerTrainHelper(Dy2StTestBase):
def setUp(self):
self.place = "gpu" if paddle.is_compiled_with_cuda() else "cpu"

Expand Down Expand Up @@ -588,8 +586,7 @@ def test_pylayer_net_no_grad(self):
)


@dy2static_unittest
class TestPyLayerJitSaveLoad(unittest.TestCase):
class TestPyLayerJitSaveLoad(Dy2StTestBase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
self.model_path = os.path.join(
Expand Down
10 changes: 3 additions & 7 deletions test/dygraph_to_static/test_reinforcement_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@

import gym
import numpy as np
from dygraph_to_static_util import (
dy2static_unittest,
test_and_compare_with_new_ir,
)
from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir

import paddle
import paddle.nn.functional as F
Expand Down Expand Up @@ -206,8 +203,7 @@ def finish_episode():
return np.array(loss_data)


@dy2static_unittest
class TestDeclarative(unittest.TestCase):
class TestDeclarative(Dy2StTestBase):
def setUp(self):
self.place = (
base.CUDAPlace(0)
Expand All @@ -216,7 +212,7 @@ def setUp(self):
)
self.args = Args()

@test_and_compare_with_new_ir(False)
@test_legacy_and_pir
def test_train(self):
st_out = train(self.args, self.place, to_static=True)
dy_out = train(self.args, self.place, to_static=False)
Expand Down
7 changes: 3 additions & 4 deletions test/dygraph_to_static/test_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import unittest

import numpy as np
from dygraph_to_static_util import dy2static_unittest, test_with_new_ir
from dygraph_to_static_utils_new import Dy2StTestBase, test_pir_only
from predictor_utils import PredictorTools

import paddle
Expand Down Expand Up @@ -386,8 +386,7 @@ def predict_analysis_inference(self, data):
return out


@dy2static_unittest
class TestResnet(unittest.TestCase):
class TestResnet(Dy2StTestBase):
def setUp(self):
self.resnet_helper = ResNetHelper()

Expand Down Expand Up @@ -420,7 +419,7 @@ def verify_predict(self):
err_msg=f'predictor_pre:\n {predictor_pre}\n, st_pre: \n{st_pre}.',
)

@test_with_new_ir
@test_pir_only
def test_resnet_new_ir(self):
static_loss = self.train(to_static=True)
dygraph_loss = self.train(to_static=False)
Expand Down
10 changes: 3 additions & 7 deletions test/dygraph_to_static/test_resnet_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@
import unittest

import numpy as np
from dygraph_to_static_util import (
dy2static_unittest,
test_and_compare_with_new_ir,
)
from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir
from test_resnet import SEED, ResNet, optimizer_setting

import paddle
Expand Down Expand Up @@ -114,13 +111,12 @@ def train(to_static, build_strategy=None):
return total_loss.numpy()


@dy2static_unittest
class TestResnet(unittest.TestCase):
class TestResnet(Dy2StTestBase):
def train(self, to_static):
paddle.jit.enable_to_static(to_static)
return train(to_static)

@test_and_compare_with_new_ir(False)
@test_legacy_and_pir
def test_resnet(self):
static_loss = self.train(to_static=True)
dygraph_loss = self.train(to_static=False)
Expand Down
10 changes: 3 additions & 7 deletions test/dygraph_to_static/test_resnet_pure_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@
import unittest

import numpy as np
from dygraph_to_static_util import (
dy2static_unittest,
test_and_compare_with_new_ir,
)
from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir
from test_resnet import SEED, ResNet, optimizer_setting

import paddle
Expand Down Expand Up @@ -115,8 +112,7 @@ def train(to_static, build_strategy=None):
return loss_data


@dy2static_unittest
class TestResnet(unittest.TestCase):
class TestResnet(Dy2StTestBase):
def train(self, to_static):
paddle.jit.enable_to_static(to_static)
build_strategy = paddle.static.BuildStrategy()
Expand All @@ -125,7 +121,7 @@ def train(self, to_static):
build_strategy.enable_inplace = False
return train(to_static, build_strategy)

@test_and_compare_with_new_ir(False)
@test_legacy_and_pir
def test_resnet(self):
if base.is_compiled_with_cuda():
static_loss = self.train(to_static=True)
Expand Down
7 changes: 3 additions & 4 deletions test/dygraph_to_static/test_resnet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import unittest

import numpy as np
from dygraph_to_static_util import dy2static_unittest, test_with_new_ir
from dygraph_to_static_utils_new import Dy2StTestBase, test_pir_only
from predictor_utils import PredictorTools

import paddle
Expand Down Expand Up @@ -242,8 +242,7 @@ def __len__(self):
return len(self.img)


@dy2static_unittest
class TestResnet(unittest.TestCase):
class TestResnet(Dy2StTestBase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()

Expand Down Expand Up @@ -427,7 +426,7 @@ def verify_predict(self):
err_msg=f'predictor_pre:\n {predictor_pre}\n, st_pre: \n{st_pre}.',
)

@test_with_new_ir
@test_pir_only
def test_resnet_new_ir(self):
static_loss = self.train(to_static=True)
dygraph_loss = self.train(to_static=False)
Expand Down
Loading

0 comments on commit 431403f

Please sign in to comment.