Skip to content

Commit

Permalink
[Dy2St] Refactor dy2st unittest decorators name - Part 3 (PaddlePaddl…
Browse files Browse the repository at this point in the history
  • Loading branch information
gouzil authored Oct 29, 2023
1 parent 8662d9c commit e8c334a
Show file tree
Hide file tree
Showing 16 changed files with 79 additions and 97 deletions.
19 changes: 10 additions & 9 deletions test/dygraph_to_static/dygraph_to_static_utils_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class MyTest(Dy2StTestBase):
@set_to_static_mode(
ToStaticMode.LEGACY_AST | ToStaticMode.SOT | ToStaticMode.PIR_AST
)
@set_ir_mode(IrMode.LEGACY_PROGRAM | IrMode.PIR)
@set_ir_mode(IrMode.LEGACY_IR | IrMode.PIR)
def test_case1(self):
raise ValueError("MyTest 1")
Expand All @@ -58,15 +58,15 @@ def lower_case_name(self):


class IrMode(Flag):
LEGACY_PROGRAM = auto()
LEGACY_IR = auto()
PIR = auto()

def lower_case_name(self):
return self.name.lower()


DEFAULT_TO_STATIC_MODE = ToStaticMode.LEGACY_AST | ToStaticMode.SOT
DEFAULT_IR_MODE = IrMode.LEGACY_PROGRAM
DEFAULT_IR_MODE = IrMode.LEGACY_IR


def to_legacy_ast_test(fn):
Expand Down Expand Up @@ -101,9 +101,10 @@ def to_pir_ast_test(fn):
raise TypeError("Don't enable PIR AST mode now!")


def to_legacy_program_test(fn):
def to_legacy_ir_test(fn):
def impl(*args, **kwargs):
logger.info("[Program] running legacy program")
logger.info("[Program] running legacy ir")
# breakpoint()
return fn(*args, **kwargs)

return impl
Expand Down Expand Up @@ -140,7 +141,7 @@ class Dy2StTestMeta(type):
}

IR_HANDLER_MAP = {
IrMode.LEGACY_PROGRAM: to_legacy_program_test,
IrMode.LEGACY_IR: to_legacy_ir_test,
IrMode.PIR: to_pir_test,
}

