Skip to content

Commit

Permalink
Add split, split_with_num tests
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc committed Apr 13, 2023
1 parent 2b40434 commit c4f3c64
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 18 deletions.
88 changes: 70 additions & 18 deletions python/paddle/fluid/tests/unittests/test_split_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_check_grad(self):


# test with attr(num)
class TestSplitOp_2(OpTest):
class TestSplitWithNumOp(OpTest):
def setUp(self):
self.python_api = paddle.split
self.public_python_api = paddle.split
Expand All @@ -74,18 +74,32 @@ def setUp(self):
self.prim_op_type = "prim"
self.dtype = self.get_dtype()
self.init_data()
self.inputs = {'X': self.x}
self.attrs = {
'axis': self.axis,
'sections': self.sections,
'num': self.num,
}

out = np.split(self.x, self.indices_or_sections, self.axis)
self.outputs = {'Out': [('out%d' % i, out[i]) for i in range(len(out))]}
if self.dtype == np.uint16:
self.inputs = {'X': convert_float_to_uint16(self.x)}
out = np.split(self.x, self.indices_or_sections, self.axis)
self.outputs = {
'Out': [
('out%d' % i, convert_float_to_uint16(out[i]))
for i in range(len(out))
]
}
else:
self.inputs = {'X': self.x}
out = np.split(self.x, self.indices_or_sections, self.axis)
self.outputs = {
'Out': [('out%d' % i, out[i]) for i in range(len(out))]
}

def init_data(self):
self.x = np.random.random((4, 5, 6)).astype(self.dtype)
if self.dtype == np.uint16:
self.x = np.random.random((4, 5, 6)).astype(np.float32)
else:
self.x = np.random.random((4, 5, 6)).astype(self.dtype)
self.axis = 2
self.sections = []
self.num = 3
Expand Down Expand Up @@ -241,28 +255,41 @@ def create_test_fp16(parent):
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestSplitFp16(parent):
class TestSplitFP16Op(parent):
def get_dtype(self):
return np.float16

def test_check_grad(self):
pass
cls_name = "{}_{}".format(parent.__name__, "FP16Op")
TestSplitFP16Op.__name__ = cls_name
globals()[cls_name] = TestSplitFP16Op


def create_test_split_with_num_fp16(parent):
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestSplitWithNumFP16Op(parent):
def get_dtype(self):
return np.float16

cls_name = "{}_{}".format(parent.__name__, "Fp16")
TestSplitFp16.__name__ = cls_name
globals()[cls_name] = TestSplitFp16
cls_name = "{}_{}".format(parent.__name__, "FP16Op")
TestSplitWithNumFP16Op.__name__ = cls_name
globals()[cls_name] = TestSplitWithNumFP16Op


create_test_fp16(TestSplitOp)
create_test_split_with_num_fp16(TestSplitWithNumOp)

# ----------------Split Bf16----------------


def create_test_bf16(parent):
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestSplitBf16(parent):
class TestSplitBF16Op(parent):
def get_dtype(self):
return np.uint16

Expand All @@ -271,14 +298,39 @@ def test_check_output(self):
self.check_output_with_place(place)

def test_check_grad(self):
pass
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'out2')

cls_name = "{}_{}".format(parent.__name__, "BF16Op")
TestSplitBF16Op.__name__ = cls_name
globals()[cls_name] = TestSplitBF16Op


def create_test_split_with_num_bf16(parent):
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestSplitWithNumBF16Op(TestSplitWithNumOp):
def get_dtype(self):
return np.uint16

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)

def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'out2')

cls_name = "{}_{}".format(parent.__name__, "Bf16")
TestSplitBf16.__name__ = cls_name
globals()[cls_name] = TestSplitBf16
cls_name = "{}_{}".format(parent.__name__, "BF16Op")
TestSplitWithNumBF16Op.__name__ = cls_name
globals()[cls_name] = TestSplitWithNumBF16Op


create_test_bf16(TestSplitOp)
create_test_split_with_num_bf16(TestSplitWithNumOp)


class TestSplitAPI(unittest.TestCase):
Expand Down
1 change: 1 addition & 0 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1976,6 +1976,7 @@ def split(x, num_or_sections, axis=0, name=None):
'int32',
'int64',
'uint8',
'uint16',
'int8',
],
'split',
Expand Down

0 comments on commit c4f3c64

Please sign in to comment.