Skip to content

Commit

Permalink
[Auto Parallel] Support Iterable dataset for auto parallel (PaddlePad…
Browse files Browse the repository at this point in the history
…dle#45518)

* support iterable dataset for auto parallel

* add split_data proto

* fix unittest bug

* fix recompute bug

* update cmake
  • Loading branch information
Caozhou1995 committed Sep 8, 2022
1 parent cf08c44 commit 0f20c40
Show file tree
Hide file tree
Showing 9 changed files with 1,314 additions and 51 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ message DistributedStrategy {
optional bool is_fl_ps_mode = 39 [ default = false ];
optional bool with_coordinator = 40 [ default = false ];
optional bool qat = 41 [ default = false ];
optional bool split_data = 42 [ default = true ];

optional RecomputeConfig recompute_configs = 101;
optional AMPConfig amp_configs = 102;
Expand Down

Large diffs are not rendered by default.

116 changes: 84 additions & 32 deletions python/paddle/distributed/auto_parallel/dist_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@

import abc
import numpy as np
from functools import wraps

import paddle
from .utils import to_list
from paddle.fluid.layers.utils import flatten
from paddle.io import DataLoader, BatchSampler, IterableDataset
from paddle.fluid.dataloader.batch_sampler import _InfiniteIterableSampler
from paddle.fluid.dataloader.dataloader_iter import _DatasetKind, default_collate_fn, default_convert_fn


Expand All @@ -29,33 +32,41 @@ def __init__(self,
epochs=1,
data_parallel_world_size=None,
data_parallel_rank=None,
drop_last=False):
drop_last=False,
split_data=True):
if isinstance(dataset, IterableDataset):
raise TypeError("IterableDataset is not supported.")
self.dataset_kind = _DatasetKind.ITER
else:
self.dataset_kind = _DatasetKind.MAP

self.dataset = dataset
self.epochs = epochs
self.drop_lost = drop_last
self.data_parallel_world_size = data_parallel_world_size
self.data_parallel_rank = data_parallel_rank
self.split_data = split_data

if batch_size is None:
self.batch_size = None
self.batch_sampler = None
else:
if data_parallel_world_size is not None:
assert batch_size % data_parallel_world_size == 0, \
"'batch_size' must be divisible by data parallel size"
for dp_world_size in data_parallel_world_size:
if dp_world_size is not None:
assert batch_size % dp_world_size == 0, \
"batch_size must be divisible by dp_world_size value {}".format(str(dp_world_size))
self.batch_size = batch_size
self.batch_sampler = BatchSampler(dataset,
batch_size=batch_size,
shuffle=False,
drop_last=drop_last)
if isinstance(dataset, IterableDataset):
self.batch_sampler = _InfiniteIterableSampler(
dataset, batch_size)
else:
self.batch_sampler = BatchSampler(dataset,
batch_size=batch_size,
shuffle=False,
drop_last=drop_last)

self.auto_collate_batch = self.batch_sampler is not None
self.sampler_iter = iter(self.index_sampler)
self.dp_world_size = 1 if data_parallel_world_size is None else data_parallel_world_size
self.dp_rank = 0 if data_parallel_rank is None else data_parallel_rank

@abc.abstractmethod
def __iter__(self):
Expand All @@ -73,7 +84,7 @@ def index_sampler(self):
if self.dataset_kind == _DatasetKind.MAP:
return list(range(len(self.dataset)))
else:
raise TypeError("Only support datasets in map-style.")
return _InfiniteIterableSampler(self.dataset, 1)


class NonIterableGeneratorLoader(DistributedDataLoader):
Expand All @@ -88,15 +99,16 @@ def __init__(self,
collate_fn=None,
data_parallel_world_size=None,
data_parallel_rank=None,
drop_last=False):
drop_last=False,
split_data=True):
self.feed_list = feed_list
self.places = places
self.steps_per_epoch = steps_per_epoch

super(NonIterableGeneratorLoader,
self).__init__(dataset, batch_size, epochs,
data_parallel_world_size, data_parallel_rank,
drop_last)
drop_last, split_data)

if self.auto_collate_batch:
self.collate_fn = collate_fn or default_collate_fn
Expand All @@ -115,17 +127,22 @@ def __iter__(self):
return self

