Skip to content

Commit 4f08667

Browse files
huiyingCCCCzqh0923
authored andcommitted
MLA layer eliminates redundant index operators
Signed-off-by: huiying <chenhuiying4@huawei.com>
1 parent 5178114 commit 4f08667

File tree

3 files changed

+38
-13
lines changed

3 files changed

+38
-13
lines changed

tests/multicard/test_offline_inference_distributed.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,20 @@ def test_models_distributed_DeepSeek():
6363
distributed_executor_backend="mp",
6464
) as vllm_model:
6565
vllm_model.generate_greedy(example_prompts, max_tokens)
66+
67+
def test_models_eliminates_index_DeepSeek():
68+
os.environ["VLLM_USE_V1"] = "1"
69+
example_prompts = [
70+
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",
71+
"Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.",
72+
"Compare and contrast artificial intelligence with human intelligence in terms of processing information.",
73+
]
74+
dtype = "half"
75+
max_tokens = 5
76+
with VllmRunner(
77+
"deepseek-ai/DeepSeek-V2-Lite",
78+
dtype=dtype,
79+
tensor_parallel_size=4,
80+
distributed_executor_backend="mp",
81+
) as vllm_model:
82+
vllm_model.generate_greedy(example_prompts, max_tokens)

vllm_ascend/attention/mla_v1.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,10 @@ def __init__(
465465
self.enable_graph_mode = additional_config.get(
466466
"enable_graph_mode", False)
467467

468+
self.cos = None
469+
self.sin = None
470+
self.debug_layer_idx = kwargs.get('debug_layer_idx', 0)
471+
468472
def _v_up_proj_and_o_proj(self, x):
469473
# Convert from (B, N, L) to (N, B, L)
470474
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
@@ -757,18 +761,20 @@ def forward(
757761
decode_ql_nope, decode_q_pe = \
758762
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
759763
if self.running_in_graph:
760-
seq_len = self.rotary_emb.max_position_embeddings
761-
cos = self.rotary_emb.cos_cached[:seq_len].to(
762-
dtype=decode_q_pe.dtype)
763-
sin = self.rotary_emb.sin_cached[:seq_len].to(
764-
dtype=decode_q_pe.dtype)
765-
cos = cos[attn_metadata.decode.input_positions]
766-
sin = sin[attn_metadata.decode.input_positions]
767-
cos = cos[:, None, None, :]
768-
sin = sin[:, None, None, :]
769-
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
764+
# During the autoregressive decoding process, the cos and sin values are exactly the same for each layer
765+
if self.debug_layer_idx == 0 or self.cos is None or self.sin is None:
766+
seq_len = self.rotary_emb.max_position_embeddings
767+
self.cos = self.rotary_emb.cos_cached[:seq_len].to(
768+
dtype=decode_q_pe.dtype)
769+
self.sin = self.rotary_emb.sin_cached[:seq_len].to(
770+
dtype=decode_q_pe.dtype)
771+
self.cos = self.cos[attn_metadata.decode.input_positions]
772+
self.sin = self.sin[attn_metadata.decode.input_positions]
773+
self.cos = self.cos[:, None, None, :]
774+
self.sin = self.sin[:, None, None, :]
775+
decode_q_pe = self.rope_single(decode_q_pe, self.cos, self.sin)
770776
decode_k_pe, decode_k_nope = self.exec_kv(
771-
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
777+
hidden_states_or_kv_c_normed, self.cos, self.sin, kv_cache,
772778
attn_metadata.slot_mapping)
773779
else:
774780
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(

vllm_ascend/models/deepseek_v2.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,9 @@ def __init__(
370370
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
371371
self.scaling = self.scaling * mscale * mscale
372372

373+
self.prefix = prefix
374+
self.debug_layer_idx = int(self.prefix.split(".")[-2])
375+
373376
# In the MLA backend, kv_cache includes both k_c and
374377
# pe (i.e. decoupled position embeddings). In particular,
375378
# the concat_and_cache_mla op requires
@@ -398,10 +401,9 @@ def __init__(
398401
kv_a_layernorm=self.kv_a_layernorm,
399402
kv_b_proj=self.kv_b_proj,
400403
o_proj=self.o_proj,
404+
debug_layer_idx=self.debug_layer_idx,
401405
)
402406

403-
self.prefix = prefix
404-
self.debug_layer_idx = int(self.prefix.split(".")[-2])
405407
self.enable_graph_mode = False
406408
additional_config = get_current_vllm_config().additional_config
407409
if additional_config:

0 commit comments

Comments
 (0)