66import os
77import random
88import time
9- from functools import cache
10- from typing import Dict , List , Optional , Tuple
119from contextlib import contextmanager , nullcontext
10+ from functools import cache
1211from pathlib import Path
13-
12+ from typing import Dict , List , Optional , Tuple
1413
1514import torch
1615import uvloop
@@ -93,15 +92,15 @@ def get_random_lora_request(
9392
9493
9594def sample_requests (tokenizer : PreTrainedTokenizerBase ,
96- args : argparse .Namespace ) -> List [SampleRequest ]:
97-
95+ args : argparse .Namespace ) -> List [SampleRequest ]:
96+
9897 dataset_path : str = args .dataset
9998 num_requests : int = args .num_prompts
10099 fixed_output_len : Optional [int ] = args .output_len
101100 model : str = args .model
102101 if fixed_output_len is not None and fixed_output_len < 4 :
103102 raise ValueError ("output_len too small" )
104-
103+
105104 # Load the dataset.
106105 with open (dataset_path ) as f :
107106 dataset = json .load (f )
@@ -227,8 +226,8 @@ def get_profiling_context(profile_dir: Optional[str] = None):
227226 sampling_params : List [SamplingParams ] = []
228227 for request in requests :
229228 prompts .append (
230- TextPrompt (prompt = request .prompt ,
231- multi_modal_data = request .multi_modal_data ))
229+ TextPrompt (prompt = request .prompt ,
230+ multi_modal_data = request .multi_modal_data ))
232231 sampling_params .append (
233232 SamplingParams (
234233 n = n ,
@@ -244,7 +243,10 @@ def get_profiling_context(profile_dir: Optional[str] = None):
244243 use_beam_search = False
245244
246245 if not use_beam_search :
247- execute = lambda : llm .generate (prompts , sampling_params , lora_request = lora_requests , use_tqdm = True )
246+ execute = lambda : llm .generate (prompts ,
247+ sampling_params ,
248+ lora_request = lora_requests ,
249+ use_tqdm = True )
248250 else :
249251 assert lora_requests is None , "BeamSearch API does not support LoRA"
250252 prompts = [request .prompt for request in requests ]
@@ -253,12 +255,12 @@ def get_profiling_context(profile_dir: Optional[str] = None):
253255 for request in requests :
254256 assert request .expected_output_len == output_len
255257 execute = lambda : llm .beam_search (
256- prompts ,
257- BeamSearchParams (
258- beam_width = n ,
259- max_tokens = output_len ,
260- ignore_eos = True ,
261- ))
258+ prompts ,
259+ BeamSearchParams (
260+ beam_width = n ,
261+ max_tokens = output_len ,
262+ ignore_eos = True ,
263+ ))
262264
263265 if args .profile_torch or args .profile_rpd :
264266 with get_profiling_context (profile_dir ):
@@ -268,7 +270,7 @@ def get_profiling_context(profile_dir: Optional[str] = None):
268270 start = time .perf_counter ()
269271 execute ()
270272 end = time .perf_counter ()
271- return end - start
273+ return end - start
272274
273275
274276async def run_vllm_async (
@@ -288,8 +290,8 @@ async def run_vllm_async(
288290 lora_requests : List [Optional [LoRARequest ]] = []
289291 for request in requests :
290292 prompts .append (
291- TextPrompt (prompt = request .prompt ,
292- multi_modal_data = request .multi_modal_data ))
293+ TextPrompt (prompt = request .prompt ,
294+ multi_modal_data = request .multi_modal_data ))
293295 sampling_params .append (
294296 SamplingParams (
295297 n = n ,
@@ -304,7 +306,7 @@ async def run_vllm_async(
304306 start = time .perf_counter ()
305307 for i , (prompt , sp ,
306308 lr ) in enumerate (zip (prompts , sampling_params , lora_requests )):
307- generator = llm .generate (prompt ,
309+ generator = llm .generate (prompt ,
308310 sp ,
309311 lora_request = lr ,
310312 request_id = f"test{ i } " )
@@ -400,51 +402,51 @@ def main(args: argparse.Namespace):
400402 tokenizer = AutoTokenizer .from_pretrained (
401403 args .tokenizer , trust_remote_code = args .trust_remote_code )
402404 if args .dataset is None :
403- vocab_size = tokenizer .vocab_size
404- requests = []
405- for _ in range (args .num_prompts ):
406-
407- request_tokenizer = tokenizer
408- lora_request : Optional [LoRARequest ] = None
409- if args .enable_lora :
410- lora_request , lora_tokenizer = get_random_lora_request (args )
411- if lora_tokenizer :
412- request_tokenizer = lora_tokenizer
413-
414- # Synthesize a prompt with the given input length.
415- candidate_ids = [
416- random .randint (0 , vocab_size - 1 )
417- for _ in range (args .input_len )
418- ]
419- # As tokenizer may add additional tokens like BOS, we need to try
420- # different lengths to get the desired input length.
421- for _ in range (5 ): # Max attempts to correct
422- candidate_prompt = request_tokenizer .decode (candidate_ids )
423- tokenized_len = len (request_tokenizer .encode (candidate_prompt ))
424-
425- if tokenized_len == args .input_len :
426- break
427-
428- # Adjust length based on difference
429- diff = args .input_len - tokenized_len
430- if diff > 0 :
431- candidate_ids .extend ([
432- random .randint (100 , vocab_size - 100 )
433- for _ in range (diff )
434- ])
435- else :
436- candidate_ids = candidate_ids [:diff ]
437- requests .append (
438- SampleRequest (prompt = candidate_prompt ,
439- prompt_len = args .input_len ,
440- expected_output_len = args .output_len ,
441- lora_request = lora_request ))
405+ vocab_size = tokenizer .vocab_size
406+ requests = []
407+ for _ in range (args .num_prompts ):
408+
409+ request_tokenizer = tokenizer
410+ lora_request : Optional [LoRARequest ] = None
411+ if args .enable_lora :
412+ lora_request , lora_tokenizer = get_random_lora_request (args )
413+ if lora_tokenizer :
414+ request_tokenizer = lora_tokenizer
415+
416+ # Synthesize a prompt with the given input length.
417+ candidate_ids = [
418+ random .randint (0 , vocab_size - 1 )
419+ for _ in range (args .input_len )
420+ ]
421+ # As tokenizer may add additional tokens like BOS, we need to try
422+ # different lengths to get the desired input length.
423+ for _ in range (5 ): # Max attempts to correct
424+ candidate_prompt = request_tokenizer .decode (candidate_ids )
425+ tokenized_len = len (request_tokenizer .encode (candidate_prompt ))
426+
427+ if tokenized_len == args .input_len :
428+ break
429+
430+ # Adjust length based on difference
431+ diff = args .input_len - tokenized_len
432+ if diff > 0 :
433+ candidate_ids .extend ([
434+ random .randint (100 , vocab_size - 100 )
435+ for _ in range (diff )
436+ ])
437+ else :
438+ candidate_ids = candidate_ids [:diff ]
439+ requests .append (
440+ SampleRequest (prompt = candidate_prompt ,
441+ prompt_len = args .input_len ,
442+ expected_output_len = args .output_len ,
443+ lora_request = lora_request ))
442444 else :
443445 requests = sample_requests (tokenizer , args )
444446
445447 is_multi_modal = any (request .multi_modal_data is not None
446448 for request in requests )
447-
449+
448450 if args .backend == "vllm" :
449451 if args .async_engine :
450452 elapsed_time = uvloop .run (
@@ -470,15 +472,16 @@ def main(args: argparse.Namespace):
470472 for request in requests )
471473 total_output_tokens = sum (request .expected_output_len
472474 for request in requests )
473-
475+
474476 if args .profile_torch or args .profile_rpd :
475477 # Profiling complete
476478 pass
477479 else :
478480 if is_multi_modal :
479- print ("\033 [91mWARNING\033 [0m: Multi-modal request detected. The "
480- "following metrics are not accurate because image tokens are not"
481- " counted. See vllm-project/vllm/issues/9778 for details." )
481+ print (
482+ "\033 [91mWARNING\033 [0m: Multi-modal request detected. The "
483+ "following metrics are not accurate because image tokens are"
484+ " not counted. See vllm-project/vllm/issues/9778 for details." )
482485 # TODO(vllm-project/vllm/issues/9778): Count molti-modal token length.
483486 print (f"Throughput: { len (requests ) / elapsed_time :.2f} requests/s, "
484487 f"{ total_num_tokens / elapsed_time :.2f} total tokens/s, "
0 commit comments