def __next__(self):
if self._cur_step < self._steps:
if not self._steps:
self._cur_step += 1
elif self._cur_step < self._steps:
self._cur_step += 1
else:
self._inner_dataloader.reset()
self.sampler_iter = iter(self.index_sampler)
raise StopIteration

def _infer_steps(self):
if self.steps_per_epoch is not None:
return self.steps_per_epoch
try:
if self.batch_size is None:
if isinstance(self.dataset, IterableDataset):
steps_per_epoch = None
elif self.batch_size is None:
steps_per_epoch = len(self.dataset)
else:
steps_per_epoch = len(self.dataset) // self.batch_size
Expand All @@ -138,26 +155,61 @@ def _infer_steps(self):
def _create_inner_dataloader(self):

def sample_data_generator():
for indices in self.sampler_iter:
assert len(indices) % self.dp_world_size == 0, \
"Please set batch_size to be divisible by data parallel size"
n = len(indices) // self.dp_world_size
cur_indices = [
indices[i:i + n] for i in range(0, len(indices), n)
]
batch = self.dataset_fetcher.fetch(cur_indices[self.dp_rank])
yield batch[:len(self.feed_list)]
while True:
try:
indices = next(self.sampler_iter)
batch = self.dataset_fetcher.fetch(indices)
if batch is None: break

except StopIteration:
self.dataset_fetcher = _DatasetKind.create_fetcher(
self.dataset_kind, self.dataset,
self.auto_collate_batch, self.collate_fn,
self.drop_lost)
break

partial_data = []
for i, d in enumerate(batch[:len(self.feed_list)]):
array = np.array(d)
if not self.split_data:
partial_data.append(array)
elif self.dp_world_sizes[i] is not None:
partial_data.append(
np.split(array,
self.dp_world_sizes[i])[self.dp_ranks[i]])
else:
partial_data.append(array)
yield partial_data

def batch_data_generator():
for indices in self.sampler_iter:
while True:
try:
indices = next(self.sampler_iter)

batch = self.dataset_fetcher.fetch(indices)
if batch is None: break
except StopIteration:
break

partial_data = []
batch = self.dataset_fetcher.fetch(indices)
for data in batch:
assert data.shape[0] % self.dp_world_size == 0, \
"Please padding dataset's batch_size to be divisible by data parallel size"
partial_data.append(
np.split(data, self.dp_world_size)[self.dp_rank])
yield partial_data[:len(self.feed_list)]
for i, d in enumerate(batch[:len(self.feed_list)]):
array = np.array(d)
if not self.split_data:
partial_data.append(array)
elif self.dp_world_sizes[i] is not None:
partial_data.append(
np.split(array,
self.dp_world_sizes[i])[self.dp_ranks[i]])
else:
partial_data.append(array)
yield partial_data

self.dp_world_sizes = [
1 for _ in range(len(self.feed_list))
] if self.data_parallel_world_size is None else self.data_parallel_world_size
self.dp_ranks = [
0 for _ in range(len(self.feed_list))
] if self.data_parallel_rank is None else self.data_parallel_rank

dataloader = paddle.fluid.io.DataLoader.from_generator(
feed_list=self.feed_list, capacity=70, iterable=False)
Expand Down
60 changes: 41 additions & 19 deletions python/paddle/distributed/auto_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,8 @@ def _optimization_tuning(self, mode):
assert "dataset" in self._user_tuning_config, "Optimization Tuning should provide with dataset."
batch_size = self._user_tuning_config["batch_size"]
dataset = self._user_tuning_config["dataset"]
dataset.dp_world_size = self._input_split_size
dataset.dp_rank = self._input_split_rank
dataset.dp_world_size = self.dp_world_sizes
dataset.dp_rank = self.dp_ranks

