Skip to content

Commit 407e56f

Browse files
committed
remove the changes in ReorderSpec
Signed-off-by: ganyi <ygan@amd.com>
1 parent 4eccb9b commit 407e56f

File tree

2 files changed

+78
-285
lines changed

2 files changed

+78
-285
lines changed

vllm/v1/attention/backends/utils.py

Lines changed: 49 additions & 240 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44
import enum
55
import functools
66
from abc import abstractmethod
7-
from collections import deque
87
from dataclasses import dataclass, field, fields, make_dataclass
9-
from enum import Enum
108
from typing import (
119
TYPE_CHECKING,
1210
Any,
@@ -51,85 +49,6 @@ def is_valid_kv_cache_layout(value: str) -> bool:
5149
return value in get_args(KVCacheLayoutType)
5250

5351

54-
class QueryLenSupport(Enum):
55-
"""Defines the level of query length support for an attention backend's
56-
decode pipeline.
57-
58-
- SINGLE_ONLY: Decode pipeline only supports single-token queries
59-
(query_len=1)
60-
- UNIFORM: Decode pipeline supports uniform multi-token queries
61-
(all requests must have same query_len > 1)
62-
- VARLEN: Decode pipeline supports variable-length queries
63-
(mixed query lengths in same batch)
64-
"""
65-
66-
SINGLE_ONLY = "single_only"
67-
UNIFORM = "uniform"
68-
VARLEN = "varlen"
69-
70-
71-
@dataclass
72-
class ReorderSpec:
73-
"""
74-
Defines how the model runner reorders requests within a batch for attention
75-
backends that distinguish between prefill, extend, and decode phases.
76-
77-
Core controls
78-
- decode_threshold: Query lengths ≤ this value are treated as decode. If
79-
None, no reorder will be applied.
80-
- split_extend: If True, split prefill into [extend_prefill, pure_prefill].
81-
- decode_query_len_support:
82-
- SINGLE_ONLY: single-token only (no spec decode)
83-
- UNIFORM: uniform multi-token queries (spec decode, equal lengths)
84-
- VARLEN: variable-length queries (spec decode, mixed lengths)
85-
86-
Example input
87-
query_len: [7, 10, 3, 1, 2, 5, 2, 15]
88-
seq_len: [10, 10, 8, 9, 8, 5, 10, 20]
89-
90-
Case 1: decode_threshold=3
91-
query_len: [3, 1, 2, 2, 7, 10, 5, 15]
92-
seq_len: [8, 9, 8, 10, 10, 10, 5, 20]
93-
└──── dec ────┘└──── pre ─────┘
94-
→ Reordered as [decode, prefill].
95-
96-
Case 2: decode_threshold=3, split_extend=True,
97-
query_len: [3, 1, 2, 2, 7, 15, 10, 5, 8]
98-
seq_len: [8, 9, 8, 10, 10, 20, 10, 5, 8]
99-
└──── dec ────┘ └── ext ──┘ └pre┘
100-
→ Reordered as [decode, extend_prefill, pure_prefill].
101-
102-
Case 3 (Future/TODO):
103-
decode_threshold=3, split_extend=True, query_len_support=UNIFORM
104-
(Move the most common ≤ decode_threshold to the front to form the largest
105-
*uniform* decode region. Here, the uniform decode region is the two q_len=2’s.)
106-
query_len: [2, 2, 3, 1, 7, 15, 10, 5]
107-
seq_len: [8, 10, 8, 9, 10, 20, 10, 5]
108-
└u dec┘ └─dec─┘└──── pre ─────┘
109-
→ Reordered as [uniform-decode(2s), decode, prefill].
110-
"""
111-
112-
decode_threshold: int | None = None
113-
"""The threshold for reordering the batch into decode and prefill
114-
requests. If `decode_threshold` is not None, prefill and
115-
decode request will be reordered according to this value. query
116-
length <= threshold will be considered as decode"""
117-
118-
split_extend: bool = False
119-
"""Whether to further split the prefill requests into pure prefill
120-
(query_len == context_len) and extend prefill (query length < context_len)
121-
in a single batch. Once this flag is set, the request will be reordered to
122-
[decode:extend_prefill:pure_prefill]"""
123-
124-
decode_query_len_support: QueryLenSupport = QueryLenSupport.SINGLE_ONLY
125-
"""Defines the level of query length support for this backend.
126-
- SINGLE_ONLY: Only single-token queries (no spec decode support)
127-
- UNIFORM: Supports uniform multi-token queries (spec decode with uniform lengths)
128-
- VARLEN: Supports variable-length queries (spec decode with mixed lengths)
129-
If set to UNIFORM or VARLEN, this will increase `decode_threshold` when
130-
speculative decoding is enabled."""
131-
132-
13352
@dataclass
13453
class CommonAttentionMetadata:
13554
"""
@@ -326,10 +245,10 @@ class AttentionCGSupport(enum.Enum):
326245
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
327246
# Does this backend/builder support CUDA Graphs for attention (default: no).
328247
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
329-
# Attention backend's reorder spec which controls if and
330-
# how to reorder the request before actually execute the
331-
# model (default: no reorder)
332-
reorder_spec: ClassVar[ReorderSpec] = ReorderSpec(None)
248+
# Does this backend/builder reorder the batch?
249+
# If not, set this to None. Otherwise set it to the query
250+
# length that will be pulled into the front of the batch.
251+
reorder_batch_threshold: int | None = None
333252

334253
@abstractmethod
335254
def __init__(
@@ -344,21 +263,21 @@ def __init__(
344263
self.vllm_config = vllm_config
345264
self.device = device
346265

347-
def _init_decode_threshold(
348-
self, decode_threshold: int = 1, supports_spec_as_decode: bool = False
266+
def _init_reorder_batch_threshold(
267+
self, reorder_batch_threshold: int = 1, supports_spec_as_decode: bool = False
349268
) -> None:
350-
self.reorder_spec.decode_threshold = decode_threshold
351-
if self.reorder_spec.decode_threshold is not None and supports_spec_as_decode:
269+
self.reorder_batch_threshold = reorder_batch_threshold
270+
if self.reorder_batch_threshold is not None and supports_spec_as_decode:
352271
# If the backend supports spec-as-decode kernels, then we can set
353-
# the decode_threshold based on the number of speculative
272+
# the reorder_batch_threshold based on the number of speculative
354273
# tokens from the config.
355274
speculative_config = self.vllm_config.speculative_config
356275
if (
357276
speculative_config is not None
358277
and speculative_config.num_speculative_tokens is not None
359278
):
360-
self.reorder_spec.decode_threshold = max(
361-
self.reorder_spec.decode_threshold,
279+
self.reorder_batch_threshold = max(
280+
self.reorder_batch_threshold,
362281
1 + speculative_config.num_speculative_tokens,
363282
)
364283

@@ -430,10 +349,6 @@ def use_cascade_attention(
430349
) -> bool:
431350
return False
432351

433-
@classmethod
434-
def reset_reorder_spec(cls, reorder_sepc):
435-
cls.reorder_spec = reorder_sepc
436-
437352

438353
@functools.lru_cache
439354
def get_kv_cache_layout():
@@ -939,109 +854,6 @@ def split_decodes_and_prefills(
939854
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
940855

941856

942-
def reorder_batch_to_split_decodes_prefills_and_extends(
943-
input_batch: "InputBatch",
944-
scheduler_output: "SchedulerOutput",
945-
decode_threshold: int = 1,
946-
) -> bool:
947-
"""
948-
Reorders the batch to split into prefill, extend
949-
and decode requests; places all requests in the order of
950-
[decodes:extend:prefill].
951-
952-
Returns:
953-
True if the batch was modified, False otherwise.
954-
"""
955-
956-
decodes = []
957-
prefills = []
958-
extends = []
959-
960-
for i, req_id in enumerate(input_batch.req_ids):
961-
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
962-
if num_tokens <= decode_threshold:
963-
decodes.append(i)
964-
elif input_batch.num_computed_tokens_cpu[i] > 0:
965-
extends.append(i)
966-
else:
967-
prefills.append(i)
968-
969-
num_decodes = len(decodes)
970-
num_extends = len(extends)
971-
# We define the reorder matrix here to help on the request reorder
972-
# reorder_matrix[(i, j)] means the id the the requests that suppose
973-
# to be in zone i but actually spot on zone j
974-
# The decode, extend and prefill are separated into 3
975-
# different zone here, 0 for decode, 1 for extend and 2 for
976-
# prefill
977-
reorder_matrix: dict[tuple[int, int], deque[int]] = {
978-
(i, j): deque() for i in range(3) for j in range(3) if i != j
979-
}
980-
981-
# collect mismatch
982-
983-
def target_idx(idx):
984-
if idx < num_decodes:
985-
# decode as zone 0
986-
return 0
987-
elif idx < num_decodes + num_extends:
988-
# extend as zone 1
989-
return 1
990-
else:
991-
# prefill as zone 2
992-
return 2
993-
994-
def fill_reorder_matrix(request_lists, reorder_sequence):
995-
for idx, seq in enumerate(reorder_sequence):
996-
request_list = request_lists[idx]
997-
for req_idx in request_list:
998-
req_target_id = target_idx(req_idx)
999-
if seq != req_target_id:
1000-
reorder_matrix[(seq, req_target_id)].append(req_idx)
1001-
1002-
def direct_zone_swap(i, j):
1003-
assert i != j
1004-
modified_batch = False
1005-
while reorder_matrix[(i, j)] and reorder_matrix[(j, i)]:
1006-
swap_req1 = reorder_matrix[(i, j)].pop()
1007-
swap_req2 = reorder_matrix[(j, i)].pop()
1008-
input_batch.swap_states(swap_req1, swap_req2)
1009-
modified_batch = True
1010-
1011-
return modified_batch
1012-
1013-
# in order 1,2,3, out order 3, 1, 2
1014-
def indirect_zone_swap(zone_list):
1015-
assert len(zone_list) == 3
1016-
modified_batch = False
1017-
while (
1018-
reorder_matrix[zone_list[0]]
1019-
and reorder_matrix[zone_list[1]]
1020-
and reorder_matrix[zone_list[2]]
1021-
):
1022-
swap_req1 = reorder_matrix[zone_list[0]].pop()
1023-
swap_req2 = reorder_matrix[zone_list[1]].pop()
1024-
swap_req3 = reorder_matrix[zone_list[2]].pop()
1025-
1026-
input_batch.swap_states(swap_req1, swap_req2)
1027-
input_batch.swap_states(swap_req2, swap_req3)
1028-
modified_batch = True
1029-
return modified_batch
1030-
1031-
fill_reorder_matrix([decodes, extends, prefills], [0, 1, 2])
1032-
1033-
modified_batch = False
1034-
# do directly swap for
1035-
modified_batch |= direct_zone_swap(0, 1) # decode <--> extend
1036-
modified_batch |= direct_zone_swap(0, 2) # decode <--> prefill
1037-
modified_batch |= direct_zone_swap(1, 2) # extend <--> prefill
1038-
1039-
modified_batch |= indirect_zone_swap(((0, 1), (1, 2), (2, 0)))
1040-
modified_batch |= indirect_zone_swap(((2, 1), (0, 2), (1, 0)))
1041-
1042-
return modified_batch
1043-
1044-
1045857
def reorder_batch_to_split_decodes_and_prefills(
1046858
input_batch: "InputBatch",
1047859
scheduler_output: "SchedulerOutput",
@@ -1054,50 +866,47 @@ def reorder_batch_to_split_decodes_and_prefills(
1054866
Returns:
1055867
True if the batch was modified, False otherwise.
1056868
"""
1057-
# We now want to reorder the batch so that the "decode" requests are at
1058-
# the front and the "prefill" requests are at the back using the least
1059-
# amount of swaps possible. (NOTE for now we loosely use "decode" to mean
1060-
# requests where attention is likely memory-bound and "prefill" to mean
1061-
# requests where attention is likely compute-bound, TODO(lucas): figure out
1062-
# a better naming here)
1063-
decodes = []
1064-
prefills = []
1065-
num_decode_tokens = 0
1066-
num_prefill_tokens = 0
1067-
1068-
for i, req_id in enumerate(input_batch.req_ids):
1069-
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
1070-
if num_tokens <= decode_threshold:
1071-
decodes.append(i)
1072-
num_decode_tokens += num_tokens
1073-
else:
1074-
prefills.append(i)
1075-
num_prefill_tokens += num_tokens
1076-
1077-
# We hope that this is fairly minimal since decodes
1078-
# should be around for a number of iterations so hopefully they are
1079-
# relatively stationary (and new request are generally appended to the
1080-
# persistent batch so already should be at the back)
1081-
# To achieve this we loop over the decodes in descending order and
1082-
# the prefills in ascending order. We swap decodes from the "back"
1083-
# i.e. past where the last decode should be in the reodorered with
1084-
# prefills from the front of the batch.
1085-
# `decodes` and `prefills` are already in ascending order just based on
1086-
# the above loop
1087-
num_decodes = len(decodes)
1088-
num_prefills = len(prefills)
1089-
modified_batch = False
869+
# We now want to reorder the batch into decode → extend → prefill order
870+
# where:
871+
# decode: request with num_scheduled_tokens <= decode_threshold
872+
# extend: non-decode request with existing context
873+
# prefill: non-decode request with no existing context
874+
# NOTE for now we loosely use "decode" to mean requests where attention is
875+
# likely memory-bound and "prefill" to mean requests where attention is
876+
# likely compute-bound,
877+
num_reqs = len(input_batch.req_ids)
878+
num_scheduled_tokens_np = np.array(
879+
[
880+
scheduler_output.num_scheduled_tokens[req_id]
881+
for req_id in input_batch.req_ids
882+
]
883+
)
884+
885+
num_computed_tokens_np = input_batch.num_computed_tokens_cpu[:num_reqs]
886+
887+
is_decode = num_scheduled_tokens_np <= decode_threshold
888+
is_extend = (~is_decode) & (num_computed_tokens_np > num_scheduled_tokens_np)
889+
is_prefill = (~is_decode) & (num_computed_tokens_np == num_scheduled_tokens_np)
1090890

1091-
for i in range(1, min(num_decodes, num_prefills) + 1):
1092-
# If the decode is at the "back" of the batch, i, we can swap it
1093-
# with the prefill closest to the front of the batch
1094-
decode_idx = decodes[num_decodes - i]
1095-
if decode_idx < num_decodes:
1096-
break
891+
# Desired order: decode → extend → prefill
892+
order_key = np.zeros(is_decode.shape, dtype=np.int32) # 0 = decode by default
893+
order_key[is_extend] = 1
894+
order_key[is_prefill] = 2
1097895

1098-
input_batch.swap_states(prefills[i - 1], decode_idx)
1099-
modified_batch = True
896+
# get a permutation of the indices that sorts the order_key, basically this means if
897+
# we reordered the batch like request[perm] we'd be in the desired order
898+
perm = np.argsort(order_key, kind="stable")
899+
# old_idx -> new_pos
900+
dest = np.empty_like(perm)
901+
dest[perm] = np.arange(num_reqs)
1100902

903+
modified_batch = False
904+
for i in range(num_reqs):
905+
while dest[i] != i:
906+
j = dest[i] # destination index for the element currently at i
907+
input_batch.swap_states(i, j)
908+
dest[i], dest[j] = dest[j], dest[i]
909+
modified_batch = True
1101910
return modified_batch
1102911

1103912

0 commit comments

Comments
 (0)