Skip to content

Commit ab92534

Browse files
FightingZhenvasqu
andauthored
enable sdpa enable gqa logic for Ascend NPU (#41601)
* enable gqa logic for Ascend NPU * remove redundant comments * fix comments about Ascend NPU --------- Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
1 parent 56a727d commit ab92534

File tree

1 file changed

+1
-5
lines changed

1 file changed

+1
-5
lines changed

src/transformers/integrations/sdpa_attention.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,15 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
2929

3030
def use_gqa_in_sdpa(attention_mask: Optional[torch.Tensor], key: torch.Tensor) -> bool:
3131
# GQA can only be used under the following conditions
32-
# 1.cuda
32+
# 1.cuda or Ascend NPU
3333
# - torch version >= 2.5
3434
# - attention_mask is None (otherwise it will fall back to the math kernel)
3535
# - key is not a torch.fx.Proxy (otherwise it will fail with a tracing error)
3636
# 2.xpu
3737
# - torch version >= 2.8
3838
# - key is not a torch.fx.Proxy (otherwise it will fail with a tracing error)
39-
# 3.npu
40-
# - npu is not supported gqa currently
4139
if _is_torch_xpu_available:
4240
return _is_torch_greater_or_equal_than_2_8 and not isinstance(key, torch.fx.Proxy)
43-
if _is_torch_npu_available:
44-
return False
4541
return _is_torch_greater_or_equal_than_2_5 and attention_mask is None and not isinstance(key, torch.fx.Proxy)
4642

4743

0 commit comments

Comments
 (0)