Skip to content

Commit ab3d132

Browse files
committed
refactor split func
Signed-off-by: ganyi <ygan@amd.com>
1 parent f5cc752 commit ab3d132

File tree

1 file changed

+11
-44
lines changed

1 file changed

+11
-44
lines changed

vllm/v1/attention/backends/utils.py

Lines changed: 11 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -761,35 +761,24 @@ def split_decodes_prefills_and_extends(
761761

762762
query_lens = query_start_loc[1:] - query_start_loc[:-1]
763763
is_prefill = query_lens > decode_threshold
764-
if not torch.any(is_prefill):
765-
return num_reqs, 0, 0, num_tokens, 0, 0
766-
764+
is_pure_prefill = (seq_lens == query_lens) & is_prefill
767765
first_prefill = is_prefill.int().argmax(dim=-1).item()
768-
assert torch.all(query_lens[first_prefill:] > decode_threshold)
769-
assert torch.all(query_lens[:first_prefill] <= decode_threshold)
770-
766+
first_pure_prefill = is_pure_prefill.int().argmax(dim=-1).item()
771767
num_decodes = first_prefill
772768
num_decode_tokens = query_start_loc[first_prefill].item()
773-
774-
query_lens_prefill = query_lens[first_prefill:]
775-
seq_lens_prefill = seq_lens[first_prefill:]
776-
is_extend = seq_lens_prefill != query_lens_prefill
777-
778-
if torch.all(is_extend):
779-
num_extends = num_reqs - num_decodes
780-
num_extend_tokens = num_tokens - num_decode_tokens
781-
return (num_decodes, num_extends, 0, num_decode_tokens, num_extend_tokens, 0)
769+
if not torch.any(is_prefill):
770+
return (num_decodes, 0, 0, num_decode_tokens, 0, 0)
782771

783772
num_prefills = num_reqs - num_decodes
784-
first_extend = is_extend.int().argmax(dim=-1).item()
773+
num_prefill_tokens = num_tokens - num_decode_tokens
774+
if not torch.any(is_pure_prefill):
775+
return (num_decodes, num_prefills, 0, num_decode_tokens, num_prefill_tokens, 0)
785776

786-
num_extends = first_extend
787-
num_pure_prefills = num_prefills - first_extend
777+
num_extends = first_pure_prefill - num_decodes
778+
num_pure_prefills = num_reqs - first_pure_prefill
788779

789-
num_extend_tokens = (
790-
query_start_loc[num_extends + num_decodes].item() - num_decode_tokens
791-
)
792-
num_pure_prefill_tokens = num_tokens - num_decode_tokens - num_extend_tokens
780+
num_pure_prefill_tokens = num_tokens - query_start_loc[first_pure_prefill]
781+
num_extend_tokens = num_prefill_tokens - num_pure_prefill_tokens
793782
return (
794783
num_decodes,
795784
num_extends,
@@ -875,28 +864,6 @@ def reorder_batch_to_split_decodes_and_prefills(
875864
# NOTE for now we loosely use "decode" to mean requests where attention is
876865
# likely memory-bound and "prefill" to mean requests where attention is
877866
# likely compute-bound,
878-
# rid = dist.get_rank()
879-
rid = 0
880-
881-
def print_order():
882-
if rid == 0:
883-
num_scheduled_tokens = [
884-
scheduler_output.num_scheduled_tokens[id] for id in input_batch.req_ids
885-
]
886-
num_scheduled_tokens_np = np.array(num_scheduled_tokens)
887-
num_computed_tokens_np = input_batch.num_computed_tokens_cpu[:num_reqs]
888-
print("num scheduled tokens: ", num_scheduled_tokens_np, flush=True)
889-
print("num computed tokens: ", num_computed_tokens_np, flush=True)
890-
is_decode = num_scheduled_tokens_np <= decode_threshold
891-
is_extend = (~is_decode) & (num_computed_tokens_np > 0)
892-
is_prefill = (~is_decode) & (num_computed_tokens_np == 0)
893-
idx = np.arange(0, is_decode.shape[0])
894-
decodes = idx[is_decode]
895-
extends = idx[is_extend]
896-
prefills = idx[is_prefill]
897-
print("decode: ", decodes, flush=True)
898-
print("extends: ", extends, flush=True)
899-
print("prefills: ", prefills, flush=True)
900867

901868
num_reqs = len(input_batch.req_ids)
902869
num_scheduled_tokens = [

0 commit comments

Comments
 (0)