@@ -964,7 +964,7 @@ def _forward_decode(
964964 decode_meta = attn_metadata .decode
965965 assert decode_meta is not None
966966 num_tokens = q_nope .size (0 )
967- if self .running_in_graph :
967+ if self .running_in_graph or self . running_chunkprefilll_with_torchair :
968968 # TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
969969 if attn_metadata .attn_state == AscendAttentionState .SpecDecoding :
970970 assert num_tokens % self .spec_token_num == 0
@@ -1080,6 +1080,7 @@ def forward(
10801080 self .running_in_graph = self .torchair_graph_enabled and attn_metadata .attn_state in [
10811081 AscendAttentionState .DecodeOnly , AscendAttentionState .SpecDecoding
10821082 ]
1083+ self .running_chunkprefilll_with_torchair = self .torchair_graph_enabled and attn_metadata .attn_state == AscendAttentionState .ChunkedPrefill
10831084 num_actual_toks = attn_metadata .num_actual_tokens
10841085 if k_pe is None and not self .running_in_graph :
10851086 kv_c , k_pe = self .kv_a_proj_with_mqa (
@@ -1116,18 +1117,25 @@ def forward(
11161117 if has_decode :
11171118 decode_k_nope = None
11181119 assert attn_metadata .decode is not None
1119- if self .running_in_graph :
1120+ if self .running_in_graph or self . running_chunkprefilll_with_torchair :
11201121 cos = attn_metadata .decode .cos
11211122 sin = attn_metadata .decode .sin
1122- with npu_stream_switch ("mla_secondary" ,
1123- 0 ,
1124- enabled = enable_multistream_mla ):
1125- npu_wait_tensor (hidden_states_or_kv_c_normed ,
1126- ckq ,
1127- enabled = enable_multistream_mla )
1123+ if self .running_chunkprefilll_with_torchair :
1124+ decode_hs = (
1125+ hidden_states_or_kv_c_normed [:num_decode_tokens ])
1126+ slots = attn_metadata .slot_mapping [:num_decode_tokens ]
11281127 decode_k_pe , decode_k_nope , decode_kv = self .exec_kv (
1129- hidden_states_or_kv_c_normed , cos , sin , kv_cache ,
1130- attn_metadata .slot_mapping )
1128+ decode_hs , cos , sin , kv_cache , slots )
1129+ else :
1130+ with npu_stream_switch ("mla_secondary" ,
1131+ 0 ,
1132+ enabled = enable_multistream_mla ):
1133+ npu_wait_tensor (hidden_states_or_kv_c_normed ,
1134+ ckq ,
1135+ enabled = enable_multistream_mla )
1136+ decode_k_pe , decode_k_nope , decode_kv = self .exec_kv (
1137+ hidden_states_or_kv_c_normed , cos , sin , kv_cache ,
1138+ attn_metadata .slot_mapping )
11311139 # Without explicitly controlling the order, IndexByTensor operations
11321140 # would be placed after `matmul W_KV_T` hindering the overlapping of
11331141 # KvRmsNormRopeCache and SingleRope.
@@ -1151,6 +1159,8 @@ def forward(
11511159 decode_k_pe ,
11521160 enabled = enable_multistream_mla )
11531161 decode_q_pe = self .rope_single (decode_q_pe , cos , sin )
1162+ elif self .running_chunkprefilll_with_torchair :
1163+ decode_q_pe = self .rope_single (decode_q_pe , cos , sin )
11541164 else :
11551165 decode_q_pe [...], decode_k_pe [...] = self .rotary_emb (
11561166 attn_metadata .decode .input_positions ,
@@ -1189,16 +1199,15 @@ def forward(
11891199 kv_cache
11901200 ) > 1 , "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
11911201 if self .torchair_graph_enabled :
1192- if kv_cache [0 ].numel (
1193- ) > 0 and attn_metadata .attn_state == AscendAttentionState .PrefillNoCache :
1202+ if kv_cache [0 ].numel () > 0 and has_prefill :
11941203 slots = attn_metadata .slot_mapping
11951204 # NOTE: Separate the kv cache in advance to avoid OOM or other issues
1196- torch_npu ._npu_reshape_and_cache (key = kv_c_normed . view (
1197- num_tokens , self .num_kv_heads , - 1 ),
1198- value = prefill_k_pe ,
1199- key_cache = kv_cache [0 ],
1200- value_cache = kv_cache [1 ],
1201- slot_indices = slots )
1205+ torch_npu ._npu_reshape_and_cache (
1206+ key = kv_c_normed . view ( num_tokens , self .num_kv_heads , - 1 ),
1207+ value = prefill_k_pe ,
1208+ key_cache = kv_cache [0 ],
1209+ value_cache = kv_cache [1 ],
1210+ slot_indices = slots [ num_decode_tokens :] )
12021211 else :
12031212 kv_c_normed = kv_c_normed .view (
12041213 [num_actual_toks , self .num_kv_heads , - 1 ])
0 commit comments