Expand Down Expand Up @@ -192,9 +193,9 @@ def __new__(cls, name, bases, attrs):
for to_static_mode, ir_mode in to_static_with_ir_modes:
if (
to_static_mode == ToStaticMode.PIR_AST
and ir_mode == IrMode.LEGACY_PROGRAM
and ir_mode == IrMode.LEGACY_IR
):
# PIR with LEGACY_PROGRAM is not a valid combination
# PIR with LEGACY_IR is not a valid combination
continue
new_attrs[
Dy2StTestMeta.test_case_name(
Expand Down Expand Up @@ -264,7 +265,7 @@ def test_pir_only(fn):


def test_legacy_and_pir(fn):
fn = set_ir_mode(IrMode.LEGACY_PROGRAM | IrMode.PIR)(fn)
fn = set_ir_mode(IrMode.LEGACY_IR | IrMode.PIR)(fn)
return fn


Expand Down
26 changes: 13 additions & 13 deletions test/dygraph_to_static/test_ifelse.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
import numpy as np
from dygraph_to_static_utils_new import (
Dy2StTestBase,
compare_legacy_with_pir,
test_ast_only,
test_legacy_and_pir,
)
from ifelse_simple_func import (
NetWithControlFlowIf,
Expand Down Expand Up @@ -66,7 +66,7 @@ def setUp(self):
self.error = "Your if/else have different number of return value."

@test_ast_only
@compare_legacy_with_pir
@test_legacy_and_pir
def test_error(self):
if self.dyfunc:
with self.assertRaisesRegex(Dygraph2StaticException, self.error):
Expand Down Expand Up @@ -98,7 +98,7 @@ def _run_dygraph(self, to_static=False):
ret = self.dyfunc(x_v)
return ret.numpy()

@compare_legacy_with_pir
@test_legacy_and_pir
def test_ast_to_func(self):
self.assertTrue((self._run_dygraph() == self._run_static()).all())

Expand Down Expand Up @@ -144,7 +144,7 @@ def _run_dygraph(self, to_static=False):
ret = self.dyfunc(x_v)
return ret.numpy()

@compare_legacy_with_pir
@test_legacy_and_pir
def test_ast_to_func(self):
self.assertTrue((self._run_dygraph() == self._run_static()).all())

Expand Down Expand Up @@ -270,7 +270,7 @@ def _run_dygraph(self, to_static=False):
ret = self.dyfunc(x_v)
return ret.numpy()

@compare_legacy_with_pir
@test_legacy_and_pir
def test_ast_to_func(self):
self.assertTrue((self._run_dygraph() == self._run_static()).all())

Expand Down Expand Up @@ -300,7 +300,7 @@ def _run(self, to_static=False):
ret = net(x_v)
return ret.numpy()

@compare_legacy_with_pir
@test_legacy_and_pir
def test_ast_to_func(self):
self.assertTrue((self._run_dygraph() == self._run_static()).all())

Expand Down Expand Up @@ -354,7 +354,7 @@ def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
self.Net = NetWithExternalFunc

@compare_legacy_with_pir
@test_legacy_and_pir
def test_ast_to_func(self):
self.assertTrue((self._run_dygraph() == self._run_static()).all())

Expand Down Expand Up @@ -412,7 +412,7 @@ def _run(self, mode, to_static):
ret = net(self.x, self.y)
return ret.numpy()

@compare_legacy_with_pir
@test_legacy_and_pir
def test_train_mode(self):
self.assertTrue(
(
Expand All @@ -421,7 +421,7 @@ def test_train_mode(self):
).all()
)

@compare_legacy_with_pir
@test_legacy_and_pir
def test_infer_mode(self):
self.assertTrue(
(
Expand All @@ -437,7 +437,7 @@ def init_net(self):


class TestNewVarCreateInOneBranch(Dy2StTestBase):
@compare_legacy_with_pir
@test_legacy_and_pir
def test_var_used_in_another_for(self):
def case_func(training):
# targets and targets_list is dynamically defined by training
Expand Down Expand Up @@ -474,7 +474,7 @@ def get_dy2stat_out(self):
return out

@test_ast_only
@compare_legacy_with_pir
@test_legacy_and_pir
def test_ast_to_func(self):
self.setUp()
self.assertIsInstance(self.out[0], (paddle.Tensor, core.eager.Tensor))
Expand All @@ -495,7 +495,7 @@ def setUp(self):
self.out = self.get_dy2stat_out()

@test_ast_only
@compare_legacy_with_pir
@test_legacy_and_pir
def test_ast_to_func(self):
self.setUp()
self.assertIsInstance(self.out, (paddle.Tensor, core.eager.Tensor))
Expand All @@ -507,7 +507,7 @@ def setUp(self):
self.dyfunc = paddle.jit.to_static(dyfunc_ifelse_ret_int4)

@test_ast_only
@compare_legacy_with_pir
@test_legacy_and_pir
def test_ast_to_func(self):
paddle.jit.enable_to_static(True)
with self.assertRaises(Dygraph2StaticException):
Expand Down
6 changes: 3 additions & 3 deletions test/dygraph_to_static/test_inplace_assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
import unittest

import numpy as np
from dygraph_to_static_util import test_and_compare_with_new_ir
from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir

import paddle


class TestInplaceAssign(unittest.TestCase):
class TestInplaceAssign(Dy2StTestBase):
def test_case0(self):
a = paddle.ones((1024, 2)) * 1
b = paddle.ones((1024, 3)) * 2
Expand All @@ -45,7 +45,7 @@ def func(x):
y.mean().backward()
np.testing.assert_array_equal(x.grad.numpy(), np.array([2.0]))

@test_and_compare_with_new_ir(False)
@test_legacy_and_pir
def test_case2(self):
@paddle.jit.to_static
def func(a, x):
Expand Down
8 changes: 3 additions & 5 deletions test/dygraph_to_static/test_logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import unittest

import numpy as np
from dygraph_to_static_util import dy2static_unittest
from dygraph_to_static_utils_new import Dy2StTestBase

import paddle
from paddle import base
Expand Down Expand Up @@ -168,8 +168,7 @@ def test_shape_not_equal(x):
return paddle.ones([1, 2, 3])


@dy2static_unittest
class TestLogicalBase(unittest.TestCase):
class TestLogicalBase(Dy2StTestBase):
def setUp(self):
self.input = np.array([3]).astype('int32')
self.place = (
Expand Down Expand Up @@ -264,8 +263,7 @@ def _set_test_func(self):
self.dygraph_func = test_shape_not_equal


@dy2static_unittest
class TestCmpopNodeToStr(unittest.TestCase):
class TestCmpopNodeToStr(Dy2StTestBase):
def test_exception(self):
with self.assertRaises(KeyError):
cmpop_node_to_str(gast.Or())
Expand Down
14 changes: 5 additions & 9 deletions test/dygraph_to_static/test_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import unittest

import numpy as np
from dygraph_to_static_util import dy2static_unittest
from dygraph_to_static_utils_new import Dy2StTestBase

import paddle
import paddle.nn.functional as F
Expand Down Expand Up @@ -230,8 +230,7 @@ def for_loop_dufunc_with_listcomp(array):
return res


@dy2static_unittest
class TestNameVisitor(unittest.TestCase):
class TestNameVisitor(Dy2StTestBase):
def setUp(self):
self.loop_funcs = [
while_loop_dyfunc,
Expand Down Expand Up @@ -301,8 +300,7 @@ def test_nested_loop_vars(self):
i += 1


@dy2static_unittest
class TestTransformWhileLoop(unittest.TestCase):
class TestTransformWhileLoop(Dy2StTestBase):
def setUp(self):
self.place = (
base.CUDAPlace(0)
Expand Down Expand Up @@ -381,8 +379,7 @@ def _init_dyfunc(self):
self.dyfunc = loop_var_contains_property


@dy2static_unittest
class TestTransformForLoop(unittest.TestCase):
class TestTransformForLoop(Dy2StTestBase):
def setUp(self):
self.place = (
base.CUDAPlace(0)
Expand Down Expand Up @@ -464,8 +461,7 @@ def forward(self, x):
return out


@dy2static_unittest
class TestForLoopMeetDict(unittest.TestCase):
class TestForLoopMeetDict(Dy2StTestBase):
def test_start(self):
net = Net()
model = paddle.jit.to_static(
Expand Down
21 changes: 11 additions & 10 deletions test/dygraph_to_static/test_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import unittest

import numpy as np
from dygraph_to_static_util import ast_only_test, dy2static_unittest
from dygraph_to_static_utils_new import Dy2StTestBase, test_ast_only

import paddle
from paddle import nn
Expand Down Expand Up @@ -45,8 +45,7 @@ def forward(self, x):
return x


@dy2static_unittest
class TestLstm(unittest.TestCase):
class TestLstm(Dy2StTestBase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()

Expand All @@ -71,8 +70,7 @@ def test_lstm_to_static(self):
static_out = self.run_lstm(to_static=True)
np.testing.assert_allclose(dygraph_out, static_out, rtol=1e-05)

@ast_only_test
def test_save_in_eval(self, with_training=True):
def save_in_eval(self, with_training: bool):
paddle.jit.enable_to_static(True)
net = Net(12, 2)
x = paddle.randn((2, 10, 12))
Expand Down Expand Up @@ -115,8 +113,13 @@ def test_save_in_eval(self, with_training=True):
err_msg=f'dygraph_out is {dygraph_out}\n static_out is \n{train_out}',
)

@test_ast_only
def test_save_without_training(self):
self.test_save_in_eval(with_training=False)
self.save_in_eval(with_training=False)

@test_ast_only
def test_save_with_training(self):
self.save_in_eval(with_training=True)


class LinearNet(nn.Layer):
Expand All @@ -132,8 +135,7 @@ def forward(self, x):
return y


@dy2static_unittest
class TestSaveInEvalMode(unittest.TestCase):
class TestSaveInEvalMode(Dy2StTestBase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()

Expand Down Expand Up @@ -176,8 +178,7 @@ def test_save_in_eval(self):
)


@dy2static_unittest
class TestEvalAfterSave(unittest.TestCase):
class TestEvalAfterSave(Dy2StTestBase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()

Expand Down
15 changes: 7 additions & 8 deletions test/dygraph_to_static/test_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
from time import time

import numpy as np
from dygraph_to_static_util import (
ast_only_test,
dy2static_unittest,
test_and_compare_with_new_ir,
from dygraph_to_static_utils_new import (
Dy2StTestBase,
compare_legacy_with_pir,
test_ast_only,
)
from predictor_utils import PredictorTools

Expand Down Expand Up @@ -130,8 +130,7 @@ def inference(self, inputs):
return x


@dy2static_unittest
class TestMNIST(unittest.TestCase):
class TestMNIST(Dy2StTestBase):
def setUp(self):
self.epoch_num = 1
self.batch_size = 64
Expand All @@ -158,14 +157,14 @@ class TestMNISTWithToStatic(TestMNIST):
still works if model is trained in dygraph mode.
"""

@test_and_compare_with_new_ir(True)
@compare_legacy_with_pir
def train_static(self):
return self.train(to_static=True)

def train_dygraph(self):
return self.train(to_static=False)

@ast_only_test
@test_ast_only
def test_mnist_to_static(self):
dygraph_loss = self.train_dygraph()
static_loss = self.train_static()
Expand Down
Loading

0 comments on commit e8c334a

Please sign in to comment.