Skip to content

Feat: Support Whisper + FlashAttention2 #2014

Closed
@hongziqi

Description

@hongziqi

Is your feature request related to a problem? Please describe.
Currently, the Whisper models integrated in mindnlp do not leverage optimized attention backends like FlashAttention2, which are widely adopted in large-scale Transformer architectures for both speed and memory efficiency.

Describe the solution you'd like
We would like to introduce FlashAttention2 support for Whisper models in mindnlp, leveraging MindSpore’s built-in operations and adapting relevant components for compatibility. Specifically, the solution includes:

  1. Replace flash-attn package functions with MindSpore APIs:
    Use mindspore.ops.flash_attention_score as a substitute for the original flash_attn_func and flash_attn_varlen_func provided by the flash-attn library, enabling FlashAttention functionality directly within the MindSpore framework.

  2. Re-implement bert_padding utilities in MindSpore:
    Convert the bert_padding utility functions used in flash-attn (such as index_first_axis and index_put_first_axis) into equivalent implementations based on MindSpore, ensuring correct behavior for batched variable-length input.

  3. Add a new utility module model_flash_attention_utils:
    Introduce a new helper module that includes functions like _flash_attention_forward and other FlashAttention2-specific utilities, adapted for Whisper and written using MindSpore APIs.

  4. Provide a new model class WhisperFlashAttention2:
    Create a new Whisper variant that integrates the above FlashAttention2 optimizations, preserving the original model architecture while improving runtime efficiency and scalability.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions