Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/ut/attention/test_mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,7 @@ def test_rope_single(self, mock_rope):
def test_forward_decode_without_graph(self, mock_page_attention_mla,
mock_up_proj):
self.impl.running_in_graph = False
self.impl.running_chunkprefilll_with_torchair = False
num_tokens = 100
num_blocks = 256
block_size = 4
Expand Down
45 changes: 27 additions & 18 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,7 +998,7 @@ def _forward_decode(
decode_meta = attn_metadata.decode
assert decode_meta is not None
num_tokens = q_nope.size(0)
if self.running_in_graph:
if self.running_in_graph or self.running_chunkprefilll_with_torchair:
# shape of knope/k_pe for npu graph mode should be:
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
block_size = kv_c_and_k_pe_cache[0].shape[1]
Expand Down Expand Up @@ -1112,6 +1112,7 @@ def forward(
self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
]
self.running_chunkprefilll_with_torchair = self.torchair_graph_enabled and attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill
num_actual_toks = attn_metadata.num_actual_tokens
if k_pe is None and not self.running_in_graph:
kv_c, k_pe = self.kv_a_proj_with_mqa(
Expand Down Expand Up @@ -1148,18 +1149,25 @@ def forward(
if has_decode:
decode_k_nope = None
assert attn_metadata.decode is not None
if self.running_in_graph:
if self.running_in_graph or self.running_chunkprefilll_with_torchair:
cos = attn_metadata.decode.cos
sin = attn_metadata.decode.sin
with npu_stream_switch("mla_secondary",
0,
enabled=enable_multistream_mla):
npu_wait_tensor(hidden_states_or_kv_c_normed,
ckq,
enabled=enable_multistream_mla)
if self.running_chunkprefilll_with_torchair:
decode_hs = (
hidden_states_or_kv_c_normed[:num_decode_tokens])
slots = attn_metadata.slot_mapping[:num_decode_tokens]
decode_k_pe, decode_k_nope, decode_kv = self.exec_kv(
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
attn_metadata.slot_mapping)
decode_hs, cos, sin, kv_cache, slots)
else:
with npu_stream_switch("mla_secondary",
0,
enabled=enable_multistream_mla):
npu_wait_tensor(hidden_states_or_kv_c_normed,
ckq,
enabled=enable_multistream_mla)
decode_k_pe, decode_k_nope, decode_kv = self.exec_kv(
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
attn_metadata.slot_mapping)
# Without explicitly controlling the order, IndexByTensor operations
# would be placed after `matmul W_KV_T` hindering the overlapping of
# KvRmsNormRopeCache and SingleRope.
Expand All @@ -1183,6 +1191,8 @@ def forward(
decode_k_pe,
enabled=enable_multistream_mla)
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
elif self.running_chunkprefilll_with_torchair:
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
else:
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
attn_metadata.decode.input_positions,
Expand Down Expand Up @@ -1221,16 +1231,15 @@ def forward(
kv_cache
) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
if self.torchair_graph_enabled:
if kv_cache[0].numel(
) > 0 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
if kv_cache[0].numel() > 0 and has_prefill:
slots = attn_metadata.slot_mapping
# NOTE: Separate the kv cache in advance to avoid OOM or other issues
torch_npu._npu_reshape_and_cache(key=kv_c_normed.view(
num_tokens, self.num_kv_heads, -1),
value=prefill_k_pe,
key_cache=kv_cache[0],
value_cache=kv_cache[1],
slot_indices=slots)
torch_npu._npu_reshape_and_cache(
key=kv_c_normed.view(num_tokens, self.num_kv_heads, -1),
value=prefill_k_pe,
key_cache=kv_cache[0],
value_cache=kv_cache[1],
slot_indices=slots[num_decode_tokens:])
else:
kv_c_normed = kv_c_normed.view(
[num_actual_toks, self.num_kv_heads, -1])
Expand Down
Loading