Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Whisper] Use Attention Cache #28931

Closed
wants to merge 4 commits into from

Conversation

sanchit-gandhi
Copy link
Contributor

@sanchit-gandhi sanchit-gandhi commented Feb 8, 2024

What does this PR do?

Refactors the Whisper model to use the attention cache abstraction proposed in #26681. This is required to have consistency with the StaticCache attention class proposed in #27931.

The complexity with the current Cache abstraction comes from the fact that Whisper is an encoder-decoder model, meaning each decoder attention layer consists of:

  1. A self-attention layer (k/v cache over the previous decoder input ids)
  2. A cross-attention layer (k/v cache from the encoder hidden-states)

=> the problematic layer for static generation is the dynamic k/v cache in the self-attention layer. In anticipation of using a static cache for this module, the proposed design uses a separate cache for each layer. We can't build the k/v cache into a single Cache abstraction, as the shapes for the self and cross-attention key-values are different (which would break compile).

The design is therefore:

past_key_values: Tuple[Cache] = (past_self_attn_key_values, past_cross_attn_key_values)

Where past_self_attn_key_values and past_cross_attn_key_values are each Cache abstractions. This is not the most elegant design, but is compatible with the current Cache abstraction. Another option would be to do a refactor of the Cache / DynamicCache / StaticCache for better compatibility with encoder-decoder models.

cc @ArthurZucker @tomaarsen @gante

self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
# add present self-attn cache to positions 1,2 of present_key_value tuple
# decoder uni-directional self-attention cached key/values states are at position 0
self_attn_past_key_value = past_key_value[0] if past_key_value is not None else None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difficulty comes here from the fact that we're dealing with two sets of past key-values per decoder layer: one from the self-attention, and one from the cross-attention. The current solution uses a separate cache for each.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

sanchit-gandhi added 2 commits February 8, 2024 17:29
@@ -801,7 +799,7 @@ def forward(

# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Whisper, MBART->WHISPER
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll propagate changes to all other MBart derived modules when we're happy with the design

@soumendukrg
Copy link

I have a fully functional working solution for static shaped Whisper, which we have extensively tested on librispeech dataset and get same accuracy as original model.

@gante
Copy link
Member

gante commented Feb 14, 2024

@sanchit-gandhi FYI I'm going to change the Cache structure a bit, while it's not widespread in the codebase. In a nutshell, given the hard constraints of the static cache (and its obvious benefits), all caches will have an interface similar to the new static cache (which differs from the original Cache implementation).

PR in progress here: #29005

After this PR is done, then we can expand its usage using the same interface, e.g. for encoder-decoder models 🤗

@sanchit-gandhi
Copy link
Contributor Author

Thanks for the context @gante! Is there anything I can do to help with the static cache refactor? Pretty keen to implement a compile-compatible cache for Whisper!

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants