Skip to content

[Feat] Add WhisperFlashAttention2 #2018

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 23, 2025

Conversation

hongziqi
Copy link
Contributor

@hongziqi hongziqi commented Apr 14, 2025

Test Report

Hard Environment:
Ascend(snt9b|32G)

Software Environment / 软件环境 (Mandatory / 必填):
-- MindSpore version (e.g., 20250414 master) : 2.6.0 (conda env - mindnlp)
-- Python version (e.g., Python 3.7.5) : 3.10.0 (conda env - mindnlp)
-- Transformers version : 4.51.0 (conda env - torch)
-- PyTorch version : 2.1.0 (conda env - torch)
-- CANN Tookies version : 8.1.RC1.alpha002
-- OS platform and distribution (e.g., Linux Ubuntu 16.04): Ubuntu 22.04.4 LTS
-- GCC/Compiler version (if compiled from source): 11.04
-- Docker image : swr.cn-central-221.ovaijisuan.com/mindformers/deepseek_v3_mindspore2.5.0-infer:20250217

Recognition time comparison

ps: take the average of three times
Example 1 - nihao.mp3(1s):

Recognition time eager flash-attention-2
mindnlp(before) 13.0729 Not supported
mindnlp(after flash+conv1d↑) 11.2208 10.4567
PTA 2.3707 3.8498

After introducing flash-attention, the performance of short audio is improved by about (13.0729-10.4567)/13.0729=20%;There is still a certain gap compared with PTA implementation.

Example 2 - tianlong0925.mp3(91s):

Recognition time eager flash-attention-2
mindnlp(before) 72.8655 Not supported
mindnlp(after flash+conv1d↑) 71.6170 64.7135
PTA 42.9799 269.3955

In the case of long audio, FlashAttention2 brings about (72.8655- 64.7135)/72.8655=11.2% acceleration effect; It is worth noting that PTA is seriously degraded in Flash mode (performance drops by about 6 times), while MindNLP implementation performs stably.

Test Code:

MindSpore

import mindspore
from mindnlp.transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import time

mindspore.set_device("Ascend", 2)

def generate_with_time(pipe, file_path):
    start_time = time.time()
    result = pipe(file_path)
    generation_time = time.time() - start_time
    return result, generation_time

model_id = "openai/whisper-large-v3"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, 
    ms_dtype=mindspore.float16, 
    low_cpu_mem_usage=True,
    use_safetensors=True,
    # attn_implementation="eager",
    attn_implementation="flash_attention_2",
)

processor = AutoProcessor.from_pretrained(model_id)

pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    ms_dtype=mindspore.float16,
    return_timestamps=True,
)


# ------ eager mode test result (ms2.5.0 + mindnlp0.4.0) ------ 
# generation_time: 93.65066742897034, result: 青光闪动一柄青钢剑疏地刺出指向中年汉子左肩使肩少年不带剑招用劳外斗
# 剑斜剑锋以削向那汉子右颈哪中年汉子竖剑挡格张来一声响双剑相击嗡嗡作声震声未竭双刃剑功复合你拆了三招中年汉子长剑猛地击落直转少年顶
# 门那少年臂向右侧左手剑绝学隐青钢剑鞠刺呐喊子大腿威两人剑法迅绝全力相搏威徒练武厅东边坐着爱人上手是个四十左右的中年道姑铁青着脸嘴唇
# 紧闭下手是个五十余岁的老者右手掠着长须神情甚是得意两人的座位相距一丈有余身后各站着二十余名男女弟子西边一排椅子上坐着十余位宾客东西
# 双方的目光都集中于场中二人的相斗眼下眼尖的少年与中年汉子已拆到七十余招前招越来越紧物资未分胜败突然周年汉子长剑挥出用力猛了身子微晃肆意摔跌席边
# 宾客中一个身穿青衫的年轻男子忍不住吃得一声笑他随即指导师太忙伸手按住了口

# generation_time: 9.72678017616272, result: 你好

# ------ flash_attention_2 mode test result (ms2.5.0 + mindnlp0.4.0 + flash) ------ 
# generation_time: 79.74670958518982, result: 青光闪动一柄青钢剑疏地刺出指向中年汉子左肩使肩少年不带剑招用劳外斗
# 剑斜剑锋以削向那汉子右颈哪中年汉子竖剑挡格张来一声响双剑相击嗡嗡作声震声未竭双刃剑功复合你拆了三招中年汉子长剑猛地击落直转少年顶
# 门那少年臂向右侧左手剑绝学隐青钢剑鞠刺呐喊子大腿威两人剑法迅绝全力相搏威徒练武厅东边坐着爱人上手是个四十左右的中年道姑铁青着脸嘴唇
# 紧闭下手是个五十余岁的老者右手掠着长须神情甚是得意两人的座位相距一丈有余身后各站着二十余名男女弟子西边一排椅子上坐着十余位宾客东西
# 双方的目光都集中于场中二人的相斗眼下眼尖的少年与中年汉子已拆到七十余招前招越来越紧物资未分胜败突然周年汉子长剑挥出用力猛了身子微晃肆意摔跌席边
# 宾客中一个身穿青衫的年轻男子忍不住吃得一声笑他随即指导师太忙伸手按住了口

