Skip to content

Commit

Permalink
[Benchmark] Add --async-engine option to benchmark_throughput.py (v…
Browse files Browse the repository at this point in the history
…llm-project#7964)

Signed-off-by: Alvant <alvasian@yandex.ru>
  • Loading branch information
njhill authored and Alvant committed Oct 26, 2024
1 parent ee73bea commit f80915f
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 19 deletions.
113 changes: 109 additions & 4 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
from typing import List, Optional, Tuple

import torch
import uvloop
from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)

from vllm.engine.arg_utils import EngineArgs
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args)
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser
from vllm.utils import FlexibleArgumentParser, merge_async_iterators


def sample_requests(
Expand Down Expand Up @@ -135,6 +138,93 @@ def run_vllm(
return end - start


async def run_vllm_async(
requests: List[Tuple[str, int, int]],
model: str,
tokenizer: str,
quantization: Optional[str],
tensor_parallel_size: int,
seed: int,
n: int,
use_beam_search: bool,
trust_remote_code: bool,
dtype: str,
max_model_len: Optional[int],
enforce_eager: bool,
kv_cache_dtype: str,
quantization_param_path: Optional[str],
device: str,
enable_prefix_caching: bool,
enable_chunked_prefill: bool,
max_num_batched_tokens: int,
distributed_executor_backend: Optional[str],
gpu_memory_utilization: float = 0.9,
num_scheduler_steps: int = 1,
use_v2_block_manager: bool = False,
download_dir: Optional[str] = None,
load_format: str = EngineArgs.load_format,
disable_async_output_proc: bool = False,
disable_frontend_multiprocessing: bool = False,
) -> float:
from vllm import SamplingParams
engine_args = AsyncEngineArgs(
model=model,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
dtype=dtype,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
quantization_param_path=quantization_param_path,
device=device,
enable_prefix_caching=enable_prefix_caching,
download_dir=download_dir,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
distributed_executor_backend=distributed_executor_backend,
load_format=load_format,
num_scheduler_steps=num_scheduler_steps,
use_v2_block_manager=use_v2_block_manager,
disable_async_output_proc=disable_async_output_proc,
worker_use_ray=False,
engine_use_ray=False,
disable_log_requests=True,
)

async with build_async_engine_client_from_engine_args(
engine_args, disable_frontend_multiprocessing) as llm:

# Add the requests to the engine.
prompts: List[str] = []
sampling_params: List[SamplingParams] = []
for prompt, _, output_len in requests:
prompts.append(prompt)
sampling_params.append(
SamplingParams(
n=n,
temperature=0.0 if use_beam_search else 1.0,
top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True,
max_tokens=output_len,
))

generators = []
start = time.perf_counter()
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
generator = llm.generate(prompt, sp, request_id=f"test{i}")
generators.append(generator)
all_gens = merge_async_iterators(*generators)
async for i, res in all_gens:
pass
end = time.perf_counter()
return end - start


def run_hf(
requests: List[Tuple[str, int, int]],
model: str,
Expand Down Expand Up @@ -230,7 +320,7 @@ def main(args: argparse.Namespace):
args.output_len)

if args.backend == "vllm":
elapsed_time = run_vllm(
run_args = [
requests, args.model, args.tokenizer, args.quantization,
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.max_model_len,
Expand All @@ -240,7 +330,14 @@ def main(args: argparse.Namespace):
args.max_num_batched_tokens, args.distributed_executor_backend,
args.gpu_memory_utilization, args.num_scheduler_steps,
args.use_v2_block_manager, args.download_dir, args.load_format,
args.disable_async_output_proc)
args.disable_async_output_proc
]

if args.async_engine:
run_args.append(args.disable_frontend_multiprocessing)
elapsed_time = uvloop.run(run_vllm_async(*run_args))
else:
elapsed_time = run_vllm(*run_args)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
Expand Down Expand Up @@ -426,6 +523,14 @@ def main(args: argparse.Namespace):
action='store_true',
default=False,
help="Disable async output processor for vLLM backend.")
parser.add_argument("--async-engine",
action='store_true',
default=False,
help="Use vLLM async engine rather than LLM class.")
parser.add_argument("--disable-frontend-multiprocessing",
action='store_true',
default=False,
help="Disable decoupled async engine frontend.")
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
Expand Down
45 changes: 30 additions & 15 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@


def model_is_embedding(model_name: str, trust_remote_code: bool,
quantization: str) -> bool:
quantization: Optional[str]) -> bool:
return ModelConfig(model=model_name,
tokenizer=model_name,
tokenizer_mode="auto",
Expand Down Expand Up @@ -96,13 +96,6 @@ async def _force_log():
@asynccontextmanager
async def build_async_engine_client(
args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]:
"""
Create AsyncEngineClient, either:
- in-process using the AsyncLLMEngine Directly
- multiprocess using AsyncLLMEngine RPC
Returns the Client or None if the creation failed.
"""

# Context manager to handle async_engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
Expand All @@ -112,14 +105,37 @@ async def build_async_engine_client(
# Backend itself still global for the silly lil' health handler
global async_engine_client

async with build_async_engine_client_from_engine_args(
engine_args, args.disable_frontend_multiprocessing) as engine:

async_engine_client = engine # type: ignore[assignment]
yield engine


@asynccontextmanager
async def build_async_engine_client_from_engine_args(
engine_args: AsyncEngineArgs,
disable_frontend_multiprocessing: bool = False,
) -> AsyncIterator[Optional[AsyncEngineClient]]:
"""
Create AsyncEngineClient, either:
- in-process using the AsyncLLMEngine Directly
- multiprocess using AsyncLLMEngine RPC
Returns the Client or None if the creation failed.
"""

# If manually triggered or embedding model, use AsyncLLMEngine in process.
# TODO: support embedding model via RPC.
if (model_is_embedding(args.model, args.trust_remote_code,
args.quantization)
or args.disable_frontend_multiprocessing):
async_engine_client = AsyncLLMEngine.from_engine_args(
if (model_is_embedding(engine_args.model, engine_args.trust_remote_code,
engine_args.quantization)
or disable_frontend_multiprocessing):
engine_client = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
yield async_engine_client
try:
yield engine_client
finally:
engine_client.shutdown_background_loop()
return

# Otherwise, use the multiprocessing AsyncLLMEngine.
Expand Down Expand Up @@ -148,7 +164,6 @@ async def build_async_engine_client(
# NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above)
rpc_client = AsyncEngineRPCClient(rpc_path)
async_engine_client = rpc_client # type: ignore

# Start RPCServer in separate process (holds the AsyncLLMEngine).
context = multiprocessing.get_context("spawn")
Expand All @@ -174,7 +189,7 @@ async def build_async_engine_client(
yield None
return

yield async_engine_client
yield rpc_client # type: ignore[misc]
finally:
# Ensure rpc server process was terminated
rpc_server_process.terminate()
Expand Down
4 changes: 4 additions & 0 deletions vllm/entrypoints/openai/rpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import cloudpickle
import zmq
import zmq.asyncio
from zmq import Frame # type: ignore[attr-defined]
from zmq.asyncio import Socket

from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
Expand Down Expand Up @@ -214,6 +215,7 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,

# Await the data from the Server.
frame = await socket.recv(copy=False)
assert isinstance(frame, Frame)
data = pickle.loads(frame.buffer)

if isinstance(data, Exception):
Expand Down Expand Up @@ -247,6 +249,7 @@ async def do_rpc_call(socket: Socket, request: RPC_REQUEST_TYPE):
f"{self._data_timeout} ms")

frame = await socket.recv(copy=False)
assert isinstance(frame, Frame)
return pickle.loads(frame.buffer)

# Make a new socket connection.
Expand Down Expand Up @@ -395,6 +398,7 @@ async def generate(
# Stream back the results from the RPC Server.
while not finished:
message = await socket.recv(copy=False)
assert isinstance(message, Frame)
request_output = pickle.loads(message.buffer)

if isinstance(request_output, Exception):
Expand Down

0 comments on commit f80915f

Please sign in to comment.