Skip to content

Commit 9254401

Browse files
tomasruiztskyloevil
authored andcommitted
Enable --profile in 'vllm bench throughput' (vllm-project#24575)
Signed-off-by: Tomas Ruiz <tomas.ruiz.te@gmail.com>
1 parent 9bed487 commit 9254401

File tree

1 file changed

+34
-4
lines changed

1 file changed

+34
-4
lines changed

vllm/benchmarks/throughput.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def run_vllm(
3737
requests: list[SampleRequest],
3838
n: int,
3939
engine_args: EngineArgs,
40+
do_profile: bool,
4041
disable_detokenize: bool = False,
4142
) -> tuple[float, Optional[list[RequestOutput]]]:
4243
from vllm import LLM, SamplingParams
@@ -75,10 +76,14 @@ def run_vllm(
7576
outputs = None
7677
if not use_beam_search:
7778
start = time.perf_counter()
79+
if do_profile:
80+
llm.start_profile()
7881
outputs = llm.generate(prompts,
7982
sampling_params,
8083
lora_request=lora_requests,
8184
use_tqdm=True)
85+
if do_profile:
86+
llm.stop_profile()
8287
end = time.perf_counter()
8388
else:
8489
assert lora_requests is None, "BeamSearch API does not support LoRA"
@@ -88,13 +93,17 @@ def run_vllm(
8893
for request in requests:
8994
assert request.expected_output_len == output_len
9095
start = time.perf_counter()
96+
if do_profile:
97+
llm.start_profile()
9198
llm.beam_search(
9299
prompts,
93100
BeamSearchParams(
94101
beam_width=n,
95102
max_tokens=output_len,
96103
ignore_eos=True,
97104
))
105+
if do_profile:
106+
llm.stop_profile()
98107
end = time.perf_counter()
99108
return end - start, outputs
100109

@@ -103,6 +112,7 @@ def run_vllm_chat(
103112
requests: list[SampleRequest],
104113
n: int,
105114
engine_args: EngineArgs,
115+
do_profile: bool,
106116
disable_detokenize: bool = False) -> tuple[float, list[RequestOutput]]:
107117
"""
108118
Run vLLM chat benchmark. This function is recommended ONLY for benchmarking
@@ -133,7 +143,11 @@ def run_vllm_chat(
133143
detokenize=not disable_detokenize,
134144
))
135145
start = time.perf_counter()
146+
if do_profile:
147+
llm.start_profile()
136148
outputs = llm.chat(prompts, sampling_params, use_tqdm=True)
149+
if do_profile:
150+
llm.stop_profile()
137151
end = time.perf_counter()
138152
return end - start, outputs
139153

@@ -142,6 +156,7 @@ async def run_vllm_async(
142156
requests: list[SampleRequest],
143157
n: int,
144158
engine_args: AsyncEngineArgs,
159+
do_profile: bool,
145160
disable_frontend_multiprocessing: bool = False,
146161
disable_detokenize: bool = False,
147162
) -> float:
@@ -185,6 +200,8 @@ async def run_vllm_async(
185200

186201
generators = []
187202
start = time.perf_counter()
203+
if do_profile:
204+
await llm.start_profile()
188205
for i, (prompt, sp,
189206
lr) in enumerate(zip(prompts, sampling_params, lora_requests)):
190207
generator = llm.generate(prompt,
@@ -195,6 +212,8 @@ async def run_vllm_async(
195212
all_gens = merge_async_iterators(*generators)
196213
async for i, res in all_gens:
197214
pass
215+
if do_profile:
216+
await llm.stop_profile()
198217
end = time.perf_counter()
199218
return end - start
200219

@@ -543,6 +562,12 @@ def add_cli_args(parser: argparse.ArgumentParser):
543562
type=str,
544563
default=None,
545564
help="Split of the HF dataset.")
565+
parser.add_argument(
566+
"--profile",
567+
action="store_true",
568+
default=False,
569+
help="Use Torch Profiler. The env variable "
570+
"VLLM_TORCH_PROFILER_DIR must be set to enable profiler.")
546571

547572
# prefix repetition dataset
548573
prefix_repetition_group = parser.add_argument_group(
@@ -600,22 +625,27 @@ def main(args: argparse.Namespace):
600625
requests,
601626
args.n,
602627
AsyncEngineArgs.from_cli_args(args),
603-
args.disable_frontend_multiprocessing,
604-
args.disable_detokenize,
628+
disable_frontend_multiprocessing=args.disable_frontend_multiprocessing,
629+
disable_detokenize=args.disable_detokenize,
630+
do_profile=args.profile,
605631
))
606632
else:
607633
elapsed_time, request_outputs = run_vllm(
608634
requests, args.n, EngineArgs.from_cli_args(args),
609-
args.disable_detokenize)
635+
disable_detokenize=args.disable_detokenize,
636+
do_profile=args.profile)
610637
elif args.backend == "hf":
611638
assert args.tensor_parallel_size == 1
639+
if args.profile:
640+
raise NotImplementedError(
641+
"Profiling not implemented yet for backend='hf'.")
612642
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
613643
args.hf_max_batch_size, args.trust_remote_code,
614644
args.disable_detokenize)
615645
elif args.backend == "vllm-chat":
616646
elapsed_time, request_outputs = run_vllm_chat(
617647
requests, args.n, EngineArgs.from_cli_args(args),
618-
args.disable_detokenize)
648+
disable_detokenize=args.disable_detokenize, do_profile=args.profile)
619649
else:
620650
raise ValueError(f"Unknown backend: {args.backend}")
621651

0 commit comments

Comments
 (0)