|
1 | 1 | # SPDX-License-Identifier: Apache-2.0 |
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
3 | 3 | import abc |
| 4 | +import copy |
4 | 5 | import enum |
5 | 6 | import functools |
6 | 7 | from abc import abstractmethod |
@@ -874,6 +875,29 @@ def reorder_batch_to_split_decodes_and_prefills( |
874 | 875 | # NOTE for now we loosely use "decode" to mean requests where attention is |
875 | 876 | # likely memory-bound and "prefill" to mean requests where attention is |
876 | 877 | # likely compute-bound, |
| 878 | + # rid = dist.get_rank() |
| 879 | + rid = 0 |
| 880 | + |
| 881 | + def print_order(): |
| 882 | + if rid == 0: |
| 883 | + num_scheduled_tokens = [ |
| 884 | + scheduler_output.num_scheduled_tokens[id] for id in input_batch.req_ids |
| 885 | + ] |
| 886 | + num_scheduled_tokens_np = np.array(num_scheduled_tokens) |
| 887 | + num_computed_tokens_np = input_batch.num_computed_tokens_cpu[:num_reqs] |
| 888 | + print("num scheduled tokens: ", num_scheduled_tokens_np, flush=True) |
| 889 | + print("num computed tokens: ", num_computed_tokens_np, flush=True) |
| 890 | + is_decode = num_scheduled_tokens_np <= decode_threshold |
| 891 | + is_extend = (~is_decode) & (num_computed_tokens_np > 0) |
| 892 | + is_prefill = (~is_decode) & (num_computed_tokens_np == 0) |
| 893 | + idx = np.arange(0, is_decode.shape[0]) |
| 894 | + decodes = idx[is_decode] |
| 895 | + extends = idx[is_extend] |
| 896 | + prefills = idx[is_prefill] |
| 897 | + print("decode: ", decodes, flush=True) |
| 898 | + print("extends: ", extends, flush=True) |
| 899 | + print("prefills: ", prefills, flush=True) |
| 900 | + |
877 | 901 | num_reqs = len(input_batch.req_ids) |
878 | 902 | num_scheduled_tokens = [ |
879 | 903 | scheduler_output.num_scheduled_tokens[id] for id in input_batch.req_ids |
@@ -907,16 +931,29 @@ def reorder_batch_to_split_decodes_and_prefills( |
907 | 931 | sorted_order = np.argsort(req_regions[needs_swap], kind="stable") |
908 | 932 | dest_indices = swap_indices[sorted_order] |
909 | 933 |
|
910 | | - src_dest_map = {int(src): int(dst) for src, dst in zip(swap_indices, dest_indices)} |
| 934 | + idx_mapping = {val: idx for idx, val in enumerate(swap_indices)} |
| 935 | + # Record the original positions of idx in input_batch, the further |
| 936 | + # reorder manipulate the position_ids, so we need this variable helps |
| 937 | + # ping out the real position in input_batch |
| 938 | + indices_positions = copy.deepcopy(swap_indices) |
| 939 | + |
| 940 | + # Then we reorder the swap_indices to dest_indices |
| 941 | + for i in range(len(swap_indices)): |
| 942 | + dst = dest_indices[i] |
| 943 | + src = swap_indices[i] |
| 944 | + if dst != src: |
| 945 | + # Get the real index position in input_batch to swap |
| 946 | + dst_pos = indices_positions[idx_mapping[dst]] |
| 947 | + src_pos = indices_positions[idx_mapping[src]] |
| 948 | + |
| 949 | + input_batch.swap_states(dst_pos, src_pos) |
| 950 | + |
| 951 | + dst_idx = idx_mapping[dst] |
| 952 | + swap_indices[i] = dst |
| 953 | + swap_indices[dst_idx] = src |
| 954 | + idx_mapping[dst] = i |
| 955 | + idx_mapping[src] = dst_idx |
911 | 956 |
|
912 | | - for src in src_dest_map: |
913 | | - dst = src_dest_map[src] |
914 | | - while src != dst: |
915 | | - input_batch.swap_states(src, dst) |
916 | | - # Mark dst as done by updating its destination to itself |
917 | | - next_dst = src_dest_map.get(dst, dst) |
918 | | - src_dest_map[dst] = dst |
919 | | - dst = next_dst |
920 | 957 | return True |
921 | 958 |
|
922 | 959 |
|
|
0 commit comments