|
55 | 55 | from vllm.sequence import IntermediateTensors, PoolerOutput |
56 | 56 | from vllm.tasks import GenerationTask, PoolingTask, SupportedTask |
57 | 57 | from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, |
58 | | - GiB_bytes, check_use_alibi, get_dtype_size, |
| 58 | + GiB_bytes, cdiv, check_use_alibi, get_dtype_size, |
59 | 59 | is_pin_memory_available, |
60 | 60 | length_from_prompt_token_ids_or_embeds, round_up, |
61 | 61 | supports_dynamo) |
@@ -2913,12 +2913,13 @@ def _dummy_run( |
2913 | 2913 | # Note: Overriding max_query_len to be the prefill tokens |
2914 | 2914 | max_query_len = num_prefill_tokens |
2915 | 2915 | elif uniform_decode: |
2916 | | - num_reqs = num_tokens // max_query_len |
| 2916 | + assert not create_mixed_batch |
| 2917 | + num_reqs = cdiv(num_tokens, max_query_len) |
2917 | 2918 | assert num_reqs <= max_num_reqs, \ |
2918 | 2919 | "Do not capture num_reqs > max_num_reqs for uniform batch" |
2919 | 2920 | num_scheduled_tokens_list = [max_query_len] * num_reqs |
2920 | 2921 | if num_tokens % max_query_len != 0: |
2921 | | - num_scheduled_tokens_list[-1] += num_tokens % max_query_len |
| 2922 | + num_scheduled_tokens_list[-1] = num_tokens % max_query_len |
2922 | 2923 | else: |
2923 | 2924 | num_reqs = min(num_tokens, max_num_reqs) |
2924 | 2925 | min_tokens_per_req = num_tokens // num_reqs |
|
0 commit comments