Description
Is your feature request related to a problem? Please describe.
Currently, the Whisper models integrated in mindnlp do not leverage optimized attention backends like FlashAttention2, which are widely adopted in large-scale Transformer architectures for both speed and memory efficiency.
Describe the solution you'd like
We would like to introduce FlashAttention2 support for Whisper models in mindnlp, leveraging MindSpore’s built-in operations and adapting relevant components for compatibility. Specifically, the solution includes:
-
Replace
flash-attn
package functions with MindSpore APIs:
Usemindspore.ops.flash_attention_score
as a substitute for the originalflash_attn_func
andflash_attn_varlen_func
provided by theflash-attn
library, enabling FlashAttention functionality directly within the MindSpore framework. -
Re-implement
bert_padding
utilities in MindSpore:
Convert thebert_padding
utility functions used inflash-attn
(such asindex_first_axis
andindex_put_first_axis
) into equivalent implementations based on MindSpore, ensuring correct behavior for batched variable-length input. -
Add a new utility module
model_flash_attention_utils
:
Introduce a new helper module that includes functions like_flash_attention_forward
and other FlashAttention2-specific utilities, adapted for Whisper and written using MindSpore APIs. -
Provide a new model class
WhisperFlashAttention2
:
Create a new Whisper variant that integrates the above FlashAttention2 optimizations, preserving the original model architecture while improving runtime efficiency and scalability.