Skip to content

Commit 7c8afd8

Browse files
committed
[bugfix] fix flash_attention_2 unavailable error on Ascend NPU
1 parent c962f15 commit 7c8afd8

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

src/transformers/integrations/npu_flash_attention.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,3 +267,8 @@ def npu_apply_rotary_emb(x, cos, sin, **kwargs):
267267
sin = sin.unsqueeze(0).unsqueeze(2)
268268

269269
return npu_rotary_mul(x, cos, sin)
270+
271+
272+
def get_npu_flash_attn_funcs():
273+
# return flash attention related functions used for Ascend NPU in order
274+
return npu_flash_attn_func, npu_flash_attn_varlen_func, pad_input, unpad_input, False

src/transformers/modeling_flash_attention_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def fa_peft_integration_check(q, k, v, target_dtype: Optional[torch.dtype] = Non
261261

262262
def _lazy_imports(impl: Optional[str]):
263263
# returns funcs and pad/unpad based on impl
264-
is_fa2 = is_flash_attn_2_available() or is_torch_npu_available()
264+
is_fa2 = is_flash_attn_2_available()
265265
is_fa3 = is_flash_attn_3_available()
266266
if impl == "flash_attention_2" or (impl is None and is_fa2 and not is_fa3):
267267
try:
@@ -299,6 +299,10 @@ def _lazy_imports(impl: Optional[str]):
299299
raise ImportError(
300300
"Failed to import flash attention 2, please install it or use another implementation."
301301
) from e
302+
if impl == "flash_attention_2" and is_torch_npu_available():
303+
# get flash attention related functions from `.integrations.npu_flash_attention` module for Ascend NPU
304+
from .integrations.npu_flash_attention import get_npu_flash_attn_funcs
305+
return get_npu_flash_attn_funcs()
302306
if impl == "flash_attention_3" or (impl is None and is_fa3):
303307
from flash_attn_interface import flash_attn_func, flash_attn_varlen_func
304308

0 commit comments

Comments
 (0)