# generation_time: 8.643609762191772, result: 你好

result, generation_time = generate_with_time(pipe, "/home/candyhong/workspace/whisper_large/tianlong0925.mp3")
# result, generation_time = generate_with_time(pipe, "/home/candyhong/workspace/whisper_large/nihao.mp3")
print(f"generation_time: {generation_time}, result: {result['text']}")

PyTorch + Ascend

import torch
import torch_npu
import time
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline

torch_npu.npu.set_compile_mode(jit_compile=False)
torch_npu.npu.config.allow_internal_format = False

def generate_with_time(pipe, file_path):
    start_time = time.time()
    result = pipe(file_path)
    generation_time = time.time() - start_time
    return result, generation_time


device = "npu:0"
torch_dtype = torch.float16

model_id = "openai/whisper-large-v3"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id,
    torch_dtype=torch_dtype,
    low_cpu_mem_usage=True,
    use_safetensors=True,
    # attn_implementation="eager",
    attn_implementation="flash_attention_2",
)
model.to(device)

processor = AutoProcessor.from_pretrained(model_id)

pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    torch_dtype=torch_dtype,
    device=device,
    return_timestamps=True,
)

# ------ eager mode test result ------ 
# generation_time: 58.42126774787903 result: 青光闪动一柄青钢剑疏地刺出指向中年汉子左肩使肩少年不带剑招用劳外斗
# 剑斜剑锋以削向那汉子右颈哪中年汉子竖剑挡格张来一声响双剑相击嗡嗡作声震声未竭双刃剑功复合你拆了三招中年汉子长剑猛地击落直展少年顶
# 门那少年臂向右侧左手剑绝学隐青钢剑鞠刺呐喊子大腿威两人剑法迅绝全力相搏威徒练武厅东边坐着爱人上手是个四十左右的中年道姑铁青着脸嘴唇
# 紧闭下手是个五十余岁的老者右手掠着长须神情甚是得意两人的座位相距一丈有余身后各站着二十余名男女弟子西边一排椅子上坐着十余位宾客东西
# 双方的目光都集中于场中二人的相斗眼下眼尖的少年与中年汉子已拆到七十余招前招越来越紧物资未分胜败突然周年汉子长剑挥出用力猛了身子微晃肆意摔跌席边
# 宾客中一个身穿青衫的年轻男子忍不住吃得一声笑他随即知道失态忙伸手按住了口

# generation_time: 2.7732555866241455, result: 你好

# ------ flash_attention_2 mode test result ------ 
# generation_time: 252.1833713054657, result: 青光闪动一柄青钢剑疏地刺出指向中年汉子左肩使肩少年不带剑招用劳外斗
# 剑斜剑锋以削向那汉子右颈哪中年汉子竖剑挡格张来一声响双剑相击嗡嗡作声震声未竭双刃剑功复合你拆了三招中年汉子长剑猛地击落直展少年顶
# 门那少年臂向右侧左手剑绝学隐青钢剑鞠刺呐喊子大腿威两人剑法迅绝全力相搏威徒练武厅东边坐着爱人上手是个四十左右的中年道姑铁青着脸嘴唇
# 紧闭下手是个五十余岁的老者右手掠着长须神情甚是得意两人的座位相距一丈有余身后各站着二十余名男女弟子西边一排椅子上坐着十余位宾客东西
# 双方的目光都集中于场中二人的相斗眼下眼尖的少年与中年汉子已拆到七十余招前招越来越紧物资未分胜败突然周年汉子长剑挥出用力猛了身子微晃肆意摔跌席边
# 宾客中一个身穿青衫的年轻男子忍不住吃得一声笑他随即知道失态忙伸手按住了口

# generation_time: 4.741170883178711, result: 你好

result, generation_time = generate_with_time(pipe, "/home/candyhong/workspace/whisper_large/tianlong0925.mp3")
# result, generation_time = generate_with_time(pipe, "/home/candyhong/workspace/whisper_large/nihao.mp3")
print(f"generation_time: {generation_time}, result: {result['text']}")

Related Issues

Fixes #2014

@lvyufeng lvyufeng merged commit 814b71b into mindspore-lab:0.4 Apr 23, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants