Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 4 additions & 18 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,15 +1082,8 @@ def forward(
decode_k_nope = None
assert attn_metadata.decode is not None
if self.running_in_graph:
seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor
cos = self.rotary_emb.cos_cached[:seq_len].to(
dtype=decode_hs_or_q_c.dtype)
sin = self.rotary_emb.sin_cached[:seq_len].to(
dtype=decode_hs_or_q_c.dtype)
cos = cos[attn_metadata.decode.input_positions]
sin = sin[attn_metadata.decode.input_positions]
cos = cos[:, None, None, :]
sin = sin[:, None, None, :]
cos = attn_metadata.decode.cos
sin = attn_metadata.decode.sin
# Without explicitly controlling the order, IndexByTensor operations
# would be placed after `matmul W_KV_T` hindering the overlapping of
# KvRmsNormRopeCache and SingleRope.
Expand Down Expand Up @@ -1127,15 +1120,8 @@ def forward(
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
if self.torchair_graph_enabled:
num_tokens = prefill_hs_or_q_c.shape[0]
seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor
cos = self.rotary_emb.cos_cached[:seq_len].to(
dtype=prefill_q_pe.dtype)
sin = self.rotary_emb.sin_cached[:seq_len].to(
dtype=prefill_q_pe.dtype)
cos = cos[attn_metadata.prefill.input_positions]
sin = sin[attn_metadata.prefill.input_positions]
cos = cos[:, None, None, :]
sin = sin[:, None, None, :]
cos = attn_metadata.prefill.cos
sin = attn_metadata.prefill.sin

prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
prefill_k_pe, prefill_k_nope = self.exec_kv_prefill(
Expand Down
35 changes: 35 additions & 0 deletions vllm_ascend/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.quantization.quant_config import AscendLinearMethod
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.utils import (dispose_tensor, npu_stream_switch,
npu_wait_tensor, vllm_version_is)

Expand Down Expand Up @@ -794,6 +795,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))

ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.max_position_embeddings = self.layers[0].self_attn.rotary_emb.max_position_embeddings
self.scaling_factor = self.layers[0].self_attn.rotary_emb.scaling_factor
self.cos_cached = self.layers[0].self_attn.rotary_emb.cos_cached
self.sin_cached = self.layers[0].self_attn.rotary_emb.sin_cached

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
Expand All @@ -819,6 +827,33 @@ def forward(
residual = intermediate_tensors["residual"]

replace_allreduce = hidden_states.shape[0] % self.tp_size == 0
# get cos, sin before layers
self.running_in_graph = self.torchair_graph_enabled and \
attn_metadata and \
attn_metadata.attn_state in [
AscendAttentionState.DecodeOnly,
AscendAttentionState.SpecDecoding
]
if attn_metadata.num_decodes > 0 and self.running_in_graph:
seq_len = self.max_position_embeddings * self.scaling_factor
cos = self.cos_cached[:seq_len].to(
dtype=hidden_states.dtype)
sin = self.sin_cached[:seq_len].to(
dtype=hidden_states.dtype)
cos = cos[attn_metadata.decode.input_positions]
sin = sin[attn_metadata.decode.input_positions]
attn_metadata.decode.cos = cos[:, None, None, :]
attn_metadata.decode.sin = sin[:, None, None, :]
if attn_metadata.num_prefills > 0 and self.torchair_graph_enabled:
seq_len = self.max_position_embeddings * self.scaling_factor
cos = self.cos_cached[:seq_len].to(
dtype=hidden_states.dtype)
sin = self.sin_cached[:seq_len].to(
dtype=hidden_states.dtype)
cos = cos[attn_metadata.prefill.input_positions]
sin = sin[attn_metadata.prefill.input_positions]
attn_metadata.prefill.cos = cos[:, None, None, :]
attn_metadata.prefill.sin = sin[:, None, None, :]

for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
Expand Down
Loading