@@ -998,7 +998,7 @@ def _forward_decode(
998998 decode_meta = attn_metadata .decode
999999 assert decode_meta is not None
10001000 num_tokens = q_nope .size (0 )
1001- if self .running_in_graph :
1001+ if self .running_in_graph or self . running_chunkprefilll_with_torchair :
10021002 # shape of knope/k_pe for npu graph mode should be:
10031003 # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
10041004 block_size = kv_c_and_k_pe_cache [0 ].shape [1 ]
@@ -1112,6 +1112,7 @@ def forward(
11121112 self .running_in_graph = self .torchair_graph_enabled and attn_metadata .attn_state in [
11131113 AscendAttentionState .DecodeOnly , AscendAttentionState .SpecDecoding
11141114 ]
1115+ self .running_chunkprefilll_with_torchair = self .torchair_graph_enabled and attn_metadata .attn_state == AscendAttentionState .ChunkedPrefill
11151116 num_actual_toks = attn_metadata .num_actual_tokens
11161117 if k_pe is None and not self .running_in_graph :
11171118 kv_c , k_pe = self .kv_a_proj_with_mqa (
@@ -1148,18 +1149,25 @@ def forward(
11481149 if has_decode :
11491150 decode_k_nope = None
11501151 assert attn_metadata .decode is not None
1151- if self .running_in_graph :
1152+ if self .running_in_graph or self . running_chunkprefilll_with_torchair :
11521153 cos = attn_metadata .decode .cos
11531154 sin = attn_metadata .decode .sin
1154- with npu_stream_switch ("mla_secondary" ,
1155- 0 ,
1156- enabled = enable_multistream_mla ):
1157- npu_wait_tensor (hidden_states_or_kv_c_normed ,
1158- ckq ,
1159- enabled = enable_multistream_mla )
1155+ if self .running_chunkprefilll_with_torchair :
1156+ decode_hs = (
1157+ hidden_states_or_kv_c_normed [:num_decode_tokens ])
1158+ slots = attn_metadata .slot_mapping [:num_decode_tokens ]
11601159 decode_k_pe , decode_k_nope , decode_kv = self .exec_kv (
1161- hidden_states_or_kv_c_normed , cos , sin , kv_cache ,
1162- attn_metadata .slot_mapping )
1160+ decode_hs , cos , sin , kv_cache , slots )
1161+ else :
1162+ with npu_stream_switch ("mla_secondary" ,
1163+ 0 ,
1164+ enabled = enable_multistream_mla ):
1165+ npu_wait_tensor (hidden_states_or_kv_c_normed ,
1166+ ckq ,
1167+ enabled = enable_multistream_mla )
1168+ decode_k_pe , decode_k_nope , decode_kv = self .exec_kv (
1169+ hidden_states_or_kv_c_normed , cos , sin , kv_cache ,
1170+ attn_metadata .slot_mapping )
11631171 # Without explicitly controlling the order, IndexByTensor operations
11641172 # would be placed after `matmul W_KV_T` hindering the overlapping of
11651173 # KvRmsNormRopeCache and SingleRope.
@@ -1183,6 +1191,8 @@ def forward(
11831191 decode_k_pe ,
11841192 enabled = enable_multistream_mla )
11851193 decode_q_pe = self .rope_single (decode_q_pe , cos , sin )
1194+ elif self .running_chunkprefilll_with_torchair :
1195+ decode_q_pe = self .rope_single (decode_q_pe , cos , sin )
11861196 else :
11871197 decode_q_pe [...], decode_k_pe [...] = self .rotary_emb (
11881198 attn_metadata .decode .input_positions ,
@@ -1221,16 +1231,15 @@ def forward(
12211231 kv_cache
12221232 ) > 1 , "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
12231233 if self .torchair_graph_enabled :
1224- if kv_cache [0 ].numel (
1225- ) > 0 and attn_metadata .attn_state == AscendAttentionState .PrefillNoCache :
1234+ if kv_cache [0 ].numel () > 0 and has_prefill :
12261235 slots = attn_metadata .slot_mapping
12271236 # NOTE: Separate the kv cache in advance to avoid OOM or other issues
12281237 torch_npu ._npu_reshape_and_cache (key = kv_c_normed .view (
12291238 num_tokens , self .num_kv_heads , - 1 ),
12301239 value = prefill_k_pe ,
12311240 key_cache = kv_cache [0 ],
12321241 value_cache = kv_cache [1 ],
1233- slot_indices = slots )
1242+ slot_indices = slots [ num_decode_tokens :] )
12341243 else :
12351244 kv_c_normed = kv_c_normed .view (
12361245 [num_actual_toks , self .num_kv_heads , - 1 ])
0 commit comments