Skip to content

Commit

Permalink
add dataloader prefetch factor in training args and trainer (#28498)
Browse files Browse the repository at this point in the history
* add dataloader prefetch factor in training args and trainer

* remove trailing spaces

* prevent dataloader_num_workers == 0 and dataloader_prefetch_factor != None

dataloader_prefetch_factor works only when data is loaded in a different process as the main one. This commit adds the necessary checks to avoid having prefetch_factor set when there is no such process.

* Remove whitespaces in empty line

* Update src/transformers/training_args.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/training_args.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/training_args.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/training_args.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
  • Loading branch information
2 people authored and Ita Zaporozhets committed May 14, 2024
1 parent 7636f1d commit f946b1e
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
3 changes: 3 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,7 @@ def get_train_dataloader(self) -> DataLoader:
dataloader_params["sampler"] = self._get_train_sampler()
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor

return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))

Expand Down Expand Up @@ -863,6 +864,7 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor

return self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))

Expand Down Expand Up @@ -895,6 +897,7 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
if not isinstance(test_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_eval_sampler(test_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor

# We use the same batch_size as for eval.
return self.accelerator.prepare(DataLoader(test_dataset, **dataloader_params))
Expand Down
25 changes: 24 additions & 1 deletion src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,9 @@ class TrainingArguments:
If True, the data loader will not shut down the worker processes after a dataset has been consumed once.
This allows to maintain the workers Dataset instances alive. Can potentially speed up training, but will
increase RAM usage. Will default to `False`.
dataloader_prefetch_factor (`int`, *optional*):
Number of batches loaded in advance by each worker.
2 means there will be a total of 2 * num_workers batches prefetched across all workers.
skip_memory_metrics (`bool`, *optional*, defaults to `True`):
Whether to skip adding of memory profiler reports to metrics. This is skipped by default because it slows
down the training and evaluation speed.
Expand Down Expand Up @@ -989,7 +992,16 @@ class TrainingArguments:
)
},
)

dataloader_prefetch_factor: int = field(
default=None,
metadata={
"help": (
"Number of batches loaded in advance by each worker. "
"2 means there will be a total of 2 * num_workers batches prefetched across all workers. "
"Default is unset"
)
},
)
past_index: int = field(
default=-1,
metadata={"help": "If >=0, uses the corresponding part of the output as the past state for next step."},
Expand Down Expand Up @@ -1737,6 +1749,12 @@ def __post_init__(self):
if self.use_cpu:
self.dataloader_pin_memory = False

if self.dataloader_num_workers == 0 and self.dataloader_prefetch_factor is not None:
raise ValueError(
"--dataloader_prefetch_factor can only be set when data is loaded in a different process, i.e."
" when --dataloader_num_workers > 1."
)

if self.push_to_hub_token is not None:
warnings.warn(
"`--push_to_hub_token` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
Expand Down Expand Up @@ -2634,6 +2652,7 @@ def set_dataloader(
num_workers: int = 0,
pin_memory: bool = True,
persistent_workers: bool = False,
prefetch_factor: Optional[int] = None,
auto_find_batch_size: bool = False,
ignore_data_skip: bool = False,
sampler_seed: Optional[int] = None,
Expand All @@ -2654,6 +2673,9 @@ def set_dataloader(
If True, the data loader will not shut down the worker processes after a dataset has been consumed
once. This allows to maintain the workers Dataset instances alive. Can potentially speed up training,
but will increase RAM usage. Will default to `False`.
prefetch_factor (`int`, *optional*):
Number of batches loaded in advance by each worker.
2 means there will be a total of 2 * num_workers batches prefetched across all workers.
auto_find_batch_size (`bool`, *optional*, defaults to `False`)
Whether to find a batch size that will fit into memory automatically through exponential decay,
avoiding CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`)
Expand Down Expand Up @@ -2684,6 +2706,7 @@ def set_dataloader(
self.dataloader_num_workers = num_workers
self.dataloader_pin_memory = pin_memory
self.dataloader_persistent_workers = persistent_workers
self.dataloader_prefetch_factor = prefetch_factor
self.auto_find_batch_size = auto_find_batch_size
self.ignore_data_skip = ignore_data_skip
self.data_seed = sampler_seed
Expand Down

0 comments on commit f946b1e

Please sign in to comment.