Skip to content

Commit

Permalink
【Paddle.Fleet】fix dataset zip py3 bug (#31441)
Browse files Browse the repository at this point in the history
* fix zip py3 bug
  • Loading branch information
MrChengmo committed Mar 30, 2021
1 parent 9b8f2de commit 4b5423e
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 10 deletions.
26 changes: 16 additions & 10 deletions python/paddle/distributed/fleet/data_generator/data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ def set_batch(self, batch_size):
'''
Set batch size of current DataGenerator
This is necessary only if a user wants to define generator_batch
Example:
.. code-block:: python
import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator):
Expand All @@ -52,7 +52,7 @@ def local_iter():
yield ("words", s[1].extend([s[1][0]]))
mydata = MyData()
mydata.set_batch(128)
'''
self.batch_size_ = batch_size

Expand All @@ -63,7 +63,7 @@ def run_from_memory(self):
Example:
.. code-block:: python
import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator):
Expand Down Expand Up @@ -100,9 +100,9 @@ def run_from_stdin(self):
generated.
Example:
.. code-block:: python
import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator):
Expand Down Expand Up @@ -161,7 +161,7 @@ def generate_sample(self, line):
The data format is list or tuple:
[(name, [feasign, ...]), ...]
or ((name, [feasign, ...]), ...)
For example:
[("words", [1926, 08, 17]), ("label", [1])]
or (("words", [1926, 08, 17]), ("label", [1]))
Expand All @@ -174,7 +174,7 @@ def generate_sample(self, line):
Example:
.. code-block:: python
import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator):
Expand Down Expand Up @@ -206,7 +206,7 @@ def generate_batch(self, samples):
Example:
.. code-block:: python
import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator):
Expand Down Expand Up @@ -259,6 +259,9 @@ def _gen_str(self, line):
Returns:
Return a string data that can be read directly by the MultiSlotDataFeed.
'''
if sys.version > '3' and isinstance(line, zip):
line = list(line)

if not isinstance(line, list) and not isinstance(line, tuple):
raise ValueError(
"the output of process() must be in list or tuple type"
Expand Down Expand Up @@ -289,7 +292,7 @@ def _gen_str(self, line):
>>> [ids_num id1 id2 ...] ...
The proto_info will be in this format:
>>> [(name, type), ...]
For example, if the input is like this:
>>> [("words", [1926, 08, 17]), ("label", [1])]
>>> or (("words", [1926, 08, 17]), ("label", [1]))
Expand All @@ -304,6 +307,9 @@ def _gen_str(self, line):
Returns:
Return a string data that can be read directly by the MultiSlotDataFeed.
'''
if sys.version > '3' and isinstance(line, zip):
line = list(line)

if not isinstance(line, list) and not isinstance(line, tuple):
raise ValueError(
"the output of process() must be in list or tuple type"
Expand Down
40 changes: 40 additions & 0 deletions python/paddle/fluid/tests/unittests/test_data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,32 @@ def data_iter():
return data_iter


class MyMultiSlotStringDataGenerator_zip(fleet.MultiSlotStringDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
feature_name = ["words", "label"]
data = [["1", "2", "3", "4"], ["0"]]
yield zip(feature_name, data)

return data_iter


class MyMultiSlotDataGenerator_zip(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
feature_name = ["words", "label"]
data = [[1, 2, 3, 4], [0]]
yield zip(feature_name, data)

return data_iter


class TestMultiSlotDataGenerator(unittest.TestCase):
def test_MultiSlotDataGenerator_basic(self):
my_ms_dg = MyMultiSlotDataGenerator()
Expand Down Expand Up @@ -149,5 +175,19 @@ def test_MultiSlotDataGenerator_error(self):
my_ms_dg.run_from_memory()


class TestMultiSlotStringDataGeneratorZip(unittest.TestCase):
def test_MultiSlotStringDataGenerator_zip(self):
my_ms_dg = MyMultiSlotStringDataGenerator_zip()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()


class TestMultiSlotDataGeneratorZip(unittest.TestCase):
def test_MultiSlotDataGenerator_zip(self):
my_ms_dg = MyMultiSlotDataGenerator_zip()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()


if __name__ == '__main__':
unittest.main()

0 comments on commit 4b5423e

Please sign in to comment.