Skip to content

Commit 4c6f26e

Browse files
frozenleavesfrozenleaves
authored andcommitted
Adapt to the SDPA interface to enable the NPU to call FlashAttentionScore (#41143)
Adapt to the SDPA interface to enable the NPU to call FlashAttentionScore. Co-authored-by: frozenleaves <frozen@Mac.local>
1 parent 99fbb87 commit 4c6f26e

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

src/transformers/integrations/sdpa_attention.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44

5-
from ..utils import is_torch_xpu_available, logging
5+
from ..utils import is_torch_npu_available, is_torch_xpu_available, logging
66
from ..utils.import_utils import is_torch_greater_or_equal
77

88

@@ -12,6 +12,7 @@
1212
_is_torch_greater_or_equal_than_2_5 = is_torch_greater_or_equal("2.5", accept_dev=True)
1313
_is_torch_greater_or_equal_than_2_8 = is_torch_greater_or_equal("2.8", accept_dev=True)
1414
_is_torch_xpu_available = is_torch_xpu_available()
15+
_is_torch_npu_available = is_torch_npu_available()
1516

1617

1718
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
@@ -35,8 +36,12 @@ def use_gqa_in_sdpa(attention_mask: Optional[torch.Tensor], key: torch.Tensor) -
3536
# 2.xpu
3637
# - torch version >= 2.8
3738
# - key is not a torch.fx.Proxy (otherwise it will fail with a tracing error)
39+
# 3.npu
40+
# - npu is not supported gqa currently
3841
if _is_torch_xpu_available:
3942
return _is_torch_greater_or_equal_than_2_8 and not isinstance(key, torch.fx.Proxy)
43+
if _is_torch_npu_available:
44+
return False
4045
return _is_torch_greater_or_equal_than_2_5 and attention_mask is None and not isinstance(key, torch.fx.Proxy)
4146

4247

@@ -80,6 +85,14 @@ def sdpa_attention_forward(
8085
if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
8186
is_causal = is_causal.item()
8287

88+
# When `is_causal = False` and the `attention_mask` is not of boolean type, the Ascend NPU's SDPA interface cannot utilize the FlashAttentionScore operator,
89+
# and falls back to small-operator concatenation. To invoke the FlashAttentionScore, the attention_mask must be converted to boolean type.
90+
# This adaptation ensures the `attention_mask` meets the requirement for using FlashAttentionScore.
91+
if _is_torch_npu_available:
92+
if attention_mask is not None and attention_mask.dtype != torch.bool:
93+
# Convert to boolean type, making sdpa to force call FlashAttentionScore to improve performance.
94+
attention_mask = torch.logical_not(attention_mask.bool()).to(query.device)
95+
8396
attn_output = torch.nn.functional.scaled_dot_product_attention(
8497
query,
8598
key,

0 commit comments

Comments
 (0)