Skip to content

Commit 1c68685

Browse files
author
zhenghaojiang.zhj
committed
use chunkprefill mla with torchair graph
Signed-off-by: zhenghaojiang.zhj <zhenghaojiang.zhj@antgroup.com>
1 parent d3c6dd9 commit 1c68685

File tree

2 files changed

+51
-36
lines changed

2 files changed

+51
-36
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 50 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -679,21 +679,28 @@ def _compute_prefill_context(
679679
q_nope = query[..., :self.qk_nope_head_dim]
680680

681681
seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32)
682-
latent_kv_dim = kv_c_and_k_pe_cache.size(3) - rope_dim
683-
cache_kv_c = kv_c_and_k_pe_cache[:, :, :, :latent_kv_dim]
684-
cache_k_pe = kv_c_and_k_pe_cache[:, :, :, latent_kv_dim:]
682+
if not self.running_chunkprefilll_with_torchair:
683+
latent_kv_dim = kv_c_and_k_pe_cache.size(3) - rope_dim
684+
cache_kv_c = kv_c_and_k_pe_cache[:, :, :, :latent_kv_dim]
685+
cache_k_pe = kv_c_and_k_pe_cache[:, :, :, latent_kv_dim:]
686+
num_heads = kv_c_and_k_pe_cache.size(2)
687+
else:
688+
latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1)
689+
cache_kv_c = kv_c_and_k_pe_cache[0]
690+
cache_k_pe = kv_c_and_k_pe_cache[1]
691+
num_heads = cache_k_pe.size(2)
685692
for i in range(iters):
686693
toks = prefill_metadata.chunked_context.seq_tot[i]
687694

