Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix latency benchmark script #118

Merged
merged 2 commits into from
May 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 33 additions & 29 deletions benchmark/benchmark_latency.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,75 @@
import argparse
import time
from typing import List

from tqdm import tqdm
import numpy as np
import torch
from tqdm import tqdm

from cacheflow.core.server import (
add_server_arguments, process_server_arguments,
init_local_server_and_frontend_with_arguments)
from cacheflow.sampling_params import SamplingParams
from cacheflow import LLM, SamplingParams


def main(args: argparse.Namespace):
server, frontend = init_local_server_and_frontend_with_arguments(args)
print(args)

# Process all the requests in a single batch if possible.
# NOTE(woosuk): If the request cannot be processed in a single batch,
# the server will automatically process the request in multiple batches.
llm = LLM(
model=args.model,
tensor_parallel_size=args.tensor_parallel_size,
max_num_seqs=args.batch_size,
max_num_batched_tokens=args.batch_size * args.input_len,
)

sampling_params = SamplingParams(
n=args.n,
temperature=0.0 if args.use_beam_search else 1.0,
top_p=1.0,
use_beam_search=args.use_beam_search,
stop_token_ids=set(),
ignore_eos=True,
max_tokens=args.output_len,
)
print(sampling_params)
input_token_ids = [0] * args.input_len
dummy_prompts = [""] * args.batch_size
dummy_prompt_token_ids = [[0] * args.input_len] * args.batch_size

def profile_step(profile=False):
def run_to_completion(profile: bool = False):
if profile:
torch.cuda.cudart().cudaProfilerStart()
for _ in range(args.batch_size):
dummy_prompt = ""
frontend._add_query(dummy_prompt, input_token_ids, sampling_params)
server.add_sequence_groups(frontend.get_inputs())
start_time = time.time()
while True:
server.step()
if not server.has_unfinished_requests():
break

llm.generate(dummy_prompts, sampling_params, dummy_prompt_token_ids,
use_tqdm=False)

end_time = time.time()
latency = end_time - start_time
if profile:
torch.cuda.cudart().cudaProfilerStop()
return latency

print("Warm up step")
profile_step()
print("Warming up...")
run_to_completion(profile=False)

# Benchmark.
latencies = []
for _ in tqdm(range(3), desc="Profile step"):
latencies.append(profile_step())
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
latencies.append(run_to_completion(profile=False))
print(f'Avg latency: {np.mean(latencies)} seconds')


if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Benchmark the latency of decoding a single sentence.')
parser = add_server_arguments(parser)
description='Benchmark the latency of processing a single batch of '
'requests till completion.')
parser.add_argument('--model', type=str, default='facebook/opt-125m')
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--input-len', type=int, default=32)
parser.add_argument('--output-len', type=int, default=128)
parser.add_argument('--batch-size', type=int, default=8)
parser.add_argument('--n', type=int, default=1)
parser.add_argument('--n', type=int, default=1,
help='Number of generated sequences per prompt.')
parser.add_argument('--use-beam-search', action='store_true')
parser.add_argument('--num-iters', type=int, default=3,
help='Number of iterations to run.')
args = parser.parse_args()
args = process_server_arguments(args)
args.max_num_batched_tokens = max(
args.max_num_batched_tokens, args.batch_size * args.input_len)
print(args)
main(args)
12 changes: 10 additions & 2 deletions cacheflow/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,26 @@ def generate(
self,
prompts: List[str],
sampling_params: Optional[SamplingParams] = None,
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True,
) -> List[RequestOutput]:
if sampling_params is None:
# Use default sampling params.
sampling_params = SamplingParams()
# Initialize tqdm.
if use_tqdm:
pbar = tqdm(total=len(prompts), desc="Processed prompts")

# Add requests to the server.
for prompt in prompts:
for i in range(len(prompts)):
prompt = prompts[i]
if prompt_token_ids is None:
token_ids = None
else:
token_ids = prompt_token_ids[i]
request_id = str(next(self.request_counter))
self.llm_server.add_request(request_id, prompt, sampling_params)
self.llm_server.add_request(request_id, prompt, sampling_params,
token_ids)

# Run the server.
outputs: List[RequestOutput] = []
Expand Down