-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Whisper] Use Attention Cache #28931
Conversation
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None | ||
# add present self-attn cache to positions 1,2 of present_key_value tuple | ||
# decoder uni-directional self-attention cached key/values states are at position 0 | ||
self_attn_past_key_value = past_key_value[0] if past_key_value is not None else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The difficulty comes here from the fact that we're dealing with two sets of past key-values per decoder layer: one from the self-attention, and one from the cross-attention. The current solution uses a separate cache for each.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@@ -801,7 +799,7 @@ def forward( | |||
|
|||
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Whisper, MBART->WHISPER |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll propagate changes to all other MBart derived modules when we're happy with the design
I have a fully functional working solution for static shaped Whisper, which we have extensively tested on librispeech dataset and get same accuracy as original model. |
@sanchit-gandhi FYI I'm going to change the PR in progress here: #29005 After this PR is done, then we can expand its usage using the same interface, e.g. for encoder-decoder models 🤗 |
Thanks for the context @gante! Is there anything I can do to help with the static cache refactor? Pretty keen to implement a compile-compatible cache for Whisper! |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
What does this PR do?
Refactors the Whisper model to use the attention cache abstraction proposed in #26681. This is required to have consistency with the
StaticCache
attention class proposed in #27931.The complexity with the current
Cache
abstraction comes from the fact that Whisper is an encoder-decoder model, meaning each decoder attention layer consists of:=> the problematic layer for static generation is the dynamic k/v cache in the self-attention layer. In anticipation of using a static cache for this module, the proposed design uses a separate cache for each layer. We can't build the k/v cache into a single
Cache
abstraction, as the shapes for the self and cross-attention key-values are different (which would break compile).The design is therefore:
Where
past_self_attn_key_values
andpast_cross_attn_key_values
are eachCache
abstractions. This is not the most elegant design, but is compatible with the currentCache
abstraction. Another option would be to do a refactor of theCache
/DynamicCache
/StaticCache
for better compatibility with encoder-decoder models.cc @ArthurZucker @tomaarsen @gante