Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Silence deprecations and use the DataLoaderConfig #29779

Merged
merged 2 commits into from
Mar 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,9 @@
if is_deepspeed_available():
from accelerate.utils import DeepSpeedSchedulerWrapper

if is_accelerate_available("0.28.0"):
from accelerate.utils import DataLoaderConfiguration


def _is_peft_model(model):
if is_peft_available():
Expand Down Expand Up @@ -4248,12 +4251,26 @@ def create_accelerator_and_postprocess(self):
grad_acc_kwargs["sync_with_dataloader"] = False
gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)

accelerator_config = self.args.accelerator_config.to_dict()

if is_accelerate_available("0.28.0"):
dataloader_config = DataLoaderConfiguration(
split_batches=accelerator_config.pop("split_batches"),
dispatch_batches=accelerator_config.pop("dispatch_batches"),
even_batches=accelerator_config.pop("even_batches"),
use_seedable_sampler=accelerator_config.pop("use_seedable_sampler"),
)
args = {
"deepspeed_plugin": self.args.deepspeed_plugin,
"gradient_accumulation_plugin": gradient_accumulation_plugin,
}
if is_accelerate_available("0.28.0"):
args["dataloader_config"] = dataloader_config
else:
args.update(accelerator_config)
Comment on lines +4256 to +4270
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This design method allows us to be prepared for future versions of accelerator_config where we might want to add more values that don't directly relate to the DataLoaderConfiguration


# create accelerator object
self.accelerator = Accelerator(
deepspeed_plugin=self.args.deepspeed_plugin,
gradient_accumulation_plugin=gradient_accumulation_plugin,
**self.args.accelerator_config.to_dict(),
)
self.accelerator = Accelerator(**args)
# some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
self.gather_function = self.accelerator.gather_for_metrics

Expand Down
Loading