diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 2ecc3f7bd7..652cff3bf4 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -664,6 +664,7 @@ def test_rope_single(self, mock_rope): def test_forward_decode_without_graph(self, mock_page_attention_mla, mock_up_proj): self.impl.running_in_graph = False + self.impl.running_chunkprefilll_with_torchair = False num_tokens = 100 num_blocks = 256 block_size = 4 diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index a8f8ae8233..48713fc385 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -998,7 +998,7 @@ def _forward_decode( decode_meta = attn_metadata.decode assert decode_meta is not None num_tokens = q_nope.size(0) - if self.running_in_graph: + if self.running_in_graph or self.running_chunkprefilll_with_torchair: # shape of knope/k_pe for npu graph mode should be: # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim] block_size = kv_c_and_k_pe_cache[0].shape[1] @@ -1112,6 +1112,7 @@ def forward( self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [ AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding ] + self.running_chunkprefilll_with_torchair = self.torchair_graph_enabled and attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill num_actual_toks = attn_metadata.num_actual_tokens if k_pe is None and not self.running_in_graph: kv_c, k_pe = self.kv_a_proj_with_mqa( @@ -1148,18 +1149,25 @@ def forward( if has_decode: decode_k_nope = None assert attn_metadata.decode is not None - if self.running_in_graph: + if self.running_in_graph or self.running_chunkprefilll_with_torchair: cos = attn_metadata.decode.cos sin = attn_metadata.decode.sin - with npu_stream_switch("mla_secondary", - 0, - enabled=enable_multistream_mla): - npu_wait_tensor(hidden_states_or_kv_c_normed, - ckq, - enabled=enable_multistream_mla) + if self.running_chunkprefilll_with_torchair: + decode_hs = ( + hidden_states_or_kv_c_normed[:num_decode_tokens]) + slots = attn_metadata.slot_mapping[:num_decode_tokens] decode_k_pe, decode_k_nope, decode_kv = self.exec_kv( - hidden_states_or_kv_c_normed, cos, sin, kv_cache, - attn_metadata.slot_mapping) + decode_hs, cos, sin, kv_cache, slots) + else: + with npu_stream_switch("mla_secondary", + 0, + enabled=enable_multistream_mla): + npu_wait_tensor(hidden_states_or_kv_c_normed, + ckq, + enabled=enable_multistream_mla) + decode_k_pe, decode_k_nope, decode_kv = self.exec_kv( + hidden_states_or_kv_c_normed, cos, sin, kv_cache, + attn_metadata.slot_mapping) # Without explicitly controlling the order, IndexByTensor operations # would be placed after `matmul W_KV_T` hindering the overlapping of # KvRmsNormRopeCache and SingleRope. @@ -1183,6 +1191,8 @@ def forward( decode_k_pe, enabled=enable_multistream_mla) decode_q_pe = self.rope_single(decode_q_pe, cos, sin) + elif self.running_chunkprefilll_with_torchair: + decode_q_pe = self.rope_single(decode_q_pe, cos, sin) else: decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( attn_metadata.decode.input_positions, @@ -1221,16 +1231,15 @@ def forward( kv_cache ) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)" if self.torchair_graph_enabled: - if kv_cache[0].numel( - ) > 0 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + if kv_cache[0].numel() > 0 and has_prefill: slots = attn_metadata.slot_mapping # NOTE: Separate the kv cache in advance to avoid OOM or other issues - torch_npu._npu_reshape_and_cache(key=kv_c_normed.view( - num_tokens, self.num_kv_heads, -1), - value=prefill_k_pe, - key_cache=kv_cache[0], - value_cache=kv_cache[1], - slot_indices=slots) + torch_npu._npu_reshape_and_cache( + key=kv_c_normed.view(num_tokens, self.num_kv_heads, -1), + value=prefill_k_pe, + key_cache=kv_cache[0], + value_cache=kv_cache[1], + slot_indices=slots[num_decode_tokens:]) else: kv_c_normed = kv_c_normed.view( [num_actual_toks, self.num_kv_heads, -1])