44import enum
55import functools
66from abc import abstractmethod
7- from collections import deque
87from dataclasses import dataclass , field , fields , make_dataclass
9- from enum import Enum
108from typing import (
119 TYPE_CHECKING ,
1210 Any ,
@@ -51,85 +49,6 @@ def is_valid_kv_cache_layout(value: str) -> bool:
5149 return value in get_args (KVCacheLayoutType )
5250
5351
54- class QueryLenSupport (Enum ):
55- """Defines the level of query length support for an attention backend's
56- decode pipeline.
57-
58- - SINGLE_ONLY: Decode pipeline only supports single-token queries
59- (query_len=1)
60- - UNIFORM: Decode pipeline supports uniform multi-token queries
61- (all requests must have same query_len > 1)
62- - VARLEN: Decode pipeline supports variable-length queries
63- (mixed query lengths in same batch)
64- """
65-
66- SINGLE_ONLY = "single_only"
67- UNIFORM = "uniform"
68- VARLEN = "varlen"
69-
70-
71- @dataclass
72- class ReorderSpec :
73- """
74- Defines how the model runner reorders requests within a batch for attention
75- backends that distinguish between prefill, extend, and decode phases.
76-
77- Core controls
78- - decode_threshold: Query lengths ≤ this value are treated as decode. If
79- None, no reorder will be applied.
80- - split_extend: If True, split prefill into [extend_prefill, pure_prefill].
81- - decode_query_len_support:
82- - SINGLE_ONLY: single-token only (no spec decode)
83- - UNIFORM: uniform multi-token queries (spec decode, equal lengths)
84- - VARLEN: variable-length queries (spec decode, mixed lengths)
85-
86- Example input
87- query_len: [7, 10, 3, 1, 2, 5, 2, 15]
88- seq_len: [10, 10, 8, 9, 8, 5, 10, 20]
89-
90- Case 1: decode_threshold=3
91- query_len: [3, 1, 2, 2, 7, 10, 5, 15]
92- seq_len: [8, 9, 8, 10, 10, 10, 5, 20]
93- └──── dec ────┘└──── pre ─────┘
94- → Reordered as [decode, prefill].
95-
96- Case 2: decode_threshold=3, split_extend=True,
97- query_len: [3, 1, 2, 2, 7, 15, 10, 5, 8]
98- seq_len: [8, 9, 8, 10, 10, 20, 10, 5, 8]
99- └──── dec ────┘ └── ext ──┘ └pre┘
100- → Reordered as [decode, extend_prefill, pure_prefill].
101-
102- Case 3 (Future/TODO):
103- decode_threshold=3, split_extend=True, query_len_support=UNIFORM
104- (Move the most common ≤ decode_threshold to the front to form the largest
105- *uniform* decode region. Here, the uniform decode region is the two q_len=2’s.)
106- query_len: [2, 2, 3, 1, 7, 15, 10, 5]
107- seq_len: [8, 10, 8, 9, 10, 20, 10, 5]
108- └u dec┘ └─dec─┘└──── pre ─────┘
109- → Reordered as [uniform-decode(2s), decode, prefill].
110- """
111-
112- decode_threshold : int | None = None
113- """The threshold for reordering the batch into decode and prefill
114- requests. If `decode_threshold` is not None, prefill and
115- decode request will be reordered according to this value. query
116- length <= threshold will be considered as decode"""
117-
118- split_extend : bool = False
119- """Whether to further split the prefill requests into pure prefill
120- (query_len == context_len) and extend prefill (query length < context_len)
121- in a single batch. Once this flag is set, the request will be reordered to
122- [decode:extend_prefill:pure_prefill]"""
123-
124- decode_query_len_support : QueryLenSupport = QueryLenSupport .SINGLE_ONLY
125- """Defines the level of query length support for this backend.
126- - SINGLE_ONLY: Only single-token queries (no spec decode support)
127- - UNIFORM: Supports uniform multi-token queries (spec decode with uniform lengths)
128- - VARLEN: Supports variable-length queries (spec decode with mixed lengths)
129- If set to UNIFORM or VARLEN, this will increase `decode_threshold` when
130- speculative decoding is enabled."""
131-
132-
13352@dataclass
13453class CommonAttentionMetadata :
13554 """
@@ -326,10 +245,10 @@ class AttentionCGSupport(enum.Enum):
326245class AttentionMetadataBuilder (abc .ABC , Generic [M ]):
327246 # Does this backend/builder support CUDA Graphs for attention (default: no).
328247 cudagraph_support : ClassVar [AttentionCGSupport ] = AttentionCGSupport .NEVER
329- # Attention backend's reorder spec which controls if and
330- # how to reorder the request before actually execute the
331- # model (default: no reorder)
332- reorder_spec : ClassVar [ ReorderSpec ] = ReorderSpec ( None )
248+ # Does this backend/builder reorder the batch?
249+ # If not, set this to None. Otherwise set it to the query
250+ # length that will be pulled into the front of the batch.
251+ reorder_batch_threshold : int | None = None
333252
334253 @abstractmethod
335254 def __init__ (
@@ -344,21 +263,21 @@ def __init__(
344263 self .vllm_config = vllm_config
345264 self .device = device
346265
347- def _init_decode_threshold (
348- self , decode_threshold : int = 1 , supports_spec_as_decode : bool = False
266+ def _init_reorder_batch_threshold (
267+ self , reorder_batch_threshold : int = 1 , supports_spec_as_decode : bool = False
349268 ) -> None :
350- self .reorder_spec . decode_threshold = decode_threshold
351- if self .reorder_spec . decode_threshold is not None and supports_spec_as_decode :
269+ self .reorder_batch_threshold = reorder_batch_threshold
270+ if self .reorder_batch_threshold is not None and supports_spec_as_decode :
352271 # If the backend supports spec-as-decode kernels, then we can set
353- # the decode_threshold based on the number of speculative
272+ # the reorder_batch_threshold based on the number of speculative
354273 # tokens from the config.
355274 speculative_config = self .vllm_config .speculative_config
356275 if (
357276 speculative_config is not None
358277 and speculative_config .num_speculative_tokens is not None
359278 ):
360- self .reorder_spec . decode_threshold = max (
361- self .reorder_spec . decode_threshold ,
279+ self .reorder_batch_threshold = max (
280+ self .reorder_batch_threshold ,
362281 1 + speculative_config .num_speculative_tokens ,
363282 )
364283
@@ -430,10 +349,6 @@ def use_cascade_attention(
430349 ) -> bool :
431350 return False
432351
433- @classmethod
434- def reset_reorder_spec (cls , reorder_sepc ):
435- cls .reorder_spec = reorder_sepc
436-
437352
438353@functools .lru_cache
439354def get_kv_cache_layout ():
@@ -939,109 +854,6 @@ def split_decodes_and_prefills(
939854 return (num_decodes , num_prefills , num_decode_tokens , num_prefill_tokens )
940855
941856
942- def reorder_batch_to_split_decodes_prefills_and_extends (
943- input_batch : "InputBatch" ,
944- scheduler_output : "SchedulerOutput" ,
945- decode_threshold : int = 1 ,
946- ) -> bool :
947- """
948- Reorders the batch to split into prefill, extend
949- and decode requests; places all requests in the order of
950- [decodes:extend:prefill].
951-
952- Returns:
953- True if the batch was modified, False otherwise.
954- """
955-
956- decodes = []
957- prefills = []
958- extends = []
959-
960- for i , req_id in enumerate (input_batch .req_ids ):
961- num_tokens = scheduler_output .num_scheduled_tokens [req_id ]
962- if num_tokens <= decode_threshold :
963- decodes .append (i )
964- elif input_batch .num_computed_tokens_cpu [i ] > 0 :
965- extends .append (i )
966- else :
967- prefills .append (i )
968-
969- num_decodes = len (decodes )
970- num_extends = len (extends )
971- # We define the reorder matrix here to help on the request reorder
972- # reorder_matrix[(i, j)] means the id the the requests that suppose
973- # to be in zone i but actually spot on zone j
974- # The decode, extend and prefill are separated into 3
975- # different zone here, 0 for decode, 1 for extend and 2 for
976- # prefill
977- reorder_matrix : dict [tuple [int , int ], deque [int ]] = {
978- (i , j ): deque () for i in range (3 ) for j in range (3 ) if i != j
979- }
980-
981- # collect mismatch
982-
983- def target_idx (idx ):
984- if idx < num_decodes :
985- # decode as zone 0
986- return 0
987- elif idx < num_decodes + num_extends :
988- # extend as zone 1
989- return 1
990- else :
991- # prefill as zone 2
992- return 2
993-
994- def fill_reorder_matrix (request_lists , reorder_sequence ):
995- for idx , seq in enumerate (reorder_sequence ):
996- request_list = request_lists [idx ]
997- for req_idx in request_list :
998- req_target_id = target_idx (req_idx )
999- if seq != req_target_id :
1000- reorder_matrix [(seq , req_target_id )].append (req_idx )
1001-
1002- def direct_zone_swap (i , j ):
1003- assert i != j
1004- modified_batch = False
1005- while reorder_matrix [(i , j )] and reorder_matrix [(j , i )]:
1006- swap_req1 = reorder_matrix [(i , j )].pop ()
1007- swap_req2 = reorder_matrix [(j , i )].pop ()
1008- input_batch .swap_states (swap_req1 , swap_req2 )
1009- modified_batch = True
1010-
1011- return modified_batch
1012-
1013- # in order 1,2,3, out order 3, 1, 2
1014- def indirect_zone_swap (zone_list ):
1015- assert len (zone_list ) == 3
1016- modified_batch = False
1017- while (
1018- reorder_matrix [zone_list [0 ]]
1019- and reorder_matrix [zone_list [1 ]]
1020- and reorder_matrix [zone_list [2 ]]
1021- ):
1022- swap_req1 = reorder_matrix [zone_list [0 ]].pop ()
1023- swap_req2 = reorder_matrix [zone_list [1 ]].pop ()
1024- swap_req3 = reorder_matrix [zone_list [2 ]].pop ()
1025-
1026- input_batch .swap_states (swap_req1 , swap_req2 )
1027- input_batch .swap_states (swap_req2 , swap_req3 )
1028- modified_batch = True
1029- return modified_batch
1030-
1031- fill_reorder_matrix ([decodes , extends , prefills ], [0 , 1 , 2 ])
1032-
1033- modified_batch = False
1034- # do directly swap for
1035- modified_batch |= direct_zone_swap (0 , 1 ) # decode <--> extend
1036- modified_batch |= direct_zone_swap (0 , 2 ) # decode <--> prefill
1037- modified_batch |= direct_zone_swap (1 , 2 ) # extend <--> prefill
1038-
1039- modified_batch |= indirect_zone_swap (((0 , 1 ), (1 , 2 ), (2 , 0 )))
1040- modified_batch |= indirect_zone_swap (((2 , 1 ), (0 , 2 ), (1 , 0 )))
1041-
1042- return modified_batch
1043-
1044-
1045857def reorder_batch_to_split_decodes_and_prefills (
1046858 input_batch : "InputBatch" ,
1047859 scheduler_output : "SchedulerOutput" ,
@@ -1054,50 +866,47 @@ def reorder_batch_to_split_decodes_and_prefills(
1054866 Returns:
1055867 True if the batch was modified, False otherwise.
1056868 """
1057- # We now want to reorder the batch so that the "decode" requests are at
1058- # the front and the "prefill" requests are at the back using the least
1059- # amount of swaps possible. (NOTE for now we loosely use "decode" to mean
1060- # requests where attention is likely memory-bound and "prefill" to mean
1061- # requests where attention is likely compute-bound, TODO(lucas): figure out
1062- # a better naming here)
1063- decodes = []
1064- prefills = []
1065- num_decode_tokens = 0
1066- num_prefill_tokens = 0
1067-
1068- for i , req_id in enumerate (input_batch .req_ids ):
1069- num_tokens = scheduler_output .num_scheduled_tokens [req_id ]
1070- if num_tokens <= decode_threshold :
1071- decodes .append (i )
1072- num_decode_tokens += num_tokens
1073- else :
1074- prefills .append (i )
1075- num_prefill_tokens += num_tokens
1076-
1077- # We hope that this is fairly minimal since decodes
1078- # should be around for a number of iterations so hopefully they are
1079- # relatively stationary (and new request are generally appended to the
1080- # persistent batch so already should be at the back)
1081- # To achieve this we loop over the decodes in descending order and
1082- # the prefills in ascending order. We swap decodes from the "back"
1083- # i.e. past where the last decode should be in the reodorered with
1084- # prefills from the front of the batch.
1085- # `decodes` and `prefills` are already in ascending order just based on
1086- # the above loop
1087- num_decodes = len (decodes )
1088- num_prefills = len (prefills )
1089- modified_batch = False
869+ # We now want to reorder the batch into decode → extend → prefill order
870+ # where:
871+ # decode: request with num_scheduled_tokens <= decode_threshold
872+ # extend: non-decode request with existing context
873+ # prefill: non-decode request with no existing context
874+ # NOTE for now we loosely use "decode" to mean requests where attention is
875+ # likely memory-bound and "prefill" to mean requests where attention is
876+ # likely compute-bound,
877+ num_reqs = len (input_batch .req_ids )
878+ num_scheduled_tokens_np = np .array (
879+ [
880+ scheduler_output .num_scheduled_tokens [req_id ]
881+ for req_id in input_batch .req_ids
882+ ]
883+ )
884+
885+ num_computed_tokens_np = input_batch .num_computed_tokens_cpu [:num_reqs ]
886+
887+ is_decode = num_scheduled_tokens_np <= decode_threshold
888+ is_extend = (~ is_decode ) & (num_computed_tokens_np > num_scheduled_tokens_np )
889+ is_prefill = (~ is_decode ) & (num_computed_tokens_np == num_scheduled_tokens_np )
1090890
1091- for i in range (1 , min (num_decodes , num_prefills ) + 1 ):
1092- # If the decode is at the "back" of the batch, i, we can swap it
1093- # with the prefill closest to the front of the batch
1094- decode_idx = decodes [num_decodes - i ]
1095- if decode_idx < num_decodes :
1096- break
891+ # Desired order: decode → extend → prefill
892+ order_key = np .zeros (is_decode .shape , dtype = np .int32 ) # 0 = decode by default
893+ order_key [is_extend ] = 1
894+ order_key [is_prefill ] = 2
1097895
1098- input_batch .swap_states (prefills [i - 1 ], decode_idx )
1099- modified_batch = True
896+ # get a permutation of the indices that sorts the order_key, basically this means if
897+ # we reordered the batch like request[perm] we'd be in the desired order
898+ perm = np .argsort (order_key , kind = "stable" )
899+ # old_idx -> new_pos
900+ dest = np .empty_like (perm )
901+ dest [perm ] = np .arange (num_reqs )
1100902
903+ modified_batch = False
904+ for i in range (num_reqs ):
905+ while dest [i ] != i :
906+ j = dest [i ] # destination index for the element currently at i
907+ input_batch .swap_states (i , j )
908+ dest [i ], dest [j ] = dest [j ], dest [i ]
909+ modified_batch = True
1101910 return modified_batch
1102911
1103912
0 commit comments