From 8f8aa76c4d308b661aceb5de369d297d208ea077 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Thu, 21 Mar 2024 09:03:02 -0400 Subject: [PATCH 1/2] Remove deprecations --- src/transformers/trainer.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index ac014c672e73e3..ad3cd7d1992c72 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -221,6 +221,9 @@ if is_deepspeed_available(): from accelerate.utils import DeepSpeedSchedulerWrapper +if is_accelerate_available("0.28.0"): + from accelerate.utils import DataLoaderConfiguration + def _is_peft_model(model): if is_peft_available(): @@ -4248,12 +4251,27 @@ def create_accelerator_and_postprocess(self): grad_acc_kwargs["sync_with_dataloader"] = False gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs) + accelerator_config = self.args.accelerator_config.to_dict() + + if is_accelerate_available("0.28.0"): + dataloader_config = DataLoaderConfiguration( + split_batches=accelerator_config.pop("split_batches"), + dispatch_batches=accelerator_config.pop("dispatch_batches"), + even_batches=accelerator_config.pop("even_batches"), + use_seedable_sampler=accelerator_config.pop("use_seedable_sampler"), + ) + args = { + "deepspeed_plugin":self.args.deepspeed_plugin, + "gradient_accumulation_plugin":gradient_accumulation_plugin + } + if is_accelerate_available("0.28.0"): + args["dataloader_config"] = dataloader_config + else: + args.update(accelerator_config) + + # create accelerator object - self.accelerator = Accelerator( - deepspeed_plugin=self.args.deepspeed_plugin, - gradient_accumulation_plugin=gradient_accumulation_plugin, - **self.args.accelerator_config.to_dict(), - ) + self.accelerator = Accelerator(**args) # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag self.gather_function = self.accelerator.gather_for_metrics From 7574f68d54b613c3ed36b8f00744976eb496c0b6 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Thu, 21 Mar 2024 09:10:39 -0400 Subject: [PATCH 2/2] Clean --- src/transformers/trainer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index ad3cd7d1992c72..2cb426ac2ce88b 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -4261,15 +4261,14 @@ def create_accelerator_and_postprocess(self): use_seedable_sampler=accelerator_config.pop("use_seedable_sampler"), ) args = { - "deepspeed_plugin":self.args.deepspeed_plugin, - "gradient_accumulation_plugin":gradient_accumulation_plugin + "deepspeed_plugin": self.args.deepspeed_plugin, + "gradient_accumulation_plugin": gradient_accumulation_plugin, } if is_accelerate_available("0.28.0"): args["dataloader_config"] = dataloader_config else: args.update(accelerator_config) - # create accelerator object self.accelerator = Accelerator(**args) # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag