Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 1 addition & 17 deletions src/transformers/integrations/npu_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


if is_torch_npu_available():
from torch_npu import npu_fusion_attention, npu_rotary_mul
from torch_npu import npu_fusion_attention


# FlashAttention2 is supported on Ascend NPU with down-right aligned causal mask by default.
Expand Down Expand Up @@ -136,19 +136,3 @@ def npu_flash_attn_varlen_func(
)[0]

return output


def npu_apply_rotary_emb(x, cos, sin, **kwargs):
# cos tensor after chunk should be repeated through chunked dimension to original shape on Ascend NPU
if len(cos.shape) == 2 and cos.shape[-1] == x.shape[-1] // 2:
cos = cos.repeat(1, 2)
# cos tensor with [S,D] shape should be unsqueezed to 4-d tensor with shape [1,S,1,D]
cos = cos.unsqueeze(0).unsqueeze(2)

# sin tensor after chunk should be repeated through chunked dimension to original shape on Ascend NPU
if len(sin.shape) == 2 and sin.shape[-1] == x.shape[-1] // 2:
sin = sin.repeat(1, 2)
# sin tensor with [S,D] shape should be unsqueezed to 4-d tensor with shape [1,S,1,D]
sin = sin.unsqueeze(0).unsqueeze(2)

return npu_rotary_mul(x, cos, sin)
14 changes: 9 additions & 5 deletions src/transformers/modeling_flash_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,20 @@ def _lazy_imports(implementation: Optional[str]):
"""
is_fa2 = is_flash_attn_2_available()
is_fa3 = is_flash_attn_3_available()
if implementation == "flash_attention_2" or (implementation is None and is_fa2 and not is_fa3):

pad_input, unpad_input = _pad_input, _unpad_input

if (implementation == "flash_attention_2" and is_fa2) or (implementation is None and is_fa2 and not is_fa3):
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import pad_input, unpad_input
elif is_torch_npu_available():
# Package `flash-attn` is unavailable on Ascend NPU, which will cause ImportError
# Flash-Attention2 related apis for Ascend NPU must be imported from `.integrations.npu_flash_attention` module
from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func
from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func
else:
pad_input, unpad_input = _pad_input, _unpad_input
if implementation == "flash_attention_3" or (implementation is None and is_fa3):
from flash_attn_interface import flash_attn_func, flash_attn_varlen_func
elif is_torch_npu_available():
from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func
from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func
# Kernels fallback
else:
flash_attn_func = getattr(implementation, "flash_attn_func", None)
Expand Down
8 changes: 0 additions & 8 deletions src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch here!

Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
from ...cache_utils import Cache
from ...configuration_utils import PretrainedConfig, layer_type_validation
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import is_flash_attn_available
from ...modeling_outputs import BaseModelOutput, ModelOutput
from ...modeling_rope_utils import rope_config_validation
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
Expand All @@ -58,13 +57,6 @@
from ...utils.hub import cached_file


if is_flash_attn_available():
from ...modeling_flash_attention_utils import apply_rotary_emb, flash_attn_varlen_func
else:
flash_attn_varlen_func = None
apply_rotary_emb = None


logger = logging.get_logger(__name__)


Expand Down