@@ -49,14 +49,20 @@ def get_kv_cache_stride_order() -> tuple[int, ...]:
4949
5050
5151@dataclass  
52- class  DeepseekV32IndexerPrefillMetadata :
52+ class  DeepseekV32IndexerPrefillChunkMetadata :
5353    block_table : torch .Tensor 
54-     query_start_loc : torch .Tensor 
55-     max_query_len : int 
5654    cu_seqlen_ks : torch .Tensor 
5755    cu_seqlen_ke : torch .Tensor 
5856    cu_seq_lens : torch .Tensor 
5957    total_seq_lens : int 
58+     token_start : int 
59+     token_end : int 
60+     num_reqs : int 
61+ 
62+ 
63+ @dataclass  
64+ class  DeepseekV32IndexerPrefillMetadata :
65+     chunks : list [DeepseekV32IndexerPrefillChunkMetadata ]
6066
6167
6268@dataclass  
@@ -98,8 +104,8 @@ class DeepseekV32IndexerMetadata:
98104
99105# TODO (zyongye) optimize this, this is now vibe coded 
100106def  kv_spans_from_batches (
101-         start_seq_loc : torch .Tensor ,
102-         seq_len_per_batch : torch .Tensor ) ->  tuple [torch .Tensor , torch .Tensor ]:
107+         start_seq_loc : torch .Tensor ,  seq_len_per_batch :  torch . Tensor , 
108+         device : torch .device ) ->  tuple [torch .Tensor , torch .Tensor ]:
103109    """ 
104110    Args: 
105111      start_seq_loc: 1D long tensor [B+1], cumulative counts of  
@@ -122,15 +128,14 @@ def kv_spans_from_batches(
122128    are the **last** `counts[i]` positions of that sequence. 
123129    """ 
124130    q  =  start_seq_loc .to (dtype = torch .long )
125-     L  =  seq_len_per_batch .to (dtype = torch .long ,  device = q . device )
131+     L  =  seq_len_per_batch .to (dtype = torch .long )
126132    assert  q .dim () ==  1  and  L .dim () ==  1 
127133    assert  q .numel () ==  L .numel () +  1 , "start_seq_loc must have length B+1" 
128134
129135    # Selected tokens per batch and totals 
130136    counts  =  q [1 :] -  q [:- 1 ]  # [B] 
131137    N  =  int (q [- 1 ].item ())  # total selected tokens 
132138    B  =  L .numel ()
133-     device  =  L .device 
134139
135140    if  N  ==  0 :
136141        return  (torch .empty (0 , dtype = torch .long , device = device ),
@@ -140,8 +145,7 @@ def kv_spans_from_batches(
140145    kv_starts_per_batch  =  torch .cumsum (L , dim = 0 ) -  L   # [B] 
141146
142147    # For each selected token, which batch does it belong to? 
143-     batch_id  =  torch .repeat_interleave (torch .arange (B , device = device ),
144-                                        counts )  # [N] 
148+     batch_id  =  torch .repeat_interleave (torch .arange (B ), counts )  # [N] 
145149
146150    # Map batch KV start to each token 
147151    start_tensor  =  kv_starts_per_batch [batch_id ]  # [N] 
@@ -151,22 +155,51 @@ def kv_spans_from_batches(
151155    L_expand  =  torch .repeat_interleave (L , counts )  # [N] 
152156    m_expand  =  torch .repeat_interleave (counts , counts )  # [N] 
153157    # position within the selected block: 1..counts[b] 
154-     pos_within  =  (torch .arange (N , device = device ,  dtype = torch .long ) - 
158+     pos_within  =  (torch .arange (N , dtype = torch .long ) - 
155159                  torch .repeat_interleave (q [:- 1 ], counts ) +  1 )
156160
157161    local_pos  =  L_expand  -  m_expand  +  pos_within   # [N], 1-based 
158162    end_location  =  start_tensor  +  local_pos   # exclusive end 
159163
160-     return  start_tensor .int (), end_location .int ()
164+     return  start_tensor .int (). to ( device ) , end_location .int (). to ( device )
161165
162166
163167def  get_max_prefill_buffer_size (vllm_config : VllmConfig ):
164168    max_model_len  =  vllm_config .model_config .max_model_len 
165-     # max_num_batched_tokens = \ 
166-     #     vllm_config.scheduler_config.max_num_batched_tokens 
167-     max_num_seq  =  vllm_config .scheduler_config .max_num_seqs 
168-     # NOTE(Chen): an estimated max size of flattened_kv. Need to double check. 
169-     return  max_model_len  *  max_num_seq 
169+     # NOTE(Chen): 2 is a magic number for controlling the prefill buffer size. 
170+     # May be tuned later. 
171+     return  max_model_len  *  2 
172+ 
173+ 
174+ def  split_prefill_chunks (seq_lens_cpu : torch .Tensor ,
175+                          max_prefill_buffer_size : int ,
176+                          reqs_start : int ) ->  list [tuple [int , int ]]:
177+     """ 
178+     Split the prefill chunks into a list of tuples of (reqs_start, reqs_end) 
179+     such that the total sequence length of each chunk is less than the 
180+     maximum prefill buffer size. 
181+ 
182+     Args: 
183+         seq_lens_cpu: The sequence lengths of the prefill requests. 
184+         max_prefill_buffer_size: The maximum prefill buffer size. 
185+         reqs_start: The start index of the prefill requests. 
186+      
187+     Returns: 
188+         A list of tuples of (reqs_start, reqs_end). 
189+     """ 
190+     chunk_seq_ids  =  []
191+     total_seq_lens  =  0 
192+     for  i  in  range (reqs_start , len (seq_lens_cpu )):
193+         cur_seq_len  =  seq_lens_cpu [i ].item ()
194+         assert  cur_seq_len  <=  max_prefill_buffer_size 
195+         total_seq_lens  +=  cur_seq_len 
196+         if  total_seq_lens  >  max_prefill_buffer_size :
197+             chunk_seq_ids .append ((reqs_start , i ))
198+             reqs_start  =  i 
199+             total_seq_lens  =  cur_seq_len 
200+     if  total_seq_lens  >  0 :
201+         chunk_seq_ids .append ((reqs_start , len (seq_lens_cpu )))
202+     return  chunk_seq_ids 
170203
171204
172205class  DeepseekV32IndexerMetadataBuilder (AttentionMetadataBuilder ):
@@ -201,6 +234,33 @@ def __init__(self, *args, **kwargs):
201234                                                     dtype = torch .int32 ,
202235                                                     device = self .device )
203236
237+     def  build_one_prefill_chunk (self , reqs_start , reqs_end ,
238+                                 query_start_loc_cpu , seq_lens_cpu ,
239+                                 block_table ):
240+         prefill_query_start_loc  =  query_start_loc_cpu [
241+             reqs_start :reqs_end  +  1 ] -  query_start_loc_cpu [reqs_start ]
242+         cu_seqlen_ks , cu_seqlen_ke  =  kv_spans_from_batches (
243+             prefill_query_start_loc , seq_lens_cpu [reqs_start :reqs_end ],
244+             self .device )
245+         token_start  =  query_start_loc_cpu [reqs_start ].item ()
246+         token_end  =  query_start_loc_cpu [reqs_end ].item ()
247+         total_seq_lens  =  seq_lens_cpu [reqs_start :reqs_end ].sum ()
248+         assert  total_seq_lens  <=  self .max_prefill_buffer_size 
249+         cu_seq_lens  =  torch .cat ([
250+             torch .zeros (1 , dtype = torch .int32 ),
251+             seq_lens_cpu [reqs_start :reqs_end ].cumsum (dim = 0 )
252+         ]).to (torch .int32 ).to (self .device )
253+         return  DeepseekV32IndexerPrefillChunkMetadata (
254+             cu_seqlen_ks = cu_seqlen_ks ,
255+             cu_seqlen_ke = cu_seqlen_ke ,
256+             cu_seq_lens = cu_seq_lens ,
257+             total_seq_lens = total_seq_lens ,
258+             block_table = block_table [reqs_start :reqs_end ],
259+             token_start = token_start ,
260+             token_end = token_end ,
261+             num_reqs = reqs_end  -  reqs_start ,
262+         )
263+ 
204264    def  build (self ,
205265              common_prefix_len : int ,
206266              common_attn_metadata : CommonAttentionMetadata ,
@@ -209,11 +269,7 @@ def build(self,
209269        num_reqs  =  common_attn_metadata .num_reqs 
210270        num_tokens  =  common_attn_metadata .num_actual_tokens 
211271
212-         device  =  self .device 
213-         block_table_tensor  =  common_attn_metadata .block_table_tensor 
214- 
215-         query_start_loc  =  common_attn_metadata .query_start_loc 
216- 
272+         query_start_loc_cpu  =  common_attn_metadata .query_start_loc_cpu 
217273        num_decodes , num_prefills , num_decode_tokens , num_prefill_tokens  =  \
218274            split_decodes_and_prefills (
219275                common_attn_metadata ,
@@ -224,27 +280,20 @@ def build(self,
224280
225281        prefill_metadata  =  None 
226282        if  num_prefills  >  0 :
227-             reqs_start  =  num_decodes 
228-             prefill_query_start_loc  =  query_start_loc [
229-                 reqs_start :] -  query_start_loc [reqs_start ]
230-             cu_seqlen_ks , cu_seqlen_ke  =  kv_spans_from_batches (
231-                 prefill_query_start_loc ,
232-                 common_attn_metadata .seq_lens [reqs_start :])
233-             total_seq_lens  =  common_attn_metadata .seq_lens [reqs_start :].sum ()
234-             assert  total_seq_lens  <  self .max_prefill_buffer_size 
235-             cu_seq_lens  =  torch .cat ([
236-                 torch .zeros (1 , dtype = torch .int32 , device = device ),
237-                 common_attn_metadata .seq_lens [reqs_start :].cumsum (dim = 0 )
238-             ]).to (torch .int32 ).cuda ()
239-             prefill_metadata  =  DeepseekV32IndexerPrefillMetadata (
240-                 block_table = block_table_tensor [reqs_start :, ...],
241-                 query_start_loc = prefill_query_start_loc ,
242-                 max_query_len = common_attn_metadata .max_query_len ,
243-                 cu_seqlen_ks = cu_seqlen_ks ,
244-                 cu_seqlen_ke = cu_seqlen_ke ,
245-                 cu_seq_lens = cu_seq_lens ,
246-                 total_seq_lens = total_seq_lens ,
283+             chunk_seq_ids  =  split_prefill_chunks (
284+                 common_attn_metadata .seq_lens_cpu ,
285+                 self .max_prefill_buffer_size ,
286+                 num_decodes ,
247287            )
288+             chunks  =  [
289+                 self .build_one_prefill_chunk (
290+                     reqs_start , reqs_end , query_start_loc_cpu ,
291+                     common_attn_metadata .seq_lens_cpu ,
292+                     common_attn_metadata .block_table_tensor )
293+                 for  reqs_start , reqs_end  in  chunk_seq_ids 
294+             ]
295+             prefill_metadata  =  DeepseekV32IndexerPrefillMetadata (
296+                 chunks = chunks , )
248297
249298        decode_metadata  =  None 
250299        if  num_decodes  >  0 :
0 commit comments