Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2St] Refactor dy2st unittest decorators name - Part 4 #58465

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions test/dygraph_to_static/test_partial_program_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,15 @@

import unittest

from dygraph_to_static_util import dy2static_unittest
from dygraph_to_static_utils_new import Dy2StTestBase

import paddle
from paddle.base import core
from paddle.jit.api import ENV_ENABLE_SOT
from paddle.jit.dy2static import partial_program, program_translator


@dy2static_unittest
class TestPartiaProgramLayerHook(unittest.TestCase):
class TestPartiaProgramLayerHook(Dy2StTestBase):
def setUp(self):
ENV_ENABLE_SOT.set(False)
self._hook = partial_program.PartialProgramLayerHook()
Expand All @@ -38,8 +37,7 @@ def test_after_infer(self):
self.assertIsNone(self._hook.after_infer(None))


@dy2static_unittest
class TestPrimHook(unittest.TestCase):
class TestPrimHook(Dy2StTestBase):
def setUp(self):
ENV_ENABLE_SOT.set(False)
core._set_prim_all_enabled(False)
Expand Down
5 changes: 2 additions & 3 deletions test/dygraph_to_static/test_place.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@

import unittest

from dygraph_to_static_util import dy2static_unittest
from dygraph_to_static_utils_new import Dy2StTestBase

import paddle


@dy2static_unittest
class TestPlace(unittest.TestCase):
class TestPlace(Dy2StTestBase):
def test_place(self):
paddle.enable_static()
x = paddle.to_tensor([1, 2, 3, 4])
Expand Down
10 changes: 3 additions & 7 deletions test/dygraph_to_static/test_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
import unittest

import numpy
from dygraph_to_static_util import (
dy2static_unittest,
test_and_compare_with_new_ir,
)
from dygraph_to_static_utils_new import Dy2StTestBase, compare_legacy_with_pir

import paddle
from paddle import base
Expand Down Expand Up @@ -87,8 +84,7 @@ def dyfunc_print_with_kwargs(x):
print("Tensor", x_t, end='\n\n', sep=': ')


