@@ -37,6 +37,7 @@ def run_vllm(
3737 requests : list [SampleRequest ],
3838 n : int ,
3939 engine_args : EngineArgs ,
40+ do_profile : bool ,
4041 disable_detokenize : bool = False ,
4142) -> tuple [float , Optional [list [RequestOutput ]]]:
4243 from vllm import LLM , SamplingParams
@@ -75,10 +76,14 @@ def run_vllm(
7576 outputs = None
7677 if not use_beam_search :
7778 start = time .perf_counter ()
79+ if do_profile :
80+ llm .start_profile ()
7881 outputs = llm .generate (prompts ,
7982 sampling_params ,
8083 lora_request = lora_requests ,
8184 use_tqdm = True )
85+ if do_profile :
86+ llm .stop_profile ()
8287 end = time .perf_counter ()
8388 else :
8489 assert lora_requests is None , "BeamSearch API does not support LoRA"
@@ -88,13 +93,17 @@ def run_vllm(
8893 for request in requests :
8994 assert request .expected_output_len == output_len
9095 start = time .perf_counter ()
96+ if do_profile :
97+ llm .start_profile ()
9198 llm .beam_search (
9299 prompts ,
93100 BeamSearchParams (
94101 beam_width = n ,
95102 max_tokens = output_len ,
96103 ignore_eos = True ,
97104 ))
105+ if do_profile :
106+ llm .stop_profile ()
98107 end = time .perf_counter ()
99108 return end - start , outputs
100109
@@ -103,6 +112,7 @@ def run_vllm_chat(
103112 requests : list [SampleRequest ],
104113 n : int ,
105114 engine_args : EngineArgs ,
115+ do_profile : bool ,
106116 disable_detokenize : bool = False ) -> tuple [float , list [RequestOutput ]]:
107117 """
108118 Run vLLM chat benchmark. This function is recommended ONLY for benchmarking
@@ -133,7 +143,11 @@ def run_vllm_chat(
133143 detokenize = not disable_detokenize ,
134144 ))
135145 start = time .perf_counter ()
146+ if do_profile :
147+ llm .start_profile ()
136148 outputs = llm .chat (prompts , sampling_params , use_tqdm = True )
149+ if do_profile :
150+ llm .stop_profile ()
137151 end = time .perf_counter ()
138152 return end - start , outputs
139153
@@ -142,6 +156,7 @@ async def run_vllm_async(
142156 requests : list [SampleRequest ],
143157 n : int ,
144158 engine_args : AsyncEngineArgs ,
159+ do_profile : bool ,
145160 disable_frontend_multiprocessing : bool = False ,
146161 disable_detokenize : bool = False ,
147162) -> float :
@@ -185,6 +200,8 @@ async def run_vllm_async(
185200
186201 generators = []
187202 start = time .perf_counter ()
203+ if do_profile :
204+ await llm .start_profile ()
188205 for i , (prompt , sp ,
189206 lr ) in enumerate (zip (prompts , sampling_params , lora_requests )):
190207 generator = llm .generate (prompt ,
@@ -195,6 +212,8 @@ async def run_vllm_async(
195212 all_gens = merge_async_iterators (* generators )
196213 async for i , res in all_gens :
197214 pass
215+ if do_profile :
216+ await llm .stop_profile ()
198217 end = time .perf_counter ()
199218 return end - start
200219
@@ -543,6 +562,12 @@ def add_cli_args(parser: argparse.ArgumentParser):
543562 type = str ,
544563 default = None ,
545564 help = "Split of the HF dataset." )
565+ parser .add_argument (
566+ "--profile" ,
567+ action = "store_true" ,
568+ default = False ,
569+ help = "Use Torch Profiler. The env variable "
570+ "VLLM_TORCH_PROFILER_DIR must be set to enable profiler." )
546571
547572 # prefix repetition dataset
548573 prefix_repetition_group = parser .add_argument_group (
@@ -600,22 +625,27 @@ def main(args: argparse.Namespace):
600625 requests ,
601626 args .n ,
602627 AsyncEngineArgs .from_cli_args (args ),
603- args .disable_frontend_multiprocessing ,
604- args .disable_detokenize ,
628+ disable_frontend_multiprocessing = args .disable_frontend_multiprocessing ,
629+ disable_detokenize = args .disable_detokenize ,
630+ do_profile = args .profile ,
605631 ))
606632 else :
607633 elapsed_time , request_outputs = run_vllm (
608634 requests , args .n , EngineArgs .from_cli_args (args ),
609- args .disable_detokenize )
635+ disable_detokenize = args .disable_detokenize ,
636+ do_profile = args .profile )
610637 elif args .backend == "hf" :
611638 assert args .tensor_parallel_size == 1
639+ if args .profile :
640+ raise NotImplementedError (
641+ "Profiling not implemented yet for backend='hf'." )
612642 elapsed_time = run_hf (requests , args .model , tokenizer , args .n ,
613643 args .hf_max_batch_size , args .trust_remote_code ,
614644 args .disable_detokenize )
615645 elif args .backend == "vllm-chat" :
616646 elapsed_time , request_outputs = run_vllm_chat (
617647 requests , args .n , EngineArgs .from_cli_args (args ),
618- args .disable_detokenize )
648+ disable_detokenize = args .disable_detokenize , do_profile = args . profile )
619649 else :
620650 raise ValueError (f"Unknown backend: { args .backend } " )
621651
0 commit comments