diff --git a/src/transformers/integrations/npu_flash_attention.py b/src/transformers/integrations/npu_flash_attention.py index 716a3481a82a..f4c4e98da9b3 100644 --- a/src/transformers/integrations/npu_flash_attention.py +++ b/src/transformers/integrations/npu_flash_attention.py @@ -19,7 +19,7 @@ if is_torch_npu_available(): - from torch_npu import npu_fusion_attention, npu_rotary_mul + from torch_npu import npu_fusion_attention # FlashAttention2 is supported on Ascend NPU with down-right aligned causal mask by default. @@ -136,19 +136,3 @@ def npu_flash_attn_varlen_func( )[0] return output - - -def npu_apply_rotary_emb(x, cos, sin, **kwargs): - # cos tensor after chunk should be repeated through chunked dimension to original shape on Ascend NPU - if len(cos.shape) == 2 and cos.shape[-1] == x.shape[-1] // 2: - cos = cos.repeat(1, 2) - # cos tensor with [S,D] shape should be unsqueezed to 4-d tensor with shape [1,S,1,D] - cos = cos.unsqueeze(0).unsqueeze(2) - - # sin tensor after chunk should be repeated through chunked dimension to original shape on Ascend NPU - if len(sin.shape) == 2 and sin.shape[-1] == x.shape[-1] // 2: - sin = sin.repeat(1, 2) - # sin tensor with [S,D] shape should be unsqueezed to 4-d tensor with shape [1,S,1,D] - sin = sin.unsqueeze(0).unsqueeze(2) - - return npu_rotary_mul(x, cos, sin) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 0d8906076829..6b9e091c1a69 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -77,16 +77,20 @@ def _lazy_imports(implementation: Optional[str]): """ is_fa2 = is_flash_attn_2_available() is_fa3 = is_flash_attn_3_available() - if implementation == "flash_attention_2" or (implementation is None and is_fa2 and not is_fa3): + + pad_input, unpad_input = _pad_input, _unpad_input + + if (implementation == "flash_attention_2" and is_fa2) or (implementation is None and is_fa2 and not is_fa3): from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import pad_input, unpad_input + elif is_torch_npu_available(): + # Package `flash-attn` is unavailable on Ascend NPU, which will cause ImportError + # Flash-Attention2 related apis for Ascend NPU must be imported from `.integrations.npu_flash_attention` module + from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func + from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func else: - pad_input, unpad_input = _pad_input, _unpad_input if implementation == "flash_attention_3" or (implementation is None and is_fa3): from flash_attn_interface import flash_attn_func, flash_attn_varlen_func - elif is_torch_npu_available(): - from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func - from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func # Kernels fallback else: flash_attn_func = getattr(implementation, "flash_attn_func", None) diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index bdb983df8487..498f8db1d610 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -44,7 +44,6 @@ from ...cache_utils import Cache from ...configuration_utils import PretrainedConfig, layer_type_validation from ...generation import GenerationMixin -from ...modeling_flash_attention_utils import is_flash_attn_available from ...modeling_outputs import BaseModelOutput, ModelOutput from ...modeling_rope_utils import rope_config_validation from ...modeling_utils import ALL_ATTENTION_FUNCTIONS @@ -58,13 +57,6 @@ from ...utils.hub import cached_file -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import apply_rotary_emb, flash_attn_varlen_func -else: - flash_attn_varlen_func = None - apply_rotary_emb = None - - logger = logging.get_logger(__name__)