Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3763 Enhance the doc of ThreadDataLoader for num_workers #3770

Merged
merged 8 commits into from
Feb 6, 2022
21 changes: 11 additions & 10 deletions monai/data/thread_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ class ThreadDataLoader(DataLoader):
on the same batch will still produce good training with minimal short-term overfitting while allowing a slow batch
generation process more time to produce a result.

Another typical usage is to accelerate light-weight preprocessing (usually cached all the deterministic transforms
and no IO operations), because it leverages the separate thread to execute preprocessing to avoid unnecessary IPC
between multiple workers of DataLoader. And as CUDA may not work well with the multi-processing of DataLoader,
`ThreadDataLoader` can be useful for GPU transforms. For more details:
https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_model_training_guide.md.

See:
* Fischetti et al. "Faster SGD training by minibatch persistency." ArXiv (2018) https://arxiv.org/abs/1806.07353
* Dami et al., "Faster Neural Network Training with Data Echoing" ArXiv (2020) https://arxiv.org/abs/1907.05550
Expand All @@ -99,20 +105,15 @@ class ThreadDataLoader(DataLoader):
dataset: input dataset.
buffer_size: number of items to buffer from the data source.
buffer_timeout: time to wait for an item from the buffer, or to wait while the buffer is full when adding items.
num_workers: number of the multi-processing workers in PyTorch DataLoader.
repeats: number of times to yield the same batch
repeats: number of times to yield the same batch.
kwargs: other arguments for `DataLoader` except for `dataset`.

"""

def __init__(
self,
dataset: Dataset,
buffer_size: int = 1,
buffer_timeout: float = 0.01,
num_workers: int = 0,
repeats: int = 1,
**kwargs,
self, dataset: Dataset, buffer_size: int = 1, buffer_timeout: float = 0.01, repeats: int = 1, **kwargs
):
super().__init__(dataset, num_workers, **kwargs)
super().__init__(dataset, **kwargs)
self.buffer_size = buffer_size
self.buffer_timeout = buffer_timeout
self.repeats = repeats
Expand Down