You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Is your feature request related to a problem? Please describe.
Currently, the Whisper models integrated in mindnlp do not leverage optimized attention backends like FlashAttention2, which are widely adopted in large-scale Transformer architectures for both speed and memory efficiency.
Describe the solution you'd like
We would like to introduce FlashAttention2 support for Whisper models in mindnlp, leveraging MindSpore’s built-in operations and adapting relevant components for compatibility. Specifically, the solution includes:
Replace flash-attn package functions with MindSpore APIs:
Use mindspore.ops.flash_attention_score as a substitute for the original flash_attn_func and flash_attn_varlen_func provided by the flash-attn library, enabling FlashAttention functionality directly within the MindSpore framework.
Re-implement bert_padding utilities in MindSpore:
Convert the bert_padding utility functions used in flash-attn (such as index_first_axis and index_put_first_axis) into equivalent implementations based on MindSpore, ensuring correct behavior for batched variable-length input.
Add a new utility module model_flash_attention_utils:
Introduce a new helper module that includes functions like _flash_attention_forward and other FlashAttention2-specific utilities, adapted for Whisper and written using MindSpore APIs.
Provide a new model class WhisperFlashAttention2:
Create a new Whisper variant that integrates the above FlashAttention2 optimizations, preserving the original model architecture while improving runtime efficiency and scalability.
The text was updated successfully, but these errors were encountered:
Is your feature request related to a problem? Please describe.
Currently, the Whisper models integrated in mindnlp do not leverage optimized attention backends like FlashAttention2, which are widely adopted in large-scale Transformer architectures for both speed and memory efficiency.
Describe the solution you'd like
We would like to introduce FlashAttention2 support for Whisper models in mindnlp, leveraging MindSpore’s built-in operations and adapting relevant components for compatibility. Specifically, the solution includes:
Replace
flash-attn
package functions with MindSpore APIs:Use
mindspore.ops.flash_attention_score
as a substitute for the originalflash_attn_func
andflash_attn_varlen_func
provided by theflash-attn
library, enabling FlashAttention functionality directly within the MindSpore framework.Re-implement
bert_padding
utilities in MindSpore:Convert the
bert_padding
utility functions used inflash-attn
(such asindex_first_axis
andindex_put_first_axis
) into equivalent implementations based on MindSpore, ensuring correct behavior for batched variable-length input.Add a new utility module
model_flash_attention_utils
:Introduce a new helper module that includes functions like
_flash_attention_forward
and other FlashAttention2-specific utilities, adapted for Whisper and written using MindSpore APIs.Provide a new model class
WhisperFlashAttention2
:Create a new Whisper variant that integrates the above FlashAttention2 optimizations, preserving the original model architecture while improving runtime efficiency and scalability.
The text was updated successfully, but these errors were encountered: