Skip to content

Commit 02ea2b3

Browse files
Fix TrainingArguments.parallelism_config NameError with accelerate<1.10.1 (#40818)
Fix ParallelismConfig type for accelerate < 1.10.1 Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
1 parent d42e96a commit 02ea2b3

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

src/transformers/training_args.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,10 @@
7777

7878
from .trainer_pt_utils import AcceleratorConfig
7979

80-
if is_accelerate_available("1.10.1"):
81-
from accelerate.parallelism_config import ParallelismConfig
80+
if is_accelerate_available("1.10.1"):
81+
from accelerate.parallelism_config import ParallelismConfig
82+
else:
83+
ParallelismConfig = Any
8284

8385
if is_torch_xla_available():
8486
import torch_xla.core.xla_model as xm
@@ -1264,7 +1266,7 @@ class TrainingArguments:
12641266
)
12651267
},
12661268
)
1267-
parallelism_config: Optional["ParallelismConfig"] = field(
1269+
parallelism_config: Optional[ParallelismConfig] = field(
12681270
default=None,
12691271
metadata={"help": ("Parallelism configuration for the training run. Requires Accelerate `1.10.1`")},
12701272
)

0 commit comments

Comments
 (0)