1+ from copy import deepcopy
12from typing import Any , List , Optional
23
34import numpy as np
@@ -95,6 +96,14 @@ def model_input_split_v1_mla_attn(
9596 seq_lens = attn_metadata .prefill .seq_lens if attn_metadata .num_prefills > 0 else attn_metadata .decode .seq_lens
9697 [seq_lens_pre , seq_lens_post ] = split_attn_tensor_type (seq_lens , seq_index )
9798
99+ query_start_loc_pre = attn_metadata .query_start_loc [:seq_index + 1 ]
100+ query_start_loc_post = deepcopy (
101+ attn_metadata .query_start_loc [seq_index :]
102+ ) - attn_metadata .query_start_loc [seq_index ]
103+ [block_table_pre ,
104+ block_table_post ] = split_attn_tensor_type (attn_metadata .block_tables ,
105+ seq_index )
106+
98107 if attn_metadata .attn_state == AscendAttentionState .PrefillNoCache or attn_metadata .attn_state == AscendAttentionState .PrefillCacheHit :
99108 # the attn_mla kernel in torch npu only accept 128*128 attn mask
100109 attn_mask_pre = attn_mask_post = attn_metadata .attn_mask
@@ -131,6 +140,18 @@ def model_input_split_v1_mla_attn(
131140 [prefill_query_lens_pre , prefill_query_lens_post
132141 ] = split_attn_tensor_type (attn_metadata .prefill .query_lens ,
133142 seq_index - attn_metadata .num_decodes )
143+ prefill_query_start_loc_pre = attn_metadata .prefill .query_start_loc [:
144+ seq_index
145+ +
146+ 1 -
147+ attn_metadata
148+ .
149+ num_decodes ]
150+ prefill_query_start_loc_post = deepcopy (
151+ attn_metadata .prefill .query_start_loc [seq_index -
152+ attn_metadata .num_decodes :]
153+ ) - attn_metadata .prefill .query_start_loc [seq_index -
154+ attn_metadata .num_decodes ]
134155 context_len_pre = seq_lens_pre [attn_metadata .num_decodes :]
135156 context_len_post = seq_lens_post
136157 prefill_max_query_len_pre = max (prefill_query_lens_pre )
@@ -139,6 +160,7 @@ def model_input_split_v1_mla_attn(
139160 attn_mask = attn_mask_pre ,
140161 query_lens = prefill_query_lens_pre ,
141162 seq_lens = seq_lens_pre ,
163+ query_start_loc = prefill_query_start_loc_pre ,
142164 input_positions = input_positions_pre ,
143165 context_lens = context_len_pre ,
144166 block_table = block_tables_pre ,
@@ -149,6 +171,7 @@ def model_input_split_v1_mla_attn(
149171 attn_mask = attn_mask_post ,
150172 query_lens = prefill_query_lens_post ,
151173 seq_lens = seq_lens_post ,
174+ query_start_loc = prefill_query_start_loc_post ,
152175 input_positions = input_positions_post ,
153176 context_lens = context_len_post ,
154177 block_table = block_tables_post ,
@@ -190,6 +213,9 @@ def model_input_split_v1_mla_attn(
190213 num_input_tokens = token_index ,
191214 head_dim = attn_metadata .head_dim ,
192215 slot_mapping = slot_mapping_pre ,
216+ seq_lens = seq_lens_pre ,
217+ query_start_loc = query_start_loc_pre ,
218+ block_tables = block_table_pre ,
193219 num_decodes = num_decodes_pre ,
194220 num_prefills = num_prefills_pre ,
195221 num_decode_tokens = num_decode_tokens_pre ,
@@ -203,6 +229,9 @@ def model_input_split_v1_mla_attn(
203229 num_input_tokens = attn_metadata .num_input_tokens - token_index ,
204230 head_dim = attn_metadata .head_dim ,
205231 slot_mapping = slot_mapping_post ,
232+ seq_lens = seq_lens_post ,
233+ query_start_loc = query_start_loc_post ,
234+ block_tables = block_table_post ,
206235 num_decodes = num_decodes_post ,
207236 num_prefills = num_prefills_post ,
208237 num_decode_tokens = num_decode_tokens_post ,
0 commit comments