diff --git a/python/paddle/fluid/tests/unittests/test_data_generator.py b/python/paddle/fluid/tests/unittests/test_data_generator.py index 7cf7439ddc86b..69d8e01fd464a 100644 --- a/python/paddle/fluid/tests/unittests/test_data_generator.py +++ b/python/paddle/fluid/tests/unittests/test_data_generator.py @@ -95,6 +95,19 @@ def data_iter(): return data_iter +class MyMultiSlotStringDataGenerator_zip(fleet.MultiSlotStringDataGenerator): + def generate_sample(self, line): + def data_iter(): + for i in range(40): + if i == 1: + yield None + feature_name = ["words", "label"] + data = [["1", "2", "3", "4"], ["0"]] + yield zip(feature_name, data) + + return data_iter + + class MyMultiSlotDataGenerator_zip(fleet.MultiSlotDataGenerator): def generate_sample(self, line): def data_iter(): @@ -162,6 +175,13 @@ def test_MultiSlotDataGenerator_error(self): my_ms_dg.run_from_memory() +class TestMultiSlotStringDataGeneratorZip(unittest.TestCase): + def test_MultiSlotStringDataGenerator_zip(self): + my_ms_dg = MyMultiSlotStringDataGenerator_zip() + my_ms_dg.set_batch(1) + my_ms_dg.run_from_memory() + + class TestMultiSlotDataGeneratorZip(unittest.TestCase): def test_MultiSlotDataGenerator_zip(self): my_ms_dg = MyMultiSlotDataGenerator_zip()