Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Cherry-pick ] to Release/2.3, Add prefetch_factor in dataloader #43674

Merged
merged 4 commits into from
Jun 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions python/paddle/fluid/dataloader/dataloader_iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(self, loader):
self._auto_collate_batch = loader.auto_collate_batch
self._num_workers = loader.num_workers
self._use_buffer_reader = loader.use_buffer_reader
self._prefetch_factor = loader.prefetch_factor
self._use_shared_memory = loader.use_shared_memory
self._timeout = loader.timeout if loader.timeout > 0 else MP_STATUS_CHECK_INTERVAL
self._worker_init_fn = loader.worker_init_fn
Expand Down Expand Up @@ -166,9 +167,10 @@ def __init__(self, loader):
self._structure_infos = []

# NOTE: len(self._places) batch data compose as an output
# iteration, set blocking_queue can cache 2 iteration datas
# iteration, set blocking_queue can cache "self._prefetch_factor" iteration datas
# at most here
self._blocking_queue_capacity = 1 * len(self._places)
self._blocking_queue_capacity = self._prefetch_factor * len(
self._places)

self._init_thread()
self._shutdown = False
Expand Down Expand Up @@ -363,11 +365,11 @@ def __init__(self, loader):
# indices outstand as _outstanding_capacity at first, and
# blocking_queue capacity is also _outstanding_capacity.
# _outstanding_capacity here to make sure each indices_queue
# has at least 2 indices, and outstanding batch cached
# output data for at least 2 iterations(Note that len(_places)
# has at least "_prefetch_factor" indices, and outstanding batch cached
# output data for at least "_prefetch_factor" iterations(Note that len(_places)
# batches will be composed as an iteration output)
self._outstanding_capacity = 2 * max(self._num_workers,
len(self._places))
self._outstanding_capacity = self._prefetch_factor * max(
self._num_workers, len(self._places))

# see _try_put_indices
self._thread_lock = threading.Lock()
Expand Down
34 changes: 20 additions & 14 deletions python/paddle/fluid/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,56 +314,58 @@ class DataLoader(object):
dataset(Dataset): the dataset to load data from, should be an
instance of subclass of :code:`paddle.io.Dataset` or
:code:`paddle.io.IterableDataset`.
feed_list (list(Tensor)|tuple(Tensor)): feed Tensor list.
feed_list (list(Tensor)|tuple(Tensor), optional): feed Tensor list.
The Tensors should be created by :code:`paddle.static.data()`.
:attr:`feed_list` must be set if :attr:`return_list` is
False. Default None.
places(list(Place)|tuple(Place)|list(str)|optional): a list of Place,
places(list(Place)|tuple(Place)|list(str), optional): a list of Place,
to put data onto, :attr:`places` can be None, if
:attr:`places` is None, default place(CPUPlace or CUDAPlace(0))
will be used. Default None. If ``places`` is list of string,
the string in the list can be ``cpu``, ``gpu:x`` and ``gpu_pinned``,
where ``x`` is the index of the GPUs.
return_list (bool): whether the return value on each device is
return_list (bool, optional): whether the return value on each device is
presented as a list. If :attr:`return_list=False`, the return
value on each device would be a dict of str -> Tensor, where
the key of the dict is the name of each fed Tensors. If
:attr:`return_list=True`, the return value on each device would
be a list(Tensor). :attr:`return_list` can only be True
in dynamic graph mode. Default True.
batch_sampler(BatchSampler): an instance of `paddle.io.BatchSampler`
batch_sampler(BatchSampler, optional): an instance of `paddle.io.BatchSampler`
to generate batch indices to draw samples from :attr:`dataset`
and combine a batch. Default None.
batch_size(int|None): sample number in a mini-batch, a substitution
batch_size(int|None, optional): sample number in a mini-batch, a substitution
parameter for :attr:`batch_sampler`, if :attr:`batch_sampler`
is not set, a default `paddle.io.BatchSampler` will be used
and initialize by :attr:`batch_size`, :attr:`shuffle` and
:attr:`drop_last`. Default 1.
shuffle(bool): whther to shuffle indices order before genrate
shuffle(bool, optional): whther to shuffle indices order before genrate
batch indices, a substitution parameter for :attr:`batch_sampler`
see :attr:`batch_size`. Default False.
drop_last(bool): whether drop the last incomplete batch dataset size
drop_last(bool, optional): whether drop the last incomplete batch dataset size
is not divisible by the batch size, a substitution parameter
for :attr:`batch_sampler`, see :attr:`batch_size`. Default False
collate_fn(callable): function to generate mini-batch data by merging
collate_fn(callable, optional): function to generate mini-batch data by merging
the sample list, None for only stack each fields of sample in axis
0(same as :attr::`np.stack(..., axis=0)`). Default None
num_workers(int): the number of subprocess to load data, 0 for no
num_workers(int, optional): the number of subprocess to load data, 0 for no
subprocess used and loading data in main process. Default 0
use_buffer_reader (bool): whether to use bufferred reader.
If use_buffer_reader=True, the DataLoader would prefetch next
use_buffer_reader (bool, optional): whether to use bufferred reader.
If use_buffer_reader=True, the DataLoader would prefetch
batch data asynchronously, so it would speed up data feeding
and occupies a little more CPU or GPU memory, i.e., the memory
of one batch input data. Default True.
use_shared_memory (bool): whether to use shared memory to speed up
prefetch_factor (int, optional): Number of batch data the DataLoader would prefetch
if use_buffer_reader=True. Default 2.
use_shared_memory (bool, optional): whether to use shared memory to speed up
putting data into inter-process queue, set :attr:`use_shared_memory`
as True only when the shared memory space on your machine(e.g.
space of '/dev/shm' on Linux operating sysytem) is large enough.
Shared memory will only be enabled in multi-process mode(num_workers
> 0). Default True.
timeout(int): the timeout value for getting data form output queue
timeout(int, optional): the timeout value for getting data form output queue
of subprocesses. Default 0.
worker_init_fn(callable): init function which will be called with
worker_init_fn(callable, optional): init function which will be called with
worker id on each subproces starting if not set as None. Default
None.

Expand Down Expand Up @@ -450,13 +452,15 @@ def __init__(self,
collate_fn=None,
num_workers=0,
use_buffer_reader=True,
prefetch_factor=2,
use_shared_memory=True,
timeout=0,
worker_init_fn=None,
persistent_workers=False):
self.return_list = return_list
self.collate_fn = collate_fn
self.use_buffer_reader = use_buffer_reader
self.prefetch_factor = prefetch_factor
self.worker_init_fn = worker_init_fn

self.dataset = dataset
Expand All @@ -483,6 +487,8 @@ def __init__(self,
num_workers = 0
self.num_workers = num_workers

assert prefetch_factor > 0, "prefetch_factor should be a positive value"

self.use_shared_memory = use_shared_memory
if use_shared_memory and num_workers == 0:
self.use_shared_memory = False
Expand Down