688695
seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i]
689696
seq_len = torch.stack([seq_len1, seq_len2])
690697
kv_c_normed = torch.empty(toks,
691-
kv_c_and_k_pe_cache.size(2),
698+
num_heads,
692699
latent_kv_dim,
693700
dtype=query.dtype,
694701
device=query.device)
695702
k_pe = torch.empty(toks,
696-
kv_c_and_k_pe_cache.size(2),
703+
num_heads,
697704
rope_dim,
698705
dtype=query.dtype,
699706
device=query.device)
@@ -952,7 +959,7 @@ def _forward_decode(
952959
[num_tokens, self.num_heads, self.kv_lora_rank],
953960
dtype=q.dtype,
954961
device=q.device)
955-
if self.running_in_graph:
962+
if self.running_in_graph or self.running_chunkprefilll_with_torchair:
956963
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
957964
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
958965
assert num_tokens % self.spec_token_num == 0
@@ -1049,13 +1056,13 @@ def forward(
10491056
self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [
10501057
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
10511058
]
1059+
self.running_chunkprefilll_with_torchair = self.torchair_graph_enabled and attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill
10521060
num_actual_toks = attn_metadata.num_actual_tokens
10531061
if k_pe is None and not self.running_in_graph:
1054-
if not self.torchair_graph_enabled:
1055-
kv_c, k_pe = self.kv_a_proj_with_mqa(
1056-
hidden_states_or_kv_c_normed)[0].split(
1057-
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
1058-
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
1062+
kv_c, k_pe = self.kv_a_proj_with_mqa(
1063+
hidden_states_or_kv_c_normed)[0].split(
1064+
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
1065+
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
10591066
else:
10601067
kv_c_normed = hidden_states_or_kv_c_normed
10611068
assert attn_metadata.num_decodes is not None and \
@@ -1068,24 +1075,23 @@ def forward(
10681075
# Inputs and outputs may be padded for CUDA graphs
10691076
output_padded = output
10701077
output = output[:num_actual_toks, ...]
1071-
if not self.torchair_graph_enabled:
1072-
kv_c_normed = kv_c_normed[:num_actual_toks, ...]
1073-
prefill_k_c_normed = kv_c_normed[num_decode_tokens:]
1078+
kv_c_normed = kv_c_normed[:num_actual_toks, ...]
1079+
prefill_k_c_normed = kv_c_normed[num_decode_tokens:]
10741080
if not self.running_in_graph:
10751081
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
10761082
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
1077-
if not self.torchair_graph_enabled:
1078-
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
1079-
k_pe = k_pe[:num_actual_toks, ...]
1080-
k_pe = k_pe.unsqueeze(1)
1081-
decode_k_pe = k_pe[:num_decode_tokens]
1082-
prefill_k_pe = k_pe[num_decode_tokens:]
1083+
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
1084+
prefill_hs = hidden_states_or_kv_c_normed[num_decode_tokens:]
1085+
k_pe = k_pe[:num_actual_toks, ...]
1086+
k_pe = k_pe.unsqueeze(1)
1087+
decode_k_pe = k_pe[:num_decode_tokens]
1088+
prefill_k_pe = k_pe[num_decode_tokens:]
10831089
else:
10841090
decode_hs_or_q_c = hidden_states_or_q_c
10851091
if has_decode:
10861092
decode_k_nope = None
10871093
assert attn_metadata.decode is not None
1088-
if self.running_in_graph:
1094+
if self.running_in_graph or self.running_chunkprefilll_with_torchair:
10891095
seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor
10901096
cos = self.rotary_emb.cos_cached[:seq_len].to(
10911097
dtype=decode_hs_or_q_c.dtype)
@@ -1095,15 +1101,23 @@ def forward(
10951101
sin = sin[attn_metadata.decode.input_positions]
10961102
cos = cos[:, None, None, :]
10971103
sin = sin[:, None, None, :]
1098-
with npu_stream_switch("mla_secondary",
1099-
0,
1100-
enabled=enable_multistream_mla):
1101-
npu_wait_tensor(hidden_states_or_kv_c_normed,
1102-
ckq,
1103-
enabled=enable_multistream_mla)
1104+
slots = attn_metadata.slot_mapping
1105+
if self.running_chunkprefilll_with_torchair:
1106+
decode_hs = (
1107+
hidden_states_or_kv_c_normed[:num_decode_tokens])
1108+
slots = attn_metadata.slot_mapping[:num_decode_tokens]
11041109
decode_k_pe, decode_k_nope, decode_kv = self.exec_kv(
1105-
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
1106-
attn_metadata.slot_mapping)
1110+
decode_hs, cos, sin, kv_cache, slots)
1111+
else:
1112+
with npu_stream_switch("mla_secondary",
1113+
0,
1114+
enabled=enable_multistream_mla):
1115+
npu_wait_tensor(hidden_states_or_kv_c_normed,
1116+
ckq,
1117+
enabled=enable_multistream_mla)
1118+
decode_k_pe, decode_k_nope, decode_kv = self.exec_kv(
1119+
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
1120+
slots)
11071121
# Without explicitly controlling the order, IndexByTensor operations
11081122
# would be placed after `matmul W_KV_T` hindering the overlapping of
11091123
# KvRmsNormRopeCache and SingleRope.
@@ -1127,6 +1141,8 @@ def forward(
11271141
decode_k_pe,
11281142
enabled=enable_multistream_mla)
11291143
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
1144+
elif self.running_chunkprefilll_with_torchair:
1145+
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
11301146
else:
11311147
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
11321148
attn_metadata.decode.input_positions,
@@ -1153,11 +1169,11 @@ def forward(
11531169

11541170
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
11551171
prefill_k_pe, prefill_k_nope = self.exec_kv_prefill(
1156-
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
1157-
attn_metadata.slot_mapping)
1172+
prefill_hs, cos, sin, kv_cache,
1173+
attn_metadata.slot_mapping[num_decode_tokens:])
11581174

11591175
kv_c_normed = prefill_k_nope[:num_actual_toks, ...]
1160-
prefill_k_c_normed = prefill_k_nope[num_decode_tokens:]
1176+
prefill_k_c_normed = prefill_k_nope
11611177
prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads,
11621178
-1)
11631179
prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1)
@@ -1168,9 +1184,8 @@ def forward(
11681184
prefill_k_pe,
11691185
max_seq_len=attn_metadata.prefill.max_seq_lens)
11701186
if self.torchair_graph_enabled:
1171-
if len(kv_cache) > 0 and kv_cache[0].numel(
1172-
) > 0 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
1173-
slots = attn_metadata.slot_mapping
1187+
if len(kv_cache) > 0 and kv_cache[0].numel() > 0 and has_prefill:
1188+
slots = attn_metadata.slot_mapping[num_decode_tokens:]
11741189
# NOTE: Separate the kv cache in advance to avoid OOM or other issues
11751190
torch_npu._npu_reshape_and_cache(key=kv_c_normed.view(
11761191
num_tokens, self.num_kv_heads, -1),

vllm_ascend/worker/model_runner_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -697,7 +697,7 @@ def get_model(self) -> nn.Module:
697697
def _make_attention_mask(self, seq_lens, query_lens, position,
698698
attn_state) -> torch.Tensor:
699699
# Chunk Prefill situation.
700-
if attn_state == AscendAttentionState.ChunkedPrefill:
700+
if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla:
701701
return self.attn_mask_builder.get_splitfuse_attn_mask(
702702
seq_lens, query_lens, position, self.dtype, self.device)
703703
# Prefill without cache situation.

0 commit comments

Comments
 (0)