Skip to content

Commit 41730d3

Browse files
committed
chunkprefill mla with torchair graph
Signed-off-by: haojiangzheng <justineric096@gmail.com>
1 parent 205eff2 commit 41730d3

File tree

2 files changed

+23
-13
lines changed

2 files changed

+23
-13
lines changed

tests/ut/attention/test_mla_v1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,7 @@ def test_rope_single(self, mock_rope):
664664
def test_forward_decode_without_graph(self, mock_page_attention_mla,
665665
mock_up_proj):
666666
self.impl.running_in_graph = False
667+
self.impl.running_chunkprefilll_with_torchair = False
667668
num_tokens = 100
668669
num_blocks = 256
669670
block_size = 4

vllm_ascend/attention/mla_v1.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -998,7 +998,7 @@ def _forward_decode(
998998
decode_meta = attn_metadata.decode
999999
assert decode_meta is not None
10001000
num_tokens = q_nope.size(0)
1001-
if self.running_in_graph:
1001+
if self.running_in_graph or self.running_chunkprefilll_with_torchair:
10021002
# shape of knope/k_pe for npu graph mode should be:
10031003
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
10041004
block_size = kv_c_and_k_pe_cache[0].shape[1]
@@ -1112,6 +1112,7 @@ def forward(
11121112
self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [
11131113
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
11141114
]
1115+
self.running_chunkprefilll_with_torchair = self.torchair_graph_enabled and attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill
11151116
num_actual_toks = attn_metadata.num_actual_tokens
11161117
if k_pe is None and not self.running_in_graph:
11171118
kv_c, k_pe = self.kv_a_proj_with_mqa(
@@ -1148,18 +1149,25 @@ def forward(
11481149
if has_decode:
11491150
decode_k_nope = None
11501151
assert attn_metadata.decode is not None
1151-
if self.running_in_graph:
1152+
if self.running_in_graph or self.running_chunkprefilll_with_torchair:
11521153
cos = attn_metadata.decode.cos
11531154
sin = attn_metadata.decode.sin
1154-
with npu_stream_switch("mla_secondary",
1155-
0,
1156-
enabled=enable_multistream_mla):
1157-
npu_wait_tensor(hidden_states_or_kv_c_normed,
1158-
ckq,
1159-
enabled=enable_multistream_mla)
1155+
if self.running_chunkprefilll_with_torchair:
1156+
decode_hs = (
1157+
hidden_states_or_kv_c_normed[:num_decode_tokens])
1158+
slots = attn_metadata.slot_mapping[:num_decode_tokens]
11601159
decode_k_pe, decode_k_nope, decode_kv = self.exec_kv(
1161-
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
1162-
attn_metadata.slot_mapping)
1160+
decode_hs, cos, sin, kv_cache, slots)
1161+
else:
1162+
with npu_stream_switch("mla_secondary",
1163+
0,
1164+
enabled=enable_multistream_mla):
1165+
npu_wait_tensor(hidden_states_or_kv_c_normed,
1166+
ckq,
1167+
enabled=enable_multistream_mla)
1168+
decode_k_pe, decode_k_nope, decode_kv = self.exec_kv(
1169+
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
1170+
attn_metadata.slot_mapping)
11631171
# Without explicitly controlling the order, IndexByTensor operations
11641172
# would be placed after `matmul W_KV_T` hindering the overlapping of
11651173
# KvRmsNormRopeCache and SingleRope.
@@ -1183,6 +1191,8 @@ def forward(
11831191
decode_k_pe,
11841192
enabled=enable_multistream_mla)
11851193
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
1194+
elif self.running_chunkprefilll_with_torchair:
1195+
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
11861196
else:
11871197
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
11881198
attn_metadata.decode.input_positions,
@@ -1221,16 +1231,15 @@ def forward(
12211231
kv_cache
12221232
) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
12231233
if self.torchair_graph_enabled:
1224-
if kv_cache[0].numel(
1225-
) > 0 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
1234+
if kv_cache[0].numel() > 0 and has_prefill:
12261235
slots = attn_metadata.slot_mapping
12271236
# NOTE: Separate the kv cache in advance to avoid OOM or other issues
12281237
torch_npu._npu_reshape_and_cache(key=kv_c_normed.view(
12291238
num_tokens, self.num_kv_heads, -1),
12301239
value=prefill_k_pe,
12311240
key_cache=kv_cache[0],
12321241
value_cache=kv_cache[1],
1233-
slot_indices=slots)
1242+
slot_indices=slots[num_decode_tokens:])
12341243
else:
12351244
kv_c_normed = kv_c_normed.view(
12361245
[num_actual_toks, self.num_kv_heads, -1])

0 commit comments

Comments
 (0)