|
1 | 1 | # SPDX-License-Identifier: Apache-2.0 |
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
3 | 3 | import abc |
4 | | -import copy |
5 | 4 | import enum |
6 | 5 | import functools |
7 | 6 | from abc import abstractmethod |
@@ -864,7 +863,6 @@ def reorder_batch_to_split_decodes_and_prefills( |
864 | 863 | # NOTE for now we loosely use "decode" to mean requests where attention is |
865 | 864 | # likely memory-bound and "prefill" to mean requests where attention is |
866 | 865 | # likely compute-bound, |
867 | | - |
868 | 866 | num_reqs = len(input_batch.req_ids) |
869 | 867 | num_scheduled_tokens = [ |
870 | 868 | scheduler_output.num_scheduled_tokens[id] for id in input_batch.req_ids |
@@ -900,22 +898,14 @@ def reorder_batch_to_split_decodes_and_prefills( |
900 | 898 |
|
901 | 899 | src_dest_map = {int(src): int(dst) for src, dst in zip(src_indices, orig_indices)} |
902 | 900 |
|
903 | | - # Then we reorder the swap_indices to dest_indices |
904 | | - for i in range(len(swap_indices)): |
905 | | - dst = dest_indices[i] |
906 | | - src = swap_indices[i] |
907 | | - if dst != src: |
908 | | - # Get the real index position in input_batch to swap |
909 | | - dst_pos = indices_positions[idx_mapping[dst]] |
910 | | - src_pos = indices_positions[idx_mapping[src]] |
911 | | - |
912 | | - input_batch.swap_states(dst_pos, src_pos) |
913 | | - |
914 | | - dst_idx = idx_mapping[dst] |
915 | | - swap_indices[i] = dst |
916 | | - swap_indices[dst_idx] = src |
917 | | - idx_mapping[dst] = i |
918 | | - idx_mapping[src] = dst_idx |
| 901 | + for src in src_dest_map: |
| 902 | + dst = src_dest_map[src] |
| 903 | + while src != dst: |
| 904 | + input_batch.swap_states(src, dst) |
| 905 | + # Mark dst as done by updating its destination to itself |
| 906 | + next_dst = src_dest_map.get(dst, dst) |
| 907 | + src_dest_map[dst] = dst |
| 908 | + dst = next_dst |
919 | 909 |
|
920 | 910 | return True |
921 | 911 |
|
|
0 commit comments