Skip to content

Commit 2fed175

Browse files
committed
MLA layer eliminates redundant index operators
Signed-off-by: huiying <chenhuiying4@huawei.com>
1 parent 908a851 commit 2fed175

File tree

2 files changed

+21
-15
lines changed

2 files changed

+21
-15
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,10 @@ def __init__(
446446
ascend_config = get_ascend_config()
447447
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
448448

449+
self.cos = None
450+
self.sin = None
451+
self.debug_layer_idx = kwargs.get('debug_layer_idx', 0)
452+
449453
def _v_up_proj_and_o_proj(self, x):
450454
# Convert from (B, N, L) to (N, B, L)
451455
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
@@ -738,19 +742,20 @@ def forward(
738742
decode_ql_nope, decode_q_pe = \
739743
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
740744
if self.running_in_graph:
741-
seq_len = self.rotary_emb.max_position_embeddings
742-
cos = self.rotary_emb.cos_cached[:seq_len].to(
743-
dtype=decode_q_pe.dtype)
744-
sin = self.rotary_emb.sin_cached[:seq_len].to(
745-
dtype=decode_q_pe.dtype)
746-
cos = cos[attn_metadata.decode.input_positions]
747-
sin = sin[attn_metadata.decode.input_positions]
748-
cos = cos[:, None, None, :]
749-
sin = sin[:, None, None, :]
750-
751-
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
745+
# During the autoregressive decoding process, the cos and sin values are exactly the same for each layer
746+
if self.debug_layer_idx == 0 or self.cos is None or self.sin is None:
747+
seq_len = self.rotary_emb.max_position_embeddings
748+
self.cos = self.rotary_emb.cos_cached[:seq_len].to(
749+
dtype=decode_q_pe.dtype)
750+
self.sin = self.rotary_emb.sin_cached[:seq_len].to(
751+
dtype=decode_q_pe.dtype)
752+
self.cos = self.cos[attn_metadata.decode.input_positions]
753+
self.sin = self.sin[attn_metadata.decode.input_positions]
754+
self.cos = self.cos[:, None, None, :]
755+
self.sin = self.sin[:, None, None, :]
756+
decode_q_pe = self.rope_single(decode_q_pe, self.cos, self.sin)
752757
decode_k_pe, decode_k_nope = self.exec_kv(
753-
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
758+
hidden_states_or_kv_c_normed, self.cos, self.sin, kv_cache,
754759
attn_metadata.slot_mapping)
755760
else:
756761
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(

vllm_ascend/models/deepseek_v2.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,9 @@ def __init__(
390390
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
391391
self.scaling = self.scaling * mscale * mscale
392392

393+
self.prefix = prefix
394+
self.debug_layer_idx = int(self.prefix.split(".")[-2])
395+
393396
# In the MLA backend, kv_cache includes both k_c and
394397
# pe (i.e. decoupled position embeddings). In particular,
395398
# the concat_and_cache_mla op requires
@@ -418,11 +421,9 @@ def __init__(
418421
kv_a_layernorm=self.kv_a_layernorm,
419422
kv_b_proj=self.kv_b_proj,
420423
o_proj=self.o_proj,
424+
debug_layer_idx=self.debug_layer_idx,
421425
)
422426

423-
self.prefix = prefix
424-
self.debug_layer_idx = int(self.prefix.split(".")[-2])
425-
426427
ascend_config = get_ascend_config()
427428
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
428429

0 commit comments

Comments
 (0)