Skip to content

Commit f610418

Browse files
authored
Fix outdated version checks of accelerator (#40969)
* Fix outdated version checks of accelerator Signed-off-by: Yuanyuan Chen <cyyever@outlook.com> * Fix outdated version checks of accelerator Signed-off-by: Yuanyuan Chen <cyyever@outlook.com> --------- Signed-off-by: Yuanyuan Chen <cyyever@outlook.com>
1 parent c532575 commit f610418

File tree

2 files changed

+4
-19
lines changed

2 files changed

+4
-19
lines changed

src/transformers/trainer.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -241,10 +241,9 @@
241241
DATA_SAMPLERS = [RandomSampler]
242242
if version.parse(accelerate_version) > version.parse("1.3.0"):
243243
from accelerate.utils import TorchTensorParallelPlugin
244-
if version.parse(accelerate_version) > version.parse("0.23.0"):
245-
from accelerate.data_loader import SeedableRandomSampler
244+
from accelerate.data_loader import SeedableRandomSampler
246245

247-
DATA_SAMPLERS += [SeedableRandomSampler]
246+
DATA_SAMPLERS += [SeedableRandomSampler]
248247

249248
if is_deepspeed_available():
250249
from accelerate.utils import DeepSpeedSchedulerWrapper
@@ -4196,9 +4195,7 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa
41964195
elif (tp_size := getattr(self.model, "_tp_size", 0)) is not None and tp_size > 1:
41974196
self._save(output_dir)
41984197
elif self.is_fsdp_enabled:
4199-
if ("FULL_STATE_DICT" in str(self.accelerator.state.fsdp_plugin.state_dict_type)) and (
4200-
version.parse(accelerate_version) > version.parse("0.24.1")
4201-
):
4198+
if "FULL_STATE_DICT" in str(self.accelerator.state.fsdp_plugin.state_dict_type):
42024199
state_dict = self.accelerator.get_state_dict(self.model)
42034200
if self.args.should_save:
42044201
self._save(output_dir, state_dict=state_dict)

tests/fsdp/test_fsdp.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -88,22 +88,11 @@ def get_master_port(real_launcher=False):
8888

8989

9090
if is_torch_available():
91-
from tests.trainer.test_trainer import ( # noqa
92-
RegressionModelConfig,
93-
RegressionPreTrainedModel,
94-
)
95-
9691
# hack to restore original logging level pre #21700
9792
get_regression_trainer = partial(tests.trainer.test_trainer.get_regression_trainer, log_level="info")
9893

99-
require_fsdp_version = require_fsdp
10094
if is_accelerate_available():
101-
from accelerate.utils.constants import (
102-
FSDP_PYTORCH_VERSION,
103-
FSDP_SHARDING_STRATEGY,
104-
)
105-
106-
require_fsdp_version = partial(require_fsdp, min_version=FSDP_PYTORCH_VERSION)
95+
from accelerate.utils.constants import FSDP_SHARDING_STRATEGY
10796

10897

10998
FSDP2_ACCELERATE_VERSION = "1.6.0"
@@ -142,7 +131,6 @@ def _parameterized_custom_name_func(func, param_num, param):
142131

143132
@require_accelerate
144133
@require_torch_accelerator
145-
@require_fsdp_version
146134
class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
147135
def setUp(self):
148136
super().setUp()

0 commit comments

Comments
 (0)