Skip to content

Commit

Permalink
add prim test for some ops (#51749)
Browse files Browse the repository at this point in the history
* add tanh and cast prim test

* fix tanh test

* fix 0-d test

* add sqrt fp16 prim test

* add public_python_api in prim test

* fix test_squeeze2_op

* add tanh prim test

* add dropout prim test

* [Dy2St]Fix clone for test state problem

* clean code

* modify test_cumsum_op

* modify test_cumsum_op

* fix dropout test

* add dropout in cmake

* fix dropout test

---------

Co-authored-by: Aurelius84 <zhangliujie@baidu.com>
  • Loading branch information
Charles-hit and Aurelius84 authored Mar 27, 2023
1 parent 20befde commit e1674e8
Show file tree
Hide file tree
Showing 4 changed files with 435 additions and 47 deletions.
4 changes: 3 additions & 1 deletion python/paddle/fluid/tests/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1212,7 +1212,9 @@ set(TEST_CINN_OPS
test_mean_op
test_unsqueeze2_op
test_meshgrid_op
test_gather_op)
test_gather_op
test_cast_op
test_dropout_op)

foreach(TEST_CINN_OPS ${TEST_CINN_OPS})
if(WITH_CINN)
Expand Down
15 changes: 12 additions & 3 deletions python/paddle/fluid/tests/unittests/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,9 +469,12 @@ def test_errors(self):
class TestTanh(TestActivation, TestParameter):
def setUp(self):
self.op_type = "tanh"
self.prim_op_type = "prim"
self.python_api = paddle.tanh
self.public_python_api = paddle.tanh
self.init_dtype()
self.init_shape()
self.if_enable_cinn()

np.random.seed(1024)
x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
Expand All @@ -483,19 +486,25 @@ def setUp(self):
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)

def init_dtype(self):
# TODO If dtype is float64, the output (Out) has diff at CPUPlace
# when using and not using inplace. Therefore, set dtype as float32
# for now.
self.dtype = np.float32

def if_enable_cinn(self):
pass


class TestTanh_ZeroDim(TestTanh):
def init_shape(self):
self.shape = []

def if_enable_cinn(self):
self.enable_cinn = False


class TestTanhAPI(unittest.TestCase):
# test paddle.tanh, paddle.nn.tanh, paddle.nn.functional.tanh
Expand Down Expand Up @@ -601,7 +610,7 @@ def test_dygraph(self):
self.assertEqual(z, z_expected)


class TestAtan_ZeroDim(TestTanh):
class TestAtan_ZeroDim(TestAtan):
def init_shape(self):
self.shape = []

Expand Down Expand Up @@ -3910,7 +3919,7 @@ def test_check_grad(self):
create_test_act_fp16_class(TestTanhshrink)
create_test_act_fp16_class(TestHardShrink)
create_test_act_fp16_class(TestSoftshrink)
create_test_act_fp16_class(TestSqrt)
create_test_act_fp16_class(TestSqrt, check_prim=True)
create_test_act_fp16_class(TestSqrtComp, check_prim=True)
create_test_act_fp16_class(TestAbs, check_prim=True)
create_test_act_fp16_class(TestCeil, grad_check=False)
Expand Down
57 changes: 26 additions & 31 deletions python/paddle/fluid/tests/unittests/test_cast_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,33 +28,8 @@
from paddle.fluid import Program, core, program_guard


def convert_to_dtype_(dtype):
if dtype == 5:
return core.VarDesc.VarType.FP32
elif dtype == 6:
return core.VarDesc.VarType.FP64
elif dtype == 4:
return core.VarDesc.VarType.FP16
elif dtype == 2:
return core.VarDesc.VarType.INT32
elif dtype == 1:
return core.VarDesc.VarType.INT16
elif dtype == 3:
return core.VarDesc.VarType.INT64
elif dtype == 0:
return core.VarDesc.VarType.BOOL
elif dtype == 22:
return core.VarDesc.VarType.BF16
elif dtype == 20:
return core.VarDesc.VarType.UINT8
elif dtype == 21:
return core.VarDesc.VarType.INT8
elif dtype == np.complex64:
raise ValueError("Not supported dtype %s" % dtype)


def cast_wrapper(x, out_dtype=None):
return paddle.tensor.cast(x, convert_to_dtype_(out_dtype))
return paddle.cast(x, paddle.dtype(out_dtype))


class TestCastOpFp32ToFp64(OpTest):
Expand All @@ -67,13 +42,15 @@ def setUp(self):
'out_dtype': int(core.VarDesc.VarType.FP64),
}
self.op_type = 'cast'
self.prim_op_type = "prim"
self.python_api = cast_wrapper
self.public_python_api = cast_wrapper

def test_check_output(self):
self.check_output()

def test_grad(self):
self.check_grad(['X'], ['Out'])
self.check_grad(['X'], ['Out'], check_prim=True)


class TestCastOpFp16ToFp32(OpTest):
Expand All @@ -86,12 +63,16 @@ def setUp(self):
'out_dtype': int(core.VarDesc.VarType.FP32),
}
self.op_type = 'cast'
self.__class__.no_need_check_grad = True
self.prim_op_type = "prim"
self.python_api = cast_wrapper
self.public_python_api = cast_wrapper

def test_check_output(self):
self.check_output(atol=1e-3)

def test_grad(self):
self.check_grad(['X'], ['Out'], check_prim=True, only_check_prim=True)


class TestCastOpFp32ToFp16(OpTest):
def setUp(self):
Expand All @@ -103,12 +84,16 @@ def setUp(self):
'out_dtype': int(core.VarDesc.VarType.FP16),
}
self.op_type = 'cast'
self.__class__.no_need_check_grad = True
self.prim_op_type = "prim"
self.python_api = cast_wrapper
self.public_python_api = cast_wrapper

def test_check_output(self):
self.check_output(atol=1e-3)

def test_grad(self):
self.check_grad(['X'], ['Out'], check_prim=True, only_check_prim=True)


class TestCastOpBf16ToFp32(OpTest):
def setUp(self):
Expand All @@ -120,12 +105,17 @@ def setUp(self):
'out_dtype': int(core.VarDesc.VarType.FP32),
}
self.op_type = 'cast'
self.__class__.no_need_check_grad = True
self.prim_op_type = "prim"
self.python_api = cast_wrapper
self.public_python_api = cast_wrapper
self.enable_cinn = False

def test_check_output(self):
self.check_output()

def test_grad(self):
self.check_grad(['X'], ['Out'], check_prim=True, only_check_prim=True)


class TestCastOpFp32ToBf16(OpTest):
def setUp(self):
Expand All @@ -137,12 +127,17 @@ def setUp(self):
'out_dtype': int(core.VarDesc.VarType.BF16),
}
self.op_type = 'cast'
self.__class__.no_need_check_grad = True
self.prim_op_type = "prim"
self.python_api = cast_wrapper
self.public_python_api = cast_wrapper
self.enable_cinn = False

def test_check_output(self):
self.check_output()

def test_grad(self):
self.check_grad(['X'], ['Out'], check_prim=True, only_check_prim=True)


class TestCastOpError(unittest.TestCase):
def test_errors(self):
Expand Down
Loading

0 comments on commit e1674e8

Please sign in to comment.