Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ONNX export for causal LM sequence classifiers by removing reverse indexing #28144

Merged
merged 5 commits into from
Dec 22, 2023

Conversation

dwyatte
Copy link
Contributor

@dwyatte dwyatte commented Dec 19, 2023

What does this PR do?

Follow-up to #27450 and another step to fixing huggingface/optimum#1527. ONNX implements indexing using a combination of its own operators and when using reverse indexing (e.g., -1 to indicate 1 element from the right side of an array), it can produce incorrect results (see PyTorch's ONNX export code). In practice, this can cause the batch dimension to get shuffled

Causal LM sequence were previously using -1 for the last token. Adding sequence_lengths = torch.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1) effectively removes reverse indexing

While this could be fixed in https://github.com/huggingface/optimum by forcing the inputs used to trace the graph to contain a pad token and avoiding reverse indexing, it seems better to fix in transformers with the added benefit of bringing the code in line with TensorFlow implementations of the same code (e.g., https://github.com/huggingface/transformers/pull/25085/files#diff-7c6fdd54ac4b8ce0c09bb17da15f176d3e5827df39dd8234fd802631e99ef38dR801-R804)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ArthurZucker, @amyeroberts, @younesbelkada (CC @fxmarty)

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this!

General comment on the indexing assumptions made here

Comment on lines 801 to 803
sequence_lengths = torch.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1).to(
logits.device
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Let's split across lines - it'll use fewer lines overall
  • Can we instead use modulo to convert to the equivalent negative index? The logic at the moment assumes we only ever want to take the final index and I think it'll be faster than torch.where
Suggested change
sequence_lengths = torch.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1).to(
logits.device
)
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That should work. Updated all changes

@dwyatte dwyatte requested a review from amyeroberts December 19, 2023 18:28
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - thanks for iterating!

I'd like to have another approval from either @younesbelkada or @ArthurZucker before merging as they know the causal lm models well

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks both, looks good yep!
More models that can benefit from this: and that use sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(dim=-1) - 1).to(logits.device)
(not sure if argmax is faster?)

  • bloom
  • falcon
  • mpt

let's try to unify this!

@dwyatte dwyatte force-pushed the causal_classification_onnx2 branch from 02c782a to 55d4013 Compare December 20, 2023 16:09
@dwyatte
Copy link
Contributor Author

dwyatte commented Dec 20, 2023

Thanks @ArthurZucker @amyeroberts. I've unified the Bloom/Falcon/MPT implementations, but doing so triggered what looks to be an unrelated CI failure. Can someone take a look and fix/disable that test if it is indeed unrelated?

FAILED tests/models/seamless_m4t/test_modeling_seamless_m4t.py::SeamlessM4TModelWithTextInputTest::test_retain_grad_hidden_states_attentions - AttributeError: 'NoneType' object has no attribute 'retain_grad'

@amyeroberts
Copy link
Collaborator

@dwyatte Yep, that's a flaky test. A patch to skip it in the testing suite was recently merged into main to prevent it affecting unrelated PRs like this one :) Could you rebase to include recent updates and trigger a new CI run?

@dwyatte dwyatte force-pushed the causal_classification_onnx2 branch from 55d4013 to b0db02c Compare December 20, 2023 17:04
@dwyatte
Copy link
Contributor Author

dwyatte commented Dec 20, 2023

@amyeroberts Hm, b0db02c contains the latest commit on main (224ab70), so I think tests/models/seamless_m4t/test_modeling_seamless_m4t.py::SeamlessM4TModelWithTextInputTest::test_retain_grad_hidden_states_attentions is still broken/flaking there

@amyeroberts
Copy link
Collaborator

@dwyatte hm, that's odd. The test shouldn't even be running as it's explicitly skipped. In your local, on this branch, do you see this skip condition in test_modeling_seamless_m4t.py?

@dwyatte dwyatte force-pushed the causal_classification_onnx2 branch from b0db02c to bb55859 Compare December 20, 2023 18:36
@dwyatte
Copy link
Contributor Author

dwyatte commented Dec 20, 2023

@amyeroberts I see what's going on -- the failure is on SeamlessM4TModelWithTextInputTest but the explicit skip exists on SeamlessM4TModelWithSpeechInputTest. Let me know if I should add the same skip to SeamlessM4TModelWithTextInputTest on my branch or if you prefer a different fix/PR

@amyeroberts
Copy link
Collaborator

@dwyatte Ah! Gotcha. Yes please, could you open another separate PR to skip the retain grad tests for all the SeamlessMT4 models?

@dwyatte dwyatte force-pushed the causal_classification_onnx2 branch from bb55859 to ebfcd71 Compare December 21, 2023 16:21
@dwyatte
Copy link
Contributor Author

dwyatte commented Dec 21, 2023

Ok @amyeroberts @ArthurZucker, after rebasing on the above, this is ready for merging. Thanks both!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@amyeroberts amyeroberts merged commit 548a8f6 into huggingface:main Dec 22, 2023
19 checks passed
staghado pushed a commit to staghado/transformers that referenced this pull request Jan 15, 2024
…e indexing (huggingface#28144)

* normalize reverse indexing for causal lm sequence classifiers

* normalize reverse indexing for causal lm sequence classifiers

* normalize reverse indexing for causal lm sequence classifiers

* use modulo instead

* unify modulo-based sequence lengths
@sentialx
Copy link

Why not just have a shared util for this, instead of repeating the code all over the place

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants