Skip to content

Commit 3cf301a

Browse files
committed
chunkprefill mla with torchair graph
Signed-off-by: haojiangzheng <justineric096@gmail.com>
1 parent 8cf97d8 commit 3cf301a

File tree

2 files changed

+28
-18
lines changed

2 files changed

+28
-18
lines changed

tests/ut/attention/test_mla_v1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,7 @@ def test_rope_single(self, mock_rope):
657657
def test_forward_decode_without_graph(self, mock_page_attention_mla,
658658
mock_up_proj):
659659
self.impl.running_in_graph = False
660+
self.impl.running_chunkprefilll_with_torchair = False
660661
num_tokens = 100
661662
num_blocks = 256
662663
block_size = 4

vllm_ascend/attention/mla_v1.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -964,7 +964,7 @@ def _forward_decode(
964964
decode_meta = attn_metadata.decode
965965
assert decode_meta is not None
966966
num_tokens = q_nope.size(0)
967-
if self.running_in_graph:
967+
if self.running_in_graph or self.running_chunkprefilll_with_torchair:
968968
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
969969
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
970970
assert num_tokens % self.spec_token_num == 0
@@ -1080,6 +1080,7 @@ def forward(
10801080
self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [
10811081
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
10821082
]
1083+
self.running_chunkprefilll_with_torchair = self.torchair_graph_enabled and attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill
10831084
num_actual_toks = attn_metadata.num_actual_tokens
10841085
if k_pe is None and not self.running_in_graph:
10851086
kv_c, k_pe = self.kv_a_proj_with_mqa(
@@ -1116,18 +1117,25 @@ def forward(
11161117
if has_decode:
11171118
decode_k_nope = None
11181119
assert attn_metadata.decode is not None
1119-
if self.running_in_graph:
1120+
if self.running_in_graph or self.running_chunkprefilll_with_torchair:
11201121
cos = attn_metadata.decode.cos
11211122
sin = attn_metadata.decode.sin
1122-
with npu_stream_switch("mla_secondary",
1123-
0,
1124-
enabled=enable_multistream_mla):
1125-
npu_wait_tensor(hidden_states_or_kv_c_normed,
1126-
ckq,
1127-
enabled=enable_multistream_mla)
1123+
if self.running_chunkprefilll_with_torchair:
1124+
decode_hs = (
1125+
hidden_states_or_kv_c_normed[:num_decode_tokens])
1126+
slots = attn_metadata.slot_mapping[:num_decode_tokens]
11281127
decode_k_pe, decode_k_nope, decode_kv = self.exec_kv(
1129-
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
1130-
attn_metadata.slot_mapping)
1128+
decode_hs, cos, sin, kv_cache, slots)
1129+
else:
1130+
with npu_stream_switch("mla_secondary",
1131+
0,
1132+
enabled=enable_multistream_mla):
1133+
npu_wait_tensor(hidden_states_or_kv_c_normed,
1134+
ckq,
1135+
enabled=enable_multistream_mla)
1136+
decode_k_pe, decode_k_nope, decode_kv = self.exec_kv(
1137+
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
1138+
attn_metadata.slot_mapping)
11311139
# Without explicitly controlling the order, IndexByTensor operations
11321140
# would be placed after `matmul W_KV_T` hindering the overlapping of
11331141
# KvRmsNormRopeCache and SingleRope.
@@ -1151,6 +1159,8 @@ def forward(
11511159
decode_k_pe,
11521160
enabled=enable_multistream_mla)
11531161
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
1162+
elif self.running_chunkprefilll_with_torchair:
1163+
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
11541164
else:
11551165
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
11561166
attn_metadata.decode.input_positions,
@@ -1189,16 +1199,15 @@ def forward(
11891199
kv_cache
11901200
) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
11911201
if self.torchair_graph_enabled:
1192-
if kv_cache[0].numel(
1193-
) > 0 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
1202+
if kv_cache[0].numel() > 0 and has_prefill:
11941203
slots = attn_metadata.slot_mapping
11951204
# NOTE: Separate the kv cache in advance to avoid OOM or other issues
1196-
torch_npu._npu_reshape_and_cache(key=kv_c_normed.view(
1197-
num_tokens, self.num_kv_heads, -1),
1198-
value=prefill_k_pe,
1199-
key_cache=kv_cache[0],
1200-
value_cache=kv_cache[1],
1201-
slot_indices=slots)
1205+
torch_npu._npu_reshape_and_cache(
1206+
key=kv_c_normed.view(num_tokens, self.num_kv_heads, -1),
1207+
value=prefill_k_pe,
1208+
key_cache=kv_cache[0],
1209+
value_cache=kv_cache[1],
1210+
slot_indices=slots[num_decode_tokens:])
12021211
else:
12031212
kv_c_normed = kv_c_normed.view(
12041213
[num_actual_toks, self.num_kv_heads, -1])

0 commit comments

Comments
 (0)