Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable average tokens across devices #34373

Merged
merged 12 commits into from
Oct 28, 2024
Merged

Conversation

techkang
Copy link
Contributor

What does this PR do?

Fix #34242

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@muellerzr pls help to check.

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! It's a good start, let's improve it a little :)

Comment on lines 3638 to 3649
if self.compute_loss_func is not None:
if (self.args.average_tokens_across_devices and num_items_in_batch is not None and
self.args.world_size > 1):
num_items_in_batch_tensor = torch.tensor(num_items_in_batch, device=self.args.device)
num_items_in_batch = int(self.accelerator.gather(num_items_in_batch_tensor).sum().cpu())
device_count_for_loss = self.args.world_size
else:
device_count_for_loss = 1
loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch)
loss *= device_count_for_loss
elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
loss = self.label_smoother(outputs, labels, shift_labels=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! We also need the case for when we don't define this, e.g. it's passed to the model forward(). So what would be better is to perform the gather much earlier, and pass the new num_items_in_batch as part of the call to compute_loss.

And then perform the loss *= where we call loss *= self.args.gradient_accumulation_steps later (right before we call accelerator.backward())

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for advice! Already fixed it, please check again.

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! We're quite close. Two more nits

Comment on lines 1533 to 1538
average_tokens_across_devices: Optional[bool] = field(
default=False,
metadata={
"help": "Whether or not to average tokens across devices."
}
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

During the __post_init__ we call setup_devices. We can change average_tokens_across_devices value to False if the world size < 1 I think!

This then simplifies it earlier to just be if self.args.average_tokens_across_devices

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, already fixed it, please check.

Comment on lines 3605 to 3608
loss *= self.args.gradient_accumulation_steps
if (self.args.average_tokens_across_devices and num_items_in_batch is not None and
self.args.world_size > 1):
loss *= self.args.world_size
Copy link
Contributor

@muellerzr muellerzr Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both of these chunks I think can be under a if num_items_in_batch is not None and self.model_accepts_loss_kwargs, since both need to be valid for the loss *= self.args.gradient_accumulation_steps

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think gradient accumulation is orthogonal to DDP, and used a new if statement. Please check my code. Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is, it's a matter of self.model.accepts_loss_kwargs

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This indeed looks like the correct solution! Thanks!

@muellerzr muellerzr requested a review from SunMarc October 24, 2024 14:32
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your work ! LGTM !

average_tokens_across_devices: Optional[bool] = field(
default=False,
metadata={
"help": "Whether or not to average tokens across devices."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could share a bit more why this arg could be useful ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, please review my code.

@techkang
Copy link
Contributor Author

@muellerzr @SunMarc I fixed one typo to make ruff happy and pass auto check. And added more explanation to the argument. Pls check again, thank you.

@techkang
Copy link
Contributor Author

One required tests failed but it's due to timeout. Can any one help to fix it? Thanks a lot.

@techkang techkang requested a review from SunMarc October 25, 2024 01:56
@techkang
Copy link
Contributor Author

techkang commented Oct 25, 2024

world_size is unavailable when only using TF backend. Using try/except to catch errors.

CI still enconunted four errors, but the same error appeared at the recently merged PR. So it may not caused by this PR.

CI errors of this PR: ci_error

CI errors for former PR: ci_error

@techkang techkang requested a review from muellerzr October 25, 2024 03:21
Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Overall LG2M, cc @SunMarc

@muellerzr
Copy link
Contributor

We'll merge once our CI allows us to 😅

Comment on lines 3612 to 3613

return loss.detach() / self.args.gradient_accumulation_steps
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We modified a lot how loss is computed, are we sure that this is loss is the same as the one applied ?

Copy link
Contributor

@muellerzr muellerzr Oct 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, it should be loss.detach() / self.args.gradient_accumulation_steps / self.accelerator.num_processes (dividing by num processes if and only if we did our loss function num tokens logic)

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks ! LGTM ! Just a nit

@muellerzr
Copy link
Contributor

cc @ydshieh

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Rebase to make sure we have green CIs, will merge anyways!

@ArthurZucker ArthurZucker merged commit d21dbd1 into huggingface:main Oct 28, 2024
19 of 24 checks passed
@ArthurZucker
Copy link
Collaborator

Thanks @techkang 🤗

ArthurZucker added a commit that referenced this pull request Nov 5, 2024
* enable average tokens across devices

* reduce earlier in case model needs it

* simplify if statement

* reformat code to make ruff happy

* add doc for argument: average_tokens_across_devices

* cannot find world size when pytorch is unavailable

* format code

---------

Co-authored-by: Zach Mueller <muellerzr@gmail.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
* enable average tokens across devices

* reduce earlier in case model needs it

* simplify if statement

* reformat code to make ruff happy

* add doc for argument: average_tokens_across_devices

* cannot find world size when pytorch is unavailable

* format code

---------

Co-authored-by: Zach Mueller <muellerzr@gmail.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
BernardZach pushed a commit to innovationcore/transformers that referenced this pull request Dec 6, 2024
* enable average tokens across devices

* reduce earlier in case model needs it

* simplify if statement

* reformat code to make ruff happy

* add doc for argument: average_tokens_across_devices

* cannot find world size when pytorch is unavailable

* format code

---------

Co-authored-by: Zach Mueller <muellerzr@gmail.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add DDP token averaging for equivalent non-parallel training similar to #34191
5 participants