Skip to content

Commit 11df6ad

Browse files
author
zhenghaojiang.zhj
committed
chunkprefill mla with torchair
1 parent d9f82eb commit 11df6ad

File tree

2 files changed

+28
-17
lines changed

2 files changed

+28
-17
lines changed

tests/ut/attention/test_mla_v1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,7 @@ def test_rope_single(self, mock_rope):
657657
def test_forward_decode_without_graph(self, mock_page_attention_mla,
658658
mock_up_proj):
659659
self.impl.running_in_graph = False
660+
self.impl.running_chunkprefilll_with_torchair = False
660661
num_tokens = 100
661662
num_blocks = 256
662663
block_size = 4

vllm_ascend/attention/mla_v1.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -961,7 +961,7 @@ def _forward_decode(
961961
decode_meta = attn_metadata.decode
962962
assert decode_meta is not None
963963
num_tokens = q_nope.size(0)
964-
if self.running_in_graph:
964+
if self.running_in_graph or self.running_chunkprefilll_with_torchair:
965965
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
966966
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
967967
assert num_tokens % self.spec_token_num == 0
@@ -1077,6 +1077,7 @@ def forward(
10771077
self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [
10781078
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
10791079
]
1080+
self.running_chunkprefilll_with_torchair = self.torchair_graph_enabled and attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill
10801081
num_actual_toks = attn_metadata.num_actual_tokens
10811082
if k_pe is None and not self.running_in_graph:
10821083
if not self.torchair_graph_enabled:
@@ -1096,13 +1097,13 @@ def forward(
10961097
# Inputs and outputs may be padded for CUDA graphs
10971098
output_padded = output
10981099
output = output[:num_actual_toks, ...]
1099-
if not self.torchair_graph_enabled:
1100+
if not self.torchair_graph_enabled or self.running_chunkprefilll_with_torchair:
11001101
kv_c_normed = kv_c_normed[:num_actual_toks, ...]
11011102
prefill_k_c_normed = kv_c_normed[num_decode_tokens:]
11021103
if not self.running_in_graph:
11031104
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
11041105
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
1105-
if not self.torchair_graph_enabled:
1106+
if not self.torchair_graph_enabled or self.running_chunkprefilll_with_torchair:
11061107
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
11071108
k_pe = k_pe[:num_actual_toks, ...]
11081109
k_pe = k_pe.unsqueeze(1)
@@ -1113,18 +1114,25 @@ def forward(
11131114
if has_decode:
11141115
decode_k_nope = None
11151116
assert attn_metadata.decode is not None
1116-
if self.running_in_graph:
1117+
if self.running_in_graph or self.running_chunkprefilll_with_torchair:
11171118
cos = attn_metadata.decode.cos
11181119
sin = attn_metadata.decode.sin
1119-
with npu_stream_switch("mla_secondary",
1120-
0,
1121-
enabled=enable_multistream_mla):
1122-
npu_wait_tensor(hidden_states_or_kv_c_normed,
1123-
ckq,
1124-
enabled=enable_multistream_mla)
1120+
if self.running_chunkprefilll_with_torchair:
1121+
decode_hs = (
1122+
hidden_states_or_kv_c_normed[:num_decode_tokens])
1123+
slots = attn_metadata.slot_mapping[:num_decode_tokens]
11251124
decode_k_pe, decode_k_nope, decode_kv = self.exec_kv(
1126-
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
1127-
attn_metadata.slot_mapping)
1125+
decode_hs, cos, sin, kv_cache, slots)
1126+
else:
1127+
with npu_stream_switch("mla_secondary",
1128+
0,
1129+
enabled=enable_multistream_mla):
1130+
npu_wait_tensor(hidden_states_or_kv_c_normed,
1131+
ckq,
1132+
enabled=enable_multistream_mla)
1133+
decode_k_pe, decode_k_nope, decode_kv = self.exec_kv(
1134+
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
1135+
attn_metadata.slot_mapping)
11281136
# Without explicitly controlling the order, IndexByTensor operations
11291137
# would be placed after `matmul W_KV_T` hindering the overlapping of
11301138
# KvRmsNormRopeCache and SingleRope.
@@ -1148,6 +1156,8 @@ def forward(
11481156
decode_k_pe,
11491157
enabled=enable_multistream_mla)
11501158
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
1159+
elif self.running_chunkprefilll_with_torchair:
1160+
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
11511161
else:
11521162
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
11531163
attn_metadata.decode.input_positions,
@@ -1166,9 +1176,10 @@ def forward(
11661176
sin = attn_metadata.prefill.sin
11671177

11681178
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
1179+
prefill_hs = hidden_states_or_kv_c_normed[num_decode_tokens:]
11691180
prefill_k_pe, prefill_k_nope = self.exec_kv_prefill(
1170-
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
1171-
attn_metadata.slot_mapping)
1181+
prefill_hs, cos, sin, kv_cache,
1182+
attn_metadata.slot_mapping[num_decode_tokens:])
11721183

11731184
kv_c_normed = prefill_k_nope[:num_actual_toks, ...]
11741185
prefill_k_c_normed = prefill_k_nope[num_decode_tokens:]
@@ -1186,9 +1197,8 @@ def forward(
11861197
kv_cache
11871198
) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
11881199
if self.torchair_graph_enabled:
1189-
if kv_cache[0].numel(
1190-
) > 0 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
1191-
slots = attn_metadata.slot_mapping
1200+
if kv_cache[0].numel() > 0 and has_prefill:
1201+
slots = attn_metadata.slot_mapping[num_decode_tokens:]
11921202
# NOTE: Separate the kv cache in advance to avoid OOM or other issues
11931203
torch_npu._npu_reshape_and_cache(key=kv_c_normed.view(
11941204
num_tokens, self.num_kv_heads, -1),

0 commit comments

Comments
 (0)