Skip to content

Commit 1228eb0

Browse files
AdrianAbeytaAdrianAbeyta
authored andcommitted
Fix linter errors
1 parent ad55c7f commit 1228eb0

File tree

1 file changed

+66
-63
lines changed

1 file changed

+66
-63
lines changed

benchmarks/profiling/benchmark_throughput.py

Lines changed: 66 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,10 @@
66
import os
77
import random
88
import time
9-
from functools import cache
10-
from typing import Dict, List, Optional, Tuple
119
from contextlib import contextmanager, nullcontext
10+
from functools import cache
1211
from pathlib import Path
13-
12+
from typing import Dict, List, Optional, Tuple
1413

1514
import torch
1615
import uvloop
@@ -93,15 +92,15 @@ def get_random_lora_request(
9392

9493

9594
def sample_requests(tokenizer: PreTrainedTokenizerBase,
96-
args: argparse.Namespace) -> List[SampleRequest]:
97-
95+
args: argparse.Namespace) -> List[SampleRequest]:
96+
9897
dataset_path: str = args.dataset
9998
num_requests: int = args.num_prompts
10099
fixed_output_len: Optional[int] = args.output_len
101100
model: str = args.model
102101
if fixed_output_len is not None and fixed_output_len < 4:
103102
raise ValueError("output_len too small")
104-
103+
105104
# Load the dataset.
106105
with open(dataset_path) as f:
107106
dataset = json.load(f)
@@ -227,8 +226,8 @@ def get_profiling_context(profile_dir: Optional[str] = None):
227226
sampling_params: List[SamplingParams] = []
228227
for request in requests:
229228
prompts.append(
230-
TextPrompt(prompt=request.prompt,
231-
multi_modal_data=request.multi_modal_data))
229+
TextPrompt(prompt=request.prompt,
230+
multi_modal_data=request.multi_modal_data))
232231
sampling_params.append(
233232
SamplingParams(
234233
n=n,
@@ -244,7 +243,10 @@ def get_profiling_context(profile_dir: Optional[str] = None):
244243
use_beam_search = False
245244

246245
if not use_beam_search:
247-
execute = lambda: llm.generate(prompts, sampling_params, lora_request=lora_requests, use_tqdm=True)
246+
execute = lambda: llm.generate(prompts,
247+
sampling_params,
248+
lora_request=lora_requests,
249+
use_tqdm=True)
248250
else:
249251
assert lora_requests is None, "BeamSearch API does not support LoRA"
250252
prompts = [request.prompt for request in requests]
@@ -253,12 +255,12 @@ def get_profiling_context(profile_dir: Optional[str] = None):
253255
for request in requests:
254256
assert request.expected_output_len == output_len
255257
execute = lambda: llm.beam_search(
256-
prompts,
257-
BeamSearchParams(
258-
beam_width=n,
259-
max_tokens=output_len,
260-
ignore_eos=True,
261-
))
258+
prompts,
259+
BeamSearchParams(
260+
beam_width=n,
261+
max_tokens=output_len,
262+
ignore_eos=True,
263+
))
262264

263265
if args.profile_torch or args.profile_rpd:
264266
with get_profiling_context(profile_dir):
@@ -268,7 +270,7 @@ def get_profiling_context(profile_dir: Optional[str] = None):
268270
start = time.perf_counter()
269271
execute()
270272
end = time.perf_counter()
271-
return end - start
273+
return end - start
272274

273275

274276
async def run_vllm_async(
@@ -288,8 +290,8 @@ async def run_vllm_async(
288290
lora_requests: List[Optional[LoRARequest]] = []
289291
for request in requests:
290292
prompts.append(
291-
TextPrompt(prompt=request.prompt,
292-
multi_modal_data=request.multi_modal_data))
293+
TextPrompt(prompt=request.prompt,
294+
multi_modal_data=request.multi_modal_data))
293295
sampling_params.append(
294296
SamplingParams(
295297
n=n,
@@ -304,7 +306,7 @@ async def run_vllm_async(
304306
start = time.perf_counter()
305307
for i, (prompt, sp,
306308
lr) in enumerate(zip(prompts, sampling_params, lora_requests)):
307-
generator = llm.generate(prompt,
309+
generator = llm.generate(prompt,
308310
sp,
309311
lora_request=lr,
310312
request_id=f"test{i}")
@@ -400,51 +402,51 @@ def main(args: argparse.Namespace):
400402
tokenizer = AutoTokenizer.from_pretrained(
401403
args.tokenizer, trust_remote_code=args.trust_remote_code)
402404
if args.dataset is None:
403-
vocab_size = tokenizer.vocab_size
404-
requests = []
405-
for _ in range(args.num_prompts):
406-
407-
request_tokenizer = tokenizer
408-
lora_request: Optional[LoRARequest] = None
409-
if args.enable_lora:
410-
lora_request, lora_tokenizer = get_random_lora_request(args)
411-
if lora_tokenizer:
412-
request_tokenizer = lora_tokenizer
413-
414-
# Synthesize a prompt with the given input length.
415-
candidate_ids = [
416-
random.randint(0, vocab_size - 1)
417-
for _ in range(args.input_len)
418-
]
419-
# As tokenizer may add additional tokens like BOS, we need to try
420-
# different lengths to get the desired input length.
421-
for _ in range(5): # Max attempts to correct
422-
candidate_prompt = request_tokenizer.decode(candidate_ids)
423-
tokenized_len = len(request_tokenizer.encode(candidate_prompt))
424-
425-
if tokenized_len == args.input_len:
426-
break
427-
428-
# Adjust length based on difference
429-
diff = args.input_len - tokenized_len
430-
if diff > 0:
431-
candidate_ids.extend([
432-
random.randint(100, vocab_size - 100)
433-
for _ in range(diff)
434-
])
435-
else:
436-
candidate_ids = candidate_ids[:diff]
437-
requests.append(
438-
SampleRequest(prompt=candidate_prompt,
439-
prompt_len=args.input_len,
440-
expected_output_len=args.output_len,
441-
lora_request=lora_request))
405+
vocab_size = tokenizer.vocab_size
406+
requests = []
407+
for _ in range(args.num_prompts):
408+
409+
request_tokenizer = tokenizer
410+
lora_request: Optional[LoRARequest] = None
411+
if args.enable_lora:
412+
lora_request, lora_tokenizer = get_random_lora_request(args)
413+
if lora_tokenizer:
414+
request_tokenizer = lora_tokenizer
415+
416+
# Synthesize a prompt with the given input length.
417+
candidate_ids = [
418+
random.randint(0, vocab_size - 1)
419+
for _ in range(args.input_len)
420+
]
421+
# As tokenizer may add additional tokens like BOS, we need to try
422+
# different lengths to get the desired input length.
423+
for _ in range(5): # Max attempts to correct
424+
candidate_prompt = request_tokenizer.decode(candidate_ids)
425+
tokenized_len = len(request_tokenizer.encode(candidate_prompt))
426+
427+
if tokenized_len == args.input_len:
428+
break
429+
430+
# Adjust length based on difference
431+
diff = args.input_len - tokenized_len
432+
if diff > 0:
433+
candidate_ids.extend([
434+
random.randint(100, vocab_size - 100)
435+
for _ in range(diff)
436+
])
437+
else:
438+
candidate_ids = candidate_ids[:diff]
439+
requests.append(
440+
SampleRequest(prompt=candidate_prompt,
441+
prompt_len=args.input_len,
442+
expected_output_len=args.output_len,
443+
lora_request=lora_request))
442444
else:
443445
requests = sample_requests(tokenizer, args)
444446

445447
is_multi_modal = any(request.multi_modal_data is not None
446448
for request in requests)
447-
449+
448450
if args.backend == "vllm":
449451
if args.async_engine:
450452
elapsed_time = uvloop.run(
@@ -470,15 +472,16 @@ def main(args: argparse.Namespace):
470472
for request in requests)
471473
total_output_tokens = sum(request.expected_output_len
472474
for request in requests)
473-
475+
474476
if args.profile_torch or args.profile_rpd:
475477
# Profiling complete
476478
pass
477479
else:
478480
if is_multi_modal:
479-
print("\033[91mWARNING\033[0m: Multi-modal request detected. The "
480-
"following metrics are not accurate because image tokens are not"
481-
" counted. See vllm-project/vllm/issues/9778 for details.")
481+
print(
482+
"\033[91mWARNING\033[0m: Multi-modal request detected. The "
483+
"following metrics are not accurate because image tokens are"
484+
" not counted. See vllm-project/vllm/issues/9778 for details.")
482485
# TODO(vllm-project/vllm/issues/9778): Count molti-modal token length.
483486
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
484487
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "

0 commit comments

Comments
 (0)