@@ -961,7 +961,7 @@ def _forward_decode(
961961 decode_meta = attn_metadata .decode
962962 assert decode_meta is not None
963963 num_tokens = q_nope .size (0 )
964- if self .running_in_graph :
964+ if self .running_in_graph or self . running_chunkprefilll_with_torchair :
965965 # TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
966966 if attn_metadata .attn_state == AscendAttentionState .SpecDecoding :
967967 assert num_tokens % self .spec_token_num == 0
@@ -1077,6 +1077,7 @@ def forward(
10771077 self .running_in_graph = self .torchair_graph_enabled and attn_metadata .attn_state in [
10781078 AscendAttentionState .DecodeOnly , AscendAttentionState .SpecDecoding
10791079 ]
1080+ self .running_chunkprefilll_with_torchair = self .torchair_graph_enabled and attn_metadata .attn_state == AscendAttentionState .ChunkedPrefill
10801081 num_actual_toks = attn_metadata .num_actual_tokens
10811082 if k_pe is None and not self .running_in_graph :
10821083 if not self .torchair_graph_enabled :
@@ -1096,13 +1097,13 @@ def forward(
10961097 # Inputs and outputs may be padded for CUDA graphs
10971098 output_padded = output
10981099 output = output [:num_actual_toks , ...]
1099- if not self .torchair_graph_enabled :
1100+ if not self .torchair_graph_enabled or self . running_chunkprefilll_with_torchair :
11001101 kv_c_normed = kv_c_normed [:num_actual_toks , ...]
11011102 prefill_k_c_normed = kv_c_normed [num_decode_tokens :]
11021103 if not self .running_in_graph :
11031104 hidden_states_or_q_c = hidden_states_or_q_c [:num_actual_toks , ...]
11041105 prefill_hs_or_q_c = hidden_states_or_q_c [num_decode_tokens :]
1105- if not self .torchair_graph_enabled :
1106+ if not self .torchair_graph_enabled or self . running_chunkprefilll_with_torchair :
11061107 decode_hs_or_q_c = hidden_states_or_q_c [:num_decode_tokens ]
11071108 k_pe = k_pe [:num_actual_toks , ...]
11081109 k_pe = k_pe .unsqueeze (1 )
@@ -1113,18 +1114,25 @@ def forward(
11131114 if has_decode :
11141115 decode_k_nope = None
11151116 assert attn_metadata .decode is not None
1116- if self .running_in_graph :
1117+ if self .running_in_graph or self . running_chunkprefilll_with_torchair :
11171118 cos = attn_metadata .decode .cos
11181119 sin = attn_metadata .decode .sin
1119- with npu_stream_switch ("mla_secondary" ,
1120- 0 ,
1121- enabled = enable_multistream_mla ):
1122- npu_wait_tensor (hidden_states_or_kv_c_normed ,
1123- ckq ,
1124- enabled = enable_multistream_mla )
1120+ if self .running_chunkprefilll_with_torchair :
1121+ decode_hs = (
1122+ hidden_states_or_kv_c_normed [:num_decode_tokens ])
1123+ slots = attn_metadata .slot_mapping [:num_decode_tokens ]
11251124 decode_k_pe , decode_k_nope , decode_kv = self .exec_kv (
1126- hidden_states_or_kv_c_normed , cos , sin , kv_cache ,
1127- attn_metadata .slot_mapping )
1125+ decode_hs , cos , sin , kv_cache , slots )
1126+ else :
1127+ with npu_stream_switch ("mla_secondary" ,
1128+ 0 ,
1129+ enabled = enable_multistream_mla ):
1130+ npu_wait_tensor (hidden_states_or_kv_c_normed ,
1131+ ckq ,
1132+ enabled = enable_multistream_mla )
1133+ decode_k_pe , decode_k_nope , decode_kv = self .exec_kv (
1134+ hidden_states_or_kv_c_normed , cos , sin , kv_cache ,
1135+ attn_metadata .slot_mapping )
11281136 # Without explicitly controlling the order, IndexByTensor operations
11291137 # would be placed after `matmul W_KV_T` hindering the overlapping of
11301138 # KvRmsNormRopeCache and SingleRope.
@@ -1148,6 +1156,8 @@ def forward(
11481156 decode_k_pe ,
11491157 enabled = enable_multistream_mla )
11501158 decode_q_pe = self .rope_single (decode_q_pe , cos , sin )
1159+ elif self .running_chunkprefilll_with_torchair :
1160+ decode_q_pe = self .rope_single (decode_q_pe , cos , sin )
11511161 else :
11521162 decode_q_pe [...], decode_k_pe [...] = self .rotary_emb (
11531163 attn_metadata .decode .input_positions ,
@@ -1166,9 +1176,10 @@ def forward(
11661176 sin = attn_metadata .prefill .sin
11671177
11681178 prefill_q_pe = self .rope_single (prefill_q_pe , cos , sin )
1179+ prefill_hs = hidden_states_or_kv_c_normed [num_decode_tokens :]
11691180 prefill_k_pe , prefill_k_nope = self .exec_kv_prefill (
1170- hidden_states_or_kv_c_normed , cos , sin , kv_cache ,
1171- attn_metadata .slot_mapping )
1181+ prefill_hs , cos , sin , kv_cache ,
1182+ attn_metadata .slot_mapping [ num_decode_tokens :] )
11721183
11731184 kv_c_normed = prefill_k_nope [:num_actual_toks , ...]
11741185 prefill_k_c_normed = prefill_k_nope [num_decode_tokens :]
@@ -1186,9 +1197,8 @@ def forward(
11861197 kv_cache
11871198 ) > 1 , "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
11881199 if self .torchair_graph_enabled :
1189- if kv_cache [0 ].numel (
1190- ) > 0 and attn_metadata .attn_state == AscendAttentionState .PrefillNoCache :
1191- slots = attn_metadata .slot_mapping
1200+ if kv_cache [0 ].numel () > 0 and has_prefill :
1201+ slots = attn_metadata .slot_mapping [num_decode_tokens :]
11921202 # NOTE: Separate the kv cache in advance to avoid OOM or other issues
11931203 torch_npu ._npu_reshape_and_cache (key = kv_c_normed .view (
11941204 num_tokens , self .num_kv_heads , - 1 ),
0 commit comments