- 
                Notifications
    
You must be signed in to change notification settings  - Fork 257
 
Description
Is your feature request related to a problem? Please describe.
Currently, the Whisper models integrated in mindnlp do not leverage optimized attention backends like FlashAttention2, which are widely adopted in large-scale Transformer architectures for both speed and memory efficiency.
Describe the solution you'd like
We would like to introduce FlashAttention2 support for Whisper models in mindnlp, leveraging MindSpore’s built-in operations and adapting relevant components for compatibility. Specifically, the solution includes:
- 
Replace
flash-attnpackage functions with MindSpore APIs:
Usemindspore.ops.flash_attention_scoreas a substitute for the originalflash_attn_funcandflash_attn_varlen_funcprovided by theflash-attnlibrary, enabling FlashAttention functionality directly within the MindSpore framework. - 
Re-implement
bert_paddingutilities in MindSpore:
Convert thebert_paddingutility functions used inflash-attn(such asindex_first_axisandindex_put_first_axis) into equivalent implementations based on MindSpore, ensuring correct behavior for batched variable-length input. - 
Add a new utility module
model_flash_attention_utils:
Introduce a new helper module that includes functions like_flash_attention_forwardand other FlashAttention2-specific utilities, adapted for Whisper and written using MindSpore APIs. - 
Provide a new model class
WhisperFlashAttention2:
Create a new Whisper variant that integrates the above FlashAttention2 optimizations, preserving the original model architecture while improving runtime efficiency and scalability.