Skip to content

Commit dd16906

Browse files
committed
[perf] optimize rope in deepseek
Signed-off-by: David9857 <985700846@qq.com>
1 parent ebb2a70 commit dd16906

File tree

2 files changed

+36
-18
lines changed

2 files changed

+36
-18
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,15 +1077,8 @@ def forward(
10771077
decode_k_nope = None
10781078
assert attn_metadata.decode is not None
10791079
if self.running_in_graph:
1080-
seq_len = self.rotary_emb.max_position_embeddings
1081-
cos = self.rotary_emb.cos_cached[:seq_len].to(
1082-
dtype=decode_hs_or_q_c.dtype)
1083-
sin = self.rotary_emb.sin_cached[:seq_len].to(
1084-
dtype=decode_hs_or_q_c.dtype)
1085-
cos = cos[attn_metadata.decode.input_positions]
1086-
sin = sin[attn_metadata.decode.input_positions]
1087-
cos = cos[:, None, None, :]
1088-
sin = sin[:, None, None, :]
1080+
cos = attn_metadata.decode.cos
1081+
sin = attn_metadata.decode.sin
10891082
# Without explicitly controlling the order, IndexByTensor operations
10901083
# would be placed after `matmul W_KV_T` hindering the overlapping of
10911084
# KvRmsNormRopeCache and SingleRope.
@@ -1122,15 +1115,8 @@ def forward(
11221115
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
11231116
if self.torchair_graph_enabled:
11241117
num_tokens = prefill_hs_or_q_c.shape[0]
1125-
seq_len = self.rotary_emb.max_position_embeddings
1126-
cos = self.rotary_emb.cos_cached[:seq_len].to(
1127-
dtype=prefill_q_pe.dtype)
1128-
sin = self.rotary_emb.sin_cached[:seq_len].to(
1129-
dtype=prefill_q_pe.dtype)
1130-
cos = cos[attn_metadata.prefill.input_positions]
1131-
sin = sin[attn_metadata.prefill.input_positions]
1132-
cos = cos[:, None, None, :]
1133-
sin = sin[:, None, None, :]
1118+
cos = attn_metadata.prefill.cos
1119+
sin = attn_metadata.prefill.sin
11341120

11351121
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
11361122
prefill_k_pe, prefill_k_nope = self.exec_kv_prefill(

vllm_ascend/models/deepseek_v2.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
from vllm_ascend.ops.fused_moe import AscendFusedMoE
7070
from vllm_ascend.quantization.quant_config import AscendLinearMethod
7171
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
72+
from vllm_ascend.attention.attention_v1 import AscendAttentionState
7273
from vllm_ascend.utils import (dispose_tensor, npu_stream_switch,
7374
npu_wait_tensor)
7475

@@ -671,6 +672,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
671672
self.make_empty_intermediate_tensors = (
672673
make_empty_intermediate_tensors_factory(
673674
["hidden_states", "residual"], config.hidden_size))
675+
676+
ascend_config = get_ascend_config()
677+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
678+
self.max_position_embeddings = self.layers[0].self_attn.rotary_emb.max_position_embeddings
679+
self.cos_cached = self.layers[0].self_attn.rotary_emb.cos_cached
680+
self.sin_cached = self.layers[0].self_attn.rotary_emb.sin_cached
674681

675682
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
676683
return self.embed_tokens(input_ids)
@@ -695,6 +702,31 @@ def forward(
695702
hidden_states = intermediate_tensors["hidden_states"]
696703
residual = intermediate_tensors["residual"]
697704

705+
# get cos, sin before layers
706+
self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [
707+
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
708+
]
709+
if attn_metadata.num_decodes > 0 and self.running_in_graph:
710+
seq_len = self.max_position_embeddings
711+
cos = self.cos_cached[:seq_len].to(
712+
dtype=hidden_states.dtype)
713+
sin = self.sin_cached[:seq_len].to(
714+
dtype=hidden_states.dtype)
715+
cos = cos[attn_metadata.decode.input_positions]
716+
sin = sin[attn_metadata.decode.input_positions]
717+
attn_metadata.decode.cos = cos[:, None, None, :]
718+
attn_metadata.decode.sin = sin[:, None, None, :]
719+
if attn_metadata.num_prefills > 0 and self.torchair_graph_enabled:
720+
seq_len = self.rotary_emb.max_position_embeddings
721+
cos = self.cos_cached[:seq_len].to(
722+
dtype=hidden_states.dtype)
723+
sin = self.sin_cached[:seq_len].to(
724+
dtype=hidden_states.dtype)
725+
cos = cos[attn_metadata.prefill.input_positions]
726+
sin = sin[attn_metadata.prefill.input_positions]
727+
attn_metadata.prefill.cos = cos[:, None, None, :]
728+
attn_metadata.prefill.sin = sin[:, None, None, :]
729+
698730
for i in range(self.start_layer, self.end_layer):
699731
layer = self.layers[i]
700732
hidden_states, residual = layer(

0 commit comments

Comments
 (0)