4949from vllm .v1 .attention .backends .utils import (
5050 AttentionCGSupport , AttentionMetadataBuilder , CommonAttentionMetadata ,
5151 make_kv_sharing_fast_prefill_attention_metadata ,
52- make_local_attention_virtual_batches )
52+ make_local_attention_virtual_batches ,
53+ reorder_batch_to_split_decodes_and_prefills )
5354from vllm .v1 .core .encoder_cache_manager import compute_encoder_budget
5455from vllm .v1 .kv_cache_interface import (AttentionSpec ,
5556 ChunkedLocalAttentionSpec ,
@@ -329,6 +330,8 @@ def __init__(
329330 self .kv_sharing_fast_prefill_logits_indices = torch .zeros (
330331 self .max_num_tokens , dtype = torch .int32 , device = self .device )
331332
333+ self .reorder_batch_threshold : Optional [int ] = None
334+
332335 def _may_reorder_batch (self , scheduler_output : "SchedulerOutput" ) -> None :
333336 """
334337 Update the order of requests in the batch based on the attention
@@ -347,20 +350,11 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
347350 if len (self .kv_cache_config .kv_cache_groups ) == 0 :
348351 return
349352
350- self .attn_metadata_builders [0 ].reorder_batch (self .input_batch ,
351- scheduler_output )
352-
353- # For models with multiple KV cache groups, the groups should agree on
354- # the same order of requests. We ensure this by only allowing the first
355- # group to reorder the batch and asserting that all other groups do not
356- # reorder the batch.
357- # TODO(tdoublep): make this more flexible so that any group can
358- # re-order the batch (not only the first).
359- # TODO(tdoublep): verify this during engine init instead of at runtime
360- for i in range (1 , len (self .kv_cache_config .kv_cache_groups )):
361- batch_reordered = self .attn_metadata_builders [i ].reorder_batch (
362- self .input_batch , scheduler_output )
363- assert not batch_reordered
353+ if self .reorder_batch_threshold is not None :
354+ reorder_batch_to_split_decodes_and_prefills (
355+ self .input_batch ,
356+ scheduler_output ,
357+ decode_threshold = self .reorder_batch_threshold )
364358
365359 # Note: used for model runner override.
366360 def _init_device_properties (self ) -> None :
@@ -2654,6 +2648,9 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
26542648 self .attn_backends .append (attn_backend_i )
26552649 self .attn_metadata_builders .append (attn_metadata_builder_i )
26562650
2651+ # Calculate reorder batch threshold (if neeeded)
2652+ self .calculate_reorder_batch_threshold ()
2653+
26572654 if len (self .attn_backends ) > 0 :
26582655 return
26592656
@@ -2688,6 +2685,28 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
26882685 self .attn_metadata_builders .append (attn_metadata_builder )
26892686 self .is_encoder_only_model = True
26902687
2688+ def calculate_reorder_batch_threshold (self ) -> None :
2689+ """
2690+ Check that if any backends reorder batches; that the reordering
2691+ is compatible (e.g., decode threshold is the same)
2692+ """
2693+ for attn_metadata_builder_i in self .attn_metadata_builders :
2694+ # check that if any backends reorder batches; that the reordering
2695+ # is compatible (e.g., decode threshold is the same)
2696+ reorder_batch_threshold_i = (
2697+ attn_metadata_builder_i .reorder_batch_threshold )
2698+ if reorder_batch_threshold_i is not None :
2699+ if self .reorder_batch_threshold is not None :
2700+ if reorder_batch_threshold_i != \
2701+ self .reorder_batch_threshold :
2702+ raise ValueError (
2703+ f"Attention backend reorders decodes with "
2704+ f"threshold { reorder_batch_threshold_i } but other "
2705+ f"backend uses threshold "
2706+ f"{ self .reorder_batch_threshold } " )
2707+ else :
2708+ self .reorder_batch_threshold = reorder_batch_threshold_i
2709+
26912710 def may_reinitialize_input_batch (self ,
26922711 kv_cache_config : KVCacheConfig ) -> None :
26932712 """
0 commit comments