Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/transformers/integrations/npu_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,8 @@ def npu_apply_rotary_emb(x, cos, sin, **kwargs):
sin = sin.unsqueeze(0).unsqueeze(2)

return npu_rotary_mul(x, cos, sin)


def get_npu_flash_attn_funcs():
# return flash attention related functions used for Ascend NPU in order
return npu_flash_attn_func, npu_flash_attn_varlen_func, pad_input, unpad_input, False
9 changes: 7 additions & 2 deletions src/transformers/modeling_flash_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def fa_peft_integration_check(q, k, v, target_dtype: Optional[torch.dtype] = Non

def _lazy_imports(impl: Optional[str]):
# returns funcs and pad/unpad based on impl
is_fa2 = is_flash_attn_2_available() or is_torch_npu_available()
is_fa2 = is_flash_attn_2_available()
is_fa3 = is_flash_attn_3_available()
if impl == "flash_attention_2" or (impl is None and is_fa2 and not is_fa3):
try:
Expand Down Expand Up @@ -299,7 +299,12 @@ def _lazy_imports(impl: Optional[str]):
raise ImportError(
"Failed to import flash attention 2, please install it or use another implementation."
) from e
if impl == "flash_attention_3" or (impl is None and is_fa3):
elif is_torch_npu_available():
# get flash attention related functions from `.integrations.npu_flash_attention` module for Ascend NPU
from .integrations.npu_flash_attention import get_npu_flash_attn_funcs

return get_npu_flash_attn_funcs()
elif impl == "flash_attention_3" or (impl is None and is_fa3):
from flash_attn_interface import flash_attn_func, flash_attn_varlen_func

pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2446,8 +2446,12 @@ def _flash_attn_2_can_dispatch(self, is_init_check: bool = False) -> bool:
preface = "FlashAttention2 has been toggled on, but it cannot be used due to the following error:"
install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2."

# package `flash-attn` can not be installed on Ascend NPU, ignore related validation logi
if importlib.util.find_spec("flash_attn") is None and not is_torch_npu_available():
# package `flash-attn` can not be installed on Ascend NPU, following validation logics can be ignored.
if is_torch_npu_available():
logger.info("Detect using FlashAttention2 on Ascend NPU.")
return True

if importlib.util.find_spec("flash_attn") is None:
raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")
else:
# Check FA2 installed version compatibility
Expand Down