Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Paddle.Fleet】fix dataset zip py3 bug #31441

Merged
merged 2 commits into from
Mar 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions python/paddle/distributed/fleet/data_generator/data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ def set_batch(self, batch_size):
'''
Set batch size of current DataGenerator
This is necessary only if a user wants to define generator_batch

Example:

.. code-block:: python

import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator):

Expand All @@ -52,7 +52,7 @@ def local_iter():
yield ("words", s[1].extend([s[1][0]]))
mydata = MyData()
mydata.set_batch(128)

'''
self.batch_size_ = batch_size

Expand All @@ -63,7 +63,7 @@ def run_from_memory(self):

Example:
.. code-block:: python

import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator):

Expand Down Expand Up @@ -100,9 +100,9 @@ def run_from_stdin(self):
generated.

Example:

.. code-block:: python

import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator):

Expand Down Expand Up @@ -161,7 +161,7 @@ def generate_sample(self, line):
The data format is list or tuple:
[(name, [feasign, ...]), ...]
or ((name, [feasign, ...]), ...)

For example:
[("words", [1926, 08, 17]), ("label", [1])]
or (("words", [1926, 08, 17]), ("label", [1]))
Expand All @@ -174,7 +174,7 @@ def generate_sample(self, line):
Example:

.. code-block:: python

import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator):

Expand Down Expand Up @@ -206,7 +206,7 @@ def generate_batch(self, samples):
Example:

.. code-block:: python

import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator):

Expand Down Expand Up @@ -259,6 +259,9 @@ def _gen_str(self, line):
Returns:
Return a string data that can be read directly by the MultiSlotDataFeed.
'''
if sys.version > '3' and isinstance(line, zip):
line = list(line)

if not isinstance(line, list) and not isinstance(line, tuple):
raise ValueError(
"the output of process() must be in list or tuple type"
Expand Down Expand Up @@ -289,7 +292,7 @@ def _gen_str(self, line):
>>> [ids_num id1 id2 ...] ...
The proto_info will be in this format:
>>> [(name, type), ...]

For example, if the input is like this:
>>> [("words", [1926, 08, 17]), ("label", [1])]
>>> or (("words", [1926, 08, 17]), ("label", [1]))
Expand All @@ -304,6 +307,9 @@ def _gen_str(self, line):
Returns:
Return a string data that can be read directly by the MultiSlotDataFeed.
'''
if sys.version > '3' and isinstance(line, zip):
line = list(line)

if not isinstance(line, list) and not isinstance(line, tuple):
raise ValueError(
"the output of process() must be in list or tuple type"
Expand Down
40 changes: 40 additions & 0 deletions python/paddle/fluid/tests/unittests/test_data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,32 @@ def data_iter():
return data_iter


class MyMultiSlotStringDataGenerator_zip(fleet.MultiSlotStringDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
feature_name = ["words", "label"]
data = [["1", "2", "3", "4"], ["0"]]
yield zip(feature_name, data)

return data_iter


class MyMultiSlotDataGenerator_zip(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
feature_name = ["words", "label"]
data = [[1, 2, 3, 4], [0]]
yield zip(feature_name, data)

return data_iter


class TestMultiSlotDataGenerator(unittest.TestCase):
def test_MultiSlotDataGenerator_basic(self):
my_ms_dg = MyMultiSlotDataGenerator()
Expand Down Expand Up @@ -149,5 +175,19 @@ def test_MultiSlotDataGenerator_error(self):
my_ms_dg.run_from_memory()


class TestMultiSlotStringDataGeneratorZip(unittest.TestCase):
def test_MultiSlotStringDataGenerator_zip(self):
my_ms_dg = MyMultiSlotStringDataGenerator_zip()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()


class TestMultiSlotDataGeneratorZip(unittest.TestCase):
def test_MultiSlotDataGenerator_zip(self):
my_ms_dg = MyMultiSlotDataGenerator_zip()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()


if __name__ == '__main__':
unittest.main()