Skip to content

Commit

Permalink
[Dy2St] Refactor dy2st unittest decorators name - Part 7 (#58509)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: SigureMo <sigure.qaq@gmail.com>
  • Loading branch information
gouzil and SigureMo authored Oct 31, 2023
1 parent 75afa68 commit 4cae876
Show file tree
Hide file tree
Showing 10 changed files with 35 additions and 229 deletions.
173 changes: 0 additions & 173 deletions test/dygraph_to_static/dygraph_to_static_util.py

This file was deleted.

5 changes: 2 additions & 3 deletions test/dygraph_to_static/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import unittest

import numpy as np
from dygraph_to_static_util import dy2static_unittest
from dygraph_to_static_utils_new import Dy2StTestBase

import paddle

Expand Down Expand Up @@ -70,8 +70,7 @@ def forward(self, x):
return self.layers(x)


@dy2static_unittest
class TestSequential(unittest.TestCase):
class TestSequential(Dy2StTestBase):
def setUp(self):
paddle.set_device('cpu')
self.seed = 2021
Expand Down
13 changes: 0 additions & 13 deletions test/dygraph_to_static/test_ifelse_basic.py

This file was deleted.

5 changes: 2 additions & 3 deletions test/dygraph_to_static/test_lac.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

os.environ["CUDA_VISIBLE_DEVICES"] = "2"

from dygraph_to_static_util import dy2static_unittest
from dygraph_to_static_utils_new import Dy2StTestBase

import paddle
from paddle import _legacy_C_ops, base
Expand Down Expand Up @@ -515,8 +515,7 @@ def create_dataloader(reader, place):
return data_loader


@dy2static_unittest
class TestLACModel(unittest.TestCase):
class TestLACModel(Dy2StTestBase):
def setUp(self):
self.args = Args()
self.place = (
Expand Down
5 changes: 2 additions & 3 deletions test/dygraph_to_static/test_no_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import unittest

import numpy
from dygraph_to_static_util import dy2static_unittest
from dygraph_to_static_utils_new import Dy2StTestBase

import paddle

Expand All @@ -32,8 +32,7 @@ def main_func(x, index):
return out


@dy2static_unittest
class TestNoGradientCase(unittest.TestCase):
class TestNoGradientCase(Dy2StTestBase):
def test_no_gradient(self):
paddle.disable_static()
x = paddle.randn([10, 3])
Expand Down
10 changes: 3 additions & 7 deletions test/dygraph_to_static/test_sentiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
import unittest

import numpy as np
from dygraph_to_static_util import (
dy2static_unittest,
test_and_compare_with_new_ir,
)
from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir
from test_lac import DynamicGRU

import paddle
Expand Down Expand Up @@ -372,12 +369,11 @@ def train(args, to_static):
return loss_data


@dy2static_unittest
class TestSentiment(unittest.TestCase):
class TestSentiment(Dy2StTestBase):
def setUp(self):
self.args = Args()

@test_and_compare_with_new_ir(False)
@test_legacy_and_pir
def train_model(self, model_type='cnn_net'):
self.args.model_type = model_type
st_out = train(self.args, True)
Expand Down
15 changes: 7 additions & 8 deletions test/dygraph_to_static/test_write_python_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@

import unittest

from dygraph_to_static_util import (
ast_only_test,
dy2static_unittest,
sot_only_test,
from dygraph_to_static_utils_new import (
Dy2StTestBase,
test_ast_only,
test_sot_only,
)

import paddle
Expand Down Expand Up @@ -99,8 +99,7 @@ def func_ifelse_write_nest_list_dict(x):
return res


@dy2static_unittest
class TestWriteContainer(unittest.TestCase):
class TestWriteContainer(Dy2StTestBase):
def setUp(self):
self.set_func()
self.set_getitem_path()
Expand All @@ -117,15 +116,15 @@ def get_raw_value(self, container, getitem_path):
out = out[path]
return out

@sot_only_test
@test_sot_only
def test_write_container_sot(self):
func_static = paddle.jit.to_static(self.func)
input = paddle.to_tensor([1, 2, 3])
out_static = self.get_raw_value(func_static(input), self.getitem_path)
out_dygraph = self.get_raw_value(self.func(input), self.getitem_path)
self.assertEqual(out_static, out_dygraph)

@ast_only_test
@test_ast_only
def test_write_container(self):
func_static = paddle.jit.to_static(self.func)
input = paddle.to_tensor([1, 2, 3])
Expand Down
14 changes: 7 additions & 7 deletions test/legacy_test/test_cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from simple_nets import batchnorm_fc_with_inputs, simple_fc_net_with_inputs

sys.path.append("../dygraph_to_static")
from dygraph_to_static_util import test_and_compare_with_new_ir
from dygraph_to_static_utils_new import compare_legacy_with_pir

import paddle
from paddle import base
Expand All @@ -31,7 +31,7 @@


class TestCondInputOutput(unittest.TestCase):
@test_and_compare_with_new_ir()
@compare_legacy_with_pir
def test_return_single_var(self):
"""
pseudocode:
Expand Down Expand Up @@ -78,7 +78,7 @@ def false_func():
np.asarray(ret), np.full((3, 2), -1, np.int32), rtol=1e-05
)

@test_and_compare_with_new_ir()
@compare_legacy_with_pir
def test_return_0d_tensor(self):
"""
pseudocode:
Expand Down Expand Up @@ -116,7 +116,7 @@ def false_func():
np.testing.assert_allclose(np.asarray(ret), np.array(2), rtol=1e-05)
self.assertEqual(ret.shape, ())

@test_and_compare_with_new_ir()
@compare_legacy_with_pir
def test_0d_tensor_as_cond(self):
"""
pseudocode:
Expand Down Expand Up @@ -217,7 +217,7 @@ def test_0d_tensor_dygraph(self):
)
self.assertEqual(a.grad.shape, [])

@test_and_compare_with_new_ir()
@compare_legacy_with_pir
def test_return_var_tuple(self):
"""
pseudocode:
Expand Down Expand Up @@ -265,7 +265,7 @@ def false_func():
np.asarray(ret[1]), np.full((2, 3), True, bool), rtol=1e-05
)

@test_and_compare_with_new_ir()
@compare_legacy_with_pir
def test_pass_and_modify_var(self):
"""
pseudocode:
Expand Down Expand Up @@ -356,7 +356,7 @@ def false_func():
self.assertIsNone(out2)
self.assertIsNone(out3)

@test_and_compare_with_new_ir()
@compare_legacy_with_pir
def test_wrong_structure_exception(self):
"""
test returning different number of tensors cannot merge into output
Expand Down
Loading

0 comments on commit 4cae876

Please sign in to comment.