Skip to content

Commit

Permalink
[Dy2St] pir dy2st unittest verification - Part 5 (#58965)
Browse files Browse the repository at this point in the history
  • Loading branch information
gouzil authored Nov 14, 2023
1 parent e80c65e commit fb85aa3
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 32 deletions.
18 changes: 10 additions & 8 deletions test/dygraph_to_static/test_ast_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from dygraph_to_static_utils_new import (
Dy2StTestBase,
test_ast_only,
test_legacy_and_pir,
test_legacy_and_pir_api,
)
from ifelse_simple_func import (
dyfunc_with_if_else,
Expand Down Expand Up @@ -48,6 +48,7 @@ def _ast2func(self, func):
return transformed_func

@test_ast_only
@test_legacy_and_pir_api
def test_ast2func(self):
def func(x, y):
return x + y
Expand All @@ -56,19 +57,19 @@ def func(x, y):
self.assertEqual(func(x, y), self._ast2func(func)(x, y))

@test_ast_only
@test_legacy_and_pir_api
def test_ast2func_dygraph(self):
paddle.disable_static()
funcs = [dyfunc_with_if_else, dyfunc_with_if_else2, nested_if_else]
x_data = np.random.random([10, 16]).astype('float32')
for func in funcs:
with base.dygraph.guard():
x_v = base.dygraph.to_variable(x_data)
true_ret = func(x_v).numpy()
test_ret = self._ast2func(func)(x_v).numpy()
self.assertTrue((true_ret == test_ret).all())
x_v = base.dygraph.to_variable(x_data)
true_ret = func(x_v).numpy()
test_ret = self._ast2func(func)(x_v).numpy()
self.assertTrue((true_ret == test_ret).all())

@test_legacy_and_pir
@test_ast_only
@test_legacy_and_pir_api
def test_ast2func_static(self):
paddle.enable_static()

Expand All @@ -83,11 +84,12 @@ def func(x):
x_v = paddle.assign(x_data)
true_ret = func(x_v)
test_ret = self._ast2func(func)(x_v)
exe = base.Executor(base.CPUPlace())
exe = base.Executor(paddle.CPUPlace())
ret = exe.run(main_program, fetch_list=[true_ret, test_ret])
self.assertTrue((ret[0] == ret[1]).all())

@test_ast_only
@test_legacy_and_pir_api
def test_ast2func_error(self):
with self.assertRaises(Exception) as e:
self.assertRaises(TypeError, ast_to_func("x = a + b", 'foo'))
Expand Down
42 changes: 22 additions & 20 deletions test/dygraph_to_static/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,30 @@
Dy2StTestBase,
test_ast_only,
test_legacy_and_pir,
test_legacy_and_pir_exe_and_pir_api,
)

from paddle import base
from paddle.jit.api import to_static
import paddle
from paddle.base.dygraph import to_variable

SEED = 2020
np.random.seed(SEED)


def test_bool_cast(x):
x = base.dygraph.to_variable(x)
x = to_variable(x)
x = bool(x)
return x


def test_int_cast(x):
x = base.dygraph.to_variable(x)
x = to_variable(x)
x = int(x)
return x


def test_float_cast(x):
x = base.dygraph.to_variable(x)
x = to_variable(x)
x = float(x)
return x

Expand All @@ -52,7 +53,7 @@ def test_not_var_cast(x):


def test_mix_cast(x):
x = base.dygraph.to_variable(x)
x = to_variable(x)
x = int(x)
x = float(x)
x = bool(x)
Expand All @@ -63,12 +64,11 @@ def test_mix_cast(x):
class TestCastBase(Dy2StTestBase):
def setUp(self):
self.place = (
base.CUDAPlace(0)
if base.is_compiled_with_cuda()
else base.CPUPlace()
paddle.CUDAPlace(0)
if paddle.is_compiled_with_cuda()
else paddle.CPUPlace()
)
self.prepare()
self.set_func()

def prepare(self):
self.input_shape = (16, 32)
Expand All @@ -81,16 +81,16 @@ def prepare(self):
self.cast_dtype = 'bool'

def set_func(self):
self.func = to_static(full_graph=True)(test_bool_cast)
self.func = paddle.jit.to_static(full_graph=True)(test_bool_cast)

def do_test(self):
with base.dygraph.guard():
res = self.func(self.input)
return res
res = self.func(self.input)
return res

@test_ast_only # TODO: add new sot only test.
@test_legacy_and_pir
@test_legacy_and_pir_exe_and_pir_api
def test_cast_result(self):
self.set_func()
res = self.do_test().numpy()
self.assertTrue(
res.dtype == self.cast_dtype,
Expand Down Expand Up @@ -119,7 +119,7 @@ def prepare(self):
self.cast_dtype = 'int32'

def set_func(self):
self.func = to_static(full_graph=True)(test_int_cast)
self.func = paddle.jit.to_static(full_graph=True)(test_int_cast)


class TestFloatCast(TestCastBase):
Expand All @@ -134,7 +134,7 @@ def prepare(self):
self.cast_dtype = 'float32'

def set_func(self):
self.func = to_static(full_graph=True)(test_float_cast)
self.func = paddle.jit.to_static(full_graph=True)(test_float_cast)


class TestMixCast(TestCastBase):
Expand All @@ -152,11 +152,12 @@ def prepare(self):
self.cast_dtype = 'float32'

def set_func(self):
self.func = to_static(full_graph=True)(test_mix_cast)
self.func = paddle.jit.to_static(full_graph=True)(test_mix_cast)

@test_ast_only # TODO: add new symbolic only test.
@test_legacy_and_pir
@test_legacy_and_pir_exe_and_pir_api
def test_cast_result(self):
self.set_func()
res = self.do_test().numpy()
self.assertTrue(
res.dtype == self.cast_dtype,
Expand Down Expand Up @@ -184,11 +185,12 @@ def prepare(self):
self.cast_dtype = 'int'

def set_func(self):
self.func = to_static(full_graph=True)(test_not_var_cast)
self.func = paddle.jit.to_static(full_graph=True)(test_not_var_cast)

@test_ast_only
@test_legacy_and_pir
def test_cast_result(self):
self.set_func()
# breakpoint()
# print("run once!!!")
res = self.do_test()
Expand Down
4 changes: 2 additions & 2 deletions test/dygraph_to_static/test_cinn_prim.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from dygraph_to_static_utils_new import (
Dy2StTestBase,
test_ast_only,
test_legacy_and_pir,
test_legacy_and_pir_exe_and_pir_api,
)

import paddle
Expand Down Expand Up @@ -171,7 +171,7 @@ def test_cinn_prim(self):


class TestBackend(Dy2StTestBase):
@test_legacy_and_pir
@test_legacy_and_pir_exe_and_pir_api
def test_backend(self):
x = paddle.randn([2, 4])
out1 = self.forward(x, 'CINN')
Expand Down
6 changes: 5 additions & 1 deletion test/dygraph_to_static/test_tensor_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
import unittest

import numpy as np
from dygraph_to_static_utils_new import Dy2StTestBase
from dygraph_to_static_utils_new import (
Dy2StTestBase,
test_legacy_and_pir_exe_and_pir_api,
)

import paddle
from paddle import nn
Expand Down Expand Up @@ -94,6 +97,7 @@ def h(g):
loss.backward()
np.testing.assert_allclose(x.grad.numpy(), x_jit.grad.numpy())

@test_legacy_and_pir_exe_and_pir_api
def test_hook_in_init_for_layer(self):
def hook(grad):
return grad * 2
Expand Down
3 changes: 2 additions & 1 deletion test/dygraph_to_static/test_variable_trans_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@

import unittest

from dygraph_to_static_utils_new import Dy2StTestBase
from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir_api

from paddle.jit.dy2static.utils import ast_to_source_code
from paddle.jit.dy2static.variable_trans_func import create_fill_constant_node


class TestVariableTransFunc(Dy2StTestBase):
@test_legacy_and_pir_api
def test_create_fill_constant_node(self):
node = create_fill_constant_node("a", 1.0)
source = "a = paddle.full(shape=[1], dtype='float64', fill_value=1.0, name='a')"
Expand Down

0 comments on commit fb85aa3

Please sign in to comment.