-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix ONNX export for causal LM sequence classifiers by removing reverse indexing #28144
Fix ONNX export for causal LM sequence classifiers by removing reverse indexing #28144
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing this!
General comment on the indexing assumptions made here
sequence_lengths = torch.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1).to( | ||
logits.device | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Let's split across lines - it'll use fewer lines overall
- Can we instead use modulo to convert to the equivalent negative index? The logic at the moment assumes we only ever want to take the final index and I think it'll be faster than torch.where
sequence_lengths = torch.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1).to( | |
logits.device | |
) | |
sequence_lengths = sequence_lengths % input_ids.shape[-1] | |
sequence_lengths = sequence_lengths.to(logits.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That should work. Updated all changes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - thanks for iterating!
I'd like to have another approval from either @younesbelkada or @ArthurZucker before merging as they know the causal lm models well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks both, looks good yep!
More models that can benefit from this: and that use sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(dim=-1) - 1).to(logits.device)
(not sure if argmax is faster?)
- bloom
- falcon
- mpt
let's try to unify this!
02c782a
to
55d4013
Compare
Thanks @ArthurZucker @amyeroberts. I've unified the Bloom/Falcon/MPT implementations, but doing so triggered what looks to be an unrelated CI failure. Can someone take a look and fix/disable that test if it is indeed unrelated?
|
@dwyatte Yep, that's a flaky test. A patch to skip it in the testing suite was recently merged into main to prevent it affecting unrelated PRs like this one :) Could you rebase to include recent updates and trigger a new CI run? |
55d4013
to
b0db02c
Compare
@amyeroberts Hm, b0db02c contains the latest commit on |
@dwyatte hm, that's odd. The test shouldn't even be running as it's explicitly skipped. In your local, on this branch, do you see this skip condition in |
b0db02c
to
bb55859
Compare
@amyeroberts I see what's going on -- the failure is on |
@dwyatte Ah! Gotcha. Yes please, could you open another separate PR to skip the retain grad tests for all the SeamlessMT4 models? |
bb55859
to
ebfcd71
Compare
Ok @amyeroberts @ArthurZucker, after rebasing on the above, this is ready for merging. Thanks both! |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
…e indexing (huggingface#28144) * normalize reverse indexing for causal lm sequence classifiers * normalize reverse indexing for causal lm sequence classifiers * normalize reverse indexing for causal lm sequence classifiers * use modulo instead * unify modulo-based sequence lengths
Why not just have a shared util for this, instead of repeating the code all over the place |
What does this PR do?
Follow-up to #27450 and another step to fixing huggingface/optimum#1527. ONNX implements indexing using a combination of its own operators and when using reverse indexing (e.g., -1 to indicate 1 element from the right side of an array), it can produce incorrect results (see PyTorch's ONNX export code). In practice, this can cause the batch dimension to get shuffled
Causal LM sequence were previously using
-1
for the last token. Addingsequence_lengths = torch.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1)
effectively removes reverse indexingWhile this could be fixed in https://github.com/huggingface/optimum by forcing the inputs used to trace the graph to contain a pad token and avoiding reverse indexing, it seems better to fix in
transformers
with the added benefit of bringing the code in line with TensorFlow implementations of the same code (e.g., https://github.com/huggingface/transformers/pull/25085/files#diff-7c6fdd54ac4b8ce0c09bb17da15f176d3e5827df39dd8234fd802631e99ef38dR801-R804)Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@ArthurZucker, @amyeroberts, @younesbelkada (CC @fxmarty)