@dy2static_unittest
class TestPrintBase(unittest.TestCase):
class TestPrintBase(Dy2StTestBase):
def setUp(self):
self.input = numpy.ones(5).astype("int32")
self.place = (
Expand All @@ -110,7 +106,7 @@ def _run(self, to_static):
def get_dygraph_output(self):
self._run(to_static=False)

@test_and_compare_with_new_ir(True)
@compare_legacy_with_pir
def get_static_output(self):
self._run(to_static=True)

Expand Down
21 changes: 8 additions & 13 deletions test/dygraph_to_static/test_program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import astor
import numpy as np
from dygraph_to_static_util import ast_only_test, dy2static_unittest
from dygraph_to_static_utils_new import Dy2StTestBase, test_ast_only
from ifelse_simple_func import (
dyfunc_with_if_else_early_return1,
dyfunc_with_if_else_early_return2,
Expand Down Expand Up @@ -212,13 +212,12 @@ def forward(self, x):
return y


@dy2static_unittest
class TestEnableDeclarative(unittest.TestCase):
class TestEnableDeclarative(Dy2StTestBase):
def setUp(self):
self.x = np.random.randn(30, 10, 32).astype('float32')
self.weight = np.random.randn(32, 64).astype('float32')

@ast_only_test
@test_ast_only
def test_raise_error(self):
with base.dygraph.guard():
paddle.jit.enable_to_static(True)
Expand Down Expand Up @@ -268,9 +267,8 @@ def switch_mode_function():
return True


@dy2static_unittest
class TestFunctionTrainEvalMode(unittest.TestCase):
@ast_only_test
class TestFunctionTrainEvalMode(Dy2StTestBase):
@test_ast_only
def test_switch_mode(self):
paddle.disable_static()
switch_mode_function.eval()
Expand Down Expand Up @@ -299,8 +297,7 @@ def test_raise_error(self):
net.foo.train()


@dy2static_unittest
class TestIfElseEarlyReturn(unittest.TestCase):
class TestIfElseEarlyReturn(Dy2StTestBase):
def test_ifelse_early_return1(self):
answer = np.zeros([2, 2]) + 1
static_func = paddle.jit.to_static(dyfunc_with_if_else_early_return1)
Expand All @@ -314,8 +311,7 @@ def test_ifelse_early_return2(self):
np.testing.assert_allclose(answer, out[0].numpy(), rtol=1e-05)


@dy2static_unittest
class TestRemoveCommentInDy2St(unittest.TestCase):
class TestRemoveCommentInDy2St(Dy2StTestBase):
def func_with_comment(self):
# Comment1
x = paddle.to_tensor([1, 2, 3])
Expand Down Expand Up @@ -356,8 +352,7 @@ def func1(x):
return func1(data)


@dy2static_unittest
class TestParameterRecorder(unittest.TestCase):
class TestParameterRecorder(Dy2StTestBase):
def test_recorder(self):
"""function calls nn.Layer case."""
net = Net()
Expand Down
10 changes: 3 additions & 7 deletions test/dygraph_to_static/test_ptb_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@
import unittest

import numpy as np
from dygraph_to_static_util import (
dy2static_unittest,
test_and_compare_with_new_ir,
)
from dygraph_to_static_utils_new import Dy2StTestBase, compare_legacy_with_pir

import paddle
from paddle import base
Expand Down Expand Up @@ -318,14 +315,13 @@ def train_dygraph(place):
return train(place)


@test_and_compare_with_new_ir(True)
@compare_legacy_with_pir
def train_static(place):
paddle.jit.enable_to_static(True)
return train(place)


@dy2static_unittest
class TestPtb(unittest.TestCase):
class TestPtb(Dy2StTestBase):
def setUp(self):
self.place = (
base.CUDAPlace(0)
Expand Down
5 changes: 2 additions & 3 deletions test/dygraph_to_static/test_ptb_lm_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import unittest

import numpy as np
from dygraph_to_static_util import dy2static_unittest
from dygraph_to_static_utils_new import Dy2StTestBase

import paddle

Expand Down Expand Up @@ -323,8 +323,7 @@ def train_static(place):
return train(place)


@dy2static_unittest
class TestPtb(unittest.TestCase):
class TestPtb(Dy2StTestBase):
def setUp(self):
self.place = (
paddle.CUDAPlace(0)
Expand Down
11 changes: 4 additions & 7 deletions test/dygraph_to_static/test_pylayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import unittest

import numpy as np
from dygraph_to_static_util import dy2static_unittest
from dygraph_to_static_utils_new import Dy2StTestBase
from test_jit_save_load import train

import paddle
Expand Down Expand Up @@ -263,8 +263,7 @@ def forward(self, x):
return out


@dy2static_unittest
class TestPyLayerBase(unittest.TestCase):
class TestPyLayerBase(Dy2StTestBase):
def setUp(self):
self.place = "gpu" if paddle.is_compiled_with_cuda() else "cpu"
self.to_static = False
Expand Down Expand Up @@ -514,8 +513,7 @@ def test_pylayer_net_with_no_grad(self):
self._run_and_compare(input1, input2)


@dy2static_unittest
class PyLayerTrainHelper(unittest.TestCase):
class PyLayerTrainHelper(Dy2StTestBase):
def setUp(self):
self.place = "gpu" if paddle.is_compiled_with_cuda() else "cpu"

Expand Down Expand Up @@ -588,8 +586,7 @@ def test_pylayer_net_no_grad(self):
)


@dy2static_unittest
class TestPyLayerJitSaveLoad(unittest.TestCase):
class TestPyLayerJitSaveLoad(Dy2StTestBase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
self.model_path = os.path.join(
Expand Down
10 changes: 3 additions & 7 deletions test/dygraph_to_static/test_reinforcement_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@

import gym
import numpy as np
from dygraph_to_static_util import (
dy2static_unittest,
test_and_compare_with_new_ir,
)
from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir

import paddle
import paddle.nn.functional as F
Expand Down Expand Up @@ -206,8 +203,7 @@ def finish_episode():
return np.array(loss_data)


@dy2static_unittest
class TestDeclarative(unittest.TestCase):
class TestDeclarative(Dy2StTestBase):
def setUp(self):
self.place = (
base.CUDAPlace(0)
Expand All @@ -216,7 +212,7 @@ def setUp(self):
)
self.args = Args()

@test_and_compare_with_new_ir(False)
@test_legacy_and_pir
def test_train(self):
st_out = train(self.args, self.place, to_static=True)
dy_out = train(self.args, self.place, to_static=False)
Expand Down
7 changes: 3 additions & 4 deletions test/dygraph_to_static/test_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import unittest

import numpy as np
from dygraph_to_static_util import dy2static_unittest, test_with_new_ir
from dygraph_to_static_utils_new import Dy2StTestBase, test_pir_only
from predictor_utils import PredictorTools

import paddle
Expand Down Expand Up @@ -386,8 +386,7 @@ def predict_analysis_inference(self, data):
return out


@dy2static_unittest
class TestResnet(unittest.TestCase):
class TestResnet(Dy2StTestBase):
def setUp(self):
self.resnet_helper = ResNetHelper()

Expand Down Expand Up @@ -420,7 +419,7 @@ def verify_predict(self):
err_msg=f'predictor_pre:\n {predictor_pre}\n, st_pre: \n{st_pre}.',
)

@test_with_new_ir
@test_pir_only
def test_resnet_new_ir(self):
static_loss = self.train(to_static=True)
dygraph_loss = self.train(to_static=False)
Expand Down
10 changes: 3 additions & 7 deletions test/dygraph_to_static/test_resnet_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@
import unittest

import numpy as np
from dygraph_to_static_util import (
dy2static_unittest,
test_and_compare_with_new_ir,
)
from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir
from test_resnet import SEED, ResNet, optimizer_setting

import paddle
Expand Down Expand Up @@ -114,13 +111,12 @@ def train(to_static, build_strategy=None):
return total_loss.numpy()


@dy2static_unittest
class TestResnet(unittest.TestCase):
class TestResnet(Dy2StTestBase):
def train(self, to_static):
paddle.jit.enable_to_static(to_static)
return train(to_static)

@test_and_compare_with_new_ir(False)
@test_legacy_and_pir
def test_resnet(self):
static_loss = self.train(to_static=True)
dygraph_loss = self.train(to_static=False)
Expand Down
10 changes: 3 additions & 7 deletions test/dygraph_to_static/test_resnet_pure_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@
import unittest

import numpy as np
from dygraph_to_static_util import (
dy2static_unittest,
test_and_compare_with_new_ir,
)
from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir
from test_resnet import SEED, ResNet, optimizer_setting

import paddle
Expand Down Expand Up @@ -115,8 +112,7 @@ def train(to_static, build_strategy=None):
return loss_data


@dy2static_unittest
class TestResnet(unittest.TestCase):
class TestResnet(Dy2StTestBase):
def train(self, to_static):
paddle.jit.enable_to_static(to_static)
build_strategy = paddle.static.BuildStrategy()
Expand All @@ -125,7 +121,7 @@ def train(self, to_static):
build_strategy.enable_inplace = False
return train(to_static, build_strategy)

@test_and_compare_with_new_ir(False)
@test_legacy_and_pir
def test_resnet(self):
if base.is_compiled_with_cuda():
static_loss = self.train(to_static=True)
Expand Down
7 changes: 3 additions & 4 deletions test/dygraph_to_static/test_resnet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import unittest

import numpy as np
from dygraph_to_static_util import dy2static_unittest, test_with_new_ir
from dygraph_to_static_utils_new import Dy2StTestBase, test_pir_only
from predictor_utils import PredictorTools

import paddle
Expand Down Expand Up @@ -242,8 +242,7 @@ def __len__(self):
return len(self.img)


@dy2static_unittest
class TestResnet(unittest.TestCase):
class TestResnet(Dy2StTestBase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()

Expand Down Expand Up @@ -427,7 +426,7 @@ def verify_predict(self):
err_msg=f'predictor_pre:\n {predictor_pre}\n, st_pre: \n{st_pre}.',
)

@test_with_new_ir
@test_pir_only
def test_resnet_new_ir(self):
static_loss = self.train(to_static=True)
dygraph_loss = self.train(to_static=False)
Expand Down
Loading