Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shifting labels for causal LM when using label smoother #17987

Merged
merged 3 commits into from
Jul 1, 2022

Conversation

seungeunrho
Copy link
Contributor

@seungeunrho seungeunrho commented Jul 1, 2022

What does this PR do?

Fixes #17960

When training CausalLM such as GPT2, loss is computed within model's foward() function and labels are shifted internally. However, if label smoothing is applied, loss is computed in trainer's compute_loss function and labels are not shifted. This causes misalignment of labels and corresponding input_ids. This commit is for resolving this misalignment.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@sgugger

When training CausalLM, loss is computed within model's foward() function and
labels are shifted internally. However, if label smoothing is applied, loss is
computed in trainer's compute_loss function and labels are not shifted.
This causes unintended confusion during the alignment of labels and corresponding
inputs. This commit is for resolving this confusion.

Resolves huggingface#17960

On branch shift_labels_for_causalLM
Changes to be committed:
	modified:   src/transformers/trainer.py
	modified:   src/transformers/trainer_pt_utils.py
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for your PR! Left a small comment and it should be good to merge.
Make sure to run make style on your branch to apply formatting.

src/transformers/trainer.py Outdated Show resolved Hide resolved
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 1, 2022

The documentation is not available anymore as the PR was closed or merged.

@seungeunrho seungeunrho requested a review from sgugger July 1, 2022 18:13
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
@sgugger sgugger merged commit 6890d19 into huggingface:main Jul 1, 2022
viclzhu pushed a commit to viclzhu/transformers that referenced this pull request Jul 18, 2022
…17987)

* Shifting labels for causal LM when using label smoother

When training CausalLM, loss is computed within model's foward() function and
labels are shifted internally. However, if label smoothing is applied, loss is
computed in trainer's compute_loss function and labels are not shifted.
This causes unintended confusion during the alignment of labels and corresponding
inputs. This commit is for resolving this confusion.

Resolves huggingface#17960

On branch shift_labels_for_causalLM
Changes to be committed:
	modified:   src/transformers/trainer.py
	modified:   src/transformers/trainer_pt_utils.py

* Update trainer.py

* Update src/transformers/trainer.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Suggestion for introducing "shift_labels" argument for Trainer
3 participants