diff --git a/src/transformers/integrations/npu_flash_attention.py b/src/transformers/integrations/npu_flash_attention.py index dd8a6dc5d07b..ed1b30d9a6b0 100644 --- a/src/transformers/integrations/npu_flash_attention.py +++ b/src/transformers/integrations/npu_flash_attention.py @@ -267,3 +267,8 @@ def npu_apply_rotary_emb(x, cos, sin, **kwargs): sin = sin.unsqueeze(0).unsqueeze(2) return npu_rotary_mul(x, cos, sin) + + +def get_npu_flash_attn_funcs(): + # return flash attention related functions used for Ascend NPU in order + return npu_flash_attn_func, npu_flash_attn_varlen_func, pad_input, unpad_input, False diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 47744eaca3f2..502c8bdff305 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -261,7 +261,7 @@ def fa_peft_integration_check(q, k, v, target_dtype: Optional[torch.dtype] = Non def _lazy_imports(impl: Optional[str]): # returns funcs and pad/unpad based on impl - is_fa2 = is_flash_attn_2_available() or is_torch_npu_available() + is_fa2 = is_flash_attn_2_available() is_fa3 = is_flash_attn_3_available() if impl == "flash_attention_2" or (impl is None and is_fa2 and not is_fa3): try: @@ -299,7 +299,12 @@ def _lazy_imports(impl: Optional[str]): raise ImportError( "Failed to import flash attention 2, please install it or use another implementation." ) from e - if impl == "flash_attention_3" or (impl is None and is_fa3): + elif is_torch_npu_available(): + # get flash attention related functions from `.integrations.npu_flash_attention` module for Ascend NPU + from .integrations.npu_flash_attention import get_npu_flash_attn_funcs + + return get_npu_flash_attn_funcs() + elif impl == "flash_attention_3" or (impl is None and is_fa3): from flash_attn_interface import flash_attn_func, flash_attn_varlen_func pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 6e86cb10026b..dc774ba76e29 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2446,8 +2446,12 @@ def _flash_attn_2_can_dispatch(self, is_init_check: bool = False) -> bool: preface = "FlashAttention2 has been toggled on, but it cannot be used due to the following error:" install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2." - # package `flash-attn` can not be installed on Ascend NPU, ignore related validation logi - if importlib.util.find_spec("flash_attn") is None and not is_torch_npu_available(): + # package `flash-attn` can not be installed on Ascend NPU, following validation logics can be ignored. + if is_torch_npu_available(): + logger.info("Detect using FlashAttention2 on Ascend NPU.") + return True + + if importlib.util.find_spec("flash_attn") is None: raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}") else: # Check FA2 installed version compatibility