Skip to content

Commit 8069fef

Browse files
committed
chunkprefill mla with torchair graph
Signed-off-by: haojiangzheng <justineric096@gmail.com>
1 parent 853efb9 commit 8069fef

File tree

1 file changed

+18
-27
lines changed

1 file changed

+18
-27
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -964,7 +964,7 @@ def _forward_decode(
964964
decode_meta = attn_metadata.decode
965965
assert decode_meta is not None
966966
num_tokens = q_nope.size(0)
967-
if self.running_in_graph or self.running_chunkprefilll_with_torchair:
967+
if self.running_in_graph:
968968
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
969969
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
970970
assert num_tokens % self.spec_token_num == 0
@@ -1080,7 +1080,6 @@ def forward(
10801080
self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [
10811081
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
10821082
]
1083-
self.running_chunkprefilll_with_torchair = self.torchair_graph_enabled and attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill
10841083
num_actual_toks = attn_metadata.num_actual_tokens
10851084
if k_pe is None and not self.running_in_graph:
10861085
kv_c, k_pe = self.kv_a_proj_with_mqa(
@@ -1117,25 +1116,18 @@ def forward(
11171116
if has_decode:
11181117
decode_k_nope = None
11191118
assert attn_metadata.decode is not None
1120-
if self.running_in_graph or self.running_chunkprefilll_with_torchair:
1119+
if self.running_in_graph:
11211120
cos = attn_metadata.decode.cos
11221121
sin = attn_metadata.decode.sin
1123-
if self.running_chunkprefilll_with_torchair:
1124-
decode_hs = (
1125-
hidden_states_or_kv_c_normed[:num_decode_tokens])
1126-
slots = attn_metadata.slot_mapping[:num_decode_tokens]
1122+
with npu_stream_switch("mla_secondary",
1123+
0,
1124+
enabled=enable_multistream_mla):
1125+
npu_wait_tensor(hidden_states_or_kv_c_normed,
1126+
ckq,
1127+
enabled=enable_multistream_mla)
11271128
decode_k_pe, decode_k_nope, decode_kv = self.exec_kv(
1128-
decode_hs, cos, sin, kv_cache, slots)
1129-
else:
1130-
with npu_stream_switch("mla_secondary",
1131-
0,
1132-
enabled=enable_multistream_mla):
1133-
npu_wait_tensor(hidden_states_or_kv_c_normed,
1134-
ckq,
1135-
enabled=enable_multistream_mla)
1136-
decode_k_pe, decode_k_nope, decode_kv = self.exec_kv(
1137-
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
1138-
attn_metadata.slot_mapping)
1129+
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
1130+
attn_metadata.slot_mapping)
11391131
# Without explicitly controlling the order, IndexByTensor operations
11401132
# would be placed after `matmul W_KV_T` hindering the overlapping of
11411133
# KvRmsNormRopeCache and SingleRope.
@@ -1159,8 +1151,6 @@ def forward(
11591151
decode_k_pe,
11601152
enabled=enable_multistream_mla)
11611153
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
1162-
elif self.running_chunkprefilll_with_torchair:
1163-
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
11641154
else:
11651155
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
11661156
attn_metadata.decode.input_positions,
@@ -1199,15 +1189,16 @@ def forward(
11991189
kv_cache
12001190
) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
12011191
if self.torchair_graph_enabled:
1202-
if kv_cache[0].numel() > 0 and has_prefill:
1192+
if kv_cache[0].numel(
1193+
) > 0 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
12031194
slots = attn_metadata.slot_mapping
12041195
# NOTE: Separate the kv cache in advance to avoid OOM or other issues
1205-
torch_npu._npu_reshape_and_cache(
1206-
key=kv_c_normed.view(num_tokens, self.num_kv_heads, -1),
1207-
value=prefill_k_pe,
1208-
key_cache=kv_cache[0],
1209-
value_cache=kv_cache[1],
1210-
slot_indices=slots[num_decode_tokens:])
1196+
torch_npu._npu_reshape_and_cache(key=kv_c_normed.view(
1197+
num_tokens, self.num_kv_heads, -1),
1198+
value=prefill_k_pe,
1199+
key_cache=kv_cache[0],
1200+
value_cache=kv_cache[1],
1201+
slot_indices=slots)
12111202
else:
12121203
kv_c_normed = kv_c_normed.view(
12131204
[num_actual_toks, self.num_kv_heads, -1])

0 commit comments

Comments
 (0)