1414from vllm .config import VllmConfig
1515from vllm .logger import init_logger
1616from vllm .v1 .attention .backends .utils import (AttentionMetadataBuilder ,
17- CommonAttentionMetadata )
18- from vllm . v1 . core . sched . output import SchedulerOutput
17+ CommonAttentionMetadata ,
18+ split_decodes_and_prefills )
1919from vllm .v1 .kv_cache_interface import AttentionSpec
20- from vllm .v1 .worker .gpu_input_batch import InputBatch
2120
2221try :
2322 import intel_extension_for_pytorch .llm .modules as ipex_modules
@@ -102,16 +101,16 @@ class TorchSDPAMetadata(AttentionMetadata):
102101 """Metadata for PagedAttention."""
103102 # (batch_size,). The length of sequences (entire tokens seen so far) per
104103 # sequence.
105- seq_lens_tensor : Optional [torch .Tensor ]
104+ decode_seq_lens_tensor : Optional [torch .Tensor ]
106105 # Maximum sequence length in the batch. 0 if it is prefill-only batch.
107- max_decode_seq_len : int
106+ decode_max_seq_len : int
108107 # (batch_size, max_blocks_per_seq).
109108 # Block addresses per sequence. (Seq id -> list of physical block)
110109 # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
111110 # in the kv cache. Each block can contain up to block_size tokens.
112111 # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
113112 # captured.
114- block_tables : Optional [torch .Tensor ]
113+ decode_block_tables : Optional [torch .Tensor ]
115114 """Metadata for TorchSDPABackend.
116115 """
117116 # Currently, input sequences can only contain all prompts
@@ -121,9 +120,9 @@ class TorchSDPAMetadata(AttentionMetadata):
121120
122121 # For chunked prefill only
123122 max_query_len : Optional [int ] = None
124- max_kv_len : Optional [int ] = None
123+ prefill_max_seq_len : Optional [int ] = None
125124 prefill_query_start_loc : Optional [torch .Tensor ] = None
126- kv_start_loc : Optional [torch .Tensor ] = None
125+ prefill_seq_start_loc : Optional [torch .Tensor ] = None
127126 prefill_block_tables : Optional [torch .Tensor ] = None
128127
129128 # For V1 logits index only
@@ -307,8 +306,8 @@ def get_seq_len_block_table_args(
307306 or attn_type == AttentionType .ENCODER_ONLY ):
308307 # Decoder self-attention
309308 # Choose max_seq_len based on whether we are in prompt_run
310- return (self .seq_lens_tensor , self .max_decode_seq_len ,
311- self .block_tables )
309+ return (self .decode_seq_lens_tensor , self .decode_max_seq_len ,
310+ self .decode_block_tables )
312311 elif attn_type == AttentionType .ENCODER_DECODER :
313312 # Enc/dec cross-attention KVs match encoder sequence length;
314313 # cross-attention utilizes special "cross" block tables
@@ -323,19 +322,14 @@ def get_seq_len_block_table_args(
323322
324323
325324class TorchSDPAMetadataBuilderV1 (AttentionMetadataBuilder [TorchSDPAMetadata ]):
325+ reorder_batch_threshold : int = 1
326326
327327 def __init__ (self , kv_cache_spec : AttentionSpec , layer_names : list [str ],
328328 vllm_config : VllmConfig , device : torch .device ) -> None :
329329 super ().__init__ (kv_cache_spec , layer_names , vllm_config , device )
330330
331331 self .scheduler_config = vllm_config .scheduler_config
332-
333- # For reorder
334- self .reorder_prompt_req_index_list = np .empty (
335- vllm_config .scheduler_config .max_num_seqs , dtype = np .int64 )
336- self .reorder_decode_req_index_list = np .empty (
337- vllm_config .scheduler_config .max_num_seqs , dtype = np .int64 )
338- self .num_prompt_req : int = 0
332+ self ._init_reorder_batch_threshold (1 , False )
339333
340334 self .seq_start_loc_cpu = torch .zeros (
341335 vllm_config .scheduler_config .max_num_seqs + 1 ,
@@ -344,50 +338,6 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
344338 )
345339 self .seq_start_loc_np = self .seq_start_loc_cpu .numpy ()
346340
347- def reorder_batch (self , input_batch : InputBatch ,
348- scheduler_output : SchedulerOutput ) -> bool :
349- prompt_list_idx = 0
350- decode_list_idx = 0
351- for req_index in range (input_batch .num_reqs ):
352- if input_batch .num_computed_tokens_cpu [
353- req_index ] < input_batch .num_prompt_tokens [req_index ]:
354- # prompt stage
355- self .reorder_prompt_req_index_list [prompt_list_idx ] = req_index
356- prompt_list_idx += 1
357- else :
358- # decode stage
359- self .reorder_decode_req_index_list [decode_list_idx ] = req_index
360- decode_list_idx += 1
361- assert decode_list_idx + prompt_list_idx == input_batch .num_reqs
362-
363- # Update prompt requests number
364- self .num_prompt_req = prompt_list_idx
365-
366- reorder_req_num = 0
367- for req_index in range (decode_list_idx ):
368- if self .reorder_decode_req_index_list [req_index ] < prompt_list_idx :
369- reorder_req_num += 1
370- else :
371- break
372-
373- if reorder_req_num == 0 :
374- return False
375-
376- reorder_prompt_list = (
377- self .reorder_prompt_req_index_list [:prompt_list_idx ]
378- [- reorder_req_num :])
379- reorder_decode_list = (
380- self .reorder_decode_req_index_list [:decode_list_idx ]
381- [:reorder_req_num ])
382- assert reorder_decode_list .size == reorder_prompt_list .size
383-
384- for idx in range (reorder_req_num ):
385- prompt_req_index = reorder_prompt_list [idx ].item ()
386- decode_req_index = reorder_decode_list [idx ].item ()
387- input_batch .swap_states (prompt_req_index , decode_req_index )
388-
389- return True
390-
391341 def build (self ,
392342 common_prefix_len : int ,
393343 common_attn_metadata : CommonAttentionMetadata ,
@@ -397,41 +347,46 @@ def build(self,
397347
398348 seq_lens_cpu = common_attn_metadata .seq_lens_cpu
399349 seq_lens_np = seq_lens_cpu .numpy ()
400- num_prompt_req = self .num_prompt_req
401- max_prefill_seq_len = seq_lens_np [:num_prompt_req ].max ().item (
402- ) if num_prompt_req > 0 else 0
403- max_decode_seq_len = seq_lens_np [num_prompt_req :num_reqs ].max ().item (
404- ) if num_prompt_req < num_reqs else 0
405- self .seq_start_loc_np [0 ] = 0
406- np .cumsum (seq_lens_np , out = self .seq_start_loc_np [1 :num_reqs + 1 ])
407350
408351 query_start_loc_cpu = common_attn_metadata .query_start_loc_cpu
409- num_prefill_tokens = int (query_start_loc_cpu [num_prompt_req ].item ())
410- num_decode_tokens = int (query_start_loc_cpu [num_reqs ].item () -
411- num_prefill_tokens )
352+ query_start_loc_np = query_start_loc_cpu .numpy ()
353+
354+ num_decodes , num_prefills , num_decode_tokens , num_prefill_tokens = \
355+ split_decodes_and_prefills (common_attn_metadata ,
356+ decode_threshold = self .reorder_batch_threshold ,
357+ require_uniform = True )
358+
359+ max_prefill_seq_len = seq_lens_np [num_decodes :num_reqs ].max ().item (
360+ ) if num_prefills > 0 else 0
361+ max_decode_seq_len = seq_lens_np [:num_decodes ].max ().item (
362+ ) if num_prefills < num_reqs else 0
363+ self .seq_start_loc_np [0 ] = 0
364+ np .cumsum (seq_lens_np , out = self .seq_start_loc_np [1 :num_reqs + 1 ])
412365
413366 slot_mapping = common_attn_metadata .slot_mapping .long ()
414367 block_table_tensor = common_attn_metadata .block_table_tensor
368+ query_start_loc_np = query_start_loc_cpu .numpy ()
369+ query_start_loc_np [num_decodes :num_reqs + 1 ] -= num_decode_tokens
415370
416371 attn_metadata = TorchSDPAMetadata (
417- num_prefills = num_prompt_req ,
372+ num_prefills = num_prefills ,
418373 num_prefill_tokens = num_prefill_tokens ,
419374 num_decode_tokens = num_decode_tokens ,
420375 slot_mapping = slot_mapping ,
421376 # to ensure inference when chunked_prefill is disabled
422377 seq_lens = seq_lens_cpu .tolist (),
423- seq_lens_tensor = seq_lens_cpu [num_prompt_req : num_reqs ], # decode
424- max_decode_seq_len = max_decode_seq_len , # decode
425- block_tables = block_table_tensor [num_prompt_req : num_reqs ], # decode
378+ decode_seq_lens_tensor = seq_lens_cpu [: num_decodes ], # decode
379+ decode_max_seq_len = max_decode_seq_len , # decode
380+ decode_block_tables = block_table_tensor [: num_decodes ], # decode
426381 chunked_prefill = self .scheduler_config .chunked_prefill_enabled ,
427382 max_query_len = max_query_len ,
428- max_kv_len = max_prefill_seq_len ,
429- prefill_query_start_loc = query_start_loc_cpu [: num_prompt_req +
383+ prefill_max_seq_len = max_prefill_seq_len ,
384+ prefill_query_start_loc = query_start_loc_cpu [num_decodes : num_reqs +
430385 1 ], # prefill
431- kv_start_loc = self .seq_start_loc_cpu [: num_prompt_req +
432- 1 ], # prefill
433- prefill_block_tables = block_table_tensor [:
434- num_prompt_req ], # prefill
386+ prefill_seq_start_loc = self .seq_start_loc_cpu [num_decodes : num_reqs +
387+ 1 ], # prefill
388+ prefill_block_tables = block_table_tensor [
389+ num_decodes : num_reqs ], # prefill
435390 query_start_loc = query_start_loc_cpu [:num_reqs +
436391 1 ], # for logits index
437392 )
@@ -596,14 +551,14 @@ def forward(
596551 import intel_extension_for_pytorch .llm .modules as ipex_modules
597552 output = torch .empty_like (query )
598553 ipex_modules .PagedAttention .flash_attn_varlen_func (
599- output [: prefill_meta .num_prefill_tokens , :, :],
600- query [: prefill_meta .num_prefill_tokens , :, :],
554+ output [prefill_meta .num_decode_tokens : , :, :],
555+ query [prefill_meta .num_decode_tokens : , :, :],
601556 key_cache ,
602557 value_cache ,
603558 prefill_meta .prefill_query_start_loc ,
604- prefill_meta .kv_start_loc ,
559+ prefill_meta .prefill_seq_start_loc ,
605560 prefill_meta .max_query_len ,
606- prefill_meta .max_kv_len ,
561+ prefill_meta .prefill_max_seq_len ,
607562 self .scale ,
608563 True ,
609564 prefill_meta .prefill_block_tables ,
@@ -621,8 +576,8 @@ def forward(
621576 ) = decode_meta .get_seq_len_block_table_args (attn_type )
622577
623578 self .paged_attn_impl .forward_decode (
624- output [attn_metadata .num_prefill_tokens : , :, :],
625- query [attn_metadata .num_prefill_tokens : , :, :],
579+ output [: attn_metadata .num_decode_tokens , :, :],
580+ query [: attn_metadata .num_decode_tokens , :, :],
626581 key_cache ,
627582 value_cache ,
628583 block_tables_arg ,
0 commit comments