-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
enable average tokens across devices #34373
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! It's a good start, let's improve it a little :)
src/transformers/trainer.py
Outdated
if self.compute_loss_func is not None: | ||
if (self.args.average_tokens_across_devices and num_items_in_batch is not None and | ||
self.args.world_size > 1): | ||
num_items_in_batch_tensor = torch.tensor(num_items_in_batch, device=self.args.device) | ||
num_items_in_batch = int(self.accelerator.gather(num_items_in_batch_tensor).sum().cpu()) | ||
device_count_for_loss = self.args.world_size | ||
else: | ||
device_count_for_loss = 1 | ||
loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch) | ||
loss *= device_count_for_loss | ||
elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): | ||
loss = self.label_smoother(outputs, labels, shift_labels=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! We also need the case for when we don't define this, e.g. it's passed to the model forward()
. So what would be better is to perform the gather
much earlier, and pass the new num_items_in_batch
as part of the call to compute_loss
.
And then perform the loss *=
where we call loss *= self.args.gradient_accumulation_steps
later (right before we call accelerator.backward()
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for advice! Already fixed it, please check again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! We're quite close. Two more nits
src/transformers/training_args.py
Outdated
average_tokens_across_devices: Optional[bool] = field( | ||
default=False, | ||
metadata={ | ||
"help": "Whether or not to average tokens across devices." | ||
} | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
During the __post_init__
we call setup_devices
. We can change average_tokens_across_devices
value to False
if the world size < 1 I think!
This then simplifies it earlier to just be if self.args.average_tokens_across_devices
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, already fixed it, please check.
src/transformers/trainer.py
Outdated
loss *= self.args.gradient_accumulation_steps | ||
if (self.args.average_tokens_across_devices and num_items_in_batch is not None and | ||
self.args.world_size > 1): | ||
loss *= self.args.world_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both of these chunks I think can be under a if num_items_in_batch is not None and self.model_accepts_loss_kwargs
, since both need to be valid for the loss *= self.args.gradient_accumulation_steps
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think gradient accumulation is orthogonal to DDP, and used a new if statement. Please check my code. Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is, it's a matter of self.model.accepts_loss_kwargs
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This indeed looks like the correct solution! Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your work ! LGTM !
src/transformers/training_args.py
Outdated
average_tokens_across_devices: Optional[bool] = field( | ||
default=False, | ||
metadata={ | ||
"help": "Whether or not to average tokens across devices." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we could share a bit more why this arg could be useful ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, please review my code.
@muellerzr @SunMarc I fixed one typo to make ruff happy and pass auto check. And added more explanation to the argument. Pls check again, thank you. |
One required tests failed but it's due to timeout. Can any one help to fix it? Thanks a lot. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Overall LG2M, cc @SunMarc
We'll merge once our CI allows us to 😅 |
|
||
return loss.detach() / self.args.gradient_accumulation_steps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We modified a lot how loss is computed, are we sure that this is loss is the same as the one applied ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, it should be loss.detach() / self.args.gradient_accumulation_steps / self.accelerator.num_processes
(dividing by num processes if and only if we did our loss function num tokens logic)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks ! LGTM ! Just a nit
cc @ydshieh |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Rebase to make sure we have green CIs, will merge anyways!
Thanks @techkang 🤗 |
* enable average tokens across devices * reduce earlier in case model needs it * simplify if statement * reformat code to make ruff happy * add doc for argument: average_tokens_across_devices * cannot find world size when pytorch is unavailable * format code --------- Co-authored-by: Zach Mueller <muellerzr@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* enable average tokens across devices * reduce earlier in case model needs it * simplify if statement * reformat code to make ruff happy * add doc for argument: average_tokens_across_devices * cannot find world size when pytorch is unavailable * format code --------- Co-authored-by: Zach Mueller <muellerzr@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* enable average tokens across devices * reduce earlier in case model needs it * simplify if statement * reformat code to make ruff happy * add doc for argument: average_tokens_across_devices * cannot find world size when pytorch is unavailable * format code --------- Co-authored-by: Zach Mueller <muellerzr@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
What does this PR do?
Fix #34242
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@muellerzr pls help to check.