Skip to content

Commit 03f4376

Browse files
committed
remove redundant apply_rotary_emb usage
1 parent eee0f00 commit 03f4376

File tree

2 files changed

+0
-23
lines changed

2 files changed

+0
-23
lines changed

src/transformers/integrations/npu_flash_attention.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -252,18 +252,3 @@ def npu_flash_attn_varlen_func(
252252

253253
return output
254254

255-
256-
def npu_apply_rotary_emb(x, cos, sin, **kwargs):
257-
# cos tensor after chunk should be repeated through chunked dimension to original shape on Ascend NPU
258-
if len(cos.shape) == 2 and cos.shape[-1] == x.shape[-1] // 2:
259-
cos = cos.repeat(1, 2)
260-
# cos tensor with [S,D] shape should be unsqueezed to 4-d tensor with shape [1,S,1,D]
261-
cos = cos.unsqueeze(0).unsqueeze(2)
262-
263-
# sin tensor after chunk should be repeated through chunked dimension to original shape on Ascend NPU
264-
if len(sin.shape) == 2 and sin.shape[-1] == x.shape[-1] // 2:
265-
sin = sin.repeat(1, 2)
266-
# sin tensor with [S,D] shape should be unsqueezed to 4-d tensor with shape [1,S,1,D]
267-
sin = sin.unsqueeze(0).unsqueeze(2)
268-
269-
return npu_rotary_mul(x, cos, sin)

src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
from ...cache_utils import Cache
4545
from ...configuration_utils import PretrainedConfig, layer_type_validation
4646
from ...generation import GenerationMixin
47-
from ...modeling_flash_attention_utils import is_flash_attn_available
4847
from ...modeling_outputs import BaseModelOutput, ModelOutput
4948
from ...modeling_rope_utils import rope_config_validation
5049
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
@@ -58,13 +57,6 @@
5857
from ...utils.hub import cached_file
5958

6059

61-
if is_flash_attn_available():
62-
from ...modeling_flash_attention_utils import apply_rotary_emb, flash_attn_varlen_func
63-
else:
64-
flash_attn_varlen_func = None
65-
apply_rotary_emb = None
66-
67-
6860
logger = logging.get_logger(__name__)
6961

7062

0 commit comments

Comments
 (0)