Skip to content

Commit

Permalink
3763 Enhance the doc of ThreadDataLoader for num_workers (#3770)
Browse files Browse the repository at this point in the history
  • Loading branch information
Nic-Ma authored and wyli committed Feb 9, 2022
1 parent 10f52ab commit 6750ee9
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions monai/data/thread_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ class ThreadDataLoader(DataLoader):
on the same batch will still produce good training with minimal short-term overfitting while allowing a slow batch
generation process more time to produce a result.
Another typical usage is to accelerate light-weight preprocessing (usually cached all the deterministic transforms
and no IO operations), because it leverages the separate thread to execute preprocessing to avoid unnecessary IPC
between multiple workers of DataLoader. And as CUDA may not work well with the multi-processing of DataLoader,
`ThreadDataLoader` can be useful for GPU transforms. For more details:
https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_model_training_guide.md.
See:
* Fischetti et al. "Faster SGD training by minibatch persistency." ArXiv (2018) https://arxiv.org/abs/1806.07353
* Dami et al., "Faster Neural Network Training with Data Echoing" ArXiv (2020) https://arxiv.org/abs/1907.05550
Expand All @@ -99,20 +105,15 @@ class ThreadDataLoader(DataLoader):
dataset: input dataset.
buffer_size: number of items to buffer from the data source.
buffer_timeout: time to wait for an item from the buffer, or to wait while the buffer is full when adding items.
num_workers: number of the multi-processing workers in PyTorch DataLoader.
repeats: number of times to yield the same batch
repeats: number of times to yield the same batch.
kwargs: other arguments for `DataLoader` except for `dataset`.
"""

def __init__(
self,
dataset: Dataset,
buffer_size: int = 1,
buffer_timeout: float = 0.01,
num_workers: int = 0,
repeats: int = 1,
**kwargs,
self, dataset: Dataset, buffer_size: int = 1, buffer_timeout: float = 0.01, repeats: int = 1, **kwargs
):
super().__init__(dataset, num_workers, **kwargs)
super().__init__(dataset, **kwargs)
self.buffer_size = buffer_size
self.buffer_timeout = buffer_timeout
self.repeats = repeats
Expand Down

0 comments on commit 6750ee9

Please sign in to comment.