@@ -679,21 +679,28 @@ def _compute_prefill_context(
679679 q_nope = query [..., :self .qk_nope_head_dim ]
680680
681681 seq_len1 = torch .tensor (prefill_metadata .query_lens , dtype = torch .int32 )
682- latent_kv_dim = kv_c_and_k_pe_cache .size (3 ) - rope_dim
683- cache_kv_c = kv_c_and_k_pe_cache [:, :, :, :latent_kv_dim ]
684- cache_k_pe = kv_c_and_k_pe_cache [:, :, :, latent_kv_dim :]
682+ if not self .running_chunkprefilll_with_torchair :
683+ latent_kv_dim = kv_c_and_k_pe_cache .size (3 ) - rope_dim
684+ cache_kv_c = kv_c_and_k_pe_cache [:, :, :, :latent_kv_dim ]
685+ cache_k_pe = kv_c_and_k_pe_cache [:, :, :, latent_kv_dim :]
686+ num_heads = kv_c_and_k_pe_cache .size (2 )
687+ else :
688+ latent_kv_dim = kv_c_and_k_pe_cache [0 ].size (- 1 )
689+ cache_kv_c = kv_c_and_k_pe_cache [0 ]
690+ cache_k_pe = kv_c_and_k_pe_cache [1 ]
691+ num_heads = cache_k_pe .size (2 )
685692 for i in range (iters ):
686693 toks = prefill_metadata .chunked_context .seq_tot [i ]
687694
688695 seq_len2 = prefill_metadata .chunked_context .chunk_seq_lens [i ]
689696 seq_len = torch .stack ([seq_len1 , seq_len2 ])
690697 kv_c_normed = torch .empty (toks ,
691- kv_c_and_k_pe_cache . size ( 2 ) ,
698+ num_heads ,
692699 latent_kv_dim ,
693700 dtype = query .dtype ,
694701 device = query .device )
695702 k_pe = torch .empty (toks ,
696- kv_c_and_k_pe_cache . size ( 2 ) ,
703+ num_heads ,
697704 rope_dim ,
698705 dtype = query .dtype ,
699706 device = query .device )
@@ -952,7 +959,7 @@ def _forward_decode(
952959 [num_tokens , self .num_heads , self .kv_lora_rank ],
953960 dtype = q .dtype ,
954961 device = q .device )
955- if self .running_in_graph :
962+ if self .running_in_graph or self . running_chunkprefilll_with_torchair :
956963 # TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
957964 if attn_metadata .attn_state == AscendAttentionState .SpecDecoding :
958965 assert num_tokens % self .spec_token_num == 0
@@ -1049,13 +1056,13 @@ def forward(
10491056 self .running_in_graph = self .torchair_graph_enabled and attn_metadata .attn_state in [
10501057 AscendAttentionState .DecodeOnly , AscendAttentionState .SpecDecoding
10511058 ]
1059+ self .running_chunkprefilll_with_torchair = self .torchair_graph_enabled and attn_metadata .attn_state == AscendAttentionState .ChunkedPrefill
10521060 num_actual_toks = attn_metadata .num_actual_tokens
10531061 if k_pe is None and not self .running_in_graph :
1054- if not self .torchair_graph_enabled :
1055- kv_c , k_pe = self .kv_a_proj_with_mqa (
1056- hidden_states_or_kv_c_normed )[0 ].split (
1057- [self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1 )
1058- kv_c_normed = self .kv_a_layernorm (kv_c .contiguous ())
1062+ kv_c , k_pe = self .kv_a_proj_with_mqa (
1063+ hidden_states_or_kv_c_normed )[0 ].split (
1064+ [self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1 )
1065+ kv_c_normed = self .kv_a_layernorm (kv_c .contiguous ())
10591066 else :
10601067 kv_c_normed = hidden_states_or_kv_c_normed
10611068 assert attn_metadata .num_decodes is not None and \
@@ -1068,24 +1075,23 @@ def forward(
10681075 # Inputs and outputs may be padded for CUDA graphs
10691076 output_padded = output
10701077 output = output [:num_actual_toks , ...]
1071- if not self .torchair_graph_enabled :
1072- kv_c_normed = kv_c_normed [:num_actual_toks , ...]
1073- prefill_k_c_normed = kv_c_normed [num_decode_tokens :]
1078+ kv_c_normed = kv_c_normed [:num_actual_toks , ...]
1079+ prefill_k_c_normed = kv_c_normed [num_decode_tokens :]
10741080 if not self .running_in_graph :
10751081 hidden_states_or_q_c = hidden_states_or_q_c [:num_actual_toks , ...]
10761082 prefill_hs_or_q_c = hidden_states_or_q_c [num_decode_tokens :]
1077- if not self . torchair_graph_enabled :
1078- decode_hs_or_q_c = hidden_states_or_q_c [: num_decode_tokens ]
1079- k_pe = k_pe [:num_actual_toks , ...]
1080- k_pe = k_pe .unsqueeze (1 )
1081- decode_k_pe = k_pe [:num_decode_tokens ]
1082- prefill_k_pe = k_pe [num_decode_tokens :]
1083+ decode_hs_or_q_c = hidden_states_or_q_c [: num_decode_tokens ]
1084+ prefill_hs = hidden_states_or_kv_c_normed [ num_decode_tokens : ]
1085+ k_pe = k_pe [:num_actual_toks , ...]
1086+ k_pe = k_pe .unsqueeze (1 )
1087+ decode_k_pe = k_pe [:num_decode_tokens ]
1088+ prefill_k_pe = k_pe [num_decode_tokens :]
10831089 else :
10841090 decode_hs_or_q_c = hidden_states_or_q_c
10851091 if has_decode :
10861092 decode_k_nope = None
10871093 assert attn_metadata .decode is not None
1088- if self .running_in_graph :
1094+ if self .running_in_graph or self . running_chunkprefilll_with_torchair :
10891095 seq_len = self .rotary_emb .max_position_embeddings * self .rotary_emb .scaling_factor
10901096 cos = self .rotary_emb .cos_cached [:seq_len ].to (
10911097 dtype = decode_hs_or_q_c .dtype )
@@ -1095,15 +1101,23 @@ def forward(
10951101 sin = sin [attn_metadata .decode .input_positions ]
10961102 cos = cos [:, None , None , :]
10971103 sin = sin [:, None , None , :]
1098- with npu_stream_switch ("mla_secondary" ,
1099- 0 ,
1100- enabled = enable_multistream_mla ):
1101- npu_wait_tensor (hidden_states_or_kv_c_normed ,
1102- ckq ,
1103- enabled = enable_multistream_mla )
1104+ slots = attn_metadata .slot_mapping
1105+ if self .running_chunkprefilll_with_torchair :
1106+ decode_hs = (
1107+ hidden_states_or_kv_c_normed [:num_decode_tokens ])
1108+ slots = attn_metadata .slot_mapping [:num_decode_tokens ]
11041109 decode_k_pe , decode_k_nope , decode_kv = self .exec_kv (
1105- hidden_states_or_kv_c_normed , cos , sin , kv_cache ,
1106- attn_metadata .slot_mapping )
1110+ decode_hs , cos , sin , kv_cache , slots )
1111+ else :
1112+ with npu_stream_switch ("mla_secondary" ,
1113+ 0 ,
1114+ enabled = enable_multistream_mla ):
1115+ npu_wait_tensor (hidden_states_or_kv_c_normed ,
1116+ ckq ,
1117+ enabled = enable_multistream_mla )
1118+ decode_k_pe , decode_k_nope , decode_kv = self .exec_kv (
1119+ hidden_states_or_kv_c_normed , cos , sin , kv_cache ,
1120+ slots )
11071121 # Without explicitly controlling the order, IndexByTensor operations
11081122 # would be placed after `matmul W_KV_T` hindering the overlapping of
11091123 # KvRmsNormRopeCache and SingleRope.
@@ -1127,6 +1141,8 @@ def forward(
11271141 decode_k_pe ,
11281142 enabled = enable_multistream_mla )
11291143 decode_q_pe = self .rope_single (decode_q_pe , cos , sin )
1144+ elif self .running_chunkprefilll_with_torchair :
1145+ decode_q_pe = self .rope_single (decode_q_pe , cos , sin )
11301146 else :
11311147 decode_q_pe [...], decode_k_pe [...] = self .rotary_emb (
11321148 attn_metadata .decode .input_positions ,
@@ -1153,11 +1169,11 @@ def forward(
11531169
11541170 prefill_q_pe = self .rope_single (prefill_q_pe , cos , sin )
11551171 prefill_k_pe , prefill_k_nope = self .exec_kv_prefill (
1156- hidden_states_or_kv_c_normed , cos , sin , kv_cache ,
1157- attn_metadata .slot_mapping )
1172+ prefill_hs , cos , sin , kv_cache ,
1173+ attn_metadata .slot_mapping [ num_decode_tokens :] )
11581174
11591175 kv_c_normed = prefill_k_nope [:num_actual_toks , ...]
1160- prefill_k_c_normed = prefill_k_nope [ num_decode_tokens :]
1176+ prefill_k_c_normed = prefill_k_nope
11611177 prefill_k_pe = prefill_k_pe .view (num_tokens , self .num_kv_heads ,
11621178 - 1 )
11631179 prefill_q = torch .cat ([prefill_q_nope , prefill_q_pe ], dim = - 1 )
@@ -1168,9 +1184,8 @@ def forward(
11681184 prefill_k_pe ,
11691185 max_seq_len = attn_metadata .prefill .max_seq_lens )
11701186 if self .torchair_graph_enabled :
1171- if len (kv_cache ) > 0 and kv_cache [0 ].numel (
1172- ) > 0 and attn_metadata .attn_state == AscendAttentionState .PrefillNoCache :
1173- slots = attn_metadata .slot_mapping
1187+ if len (kv_cache ) > 0 and kv_cache [0 ].numel () > 0 and has_prefill :
1188+ slots = attn_metadata .slot_mapping [num_decode_tokens :]
11741189 # NOTE: Separate the kv cache in advance to avoid OOM or other issues
11751190 torch_npu ._npu_reshape_and_cache (key = kv_c_normed .view (
11761191 num_tokens , self .num_kv_heads , - 1 ),
0 commit comments