From 03c2458a93b84f46b74faed1864c072a2e1da411 Mon Sep 17 00:00:00 2001 From: MrChengmo Date: Thu, 18 Mar 2021 19:03:08 +0800 Subject: [PATCH] add unittest --- .../tests/unittests/test_data_generator.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_data_generator.py b/python/paddle/fluid/tests/unittests/test_data_generator.py index 7cf7439ddc86b..69d8e01fd464a 100644 --- a/python/paddle/fluid/tests/unittests/test_data_generator.py +++ b/python/paddle/fluid/tests/unittests/test_data_generator.py @@ -95,6 +95,19 @@ def data_iter(): return data_iter +class MyMultiSlotStringDataGenerator_zip(fleet.MultiSlotStringDataGenerator): + def generate_sample(self, line): + def data_iter(): + for i in range(40): + if i == 1: + yield None + feature_name = ["words", "label"] + data = [["1", "2", "3", "4"], ["0"]] + yield zip(feature_name, data) + + return data_iter + + class MyMultiSlotDataGenerator_zip(fleet.MultiSlotDataGenerator): def generate_sample(self, line): def data_iter(): @@ -162,6 +175,13 @@ def test_MultiSlotDataGenerator_error(self): my_ms_dg.run_from_memory() +class TestMultiSlotStringDataGeneratorZip(unittest.TestCase): + def test_MultiSlotStringDataGenerator_zip(self): + my_ms_dg = MyMultiSlotStringDataGenerator_zip() + my_ms_dg.set_batch(1) + my_ms_dg.run_from_memory() + + class TestMultiSlotDataGeneratorZip(unittest.TestCase): def test_MultiSlotDataGenerator_zip(self): my_ms_dg = MyMultiSlotDataGenerator_zip()