diff --git a/monai/data/thread_buffer.py b/monai/data/thread_buffer.py index cdd7c05f31..e21af69813 100644 --- a/monai/data/thread_buffer.py +++ b/monai/data/thread_buffer.py @@ -89,6 +89,12 @@ class ThreadDataLoader(DataLoader): on the same batch will still produce good training with minimal short-term overfitting while allowing a slow batch generation process more time to produce a result. + Another typical usage is to accelerate light-weight preprocessing (usually cached all the deterministic transforms + and no IO operations), because it leverages the separate thread to execute preprocessing to avoid unnecessary IPC + between multiple workers of DataLoader. And as CUDA may not work well with the multi-processing of DataLoader, + `ThreadDataLoader` can be useful for GPU transforms. For more details: + https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_model_training_guide.md. + See: * Fischetti et al. "Faster SGD training by minibatch persistency." ArXiv (2018) https://arxiv.org/abs/1806.07353 * Dami et al., "Faster Neural Network Training with Data Echoing" ArXiv (2020) https://arxiv.org/abs/1907.05550 @@ -99,20 +105,15 @@ class ThreadDataLoader(DataLoader): dataset: input dataset. buffer_size: number of items to buffer from the data source. buffer_timeout: time to wait for an item from the buffer, or to wait while the buffer is full when adding items. - num_workers: number of the multi-processing workers in PyTorch DataLoader. - repeats: number of times to yield the same batch + repeats: number of times to yield the same batch. + kwargs: other arguments for `DataLoader` except for `dataset`. + """ def __init__( - self, - dataset: Dataset, - buffer_size: int = 1, - buffer_timeout: float = 0.01, - num_workers: int = 0, - repeats: int = 1, - **kwargs, + self, dataset: Dataset, buffer_size: int = 1, buffer_timeout: float = 0.01, repeats: int = 1, **kwargs ): - super().__init__(dataset, num_workers, **kwargs) + super().__init__(dataset, **kwargs) self.buffer_size = buffer_size self.buffer_timeout = buffer_timeout self.repeats = repeats