from .tuner.optimization_tuner import OptimizationTuner
self._optimization_tuner = OptimizationTuner(self._user_tuning_config,
Expand Down Expand Up @@ -276,8 +276,13 @@ def _plan(self, mode):
if var.name in block.vars:
feed_list.append(block.vars[var.name])

self._input_split_size, self._input_split_rank = self._get_input_split_info(
feed_list[0], self._dist_contexts[mode])
self.dp_world_sizes = []
self.dp_ranks = []
for feed_var in feed_list:
dp_world_size, dp_rank = self._get_input_split_info(
feed_var, self._dist_contexts[mode])
self.dp_world_sizes.append(dp_world_size)
self.dp_ranks.append(dp_rank)

def _parallel(self, mode, all_ranks=False):
# Parallelize program based on the planner's results
Expand Down Expand Up @@ -484,15 +489,23 @@ def fit(self,
for epoch in range(epochs):
train_logs = {"epoch: {:d} ": epoch}
for step, _ in enumerate(train_dataloader):
try:
outs = self._executor.run(self.main_program,
fetch_list=fetch_list,
use_program_cache=use_cache,
return_numpy=return_numpy)
except fluid.core.EOFException:
break

outs = self._executor.run(self.main_program,
fetch_list=fetch_list,
use_program_cache=self._use_cache,
return_numpy=self._return_numpy)
train_logs["step: {:d} "] = step
if lr_scheduler is not None:
lr_scheduler.step()
train_logs["lr: {:5e} "] = self._lr_optimizer.get_lr()
try:
train_logs["lr: {:5e} "] = self._lr_optimizer.get_lr()
except:
train_logs[
"lr: {:5e} "] = self._lr_optimizer._learning_rate.get_lr(
)
# inner fetches
if fetch_loss:
train_logs["loss: {:9f} "] = outs[0][0]
Expand Down Expand Up @@ -530,10 +543,13 @@ def evaluate(self,

for step, _ in enumerate(eval_dataloader):
eval_logs = {"step: {:d} ": step}
outs = self._executor.run(self.main_program,
fetch_list=fetch_list,
use_program_cache=self._use_cache,
return_numpy=self._return_numpy)
try:
outs = self._executor.run(self.main_program,
fetch_list=fetch_list,
use_program_cache=use_cache,
return_numpy=return_numpy)
except fluid.core.EOFException:
break
# inner fetches
if fetch_loss:
eval_logs["loss: {:9f} "] = outs[0][0]
Expand Down Expand Up @@ -574,10 +590,13 @@ def predict(self, test_data, batch_size=1, collate_fn=None, callbacks=None):
outputs = []
for step, _ in enumerate(test_dataloader):
predict_logs = {"step: {:d} ": step}
outs = self._executor.run(self.main_program,
fetch_list=fetch_list,
use_program_cache=self._use_cache,
return_numpy=self._return_numpy)
try:
outs = self._executor.run(self.main_program,
fetch_list=fetch_list,
use_program_cache=use_cache,
return_numpy=return_numpy)
except fluid.core.EOFException:
break
outputs.append(outs[:len(fetch_outputs)])
for i, out in enumerate(outs):
predict_logs[fetch_map[fetch_list[i]] + ": {}"] = out
Expand Down Expand Up @@ -626,8 +645,9 @@ def _create_dataloader(self,
epochs,
steps_per_epoch,
collate_fn,
data_parallel_world_size=self._input_split_size,
data_parallel_rank=self._input_split_rank)
data_parallel_world_size=self.dp_world_sizes,
data_parallel_rank=self.dp_ranks,
split_data=self.strategy.split_data)

# move read op from the end of program to the start of program
new_op_size = len(dist_main_block.ops)
Expand Down Expand Up @@ -722,6 +742,8 @@ def _set_recompute_ckpts(self):
self.model, "gpt"
) and self.model.__class__.__name__ == 'GPTForPretraining':
exact_ckpts = self.model.gpt.checkpoints
else:
exact_ckpts = config["checkpoints"]
else:
exact_ckpts = config["checkpoints"]

Expand Down
22 changes: 22 additions & 0 deletions python/paddle/distributed/fleet/base/distributed_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1997,6 +1997,28 @@ def auto_search(self, flag):
else:
print("WARNING: auto-search should have value of bool type")

@property
def split_data(self):
"""
Indicating whether we split the data. If True, we split the data.
Default Value: True
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.split_data = True
"""
return self.strategy.split_data

@split_data.setter
def split_data(self, flag):
if isinstance(flag, bool):
self.strategy.split_data = flag
else:
print("WARNING: split_data should have value of bool type")

@property
def qat(self):
"""
Expand Down
Loading

0 comments on commit 0f20c40

Please sign in to comment.