@@ -964,7 +964,7 @@ def _forward_decode(
964964        decode_meta  =  attn_metadata .decode 
965965        assert  decode_meta  is  not   None 
966966        num_tokens  =  q_nope .size (0 )
967-         if  self .running_in_graph   or   self . running_chunkprefilll_with_torchair :
967+         if  self .running_in_graph :
968968            # TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim] 
969969            if  attn_metadata .attn_state  ==  AscendAttentionState .SpecDecoding :
970970                assert  num_tokens  %  self .spec_token_num  ==  0 
@@ -1080,7 +1080,6 @@ def forward(
10801080        self .running_in_graph  =  self .torchair_graph_enabled  and  attn_metadata .attn_state  in  [
10811081            AscendAttentionState .DecodeOnly , AscendAttentionState .SpecDecoding 
10821082        ]
1083-         self .running_chunkprefilll_with_torchair  =  self .torchair_graph_enabled  and  attn_metadata .attn_state  ==  AscendAttentionState .ChunkedPrefill 
10841083        num_actual_toks  =  attn_metadata .num_actual_tokens 
10851084        if  k_pe  is  None  and  not  self .running_in_graph :
10861085            kv_c , k_pe  =  self .kv_a_proj_with_mqa (
@@ -1117,25 +1116,18 @@ def forward(
11171116        if  has_decode :
11181117            decode_k_nope  =  None 
11191118            assert  attn_metadata .decode  is  not   None 
1120-             if  self .running_in_graph   or   self . running_chunkprefilll_with_torchair :
1119+             if  self .running_in_graph :
11211120                cos  =  attn_metadata .decode .cos 
11221121                sin  =  attn_metadata .decode .sin 
1123-                 if  self .running_chunkprefilll_with_torchair :
1124-                     decode_hs  =  (
1125-                         hidden_states_or_kv_c_normed [:num_decode_tokens ])
1126-                     slots  =  attn_metadata .slot_mapping [:num_decode_tokens ]
1122+                 with  npu_stream_switch ("mla_secondary" ,
1123+                                        0 ,
1124+                                        enabled = enable_multistream_mla ):
1125+                     npu_wait_tensor (hidden_states_or_kv_c_normed ,
1126+                                     ckq ,
1127+                                     enabled = enable_multistream_mla )
11271128                    decode_k_pe , decode_k_nope , decode_kv  =  self .exec_kv (
1128-                         decode_hs , cos , sin , kv_cache , slots )
1129-                 else :
1130-                     with  npu_stream_switch ("mla_secondary" ,
1131-                                            0 ,
1132-                                            enabled = enable_multistream_mla ):
1133-                         npu_wait_tensor (hidden_states_or_kv_c_normed ,
1134-                                         ckq ,
1135-                                         enabled = enable_multistream_mla )
1136-                         decode_k_pe , decode_k_nope , decode_kv  =  self .exec_kv (
1137-                             hidden_states_or_kv_c_normed , cos , sin , kv_cache ,
1138-                             attn_metadata .slot_mapping )
1129+                         hidden_states_or_kv_c_normed , cos , sin , kv_cache ,
1130+                         attn_metadata .slot_mapping )
11391131                # Without explicitly controlling the order, IndexByTensor operations 
11401132                # would be placed after `matmul W_KV_T` hindering the overlapping of 
11411133                # KvRmsNormRopeCache and SingleRope. 
@@ -1159,8 +1151,6 @@ def forward(
11591151                                    decode_k_pe ,
11601152                                    enabled = enable_multistream_mla )
11611153                    decode_q_pe  =  self .rope_single (decode_q_pe , cos , sin )
1162-             elif  self .running_chunkprefilll_with_torchair :
1163-                 decode_q_pe  =  self .rope_single (decode_q_pe , cos , sin )
11641154            else :
11651155                decode_q_pe [...], decode_k_pe [...] =  self .rotary_emb (
11661156                    attn_metadata .decode .input_positions ,
@@ -1199,15 +1189,16 @@ def forward(
11991189            kv_cache 
12001190        ) >  1 , "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)" 
12011191        if  self .torchair_graph_enabled :
1202-             if  kv_cache [0 ].numel () >  0  and  has_prefill :
1192+             if  kv_cache [0 ].numel (
1193+             ) >  0  and  attn_metadata .attn_state  ==  AscendAttentionState .PrefillNoCache :
12031194                slots  =  attn_metadata .slot_mapping 
12041195                # NOTE: Separate the kv cache in advance to avoid OOM or other issues 
1205-                 torch_npu ._npu_reshape_and_cache (
1206-                     key = kv_c_normed . view ( num_tokens , self .num_kv_heads , - 1 ),
1207-                     value = prefill_k_pe ,
1208-                     key_cache = kv_cache [0 ],
1209-                     value_cache = kv_cache [1 ],
1210-                     slot_indices = slots [ num_decode_tokens :] )
1196+                 torch_npu ._npu_reshape_and_cache (key = kv_c_normed . view ( 
1197+                     num_tokens , self .num_kv_heads , - 1 ),
1198+                                                   value = prefill_k_pe ,
1199+                                                   key_cache = kv_cache [0 ],
1200+                                                   value_cache = kv_cache [1 ],
1201+                                                   slot_indices = slots )
12111202        else :
12121203            kv_c_normed  =  kv_c_normed .view (
12131204                [num_actual_toks , self .num_kv_heads , - 1 ])
0 commit comments