Skip to content

Commit b84e6bc

Browse files
committed
[perf] optimize rope in deepseek
Signed-off-by: David9857 <985700846@qq.com>
1 parent 0e43813 commit b84e6bc

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

vllm_ascend/models/deepseek_v2.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
from vllm_ascend.ops.fused_moe import AscendFusedMoE
7575
from vllm_ascend.quantization.quant_config import AscendLinearMethod
7676
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
77+
from vllm_ascend.attention.attention_v1 import AscendAttentionState
7778
from vllm_ascend.utils import (dispose_tensor, npu_stream_switch,
7879
npu_wait_tensor, vllm_version_is)
7980

@@ -794,6 +795,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
794795
self.make_empty_intermediate_tensors = (
795796
make_empty_intermediate_tensors_factory(
796797
["hidden_states", "residual"], config.hidden_size))
798+
799+
ascend_config = get_ascend_config()
800+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
801+
self.max_position_embeddings = self.layers[0].self_attn.rotary_emb.max_position_embeddings
802+
self.scaling_factor = self.layers[0].self_attn.rotary_emb.scaling_factor
803+
self.cos_cached = self.layers[0].self_attn.rotary_emb.cos_cached
804+
self.sin_cached = self.layers[0].self_attn.rotary_emb.sin_cached
797805

798806
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
799807
return self.embed_tokens(input_ids)
@@ -819,6 +827,30 @@ def forward(
819827
residual = intermediate_tensors["residual"]
820828

821829
replace_allreduce = hidden_states.shape[0] % self.tp_size == 0
830+
# get cos, sin before layers
831+
self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [
832+
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
833+
]
834+
if attn_metadata.num_decodes > 0 and self.running_in_graph:
835+
seq_len = self.max_position_embeddings * self.scaling_factor
836+
cos = self.cos_cached[:seq_len].to(
837+
dtype=hidden_states.dtype)
838+
sin = self.sin_cached[:seq_len].to(
839+
dtype=hidden_states.dtype)
840+
cos = cos[attn_metadata.decode.input_positions]
841+
sin = sin[attn_metadata.decode.input_positions]
842+
attn_metadata.decode.cos = cos[:, None, None, :]
843+
attn_metadata.decode.sin = sin[:, None, None, :]
844+
if attn_metadata.num_prefills > 0 and self.torchair_graph_enabled:
845+
seq_len = self.max_position_embeddings * self.scaling_factor
846+
cos = self.cos_cached[:seq_len].to(
847+
dtype=hidden_states.dtype)
848+
sin = self.sin_cached[:seq_len].to(
849+
dtype=hidden_states.dtype)
850+
cos = cos[attn_metadata.prefill.input_positions]
851+
sin = sin[attn_metadata.prefill.input_positions]
852+
attn_metadata.prefill.cos = cos[:, None, None, :]
853+
attn_metadata.prefill.sin = sin[:, None, None, :]
822854

823855
for i in range(self.start_layer, self.end_layer):
824856
layer = self.layers[i]

0 commit comments

Comments
 (0)