Skip to content

Feat: Support Whisper + FlashAttention2 #2014

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
hongziqi opened this issue Apr 8, 2025 · 0 comments
Open

Feat: Support Whisper + FlashAttention2 #2014

hongziqi opened this issue Apr 8, 2025 · 0 comments

Comments

@hongziqi
Copy link
Contributor

hongziqi commented Apr 8, 2025

Is your feature request related to a problem? Please describe.
Currently, the Whisper models integrated in mindnlp do not leverage optimized attention backends like FlashAttention2, which are widely adopted in large-scale Transformer architectures for both speed and memory efficiency.

Describe the solution you'd like
We would like to introduce FlashAttention2 support for Whisper models in mindnlp, leveraging MindSpore’s built-in operations and adapting relevant components for compatibility. Specifically, the solution includes:

  1. Replace flash-attn package functions with MindSpore APIs:
    Use mindspore.ops.flash_attention_score as a substitute for the original flash_attn_func and flash_attn_varlen_func provided by the flash-attn library, enabling FlashAttention functionality directly within the MindSpore framework.

  2. Re-implement bert_padding utilities in MindSpore:
    Convert the bert_padding utility functions used in flash-attn (such as index_first_axis and index_put_first_axis) into equivalent implementations based on MindSpore, ensuring correct behavior for batched variable-length input.

  3. Add a new utility module model_flash_attention_utils:
    Introduce a new helper module that includes functions like _flash_attention_forward and other FlashAttention2-specific utilities, adapted for Whisper and written using MindSpore APIs.

  4. Provide a new model class WhisperFlashAttention2:
    Create a new Whisper variant that integrates the above FlashAttention2 optimizations, preserving the original model architecture while improving runtime efficiency and scalability.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant