@@ -964,7 +964,7 @@ def _forward_decode(
964964        decode_meta  =  attn_metadata .decode 
965965        assert  decode_meta  is  not   None 
966966        num_tokens  =  q_nope .size (0 )
967-         if  self .running_in_graph :
967+         if  self .running_in_graph   or   self . running_chunkprefilll_with_torchair :
968968            # TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim] 
969969            if  attn_metadata .attn_state  ==  AscendAttentionState .SpecDecoding :
970970                assert  num_tokens  %  self .spec_token_num  ==  0 
@@ -1080,6 +1080,7 @@ def forward(
10801080        self .running_in_graph  =  self .torchair_graph_enabled  and  attn_metadata .attn_state  in  [
10811081            AscendAttentionState .DecodeOnly , AscendAttentionState .SpecDecoding 
10821082        ]
1083+         self .running_chunkprefilll_with_torchair  =  self .torchair_graph_enabled  and  attn_metadata .attn_state  ==  AscendAttentionState .ChunkedPrefill 
10831084        num_actual_toks  =  attn_metadata .num_actual_tokens 
10841085        if  k_pe  is  None  and  not  self .running_in_graph :
10851086            kv_c , k_pe  =  self .kv_a_proj_with_mqa (
@@ -1116,18 +1117,25 @@ def forward(
11161117        if  has_decode :
11171118            decode_k_nope  =  None 
11181119            assert  attn_metadata .decode  is  not   None 
1119-             if  self .running_in_graph :
1120+             if  self .running_in_graph   or   self . running_chunkprefilll_with_torchair :
11201121                cos  =  attn_metadata .decode .cos 
11211122                sin  =  attn_metadata .decode .sin 
1122-                 with  npu_stream_switch ("mla_secondary" ,
1123-                                        0 ,
1124-                                        enabled = enable_multistream_mla ):
1125-                     npu_wait_tensor (hidden_states_or_kv_c_normed ,
1126-                                     ckq ,
1127-                                     enabled = enable_multistream_mla )
1123+                 if  self .running_chunkprefilll_with_torchair :
1124+                     decode_hs  =  (
1125+                         hidden_states_or_kv_c_normed [:num_decode_tokens ])
1126+                     slots  =  attn_metadata .slot_mapping [:num_decode_tokens ]
11281127                    decode_k_pe , decode_k_nope , decode_kv  =  self .exec_kv (
1129-                         hidden_states_or_kv_c_normed , cos , sin , kv_cache ,
1130-                         attn_metadata .slot_mapping )
1128+                         decode_hs , cos , sin , kv_cache , slots )
1129+                 else :
1130+                     with  npu_stream_switch ("mla_secondary" ,
1131+                                            0 ,
1132+                                            enabled = enable_multistream_mla ):
1133+                         npu_wait_tensor (hidden_states_or_kv_c_normed ,
1134+                                         ckq ,
1135+                                         enabled = enable_multistream_mla )
1136+                         decode_k_pe , decode_k_nope , decode_kv  =  self .exec_kv (
1137+                             hidden_states_or_kv_c_normed , cos , sin , kv_cache ,
1138+                             attn_metadata .slot_mapping )
11311139                # Without explicitly controlling the order, IndexByTensor operations 
11321140                # would be placed after `matmul W_KV_T` hindering the overlapping of 
11331141                # KvRmsNormRopeCache and SingleRope. 
@@ -1151,6 +1159,8 @@ def forward(
11511159                                    decode_k_pe ,
11521160                                    enabled = enable_multistream_mla )
11531161                    decode_q_pe  =  self .rope_single (decode_q_pe , cos , sin )
1162+             elif  self .running_chunkprefilll_with_torchair :
1163+                 decode_q_pe  =  self .rope_single (decode_q_pe , cos , sin )
11541164            else :
11551165                decode_q_pe [...], decode_k_pe [...] =  self .rotary_emb (
11561166                    attn_metadata .decode .input_positions ,
@@ -1189,16 +1199,15 @@ def forward(
11891199            kv_cache 
11901200        ) >  1 , "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)" 
11911201        if  self .torchair_graph_enabled :
1192-             if  kv_cache [0 ].numel (
1193-             ) >  0  and  attn_metadata .attn_state  ==  AscendAttentionState .PrefillNoCache :
1202+             if  kv_cache [0 ].numel () >  0  and  has_prefill :
11941203                slots  =  attn_metadata .slot_mapping 
11951204                # NOTE: Separate the kv cache in advance to avoid OOM or other issues 
11961205                torch_npu ._npu_reshape_and_cache (key = kv_c_normed .view (
11971206                    num_tokens , self .num_kv_heads , - 1 ),
11981207                                                 value = prefill_k_pe ,
11991208                                                 key_cache = kv_cache [0 ],
12001209                                                 value_cache = kv_cache [1 ],
1201-                                                  slot_indices = slots )
1210+                                                  slot_indices = slots [ num_decode_tokens :] )
12021211        else :
12031212            kv_c_normed  =  kv_c_normed .view (
12041213                [num_actual_toks , self .num_kv_heads , - 1 ])
0 commit comments