Skip to content

Commit

Permalink
Fix TrainingArguments regression with torch <2.0.0 for dataloader_pre…
Browse files Browse the repository at this point in the history
…fetch_factor (#29447)

* Fix TrainingArguments regression with torch <2.0.0 for dataloader_prefetch_factor

dataloader_prefetch_factor was added to TrainingArguments in #28498 with the default value None, but  versions of torch<2.0.0 do not accept None and will raise an error if num_workers == 0 and prefetch_factor != 2

* Add is_torch_available() check

* Use is_torch_greater_or_equal_than_2_0

add back check for dataloader_prefetch_factor
  • Loading branch information
ringohoffman authored Mar 6, 2024
1 parent b27aa20 commit 2890116
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@
import torch
import torch.distributed as dist

from .pytorch_utils import is_torch_greater_or_equal_than_2_0

if is_accelerate_available():
from accelerate.state import AcceleratorState, PartialState
from accelerate.utils import DistributedType
Expand Down Expand Up @@ -1023,13 +1025,13 @@ class TrainingArguments:
)
},
)
dataloader_prefetch_factor: int = field(
default=None,
dataloader_prefetch_factor: Optional[int] = field(
default=None if not is_torch_available() or is_torch_greater_or_equal_than_2_0 else 2,
metadata={
"help": (
"Number of batches loaded in advance by each worker. "
"2 means there will be a total of 2 * num_workers batches prefetched across all workers. "
"Default is unset"
"Default is 2 for PyTorch < 2.0.0 and otherwise None."
)
},
)
Expand Down Expand Up @@ -1807,7 +1809,11 @@ def __post_init__(self):
if self.use_cpu:
self.dataloader_pin_memory = False

if self.dataloader_num_workers == 0 and self.dataloader_prefetch_factor is not None:
if (
(not is_torch_available() or is_torch_greater_or_equal_than_2_0)
and self.dataloader_num_workers == 0
and self.dataloader_prefetch_factor is not None
):
raise ValueError(
"--dataloader_prefetch_factor can only be set when data is loaded in a different process, i.e."
" when --dataloader_num_workers > 1."
Expand Down

0 comments on commit 2890116

Please sign in to comment.