Skip to content

Commit 25dd155

Browse files
MatthewBonannirobertgshaw2-redhatBam4d
authored andcommitted
[BugFix] [DP/EP] Fix slow execution when BS <= DP (#25407)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Chris Bamford <chrisbam4d@gmail.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
1 parent 864bbe3 commit 25dd155

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
from vllm.sequence import IntermediateTensors, PoolerOutput
5656
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
5757
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
58-
GiB_bytes, check_use_alibi, get_dtype_size,
58+
GiB_bytes, cdiv, check_use_alibi, get_dtype_size,
5959
is_pin_memory_available,
6060
length_from_prompt_token_ids_or_embeds, round_up,
6161
supports_dynamo)
@@ -2913,12 +2913,13 @@ def _dummy_run(
29132913
# Note: Overriding max_query_len to be the prefill tokens
29142914
max_query_len = num_prefill_tokens
29152915
elif uniform_decode:
2916-
num_reqs = num_tokens // max_query_len
2916+
assert not create_mixed_batch
2917+
num_reqs = cdiv(num_tokens, max_query_len)
29172918
assert num_reqs <= max_num_reqs, \
29182919
"Do not capture num_reqs > max_num_reqs for uniform batch"
29192920
num_scheduled_tokens_list = [max_query_len] * num_reqs
29202921
if num_tokens % max_query_len != 0:
2921-
num_scheduled_tokens_list[-1] += num_tokens % max_query_len
2922+
num_scheduled_tokens_list[-1] = num_tokens % max_query_len
29222923
else:
29232924
num_reqs = min(num_tokens, max_num_reqs)
29242925
min_tokens_per_req = num_tokens // num_reqs

vllm/v1/worker/gpu_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ def profile(self, is_start: bool = True):
487487
sort_by="self_cuda_time_total"))
488488

489489
def execute_dummy_batch(self) -> None:
490-
self.model_runner._dummy_run(1)
490+
self.model_runner._dummy_run(1, uniform_decode=True)
491491

492492
def add_lora(self, lora_request: LoRARequest) -> bool:
493493
return self.model_runner.add_lora(lora_request)

0 commit comments

Comments
 (0)