Skip to content

Commit

Permalink
add multiprocessing to MapDataset (PaddlePaddle#308)
Browse files Browse the repository at this point in the history
* add multiprocessing to MapDataset

* Add multiprocess requirement

* Add contiguous arguement to shard function

* change shard and add split check for load_dataset

* minor fix

Co-authored-by: Guo Sheng <whucsgs@163.com>
  • Loading branch information
smallv0221 and guoshengCS authored Apr 28, 2021
1 parent aca073a commit 2aef4b3
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 26 deletions.
143 changes: 118 additions & 25 deletions paddlenlp/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import warnings
import sys
import inspect
from multiprocess import Pool, RLock

import paddle.distributed as dist
from paddle.io import Dataset, IterableDataset
Expand Down Expand Up @@ -105,10 +106,33 @@ def load_dataset(path_or_read_func,
return reader_instance.read(**custom_kwargs)
else:
reader_cls = import_main_class(path_or_read_func)
if not name:
reader_instance = reader_cls(lazy=lazy, **kwargs)
reader_instance = reader_cls(lazy=lazy, name=name, **kwargs)

# Check if selected name and split is valid in this DatasetBuilder
if hasattr(reader_instance, 'BUILDER_CONFIGS'):
if name in reader_cls.BUILDER_CONFIGS.keys():
split_names = reader_cls.BUILDER_CONFIGS[name]['splits'].keys()
else:
raise ValueError(
'Invalid name "{}". Should be one of {}.'.format(
name, list(reader_cls.BUILDER_CONFIGS.keys())))
elif hasattr(reader_instance, 'SPLITS'):
split_names = reader_instance.SPLITS.keys()
else:
reader_instance = reader_cls(lazy=lazy, name=name, **kwargs)
raise AttributeError(
"Either 'SPLITS' or 'BUILDER_CONFIGS' must be implemented for DatasetBuilder."
)

selected_splits = []
selected_splits += data_files.keys() if isinstance(
data_files, dict) else selected_splits
selected_splits = selected_splits + splits if isinstance(
splits, list) else selected_splits + [splits]

for split_name in selected_splits:
if split_name not in split_names and split_name != None:
raise ValueError('Invalid split "{}". Should be one of {}.'.
format(split_name, list(split_names)))

datasets = reader_instance.read_datasets(
data_files=data_files, splits=splits)
Expand Down Expand Up @@ -158,25 +182,57 @@ def __len__(self):
"""
return len(self.new_data)

def filter(self, fn):
def filter(self, fn, num_workers=0):
"""
Filters samples by the filter function and uses the filtered data to
update this dataset.
Args:
fn (callable): A filter function that takes a sample as input and
returns a boolean. Samples that return False would be discarded.
num_workers(int, optional): Number of processes for multiprocessing. If
set to 0, it doesn't use multiprocessing. Defalt: 0.
"""
assert num_workers >= 0, "num_workers should be a non-negative value"
if num_workers > 0:
with Pool(num_workers, initargs=(RLock(), )) as pool:

def filter_shard(num_workers, index, fn):
self.shard(
num_shards=num_workers, index=index, contiguous=True)
self._filter(fn=fn)
return self

kwds_per_shard = [
dict(
num_workers=num_workers, index=rank, fn=fn)
for rank in range(num_workers)
]
results = [
pool.apply_async(
filter_shard, kwds=kwds) for kwds in kwds_per_shard
]
transformed_shards = [r.get() for r in results]

self.new_data = []
for i in range(num_workers):
self.new_data += transformed_shards[i].new_data
return self
else:
return self._filter(fn)

def _filter(self, fn):
self.new_data = [
self.new_data[idx] for idx in range(len(self.new_data))
if fn(self.new_data[idx])
]
return self

def shard(self, num_shards=None, index=None):
def shard(self, num_shards=None, index=None, contiguous=False):
"""
Uses samples whose indices mod `index` equals 0 to update this dataset.
Split the dataset into `num_shards` pieces. Note that the size of each
shard might be different because the original dataset may not be evenly
divisible.
Args:
num_shards (int, optional): An integer representing the number of
Expand All @@ -185,24 +241,32 @@ def shard(self, num_shards=None, index=None):
index (int, optional): An integer representing the index of the
current shard. If None, `index` would be the current trainer rank
id. Default: None.
contiguous: (bool, optional): If true, contiguous chunks of data
will be select for sharding. And total number of examples will
be the same. Otherwise each shard will contain all examples of
dataset whose index mod `num_shards` = `index`. Default: False.
"""
if num_shards is None:
num_shards = dist.get_world_size()
if index is None:
index = dist.get_rank()

num_samples = int(math.ceil(len(self.new_data) * 1.0 / num_shards))
# add extra samples to make it evenly divisible
self.new_data = [
self.new_data[idx] for idx in range(len(self.new_data))
if idx % num_shards == index
]
if len(self.new_data) < num_samples:
self.new_data.append(self.new_data[index + 1 - num_shards])
if contiguous:
div = len(self) // num_shards
mod = len(self) % num_shards
start = div * index + min(index, mod)
end = start + div + (1 if index < mod else 0)
self.new_data = self.new_data[start:end]
else:
num_samples = int(math.ceil(len(self.new_data) * 1.0 / num_shards))
self.new_data = [
self.new_data[idx] for idx in range(len(self.new_data))
if idx % num_shards == index
]

return self

def map(self, fn, lazy=True, batched=False):
def map(self, fn, lazy=True, batched=False, num_workers=0):
"""
Performs specific function on the dataset to transform and update every sample.
Expand All @@ -215,8 +279,44 @@ def map(self, fn, lazy=True, batched=False):
result on all epochs. Defalt: False.
batched(bool, optional): If True, transformations would take all examples as
input and return a collection of transformed examples. Note that if set
True, `lazy` option would be ignored.
True, `lazy` option would be ignored. Defalt: False.
num_workers(int, optional): Number of processes for multiprocessing. If
set to 0, it doesn't use multiprocessing. Note that if set to positive
value, `lazy` option would be ignored. Defalt: 0.
"""

assert num_workers >= 0, "num_workers should be a non-negative value"
if num_workers > 0:
with Pool(num_workers, initargs=(RLock(), )) as pool:

def map_shard(num_workers, index, fn, batched):
self.shard(
num_shards=num_workers, index=index, contiguous=True)
self._map(fn=fn, lazy=False, batched=batched)
return self

kwds_per_shard = [
dict(
num_workers=num_workers,
index=rank,
fn=fn,
batched=batched) for rank in range(num_workers)
]
results = [
pool.apply_async(
map_shard, kwds=kwds) for kwds in kwds_per_shard
]
transformed_shards = [r.get() for r in results]

self.new_data = []
for i in range(num_workers):
self.new_data += transformed_shards[i].new_data

return self
else:
return self._map(fn, lazy=lazy, batched=batched)

def _map(self, fn, lazy=True, batched=False):
if batched:
self.new_data = fn(self.new_data)
elif lazy:
Expand All @@ -225,12 +325,8 @@ def map(self, fn, lazy=True, batched=False):
self.new_data = [
fn(self.new_data[idx]) for idx in range(len(self.new_data))
]

return self

def __getattr__(self, name):
return getattr(self.data, name)


class IterDataset(IterableDataset):
"""
Expand Down Expand Up @@ -311,7 +407,7 @@ def filter(self, fn):

def shard(self, num_shards=None, index=None):
"""
Uses samples whose indices mod `index` equals 0 to update this dataset.
Split the dataset into `num_shards` pieces.
Args:
num_shards (int, optional): An integer representing the number of
Expand Down Expand Up @@ -349,9 +445,6 @@ def map(self, fn):

return self

def __getattr__(self, name):
return getattr(self.data, name)


class DatasetBuilder:
"""
Expand Down Expand Up @@ -381,7 +474,7 @@ def read_datasets(self, splits=None, data_files=None):
data_files, dict
) or isinstance(data_files, tuple) or isinstance(
data_files, list
), "`data_files` should be a string or tuple or list or a dictionary whose key is split name ande value is a path of data file."
), "`data_files` should be a string or tuple or list or a dictionary whose key is split name and value is the path of data file."
if isinstance(data_files, str):
split = 'train'
datasets.append(self.read(filename=data_files, split=split))
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ jieba
h5py
colorlog
colorama
seqeval
seqeval
multiprocess

0 comments on commit 2aef4b3

Please sign in to comment.