Skip to content

Commit f5cc752

Browse files
committed
fix reorder
Signed-off-by: ganyi <ygan@amd.com>
1 parent ce8d3e7 commit f5cc752

File tree

3 files changed

+62
-10
lines changed

3 files changed

+62
-10
lines changed

tests/v1/attention/test_batch_reordering.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,21 @@ class ReorderTestCase:
8383
expected_order=[3, 1, 2, 0], # Only swap 0↔3, keep 1 and 2 in place
8484
expected_modified=True,
8585
),
86+
"complicated_mixed_interleaved": ReorderTestCase(
87+
requests=[
88+
(1, 20),
89+
(1, 50),
90+
(374, 0),
91+
(300, 20),
92+
(1, 20),
93+
(256, 0),
94+
(1, 5),
95+
(27, 0),
96+
(1, 4),
97+
],
98+
expected_order=[0, 1, 6, 8, 4, 3, 2, 7, 5],
99+
expected_modified=True,
100+
),
86101
}
87102

88103

vllm/v1/attention/backends/rocm_aiter_fa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from vllm.config import VllmConfig
1818
from vllm.logger import init_logger
1919
from vllm.platforms import current_platform
20-
from vllm.utils import cdiv
20+
from vllm.utils.math_utils import cdiv
2121
from vllm.v1.attention.backends.utils import (
2222
AttentionCGSupport,
2323
AttentionMetadataBuilder,

vllm/v1/attention/backends/utils.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import abc
4+
import copy
45
import enum
56
import functools
67
from abc import abstractmethod
@@ -874,6 +875,29 @@ def reorder_batch_to_split_decodes_and_prefills(
874875
# NOTE for now we loosely use "decode" to mean requests where attention is
875876
# likely memory-bound and "prefill" to mean requests where attention is
876877
# likely compute-bound,
878+
# rid = dist.get_rank()
879+
rid = 0
880+
881+
def print_order():
882+
if rid == 0:
883+
num_scheduled_tokens = [
884+
scheduler_output.num_scheduled_tokens[id] for id in input_batch.req_ids
885+
]
886+
num_scheduled_tokens_np = np.array(num_scheduled_tokens)
887+
num_computed_tokens_np = input_batch.num_computed_tokens_cpu[:num_reqs]
888+
print("num scheduled tokens: ", num_scheduled_tokens_np, flush=True)
889+
print("num computed tokens: ", num_computed_tokens_np, flush=True)
890+
is_decode = num_scheduled_tokens_np <= decode_threshold
891+
is_extend = (~is_decode) & (num_computed_tokens_np > 0)
892+
is_prefill = (~is_decode) & (num_computed_tokens_np == 0)
893+
idx = np.arange(0, is_decode.shape[0])
894+
decodes = idx[is_decode]
895+
extends = idx[is_extend]
896+
prefills = idx[is_prefill]
897+
print("decode: ", decodes, flush=True)
898+
print("extends: ", extends, flush=True)
899+
print("prefills: ", prefills, flush=True)
900+
877901
num_reqs = len(input_batch.req_ids)
878902
num_scheduled_tokens = [
879903
scheduler_output.num_scheduled_tokens[id] for id in input_batch.req_ids
@@ -907,16 +931,29 @@ def reorder_batch_to_split_decodes_and_prefills(
907931
sorted_order = np.argsort(req_regions[needs_swap], kind="stable")
908932
dest_indices = swap_indices[sorted_order]
909933

910-
src_dest_map = {int(src): int(dst) for src, dst in zip(swap_indices, dest_indices)}
934+
idx_mapping = {val: idx for idx, val in enumerate(swap_indices)}
935+
# Record the original positions of idx in input_batch, the further
936+
# reorder manipulate the position_ids, so we need this variable helps
937+
# ping out the real position in input_batch
938+
indices_positions = copy.deepcopy(swap_indices)
939+
940+
# Then we reorder the swap_indices to dest_indices
941+
for i in range(len(swap_indices)):
942+
dst = dest_indices[i]
943+
src = swap_indices[i]
944+
if dst != src:
945+
# Get the real index position in input_batch to swap
946+
dst_pos = indices_positions[idx_mapping[dst]]
947+
src_pos = indices_positions[idx_mapping[src]]
948+
949+
input_batch.swap_states(dst_pos, src_pos)
950+
951+
dst_idx = idx_mapping[dst]
952+
swap_indices[i] = dst
953+
swap_indices[dst_idx] = src
954+
idx_mapping[dst] = i
955+
idx_mapping[src] = dst_idx
911956

912-
for src in src_dest_map:
913-
dst = src_dest_map[src]
914-
while src != dst:
915-
input_batch.swap_states(src, dst)
916-
# Mark dst as done by updating its destination to itself
917-
next_dst = src_dest_map.get(dst, dst)
918-
src_dest_map[dst] = dst
919-
dst = next_dst
920957
return True
921958

922959

0 commit comments

Comments
 (0)