diff --git a/docs/contributing/benchmarks.md b/docs/contributing/benchmarks.md index 52a16d7bdbff..89524ed3bc63 100644 --- a/docs/contributing/benchmarks.md +++ b/docs/contributing/benchmarks.md @@ -7,7 +7,7 @@ toc_depth: 4 vLLM provides comprehensive benchmarking tools for performance testing and evaluation: - **[Benchmark CLI](#benchmark-cli)**: `vllm bench` CLI tools and specialized benchmark scripts for interactive performance testing -- **[Batch Scripts](#batch-scripts)**: Run `vllm bench` against multiple configurations conveniently +- **[Parameter sweeps](#parameter-sweeps)**: Automate `vllm bench` runs for multiple configurations - **[Performance benchmarks](#performance-benchmarks)**: Automated CI benchmarks for development - **[Nightly benchmarks](#nightly-benchmarks)**: Comparative benchmarks against alternatives @@ -925,15 +925,13 @@ throughput numbers correctly is also adjusted. -## Batch Scripts +## Parameter Sweeps -### Batch Serving Script +### Online Benchmark -[`vllm/benchmarks/serve_multi.py`](../../vllm/benchmarks/serve_multi.py) automatically starts `vllm serve` and runs `vllm bench serve` over multiple configurations. +[`vllm/benchmarks/sweep/serve.py`](../../vllm/benchmarks/sweep/serve.py) automatically starts `vllm serve` and runs `vllm bench serve` to evaluate vLLM over multiple configurations. -#### Batch Mode - -The basic purpose of this script is to evaluate vLLM under different settings. Follows these steps to run the script: +Follow these steps to run the script: 1. Construct the base command to `vllm serve`, and pass it to the `--serve-cmd` option. 2. Construct the base command to `vllm bench serve`, and pass it to the `--bench-cmd` option. @@ -996,7 +994,7 @@ The basic purpose of this script is to evaluate vLLM under different settings. F Example command: ```bash -python vllm/benchmarks/serve_multi.py \ +python -m vllm.benchmarks.sweep.serve \ --serve-cmd 'vllm serve meta-llama/Llama-2-7b-chat-hf' \ --bench-cmd 'vllm bench serve --model meta-llama/Llama-2-7b-chat-hf --backend vllm --endpoint /v1/completions --dataset-name sharegpt --dataset-path benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json' \ --serve-params benchmarks/serve_hparams.json \ @@ -1018,9 +1016,9 @@ python vllm/benchmarks/serve_multi.py \ !!! tip You can use the `--resume` option to continue the parameter sweep if one of the runs failed. -#### SLA Mode +### SLA Auto-Tuner -By passing SLA constraints via `--sla-params`, you can run this script in SLA mode, causing it to adjust either the request rate or concurrency (choose using `--sla-variable`) in order to satisfy the SLA constraints. +[`vllm/benchmarks/sweep/serve_sla.py`](../../vllm/benchmarks/sweep/serve_sla.py) is a wrapper over [`vllm/benchmarks/sweep/serve.py`](../../vllm/benchmarks/sweep/serve.py) that tunes either the request rate or concurrency (choose using `--sla-variable`) in order to satisfy the SLA constraints given by `--sla-params`. For example, to ensure E2E latency within different target values for 99% of requests: @@ -1044,7 +1042,7 @@ For example, to ensure E2E latency within different target values for 99% of req Example command: ```bash -python vllm/benchmarks/serve_multi.py \ +python -m vllm.benchmarks.sweep.serve_sla \ --serve-cmd 'vllm serve meta-llama/Llama-2-7b-chat-hf' \ --bench-cmd 'vllm bench serve --model meta-llama/Llama-2-7b-chat-hf --backend vllm --endpoint /v1/completions --dataset-name sharegpt --dataset-path benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json' \ --serve-params benchmarks/serve_hparams.json \ @@ -1066,6 +1064,24 @@ The algorithm for adjusting the SLA variable is as follows: For a given combination of `--serve-params` and `--bench-params`, we share the benchmark results across `--sla-params` to avoid rerunning benchmarks with the same SLA variable value. +### Visualizer + +[`vllm/benchmarks/sweep/plot.py`](../../vllm/benchmarks/sweep/plot.py) can be used to plot performance curves from parameter sweep results. + +Example command: + +```bash +python -m vllm.benchmarks.sweep.plot benchmarks/results/ \ + --var-x max_concurrency \ + --row-by random_input_len \ + --col-by random_output_len \ + --curve-by api_server_count,max_num_batched_tokens \ + --filter-by 'max_concurrency<=1024' +``` + +!!! tip + You can use `--dry-run` to preview the figures to be plotted. + ## Performance Benchmarks The performance benchmarks are used for development to confirm whether new changes improve performance under various workloads. They are triggered on every commit with both the `perf-benchmarks` and `ready` labels, and when a PR is merged into vLLM. diff --git a/vllm/benchmarks/serve_multi.py b/vllm/benchmarks/serve_multi.py deleted file mode 100644 index e8524473aedd..000000000000 --- a/vllm/benchmarks/serve_multi.py +++ /dev/null @@ -1,1157 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import argparse -import contextlib -import json -import math -import os -import shlex -import signal -import subprocess -from abc import ABC, abstractmethod -from datetime import datetime -from pathlib import Path -from typing import Literal, get_args - -import pandas as pd -import requests -import seaborn as sns -from typing_extensions import assert_never, override - -_BAD_PARAMS_TYPE_MSG = ( - "The parameters to vary should be expressed as a JSON list of dictionaries." -) - - -def _parse_params(params: list[dict[str, object]]): - if not isinstance(params, list): - raise TypeError(f"{_BAD_PARAMS_TYPE_MSG} Found JSON type {type(params)}") - - for comb in params: - if not isinstance(comb, dict): - raise TypeError(f"{_BAD_PARAMS_TYPE_MSG} Found item type {type(comb)}") - - return params - - -class SLACriterionBase(ABC): - def __init__(self, target: float) -> None: - super().__init__() - - self.target = target - - @abstractmethod - def validate(self, actual: float) -> bool: - """Return `True` if this criterion is met; otherwise `False`.""" - raise NotImplementedError - - @abstractmethod - def format_cond(self, lhs: str) -> str: - raise NotImplementedError - - def print_and_validate( - self, - metrics: dict[str, float], - metrics_key: str, - ) -> bool: - metric = metrics[metrics_key] - result = self.validate(metric) - - cond = self.format_cond(f"{metrics_key} = {metric:.2f}") - print(f"Validating SLA: {cond} | " + ("PASSED" if result else "FAILED")) - - return result - - -class SLALessThan(SLACriterionBase): - @override - def validate(self, actual: float) -> bool: - return actual < self.target - - @override - def format_cond(self, lhs: str) -> str: - return f"{lhs}<{self.target:.2f}" - - -class SLALessThanOrEqual(SLACriterionBase): - @override - def validate(self, actual: float) -> bool: - return actual <= self.target - - @override - def format_cond(self, lhs: str) -> str: - return f"{lhs}<={self.target:.2f}" - - -class SLAGreaterThan(SLACriterionBase): - @override - def validate(self, actual: float) -> bool: - return actual > self.target - - @override - def format_cond(self, lhs: str) -> str: - return f"{lhs}>{self.target:.2f}" - - -class SLAGreaterThanOrEqual(SLACriterionBase): - @override - def validate(self, actual: float) -> bool: - return actual >= self.target - - @override - def format_cond(self, lhs: str) -> str: - return f"{lhs}>={self.target:.2f}" - - -# NOTE: The ordering is important! Match longer op_keys first -SLA_CRITERIA: dict[str, type[SLACriterionBase]] = { - "<=": SLALessThanOrEqual, - ">=": SLAGreaterThanOrEqual, - "<": SLALessThan, - ">": SLAGreaterThan, -} - - -def _parse_sla_item(sla_item: dict[str, str]): - sla_criteria: dict[str, SLACriterionBase] = {} - - for metric_key, metric_value in sla_item.items(): - for op_key in SLA_CRITERIA: - if metric_value.startswith(op_key): - sla_criteria[metric_key] = SLA_CRITERIA[op_key]( - float(metric_value.removeprefix(op_key)) - ) - break - else: - raise ValueError( - f"Invalid operator for SLA constraint '{metric_key}={metric_value}'. " - f"Valid operators are: {set(SLA_CRITERIA)}", - ) - - return sla_criteria - - -def _parse_sla(sla: list[dict[str, str]]): - return [_parse_sla_item(item) for item in sla] - - -# In JSON, we prefer "_" -def _iter_param_key_candidates(param_key: str): - yield param_key - yield param_key.replace("-", "_") - yield param_key.replace("_", "-") - - -# In CLI, we prefer "-" -def _iter_cmd_key_candidates(param_key: str): - for k in reversed(tuple(_iter_param_key_candidates(param_key))): - yield "--" + k - - -def _normalize_cmd_key(param_key: str): - return next(_iter_cmd_key_candidates(param_key)) - - -def _override_args(cmd: list[str], params: dict[str, object]): - cmd = list(cmd) - - for k, v in params.items(): - for k_candidate in _iter_cmd_key_candidates(k): - try: - k_idx = cmd.index(k_candidate) - - if isinstance(v, bool): - cmd[k_idx] = _normalize_cmd_key(k if v else "no-" + k) - else: - cmd[k_idx + 1] = str(v) - - break - except ValueError: - continue - else: - if isinstance(v, bool): - cmd.append(_normalize_cmd_key(k if v else "no-" + k)) - else: - cmd.extend([_normalize_cmd_key(k), str(v)]) - - return cmd - - -class ServerWrapper: - def __init__( - self, - server_cmd: list[str], - after_bench_cmd: list[str], - *, - show_stdout: bool, - ) -> None: - super().__init__() - - self.server_cmd = server_cmd - self.after_bench_cmd = after_bench_cmd - self.show_stdout = show_stdout - - def run_subcommand(self, cmd: list[str]): - return subprocess.run( - cmd, - stdout=None if self.show_stdout else subprocess.DEVNULL, - check=True, - ) - - def after_bench(self) -> None: - if not self.after_bench_cmd: - self.reset_caches() - return - - self.run_subcommand(self.after_bench_cmd) - - def _get_vllm_server_address(self) -> str: - server_cmd = self.server_cmd - - for host_key in ("--host",): - if host_key in server_cmd: - host = server_cmd[server_cmd.index(host_key) + 1] - break - else: - host = "localhost" - - for port_key in ("-p", "--port"): - if port_key in server_cmd: - port = int(server_cmd[server_cmd.index(port_key) + 1]) - break - else: - port = 8000 # The default value in vllm serve - - return f"http://{host}:{port}" - - def reset_caches(self) -> None: - server_cmd = self.server_cmd - - # Use `.endswith()` to match `/bin/...` - if server_cmd[0].endswith("vllm"): - server_address = self._get_vllm_server_address() - print(f"Resetting caches at {server_address}") - - res = requests.post(f"{server_address}/reset_prefix_cache") - res.raise_for_status() - - res = requests.post(f"{server_address}/reset_mm_cache") - res.raise_for_status() - elif server_cmd[0].endswith("infinity_emb"): - if "--vector-disk-cache" in server_cmd: - raise NotImplementedError( - "Infinity server uses caching but does not expose a method " - "to reset the cache" - ) - else: - raise NotImplementedError( - f"No implementation of `reset_caches` for `{server_cmd[0]}` server. " - "Please specify a custom command via `--after-bench-cmd`." - ) - - -@contextlib.contextmanager -def _run_server( - serve_cmd: list[str], - after_bench_cmd: list[str], - *, - show_stdout: bool, - serve_overrides: dict[str, object], - dry_run: bool, -): - server_cmd = _override_args(serve_cmd, serve_overrides) - - print("[BEGIN SERVER]") - print(f"Server overrides: {serve_overrides}") - print(f"Server command: {server_cmd}") - - if dry_run: - yield None - print("[END SERVER]") - return - - # Create new process group for clean termination - server_process = subprocess.Popen( - server_cmd, - start_new_session=True, - stdout=None if show_stdout else subprocess.DEVNULL, - # Need VLLM_SERVER_DEV_MODE=1 for `_reset_caches` - env={**os.environ, "VLLM_SERVER_DEV_MODE": "1"}, - ) - - try: - yield ServerWrapper( - server_cmd, - after_bench_cmd, - show_stdout=show_stdout, - ) - finally: - if server_process.poll() is None: - # In case only some processes have been terminated - with contextlib.suppress(ProcessLookupError): - # We need to kill both API Server and Engine processes - os.killpg(os.getpgid(server_process.pid), signal.SIGKILL) - - print("[END SERVER]") - - -def _run_benchmark( - server: ServerWrapper | None, - bench_cmd: list[str], - *, - serve_overrides: dict[str, object], - bench_overrides: dict[str, object], - run_number: int, - output_path: Path, - dry_run: bool, -): - benchmark_cmd = [ - *_override_args(bench_cmd, bench_overrides), - "--save-result", - "--result-dir", - str(output_path.parent), - "--result-filename", - output_path.name, - ] - - print("[BEGIN BENCHMARK]") - print(f"Benchmark overrides: {bench_overrides}") - print(f"Run Number: {run_number}") - print(f"Benchmark command: {benchmark_cmd}") - print(f"Output file: {output_path}") - - run_data: dict[str, object] - - if output_path.exists(): - print("Found existing results. Skipping.") - - with output_path.open("rb") as f: - run_data = json.load(f) - return run_data - - if server is None: - assert dry_run - print("[END BENCHMARK]") - return None - - output_path.parent.mkdir(parents=True, exist_ok=True) - - server.run_subcommand(benchmark_cmd) - server.after_bench() - - with output_path.open("rb") as f: - run_data = json.load(f) - - run_data["run_number"] = run_number - run_data.update(serve_overrides) - - with output_path.open("w") as f: - json.dump(run_data, f, indent=4) - - print("[END BENCHMARK]") - - return run_data - - -def _get_comb_base_path( - output_dir: Path, - serve_comb: dict[str, object], - bench_comb: dict[str, object], -): - return output_dir / "-".join( - ( - "SERVE", - *(f"{k}={v}" for k, v in serve_comb.items()), - "BENCH", - *(f"{k}={v}" for k, v in bench_comb.items()), - ) - ).replace("/", "_").replace("..", "__") # Sanitize - - -def _get_comb_run_path(base_path: Path, run_number: int | None): - if run_number is None: - return base_path / "summary.json" - - return base_path / f"run={run_number}.json" - - -def _comb_needs_server( - serve_comb: dict[str, object], - bench_combs: list[dict[str, object]], - output_dir: Path, -): - for bench_comb in bench_combs: - base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb) - if not _get_comb_run_path(base_path, run_number=None).exists(): - return True - - return False - - -def _run_comb( - server: ServerWrapper | None, - bench_cmd: list[str], - *, - serve_comb: dict[str, object], - bench_comb: dict[str, object], - base_path: Path, - num_runs: int, - dry_run: bool, -): - comb_data = list[dict[str, object]]() - - for run_number in range(num_runs): - run_data = _run_benchmark( - server, - bench_cmd, - serve_overrides=serve_comb, - bench_overrides=bench_comb, - run_number=run_number, - output_path=_get_comb_run_path(base_path, run_number), - dry_run=dry_run, - ) - - if run_data is not None: - comb_data.append(run_data) - - if dry_run: - return None - - with _get_comb_run_path(base_path, run_number=None).open("w") as f: - json.dump(comb_data, f, indent=4) - - return comb_data - - -def run_combs( - serve_cmd: list[str], - bench_cmd: list[str], - after_bench_cmd: list[str], - *, - show_stdout: bool, - serve_params: list[dict[str, object]], - bench_params: list[dict[str, object]], - output_dir: Path, - num_runs: int, - dry_run: bool, -): - all_data = list[dict[str, object]]() - for serve_comb in serve_params: - with ( - _run_server( - serve_cmd, - after_bench_cmd, - show_stdout=show_stdout, - serve_overrides=serve_comb, - dry_run=dry_run, - ) - if _comb_needs_server(serve_comb, bench_params, output_dir) - else contextlib.nullcontext() - ) as server: - for bench_comb in bench_params: - base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb) - - comb_data = _run_comb( - server, - bench_cmd, - serve_comb=serve_comb, - bench_comb=bench_comb, - base_path=base_path, - num_runs=num_runs, - dry_run=dry_run, - ) - - if comb_data is not None: - all_data.extend(comb_data) - - if dry_run: - return None - - combined_df = pd.DataFrame.from_records(all_data) - combined_df.to_csv(output_dir / "summary.csv") - - return combined_df - - -def _get_sla_base_path( - output_dir: Path, - serve_comb: dict[str, object], - bench_comb: dict[str, object], -): - return output_dir / "-".join( - ( - "SERVE", - *(f"{k}={v}" for k, v in serve_comb.items()), - "BENCH", - *(f"{k}={v}" for k, v in bench_comb.items()), - ) - ).replace("/", "_").replace("..", "__") # Sanitize - - -def _get_sla_iter_path( - base_path: Path, - sla_comb: dict[str, SLACriterionBase], - sla_variable: str, - sla_value: int | None, -): - if sla_value is None: - prefix = "-".join(v.format_cond(k) for k, v in sla_comb.items()) - return base_path / f"SLA-{prefix}.json" - - return base_path / f"{sla_variable}={sla_value}" - - -def _get_sla_run_path(iter_path: Path, run_number: int | None): - if run_number is None: - return iter_path / "summary.json" - - return iter_path / f"run={run_number}.json" - - -def _sla_needs_server( - serve_comb: dict[str, object], - bench_combs: list[dict[str, object]], - sla_combs: list[dict[str, SLACriterionBase]], - sla_variable: str, - output_dir: Path, -): - for bench_comb in bench_combs: - base_path = _get_sla_base_path(output_dir, serve_comb, bench_comb) - for sla_comb in sla_combs: - if not _get_sla_iter_path( - base_path, - sla_comb, - sla_variable, - sla_value=None, - ).exists(): - return True - - return False - - -def _run_sla( - server: ServerWrapper | None, - bench_cmd: list[str], - *, - serve_comb: dict[str, object], - bench_comb: dict[str, object], - iter_path: Path, - num_runs: int, - dry_run: bool, -): - iter_data = list[dict[str, object]]() - - for run_number in range(num_runs): - run_data = _run_benchmark( - server, - bench_cmd, - serve_overrides=serve_comb, - bench_overrides=bench_comb, - run_number=run_number, - output_path=_get_sla_run_path(iter_path, run_number), - dry_run=dry_run, - ) - - if run_data is not None: - iter_data.append(run_data) - - if dry_run: - return None - - with _get_sla_run_path(iter_path, run_number=None).open("w") as f: - json.dump(iter_data, f, indent=4) - - return iter_data - - -SLAVariable = Literal["request_rate", "max_concurrency"] - - -def _estimate_sla_value(run_data: dict[str, object], sla_variable: SLAVariable): - request_throughput = float(run_data["request_throughput"]) # type: ignore - if sla_variable == "request_rate": - return request_throughput - if sla_variable == "max_concurrency": - mean_latency_ms = float(run_data["mean_e2el_ms"]) # type: ignore - return request_throughput * mean_latency_ms / 1000 - - assert_never(sla_variable) - - -def _estimate_sla_bounds( - server: ServerWrapper | None, - bench_cmd: list[str], - *, - serve_comb: dict[str, object], - bench_comb: dict[str, object], - sla_comb: dict[str, SLACriterionBase], - base_path: Path, - num_runs: int, - dry_run: bool, - sla_variable: SLAVariable, - init_value: int, - max_value: int, -): - sla_data = list[dict[str, object]]() - - max_passing: int = 0 - min_failing: int = 0 - - val: int = init_value - assert val > 0 - - while True: - print(f"Testing {sla_variable}: {val} req/s") - - iter_data = _run_sla( - server, - bench_cmd, - serve_comb=serve_comb, - bench_comb={**bench_comb, sla_variable: val}, - iter_path=_get_sla_iter_path(base_path, sla_comb, sla_variable, val), - num_runs=num_runs, - dry_run=dry_run, - ) - - assert iter_data is not None - sla_data.extend(iter_data) - - iter_data_mean = { - k: sum(float(run_data[k]) for run_data in iter_data) / len(iter_data) # type: ignore - for k in sla_comb - } - - sla_results = [ - criterion.print_and_validate(iter_data_mean, k) - for k, criterion in sla_comb.items() - ] - - if all(sla_results): - print("SLA criteria are met.") - max_passing = val - val *= 2 - else: - print("SLA criteria are not met.") - min_failing = val - break - - if val >= max_value: - break - - return sla_data, (max_passing, min_failing) - - -def _find_sla_value( - server: ServerWrapper | None, - bench_cmd: list[str], - *, - serve_comb: dict[str, object], - bench_comb: dict[str, object], - sla_comb: dict[str, SLACriterionBase], - base_path: Path, - num_runs: int, - dry_run: bool, - sla_variable: SLAVariable, - min_value: int, - max_value: int, -): - sla_data = list[dict[str, object]]() - - left: int = min_value - right: int = max_value - - while True: - val = (left + right) // 2 - print(f"Testing {sla_variable}: {val} req/s") - - iter_data = _run_sla( - server, - bench_cmd, - serve_comb=serve_comb, - bench_comb={**bench_comb, sla_variable: val}, - iter_path=_get_sla_iter_path(base_path, sla_comb, sla_variable, val), - num_runs=num_runs, - dry_run=dry_run, - ) - - assert iter_data is not None - sla_data.extend(iter_data) - - iter_data_mean = { - k: sum(float(run_data[k]) for run_data in iter_data) / len(iter_data) # type: ignore - for k in sla_comb - } - - sla_results = [ - criterion.print_and_validate(iter_data_mean, k) - for k, criterion in sla_comb.items() - ] - - if all(sla_results): - print("SLA criteria are met.") - left = val - else: - print("SLA criteria are not met.") - right = val - - if right - left <= 1: - break - - return sla_data, left - - -def _search_sla( - server: ServerWrapper | None, - bench_cmd: list[str], - *, - serve_comb: dict[str, object], - bench_comb: dict[str, object], - sla_comb: dict[str, SLACriterionBase], - sla_variable: SLAVariable, - sla_inf_value: int = 65536, # The value that represents infinite QPS - base_path: Path, - num_runs: int, - dry_run: bool, -): - print("[SLA START]") - print(f"SLA criteria: {', '.join(v.format_cond(k) for k, v in sla_comb.items())}") - - sla_data_0 = _run_sla( - server, - bench_cmd, - serve_comb=serve_comb, - bench_comb={**bench_comb, sla_variable: sla_inf_value}, - iter_path=_get_sla_iter_path(base_path, sla_comb, sla_variable, sla_inf_value), - num_runs=num_runs, - dry_run=dry_run, - ) - if sla_data_0 is None: - assert dry_run - print("Omitting SLA search.") - print("[SLA END]") - return None - - sla_init_value = math.ceil( - sum(_estimate_sla_value(item, sla_variable) for item in sla_data_0) - / len(sla_data_0) - ) - print(f"Initial {sla_variable} to search: {sla_init_value} req/s.") - - sla_data_1, (sla_min, sla_max) = _estimate_sla_bounds( - server, - bench_cmd, - serve_comb=serve_comb, - bench_comb=bench_comb, - sla_comb=sla_comb, - base_path=base_path, - num_runs=num_runs, - dry_run=dry_run, - sla_variable=sla_variable, - init_value=sla_init_value, - max_value=sla_inf_value, - ) - print(f"Range of {sla_variable} to search: [{sla_min}, {sla_max}] req/s.") - - sla_data_2, sla_value = _find_sla_value( - server, - bench_cmd, - serve_comb=serve_comb, - bench_comb=bench_comb, - sla_comb=sla_comb, - base_path=base_path, - num_runs=num_runs, - dry_run=dry_run, - sla_variable=sla_variable, - min_value=sla_min, - max_value=sla_max, - ) - - sla_data = sla_data_0 + sla_data_1 + sla_data_2 - print(f"Maximum {sla_variable} for SLA: {sla_value} req/s.") - - with _get_sla_iter_path( - base_path, - sla_comb, - sla_variable, - sla_value=None, - ).open("w") as f: - json.dump(sla_data, f, indent=4) - - print("[SLA END]") - - return sla_data - - -def _plot_throughput_latency_curve( - all_data: list[dict[str, object]], - serve_combs: list[dict[str, object]], - bench_comb: dict[str, object], - output_dir: Path, -): - fig_path = output_dir / "-".join( - ( - "BENCH", - *(f"{k}={v}" for k, v in bench_comb.items()), - ) - ).replace("/", "_").replace("..", "__") # Sanitize - - df = pd.DataFrame.from_records( - [item for item in all_data if all(item[k] == bench_comb[k] for k in bench_comb)] - ) - - # Group together points with similar throughput - df["request_throughput"] = df["request_throughput"].round() - - # Preserve the key order using dictionary - all_comb_keys = {k: None for comb in serve_combs for k in comb} - for k in all_comb_keys: - df[k] = df[k].astype(str) - - keys_per_comb = [comb.keys() for comb in serve_combs] - if ( - all(ks == keys_per_comb[0] for ks in keys_per_comb) - and len(keys_per_comb[0]) <= 3 - ): - hue, style, size, *_ = (*keys_per_comb[0], None, None) - ax = sns.lineplot( - df, - x="request_throughput", - y="p99_e2el_ms", - hue=hue, - style=style, - size=size, - markers=True, - ) - else: - df["category"] = df[list(all_comb_keys)].agg("-".join, axis=1) - ax = sns.lineplot( - df, - x="request_throughput", - y="p99_e2el_ms", - hue="category", - markers=True, - ) - - sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1)) - - fig = ax.get_figure() - assert fig is not None - - fig.tight_layout() - fig.savefig(fig_path) - - -def _plot_throughput_latency_curves( - all_data: list[dict[str, object]], - serve_combs: list[dict[str, object]], - bench_combs: list[dict[str, object]], - output_dir: Path, -): - for bench_comb in bench_combs: - _plot_throughput_latency_curve(all_data, serve_combs, bench_comb, output_dir) - - -def run_slas( - serve_cmd: list[str], - bench_cmd: list[str], - after_bench_cmd: list[str], - *, - show_stdout: bool, - serve_params: list[dict[str, object]], - bench_params: list[dict[str, object]], - sla_params: list[dict[str, SLACriterionBase]], - sla_variable: SLAVariable, - output_dir: Path, - num_runs: int, - dry_run: bool, -): - if any( - k in bench_comb - for bench_comb in bench_params - for k in _iter_param_key_candidates(sla_variable) - ): - raise ValueError( - f"You should not override `{sla_variable}` in `bench_params` in SLA mode, " - "since it is supposed to be determined automatically." - ) - - all_data = list[dict[str, object]]() - for serve_comb in serve_params: - with ( - _run_server( - serve_cmd, - after_bench_cmd, - show_stdout=show_stdout, - serve_overrides=serve_comb, - dry_run=dry_run, - ) - if _sla_needs_server( - serve_comb, - bench_params, - sla_params, - sla_variable, - output_dir, - ) - else contextlib.nullcontext() - ) as server: - for bench_comb in bench_params: - for sla_comb in sla_params: - base_path = _get_sla_base_path(output_dir, serve_comb, bench_comb) - - comb_data = _search_sla( - server, - bench_cmd, - serve_comb=serve_comb, - bench_comb=bench_comb, - sla_comb=sla_comb, - sla_variable=sla_variable, - base_path=base_path, - num_runs=num_runs, - dry_run=dry_run, - ) - - if comb_data is not None: - all_data.extend(comb_data) - - if dry_run: - return None - - combined_df = pd.DataFrame.from_records(all_data) - combined_df.to_csv(output_dir / "summary.csv") - - _plot_throughput_latency_curves(all_data, serve_params, bench_params, output_dir) - - return combined_df - - -def _run_main( - serve_cmd: list[str], - bench_cmd: list[str], - after_bench_cmd: list[str], - *, - show_stdout: bool, - serve_params: list[dict[str, object]], - bench_params: list[dict[str, object]], - sla_params: list[dict[str, SLACriterionBase]], - sla_variable: SLAVariable, - output_dir: Path, - num_runs: int, - dry_run: bool, -): - if sla_params: - return run_slas( - serve_cmd=serve_cmd, - bench_cmd=bench_cmd, - after_bench_cmd=after_bench_cmd, - show_stdout=show_stdout, - serve_params=serve_params, - bench_params=bench_params, - sla_params=sla_params, - sla_variable=sla_variable, - output_dir=output_dir, - num_runs=num_runs, - dry_run=dry_run, - ) - - return run_combs( - serve_cmd=serve_cmd, - bench_cmd=bench_cmd, - after_bench_cmd=after_bench_cmd, - show_stdout=show_stdout, - serve_params=serve_params, - bench_params=bench_params, - output_dir=output_dir, - num_runs=num_runs, - dry_run=dry_run, - ) - - -def run_main( - serve_cmd: list[str], - bench_cmd: list[str], - after_bench_cmd: list[str], - *, - show_stdout: bool, - serve_params: list[dict[str, object]], - bench_params: list[dict[str, object]], - sla_params: list[dict[str, SLACriterionBase]], - sla_variable: SLAVariable, - output_dir: Path, - num_runs: int, - dry_run: bool, - resume: str | None, -): - timestamp = resume or datetime.now().strftime("%Y%m%d_%H%M%S") - output_dir = output_dir / timestamp - - if resume and not output_dir.exists(): - raise ValueError(f"Cannot resume from non-existent directory ({output_dir})") - - try: - return _run_main( - serve_cmd=serve_cmd, - bench_cmd=bench_cmd, - after_bench_cmd=after_bench_cmd, - show_stdout=show_stdout, - serve_params=serve_params, - bench_params=bench_params, - sla_params=sla_params, - sla_variable=sla_variable, - output_dir=output_dir, - num_runs=num_runs, - dry_run=dry_run, - ) - except BaseException as exc: - raise RuntimeError( - f"The script was terminated early. Use `--resume {timestamp}` " - f"to continue the script from its last checkpoint." - ) from exc - - -def main(): - parser = argparse.ArgumentParser( - description="Run vLLM server benchmark on a parameter grid of settings." - ) - parser.add_argument( - "--serve-cmd", - type=str, - required=True, - help="The command used to run the server: `vllm serve ...`", - ) - parser.add_argument( - "--bench-cmd", - type=str, - required=True, - help="The command used to run the benchmark: `vllm bench serve ...`", - ) - parser.add_argument( - "--after-bench-cmd", - type=str, - default=None, - help="After a benchmark run is complete, invoke this command instead of the " - "default `ServerWrapper.clear_cache()`.", - ) - parser.add_argument( - "--show-stdout", - action="store_true", - help="If set, logs the standard output of subcommands. " - "Useful for debugging but can be quite spammy.", - ) - parser.add_argument( - "--serve-params", - type=str, - default=None, - help="Path to JSON file containing a list of parameter combinations " - "for the `vllm serve` command. " - "If both `serve_params` and `bench_params` are given, " - "this script will iterate over their Cartesian product.", - ) - parser.add_argument( - "--bench-params", - type=str, - default=None, - help="Path to JSON file containing a list of parameter combinations " - "for the `vllm bench serve` command. " - "If both `serve_params` and `bench_params` are given, " - "this script will iterate over their Cartesian product.", - ) - parser.add_argument( - "--sla-params", - type=str, - default=None, - help="Path to JSON file containing a list of SLA constraints to satisfy. " - 'Each constraint is expressed in `{"": ""}` format, ' - 'e.g.: `{"p99_e2el_ms": "<=500"}` means that ' - "the E2E latency should be less than 500ms 99% of the time. " - "Setting this option runs this script in SLA mode, which searches for the " - "maximum `sla_variable` that satisfies the constraints for each combination " - "of `serve_params`, `bench_params`, and `sla_params`.", - ) - parser.add_argument( - "--sla-variable", - type=str, - choices=get_args(SLAVariable), - default="request_rate", - help="Whether to tune request rate or maximum concurrency to satisfy " - "the SLA constraints.", - ) - parser.add_argument( - "-o", - "--output-dir", - type=str, - default="results", - help="The directory to which results are written.", - ) - parser.add_argument( - "--num-runs", - type=int, - default=3, - help="Number of runs per parameter combination.", - ) - parser.add_argument( - "--dry-run", - action="store_true", - help="If set, prints the commands to run then exits without running them.", - ) - parser.add_argument( - "--resume", - type=str, - default=None, - help="Set this to the name of a directory under `output_dir` (which is a " - "timestamp) to resume a previous execution of this script, i.e., only run " - "parameter combinations for which there are still no output files.", - ) - - args = parser.parse_args() - - serve_cmd = shlex.split(args.serve_cmd) - bench_cmd = shlex.split(args.bench_cmd) - after_bench_cmd = ( - [] if args.after_bench_cmd is None else shlex.split(args.after_bench_cmd) - ) - - serve_params: list[dict[str, object]] - if args.serve_params: - with open(args.serve_params, "rb") as f: - serve_params = _parse_params(json.load(f)) - else: - # i.e.: run serve_cmd without any modification - serve_params = [{}] - - bench_params: list[dict[str, object]] - if args.bench_params: - with open(args.bench_params, "rb") as f: - bench_params = _parse_params(json.load(f)) - else: - # i.e.: run bench_cmd without any modification - bench_params = [{}] - - sla_params: list[dict[str, SLACriterionBase]] - if args.sla_params: - with open(args.sla_params, "rb") as f: - sla_params = _parse_sla(json.load(f)) - else: - sla_params = [] - - num_runs = args.num_runs - if num_runs < 1: - raise ValueError("`num_runs` should be at least 1.") - - run_main( - serve_cmd=serve_cmd, - bench_cmd=bench_cmd, - after_bench_cmd=after_bench_cmd, - show_stdout=args.show_stdout, - serve_params=serve_params, - bench_params=bench_params, - sla_params=sla_params, - sla_variable=args.sla_variable, - output_dir=Path(args.output_dir), - num_runs=num_runs, - dry_run=args.dry_run, - resume=args.resume, - ) - - -if __name__ == "__main__": - main() diff --git a/vllm/benchmarks/sweep/__init__.py b/vllm/benchmarks/sweep/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/benchmarks/sweep/param_sweep.py b/vllm/benchmarks/sweep/param_sweep.py new file mode 100644 index 000000000000..986561ed8502 --- /dev/null +++ b/vllm/benchmarks/sweep/param_sweep.py @@ -0,0 +1,91 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json +import os +from typing import Any + + +class ParameterSweep(list["ParameterSweepItem"]): + @classmethod + def read_json(cls, filepath: os.PathLike): + with open(filepath, "rb") as f: + records = json.load(f) + + return cls.from_records(records) + + @classmethod + def from_records(cls, records: list[dict[str, object]]): + if not isinstance(records, list): + raise TypeError( + f"The parameter sweep should be a list of dictionaries, " + f"but found type: {type(records)}" + ) + + return cls(ParameterSweepItem.from_record(record) for record in records) + + +class ParameterSweepItem(dict[str, object]): + @classmethod + def from_record(cls, record: dict[str, object]): + if not isinstance(record, dict): + raise TypeError( + f"Each item in the parameter sweep should be a dictionary, " + f"but found type: {type(record)}" + ) + + return cls(record) + + def __or__(self, other: dict[str, Any]): + return type(self)(super().__or__(other)) + + # In JSON, we prefer "_" + def _iter_param_key_candidates(self, param_key: str): + # Inner config arguments are not converted by the CLI + if "." in param_key: + prefix, rest = param_key.split(".", 1) + for prefix_candidate in self._iter_param_key_candidates(prefix): + yield prefix_candidate + "." + rest + + return + + yield param_key + yield param_key.replace("-", "_") + yield param_key.replace("_", "-") + + # In CLI, we prefer "-" + def _iter_cmd_key_candidates(self, param_key: str): + for k in reversed(tuple(self._iter_param_key_candidates(param_key))): + yield "--" + k + + def _normalize_cmd_key(self, param_key: str): + return next(self._iter_cmd_key_candidates(param_key)) + + def has_param(self, param_key: str) -> bool: + return any(k in self for k in self._iter_param_key_candidates(param_key)) + + def apply_to_cmd(self, cmd: list[str]) -> list[str]: + cmd = list(cmd) + + for k, v in self.items(): + for k_candidate in self._iter_cmd_key_candidates(k): + try: + k_idx = cmd.index(k_candidate) + + if isinstance(v, bool): + cmd[k_idx] = self._normalize_cmd_key(k if v else "no-" + k) + else: + cmd[k_idx + 1] = str(v) + + break + except ValueError: + continue + else: + if isinstance(v, bool): + cmd.append(self._normalize_cmd_key(k if v else "no-" + k)) + else: + cmd.extend([self._normalize_cmd_key(k), str(v)]) + + return cmd + + def as_text(self, sep: str = ", ") -> str: + return sep.join(f"{k}={v}" for k, v in self.items()) diff --git a/vllm/benchmarks/sweep/plot.py b/vllm/benchmarks/sweep/plot.py new file mode 100644 index 000000000000..92485c09b416 --- /dev/null +++ b/vllm/benchmarks/sweep/plot.py @@ -0,0 +1,530 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse +import json +from abc import ABC, abstractmethod +from concurrent.futures import ProcessPoolExecutor +from dataclasses import dataclass +from functools import partial +from pathlib import Path +from types import TracebackType + +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns +from typing_extensions import Self, override + +from vllm.utils.collection_utils import full_groupby + +from .utils import sanitize_filename + + +@dataclass +class PlotFilterBase(ABC): + var: str + target: str + + @classmethod + def parse_str(cls, s: str): + for op_key in PLOT_FILTERS: + if op_key in s: + key, value = s.split(op_key) + return PLOT_FILTERS[op_key]( + key, + value.removeprefix(op_key).strip("'").strip('"'), + ) + else: + raise ValueError( + f"Invalid operator for plot filter '{s}'. " + f"Valid operators are: {sorted(PLOT_FILTERS)}", + ) + + @abstractmethod + def apply(self, df: pd.DataFrame) -> pd.DataFrame: + """Applies this filter to a DataFrame.""" + raise NotImplementedError + + +@dataclass +class PlotEqualTo(PlotFilterBase): + @override + def apply(self, df: pd.DataFrame) -> pd.DataFrame: + try: + target = float(self.target) + except ValueError: + target = self.target + + return df[df[self.var] == target] + + +@dataclass +class PlotLessThan(PlotFilterBase): + @override + def apply(self, df: pd.DataFrame) -> pd.DataFrame: + return df[df[self.var] < float(self.target)] + + +@dataclass +class PlotLessThanOrEqualTo(PlotFilterBase): + @override + def apply(self, df: pd.DataFrame) -> pd.DataFrame: + return df[df[self.var] <= float(self.target)] + + +@dataclass +class PlotGreaterThan(PlotFilterBase): + @override + def apply(self, df: pd.DataFrame) -> pd.DataFrame: + return df[df[self.var] > float(self.target)] + + +@dataclass +class PlotGreaterThanOrEqualTo(PlotFilterBase): + @override + def apply(self, df: pd.DataFrame) -> pd.DataFrame: + return df[df[self.var] >= float(self.target)] + + +# NOTE: The ordering is important! Match longer op_keys first +PLOT_FILTERS: dict[str, type[PlotFilterBase]] = { + "==": PlotEqualTo, + "<=": PlotLessThanOrEqualTo, + ">=": PlotGreaterThanOrEqualTo, + "<": PlotLessThan, + ">": PlotGreaterThan, +} + + +class PlotFilters(list[PlotFilterBase]): + @classmethod + def parse_str(cls, s: str): + if not s: + return cls() + + return cls(PlotFilterBase.parse_str(e) for e in s.split(",")) + + def apply(self, df: pd.DataFrame) -> pd.DataFrame: + for item in self: + df = item.apply(df) + + return df + + +@dataclass +class PlotBinner: + var: str + bin_size: float + + @classmethod + def parse_str(cls, s: str): + for op_key in PLOT_BINNERS: + if op_key in s: + key, value = s.split(op_key) + return PLOT_BINNERS[op_key](key, float(value.removeprefix(op_key))) + else: + raise ValueError( + f"Invalid operator for plot binner '{s}'. " + f"Valid operators are: {sorted(PLOT_BINNERS)}", + ) + + def apply(self, df: pd.DataFrame) -> pd.DataFrame: + """Applies this binner to a DataFrame.""" + df = df.copy() + df[self.var] = df[self.var] // self.bin_size * self.bin_size + return df + + +PLOT_BINNERS: dict[str, type[PlotBinner]] = { + "%": PlotBinner, +} + + +class PlotBinners(list[PlotBinner]): + @classmethod + def parse_str(cls, s: str): + if not s: + return cls() + + return cls(PlotBinner.parse_str(e) for e in s.split(",")) + + def apply(self, df: pd.DataFrame) -> pd.DataFrame: + for item in self: + df = item.apply(df) + + return df + + +def _json_load_bytes(path: Path) -> list[dict[str, object]]: + with path.open("rb") as f: + return json.load(f) + + +def _get_metric(run_data: dict[str, object], metric_key: str): + try: + return run_data[metric_key] + except KeyError as exc: + raise ValueError(f"Cannot find metric {metric_key!r} in {run_data=}") from exc + + +def _get_group(run_data: dict[str, object], group_keys: list[str]): + return tuple((k, str(_get_metric(run_data, k))) for k in group_keys) + + +def _get_fig_path(fig_dir: Path, group: tuple[tuple[str, str], ...]): + parts = list[str]() + if group: + parts.extend(("FIGURE-", *(f"{k}={v}" for k, v in group))) + else: + parts.append("figure") + + return fig_dir / sanitize_filename("-".join(parts) + ".png") + + +class DummyExecutor: + map = map + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + exc_traceback: TracebackType | None, + ) -> None: + return None + + +def _plot_fig( + fig_dir: Path, + fig_group_data: tuple[tuple[tuple[str, str], ...], list[dict[str, object]]], + row_by: list[str], + col_by: list[str], + curve_by: list[str], + *, + var_x: str, + var_y: str, + filter_by: PlotFilters, + bin_by: PlotBinners, + scale_x: str | None, + scale_y: str | None, + dry_run: bool, +): + fig_group, fig_data = fig_group_data + + row_groups = full_groupby( + fig_data, + key=lambda item: _get_group(item, row_by), + ) + num_rows = len(row_groups) + num_cols = max( + len(full_groupby(row_data, key=lambda item: _get_group(item, col_by))) + for _, row_data in row_groups + ) + + fig_path = _get_fig_path(fig_dir, fig_group) + + print("[BEGIN FIGURE]") + print(f"Group: {dict(fig_group)}") + print(f"Grid: {num_rows} rows x {num_cols} cols") + print(f"Output file: {fig_path}") + + if dry_run: + print("[END FIGURE]") + return + + df = pd.DataFrame.from_records(fig_data) + + if var_x not in df.columns: + raise ValueError( + f"Cannot find {var_x=!r} in parameter sweep results. " + f"Available variables: {df.columns.tolist()}" + ) + if var_y not in df.columns: + raise ValueError( + f"Cannot find {var_y=!r} in parameter sweep results. " + f"Available variables: {df.columns.tolist()}" + ) + for k in row_by: + if k not in df.columns: + raise ValueError( + f"Cannot find row_by={k!r} in parameter sweep results. " + f"Available variables: {df.columns.tolist()}" + ) + for k in col_by: + if k not in df.columns: + raise ValueError( + f"Cannot find col_by={k!r} in parameter sweep results. " + f"Available variables: {df.columns.tolist()}" + ) + for k in curve_by: + if k not in df.columns: + raise ValueError( + f"Cannot find curve_by={k!r} in parameter sweep results. " + f"Available variables: {df.columns.tolist()}" + ) + + df = filter_by.apply(df) + df = bin_by.apply(df) + + df["row_group"] = ( + pd.concat( + [k + "=" + df[k].astype(str) for k in row_by], + axis=1, + ).agg("\n".join, axis=1) + if row_by + else "(All)" + ) + + df["col_group"] = ( + pd.concat( + [k + "=" + df[k].astype(str) for k in col_by], + axis=1, + ).agg("\n".join, axis=1) + if col_by + else "(All)" + ) + + g = sns.FacetGrid(df, row="row_group", col="col_group") + + if row_by and col_by: + g.set_titles("{row_name}\n{col_name}") + elif row_by: + g.set_titles("{row_name}") + elif col_by: + g.set_titles("{col_name}") + else: + g.set_titles("") + + if scale_x: + g.set(xscale=scale_x) + if scale_y: + g.set(yscale=scale_y) + + if len(curve_by) <= 3: + hue, style, size, *_ = (*curve_by, None, None, None) + + g.map_dataframe( + sns.lineplot, + x=var_x, + y=var_y, + hue=hue, + style=style, + size=size, + markers=True, + ) + + g.add_legend(title=hue) + else: + df["curve_group"] = ( + pd.concat( + [k + "=" + df[k].astype(str) for k in curve_by], + axis=1, + ).agg("\n".join, axis=1) + if curve_by + else "(All)" + ) + + g.map_dataframe( + sns.lineplot, + x=var_x, + y=var_y, + hue="curve_group", + markers=True, + ) + + g.add_legend() + + g.savefig(fig_path) + plt.close(g.figure) + + print("[END FIGURE]") + + +def plot( + output_dir: Path, + fig_dir: Path, + fig_by: list[str], + row_by: list[str], + col_by: list[str], + curve_by: list[str], + *, + var_x: str, + var_y: str, + filter_by: PlotFilters, + bin_by: PlotBinners, + scale_x: str | None, + scale_y: str | None, + dry_run: bool, +): + all_data = [ + run_data + for path in output_dir.rglob("**/summary.json") + for run_data in _json_load_bytes(path) + ] + + if not all_data: + raise ValueError(f"Did not find any parameter sweep results under {output_dir}") + + fig_dir.mkdir(parents=True, exist_ok=True) + + fig_groups = full_groupby( + all_data, + key=lambda item: _get_group(item, fig_by), + ) + + with DummyExecutor() if len(fig_groups) <= 1 else ProcessPoolExecutor() as executor: + # Resolve the iterable to ensure that the workers are run + all( + executor.map( + partial( + _plot_fig, + fig_dir, + row_by=row_by, + col_by=col_by, + curve_by=curve_by, + var_x=var_x, + var_y=var_y, + filter_by=filter_by, + bin_by=bin_by, + scale_x=scale_x, + scale_y=scale_y, + dry_run=dry_run, + ), + fig_groups, + ) + ) + + +def add_cli_args(parser: argparse.ArgumentParser): + parser.add_argument( + "OUTPUT_DIR", + type=str, + default="results", + help="The directory containing the results to plot, " + "i.e., the `--output-dir` argument to the parameter sweep script.", + ) + parser.add_argument( + "--fig-dir", + type=str, + default="", + help="The directory to save the figures, relative to `OUTPUT_DIR`. " + "By default, the same directory is used.", + ) + parser.add_argument( + "--fig-by", + type=str, + default="", + help="A comma-separated list of variables, such that a separate figure " + "is created for each combination of these variables.", + ) + parser.add_argument( + "--row-by", + type=str, + default="", + help="A comma-separated list of variables, such that a separate row " + "is created for each combination of these variables.", + ) + parser.add_argument( + "--col-by", + type=str, + default="", + help="A comma-separated list of variables, such that a separate column " + "is created for each combination of these variables.", + ) + parser.add_argument( + "--curve-by", + type=str, + default=None, + help="A comma-separated list of variables, such that a separate curve " + "is created for each combination of these variables.", + ) + parser.add_argument( + "--var-x", + type=str, + default="request_throughput", + help="The variable for the x-axis.", + ) + parser.add_argument( + "--var-y", + type=str, + default="p99_e2el_ms", + help="The variable for the y-axis", + ) + parser.add_argument( + "--filter-by", + type=str, + default="", + help="A comma-separated list of statements indicating values to filter by. " + "This is useful to remove outliers. " + "Example: `max_concurrency<1000,max_num_batched_tokens<=4096` means " + "plot only the points where `max_concurrency` is less than 1000 and " + "`max_num_batched_tokens` is no greater than 4096.", + ) + parser.add_argument( + "--bin-by", + type=str, + default="", + help="A comma-separated list of statements indicating values to bin by. " + "This is useful to avoid plotting points that are too close together. " + "Example: `request_throughput%1` means " + "use a bin size of 1 for the `request_throughput` variable.", + ) + parser.add_argument( + "--scale-x", + type=str, + default=None, + help="The scale to use for the x-axis. " + "Currently only accepts string values such as 'log' and 'sqrt'. " + "See also: https://seaborn.pydata.org/generated/seaborn.objects.Plot.scale.html", + ) + parser.add_argument( + "--scale-y", + type=str, + default=None, + help="The scale to use for the y-axis. " + "Currently only accepts string values such as 'log' and 'sqrt'. " + "See also: https://seaborn.pydata.org/generated/seaborn.objects.Plot.scale.html", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="If set, prints the information about each figure to plot, " + "then exits without drawing them.", + ) + + +def main(args: argparse.Namespace): + output_dir = Path(args.OUTPUT_DIR) + if not output_dir.exists(): + raise ValueError(f"No parameter sweep results under {output_dir}") + + curve_by = [] if not args.curve_by else args.curve_by.split(",") + row_by = [] if not args.row_by else args.row_by.split(",") + col_by = [] if not args.col_by else args.col_by.split(",") + fig_by = [] if not args.fig_by else args.fig_by.split(",") + + plot( + output_dir=output_dir, + fig_dir=output_dir / args.fig_dir, + fig_by=fig_by, + row_by=row_by, + col_by=col_by, + curve_by=curve_by, + var_x=args.var_x, + var_y=args.var_y, + filter_by=PlotFilters.parse_str(args.filter_by), + bin_by=PlotBinners.parse_str(args.bin_by), + scale_x=args.scale_x, + scale_y=args.scale_y, + dry_run=args.dry_run, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Plot performance curves from parameter sweep results." + ) + add_cli_args(parser) + + main(parser.parse_args()) diff --git a/vllm/benchmarks/sweep/serve.py b/vllm/benchmarks/sweep/serve.py new file mode 100644 index 000000000000..6e408dac0b49 --- /dev/null +++ b/vllm/benchmarks/sweep/serve.py @@ -0,0 +1,407 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse +import contextlib +import json +import shlex +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path + +import pandas as pd + +from .param_sweep import ParameterSweep, ParameterSweepItem +from .server import ServerProcess +from .utils import sanitize_filename + + +@contextlib.contextmanager +def run_server( + serve_cmd: list[str], + after_bench_cmd: list[str], + *, + show_stdout: bool, + serve_overrides: ParameterSweepItem, + dry_run: bool, +): + server_cmd = serve_overrides.apply_to_cmd(serve_cmd) + + print("[BEGIN SERVER]") + print(f"Server overrides: {serve_overrides}") + print(f"Server command: {server_cmd}") + + if dry_run: + yield None + print("[END SERVER]") + return + + with ServerProcess(server_cmd, after_bench_cmd, show_stdout=show_stdout) as server: + yield server + + print("[END SERVER]") + + +def _update_run_data( + run_data: dict[str, object], + serve_overrides: ParameterSweepItem, + bench_overrides: ParameterSweepItem, + run_number: int, +): + run_data["run_number"] = run_number + run_data.update(serve_overrides) + run_data.update(bench_overrides) + + return run_data + + +def run_benchmark( + server: ServerProcess | None, + bench_cmd: list[str], + *, + serve_overrides: ParameterSweepItem, + bench_overrides: ParameterSweepItem, + run_number: int, + output_path: Path, + dry_run: bool, +): + benchmark_cmd = [ + *bench_overrides.apply_to_cmd(bench_cmd), + "--save-result", + "--result-dir", + str(output_path.parent), + "--result-filename", + output_path.name, + ] + + print("[BEGIN BENCHMARK]") + print(f"Benchmark overrides: {bench_overrides}") + print(f"Run Number: {run_number}") + print(f"Benchmark command: {benchmark_cmd}") + print(f"Output file: {output_path}") + + run_data: dict[str, object] + + if output_path.exists(): + print("Found existing results. Skipping.") + + with output_path.open("rb") as f: + run_data = json.load(f) + return _update_run_data( + run_data, + serve_overrides, + bench_overrides, + run_number, + ) + + if server is None: + if not dry_run: + raise ValueError(f"Cannot find results at {output_path}") + + print("[END BENCHMARK]") + return None + + output_path.parent.mkdir(parents=True, exist_ok=True) + + server.run_subcommand(benchmark_cmd) + server.after_bench() + + with output_path.open("rb") as f: + run_data = json.load(f) + + run_data = _update_run_data( + run_data, + serve_overrides, + bench_overrides, + run_number, + ) + + with output_path.open("w") as f: + json.dump(run_data, f, indent=4) + + print("[END BENCHMARK]") + + return run_data + + +def _get_comb_base_path( + output_dir: Path, + serve_comb: ParameterSweepItem, + bench_comb: ParameterSweepItem, +): + parts = list[str]() + if serve_comb: + parts.extend(("SERVE-", serve_comb.as_text(sep="-"))) + if bench_comb: + parts.extend(("BENCH-", bench_comb.as_text(sep="-"))) + + return output_dir / sanitize_filename("-".join(parts)) + + +def _get_comb_run_path(base_path: Path, run_number: int | None): + if run_number is None: + return base_path / "summary.json" + + return base_path / f"run={run_number}.json" + + +def _comb_needs_server( + serve_comb: ParameterSweepItem, + bench_combs: ParameterSweep, + output_dir: Path, +): + for bench_comb in bench_combs: + base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb) + if not _get_comb_run_path(base_path, run_number=None).exists(): + return True + + return False + + +def run_comb( + server: ServerProcess | None, + bench_cmd: list[str], + *, + serve_comb: ParameterSweepItem, + bench_comb: ParameterSweepItem, + base_path: Path, + num_runs: int, + dry_run: bool, +): + comb_data = list[dict[str, object]]() + + for run_number in range(num_runs): + run_data = run_benchmark( + server, + bench_cmd, + serve_overrides=serve_comb, + bench_overrides=bench_comb, + run_number=run_number, + output_path=_get_comb_run_path(base_path, run_number), + dry_run=dry_run, + ) + + if run_data is not None: + comb_data.append(run_data) + + if dry_run: + return None + + with _get_comb_run_path(base_path, run_number=None).open("w") as f: + json.dump(comb_data, f, indent=4) + + return comb_data + + +def run_combs( + serve_cmd: list[str], + bench_cmd: list[str], + after_bench_cmd: list[str], + *, + show_stdout: bool, + serve_params: ParameterSweep, + bench_params: ParameterSweep, + output_dir: Path, + num_runs: int, + dry_run: bool, +): + all_data = list[dict[str, object]]() + for serve_comb in serve_params: + with ( + run_server( + serve_cmd, + after_bench_cmd, + show_stdout=show_stdout, + serve_overrides=serve_comb, + dry_run=dry_run, + ) + if _comb_needs_server(serve_comb, bench_params, output_dir) + else contextlib.nullcontext() + ) as server: + for bench_comb in bench_params: + base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb) + + comb_data = run_comb( + server, + bench_cmd, + serve_comb=serve_comb, + bench_comb=bench_comb, + base_path=base_path, + num_runs=num_runs, + dry_run=dry_run, + ) + + if comb_data is not None: + all_data.extend(comb_data) + + if dry_run: + return None + + combined_df = pd.DataFrame.from_records(all_data) + combined_df.to_csv(output_dir / "summary.csv") + + return combined_df + + +@dataclass +class SweepServeArgs: + serve_cmd: list[str] + bench_cmd: list[str] + after_bench_cmd: list[str] + show_stdout: bool + serve_params: ParameterSweep + bench_params: ParameterSweep + output_dir: Path + num_runs: int + dry_run: bool + resume: str | None + + @classmethod + def from_cli_args(cls, args: argparse.Namespace): + serve_cmd = shlex.split(args.serve_cmd) + bench_cmd = shlex.split(args.bench_cmd) + after_bench_cmd = ( + [] if args.after_bench_cmd is None else shlex.split(args.after_bench_cmd) + ) + + if args.serve_params: + serve_params = ParameterSweep.read_json(args.serve_params) + else: + # i.e.: run serve_cmd without any modification + serve_params = ParameterSweep.from_records([{}]) + + if args.bench_params: + bench_params = ParameterSweep.read_json(args.bench_params) + else: + # i.e.: run bench_cmd without any modification + bench_params = ParameterSweep.from_records([{}]) + + num_runs = args.num_runs + if num_runs < 1: + raise ValueError("`num_runs` should be at least 1.") + + return cls( + serve_cmd=serve_cmd, + bench_cmd=bench_cmd, + after_bench_cmd=after_bench_cmd, + show_stdout=args.show_stdout, + serve_params=serve_params, + bench_params=bench_params, + output_dir=Path(args.output_dir), + num_runs=num_runs, + dry_run=args.dry_run, + resume=args.resume, + ) + + @classmethod + def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + parser.add_argument( + "--serve-cmd", + type=str, + required=True, + help="The command used to run the server: `vllm serve ...`", + ) + parser.add_argument( + "--bench-cmd", + type=str, + required=True, + help="The command used to run the benchmark: `vllm bench serve ...`", + ) + parser.add_argument( + "--after-bench-cmd", + type=str, + default=None, + help="After a benchmark run is complete, invoke this command instead of " + "the default `ServerWrapper.clear_cache()`.", + ) + parser.add_argument( + "--show-stdout", + action="store_true", + help="If set, logs the standard output of subcommands. " + "Useful for debugging but can be quite spammy.", + ) + parser.add_argument( + "--serve-params", + type=str, + default=None, + help="Path to JSON file containing a list of parameter combinations " + "for the `vllm serve` command. " + "If both `serve_params` and `bench_params` are given, " + "this script will iterate over their Cartesian product.", + ) + parser.add_argument( + "--bench-params", + type=str, + default=None, + help="Path to JSON file containing a list of parameter combinations " + "for the `vllm bench serve` command. " + "If both `serve_params` and `bench_params` are given, " + "this script will iterate over their Cartesian product.", + ) + parser.add_argument( + "-o", + "--output-dir", + type=str, + default="results", + help="The directory to which results are written.", + ) + parser.add_argument( + "--num-runs", + type=int, + default=3, + help="Number of runs per parameter combination.", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="If set, prints the commands to run, " + "then exits without executing them.", + ) + parser.add_argument( + "--resume", + type=str, + default=None, + help="Set this to the name of a directory under `output_dir` (which is a " + "timestamp) to resume a previous execution of this script, i.e., only run " + "parameter combinations for which there are still no output files.", + ) + + return parser + + +def run_main(args: SweepServeArgs): + timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = args.output_dir / timestamp + + if args.resume and not output_dir.exists(): + raise ValueError(f"Cannot resume from non-existent directory ({output_dir})") + + try: + return run_combs( + serve_cmd=args.serve_cmd, + bench_cmd=args.bench_cmd, + after_bench_cmd=args.after_bench_cmd, + show_stdout=args.show_stdout, + serve_params=args.serve_params, + bench_params=args.bench_params, + output_dir=output_dir, + num_runs=args.num_runs, + dry_run=args.dry_run, + ) + except BaseException as exc: + raise RuntimeError( + f"The script was terminated early. Use `--resume {timestamp}` " + f"to continue the script from its last checkpoint." + ) from exc + + +def main(args: argparse.Namespace): + run_main(SweepServeArgs.from_cli_args(args)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Run vLLM server benchmark under multiple settings." + ) + SweepServeArgs.add_cli_args(parser) + + main(parser.parse_args()) diff --git a/vllm/benchmarks/sweep/serve_sla.py b/vllm/benchmarks/sweep/serve_sla.py new file mode 100644 index 000000000000..62e2917dc22b --- /dev/null +++ b/vllm/benchmarks/sweep/serve_sla.py @@ -0,0 +1,483 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse +import contextlib +import json +import math +from dataclasses import asdict, dataclass +from datetime import datetime +from pathlib import Path +from typing import Literal, get_args + +import pandas as pd +from typing_extensions import assert_never + +from .param_sweep import ParameterSweep, ParameterSweepItem +from .serve import SweepServeArgs, run_benchmark, run_server +from .server import ServerProcess +from .sla_sweep import SLASweep, SLASweepItem +from .utils import sanitize_filename + + +def _get_sla_base_path( + output_dir: Path, + serve_comb: ParameterSweepItem, + bench_comb: ParameterSweepItem, +): + parts = list[str]() + if serve_comb: + parts.extend(("SERVE-", serve_comb.as_text(sep="-"))) + if bench_comb: + parts.extend(("BENCH-", bench_comb.as_text(sep="-"))) + + return output_dir / sanitize_filename("-".join(parts)) + + +def _get_sla_iter_path( + base_path: Path, + sla_comb: SLASweepItem, + sla_variable: str, + sla_value: int | None, +): + if sla_value is None: + prefix = sla_comb.as_text(sep="-") + return base_path / f"SLA--{prefix}.json" + + return base_path / f"{sla_variable}={sla_value}" + + +def _get_sla_run_path(iter_path: Path, run_number: int | None): + if run_number is None: + return iter_path / "summary.json" + + return iter_path / f"run={run_number}.json" + + +def _sla_needs_server( + serve_comb: ParameterSweepItem, + bench_combs: ParameterSweep, + sla_combs: SLASweep, + sla_variable: str, + output_dir: Path, +): + for bench_comb in bench_combs: + base_path = _get_sla_base_path(output_dir, serve_comb, bench_comb) + for sla_comb in sla_combs: + if not _get_sla_iter_path( + base_path, + sla_comb, + sla_variable, + sla_value=None, + ).exists(): + return True + + return False + + +def run_sla( + server: ServerProcess | None, + bench_cmd: list[str], + *, + serve_comb: ParameterSweepItem, + bench_comb: ParameterSweepItem, + iter_path: Path, + num_runs: int, + dry_run: bool, +): + iter_data = list[dict[str, object]]() + + for run_number in range(num_runs): + run_data = run_benchmark( + server, + bench_cmd, + serve_overrides=serve_comb, + bench_overrides=bench_comb, + run_number=run_number, + output_path=_get_sla_run_path(iter_path, run_number), + dry_run=dry_run, + ) + + if run_data is not None: + iter_data.append(run_data) + + if dry_run: + return None + + with _get_sla_run_path(iter_path, run_number=None).open("w") as f: + json.dump(iter_data, f, indent=4) + + return iter_data + + +SLAVariable = Literal["request_rate", "max_concurrency"] + + +def _estimate_sla_value(run_data: dict[str, object], sla_variable: SLAVariable): + request_throughput = float(run_data["request_throughput"]) # type: ignore + if sla_variable == "request_rate": + return request_throughput + if sla_variable == "max_concurrency": + mean_latency_ms = float(run_data["mean_e2el_ms"]) # type: ignore + return request_throughput * mean_latency_ms / 1000 + + assert_never(sla_variable) + + +def _estimate_sla_bounds( + server: ServerProcess | None, + bench_cmd: list[str], + *, + serve_comb: ParameterSweepItem, + bench_comb: ParameterSweepItem, + sla_comb: SLASweepItem, + base_path: Path, + num_runs: int, + dry_run: bool, + sla_variable: SLAVariable, + init_value: int, + max_value: int, +): + sla_data = list[dict[str, object]]() + + max_passing: int = 0 + min_failing: int = 0 + + val: int = init_value + assert val > 0 + + while True: + print(f"Testing {sla_variable}: {val} req/s") + + iter_data = run_sla( + server, + bench_cmd, + serve_comb=serve_comb, + bench_comb=bench_comb | {sla_variable: val}, + iter_path=_get_sla_iter_path(base_path, sla_comb, sla_variable, val), + num_runs=num_runs, + dry_run=dry_run, + ) + + assert iter_data is not None + sla_data.extend(iter_data) + + iter_data_mean = { + k: sum(float(run_data[k]) for run_data in iter_data) / len(iter_data) # type: ignore + for k in sla_comb + } + + sla_results = [ + criterion.print_and_validate(iter_data_mean, k) + for k, criterion in sla_comb.items() + ] + + if all(sla_results): + print("SLA criteria are met.") + max_passing = val + val *= 2 + else: + print("SLA criteria are not met.") + min_failing = val + break + + if val >= max_value: + break + + return sla_data, (max_passing, min_failing) + + +def _find_sla_value( + server: ServerProcess | None, + bench_cmd: list[str], + *, + serve_comb: ParameterSweepItem, + bench_comb: ParameterSweepItem, + sla_comb: SLASweepItem, + base_path: Path, + num_runs: int, + dry_run: bool, + sla_variable: SLAVariable, + min_value: int, + max_value: int, +): + sla_data = list[dict[str, object]]() + + left: int = min_value + right: int = max_value + + while True: + val = (left + right) // 2 + print(f"Testing {sla_variable}: {val} req/s") + + iter_data = run_sla( + server, + bench_cmd, + serve_comb=serve_comb, + bench_comb=bench_comb | {sla_variable: val}, + iter_path=_get_sla_iter_path(base_path, sla_comb, sla_variable, val), + num_runs=num_runs, + dry_run=dry_run, + ) + + assert iter_data is not None + sla_data.extend(iter_data) + + iter_data_mean = { + k: sum(float(run_data[k]) for run_data in iter_data) / len(iter_data) # type: ignore + for k in sla_comb + } + + sla_results = [ + criterion.print_and_validate(iter_data_mean, k) + for k, criterion in sla_comb.items() + ] + + if all(sla_results): + print("SLA criteria are met.") + left = val + else: + print("SLA criteria are not met.") + right = val + + if right - left <= 1: + break + + return sla_data, left + + +def search_sla( + server: ServerProcess | None, + bench_cmd: list[str], + *, + serve_comb: ParameterSweepItem, + bench_comb: ParameterSweepItem, + sla_comb: SLASweepItem, + sla_variable: SLAVariable, + sla_inf_value: int = 65536, # The value that represents infinite QPS + base_path: Path, + num_runs: int, + dry_run: bool, +): + print("[SLA START]") + print(f"SLA criteria: {sla_comb.as_text()}") + + sla_data_0 = run_sla( + server, + bench_cmd, + serve_comb=serve_comb, + bench_comb=bench_comb | {sla_variable: sla_inf_value}, + iter_path=_get_sla_iter_path(base_path, sla_comb, sla_variable, sla_inf_value), + num_runs=num_runs, + dry_run=dry_run, + ) + if sla_data_0 is None: + assert dry_run + print("Omitting SLA search.") + print("[SLA END]") + return None + + sla_init_value = math.ceil( + sum(_estimate_sla_value(item, sla_variable) for item in sla_data_0) + / len(sla_data_0) + ) + print(f"Initial {sla_variable} to search: {sla_init_value} req/s.") + + sla_data_1, (sla_min, sla_max) = _estimate_sla_bounds( + server, + bench_cmd, + serve_comb=serve_comb, + bench_comb=bench_comb, + sla_comb=sla_comb, + base_path=base_path, + num_runs=num_runs, + dry_run=dry_run, + sla_variable=sla_variable, + init_value=sla_init_value, + max_value=sla_inf_value, + ) + print(f"Range of {sla_variable} to search: [{sla_min}, {sla_max}] req/s.") + + sla_data_2, sla_value = _find_sla_value( + server, + bench_cmd, + serve_comb=serve_comb, + bench_comb=bench_comb, + sla_comb=sla_comb, + base_path=base_path, + num_runs=num_runs, + dry_run=dry_run, + sla_variable=sla_variable, + min_value=sla_min, + max_value=sla_max, + ) + + sla_data = sla_data_0 + sla_data_1 + sla_data_2 + print(f"Maximum {sla_variable} for SLA: {sla_value} req/s.") + + with _get_sla_iter_path( + base_path, + sla_comb, + sla_variable, + sla_value=None, + ).open("w") as f: + json.dump(sla_data, f, indent=4) + + print("[SLA END]") + + return sla_data + + +def run_slas( + serve_cmd: list[str], + bench_cmd: list[str], + after_bench_cmd: list[str], + *, + show_stdout: bool, + serve_params: ParameterSweep, + bench_params: ParameterSweep, + sla_params: SLASweep, + sla_variable: SLAVariable, + output_dir: Path, + num_runs: int, + dry_run: bool, +): + if any(bench_comb.has_param(sla_variable) for bench_comb in bench_params): + raise ValueError( + f"You should not override `{sla_variable}` in `bench_params` in SLA mode, " + "since it is supposed to be determined automatically." + ) + + all_data = list[dict[str, object]]() + for serve_comb in serve_params: + with ( + run_server( + serve_cmd, + after_bench_cmd, + show_stdout=show_stdout, + serve_overrides=serve_comb, + dry_run=dry_run, + ) + if _sla_needs_server( + serve_comb, + bench_params, + sla_params, + sla_variable, + output_dir, + ) + else contextlib.nullcontext() + ) as server: + for bench_comb in bench_params: + for sla_comb in sla_params: + base_path = _get_sla_base_path(output_dir, serve_comb, bench_comb) + + comb_data = search_sla( + server, + bench_cmd, + serve_comb=serve_comb, + bench_comb=bench_comb, + sla_comb=sla_comb, + sla_variable=sla_variable, + base_path=base_path, + num_runs=num_runs, + dry_run=dry_run, + ) + + if comb_data is not None: + all_data.extend(comb_data) + + if dry_run: + return None + + combined_df = pd.DataFrame.from_records(all_data) + combined_df.to_csv(output_dir / "summary.csv") + + return combined_df + + +@dataclass +class SweepServeSLAArgs(SweepServeArgs): + sla_params: SLASweep + sla_variable: SLAVariable + + @classmethod + def from_cli_args(cls, args: argparse.Namespace): + base_args = super().from_cli_args(args) + + if args.sla_params: + sla_params = SLASweep.read_json(args.sla_params) + else: + sla_params = SLASweep.from_records([]) + + return cls( + **asdict(base_args), + sla_params=sla_params, + sla_variable=args.sla_variable, + ) + + @classmethod + def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + parser = super().add_cli_args(parser) + + parser.add_argument( + "--sla-params", + type=str, + required=True, + help="Path to JSON file containing a list of SLA constraints to satisfy. " + 'Each constraint is expressed in `{"": ""}` format, ' + 'e.g.: `{"p99_e2el_ms": "<=500"}` means that ' + "the E2E latency should be less than 500ms 99%% of the time. " + "Setting this option runs this script in SLA mode, which searches for " + "the maximum `sla_variable` that satisfies the constraints for " + "each combination of `serve_params`, `bench_params`, and `sla_params`.", + ) + parser.add_argument( + "--sla-variable", + type=str, + choices=get_args(SLAVariable), + default="request_rate", + help="Whether to tune request rate or maximum concurrency to satisfy " + "the SLA constraints.", + ) + + return parser + + +def run_main(args: SweepServeSLAArgs): + timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = args.output_dir / timestamp + + if args.resume and not output_dir.exists(): + raise ValueError(f"Cannot resume from non-existent directory ({output_dir})") + + try: + return run_slas( + serve_cmd=args.serve_cmd, + bench_cmd=args.bench_cmd, + after_bench_cmd=args.after_bench_cmd, + show_stdout=args.show_stdout, + serve_params=args.serve_params, + bench_params=args.bench_params, + sla_params=args.sla_params, + sla_variable=args.sla_variable, + output_dir=output_dir, + num_runs=args.num_runs, + dry_run=args.dry_run, + ) + except BaseException as exc: + raise RuntimeError( + f"The script was terminated early. Use `--resume {timestamp}` " + f"to continue the script from its last checkpoint." + ) from exc + + +def main(args: argparse.Namespace): + run_main(SweepServeSLAArgs.from_cli_args(args)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Tune a variable to meet SLAs under multiple settings." + ) + SweepServeSLAArgs.add_cli_args(parser) + + main(parser.parse_args()) diff --git a/vllm/benchmarks/sweep/server.py b/vllm/benchmarks/sweep/server.py new file mode 100644 index 000000000000..f17578726415 --- /dev/null +++ b/vllm/benchmarks/sweep/server.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib +import os +import signal +import subprocess +from types import TracebackType + +import requests +from typing_extensions import Self + + +class ServerProcess: + def __init__( + self, + server_cmd: list[str], + after_bench_cmd: list[str], + *, + show_stdout: bool, + ) -> None: + super().__init__() + + self.server_cmd = server_cmd + self.after_bench_cmd = after_bench_cmd + self.show_stdout = show_stdout + + def __enter__(self) -> Self: + self.start() + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + exc_traceback: TracebackType | None, + ) -> None: + self.stop() + + def start(self): + # Create new process for clean termination + self._server_process = subprocess.Popen( + self.server_cmd, + start_new_session=True, + stdout=None if self.show_stdout else subprocess.DEVNULL, + # Need `VLLM_SERVER_DEV_MODE=1` for `_reset_caches` + env=os.environ | {"VLLM_SERVER_DEV_MODE": "1"}, + ) + + def stop(self): + server_process = self._server_process + + if server_process.poll() is None: + # In case only some processes have been terminated + with contextlib.suppress(ProcessLookupError): + # We need to kill both API Server and Engine processes + os.killpg(os.getpgid(server_process.pid), signal.SIGKILL) + + def run_subcommand(self, cmd: list[str]): + return subprocess.run( + cmd, + stdout=None if self.show_stdout else subprocess.DEVNULL, + check=True, + ) + + def after_bench(self) -> None: + if not self.after_bench_cmd: + self.reset_caches() + return + + self.run_subcommand(self.after_bench_cmd) + + def _get_vllm_server_address(self) -> str: + server_cmd = self.server_cmd + + for host_key in ("--host",): + if host_key in server_cmd: + host = server_cmd[server_cmd.index(host_key) + 1] + break + else: + host = "localhost" + + for port_key in ("-p", "--port"): + if port_key in server_cmd: + port = int(server_cmd[server_cmd.index(port_key) + 1]) + break + else: + port = 8000 # The default value in vllm serve + + return f"http://{host}:{port}" + + def reset_caches(self) -> None: + server_cmd = self.server_cmd + + # Use `.endswith()` to match `/bin/...` + if server_cmd[0].endswith("vllm"): + server_address = self._get_vllm_server_address() + print(f"Resetting caches at {server_address}") + + res = requests.post(f"{server_address}/reset_prefix_cache") + res.raise_for_status() + + res = requests.post(f"{server_address}/reset_mm_cache") + res.raise_for_status() + elif server_cmd[0].endswith("infinity_emb"): + if "--vector-disk-cache" in server_cmd: + raise NotImplementedError( + "Infinity server uses caching but does not expose a method " + "to reset the cache" + ) + else: + raise NotImplementedError( + f"No implementation of `reset_caches` for `{server_cmd[0]}` server. " + "Please specify a custom command via `--after-bench-cmd`." + ) diff --git a/vllm/benchmarks/sweep/sla_sweep.py b/vllm/benchmarks/sweep/sla_sweep.py new file mode 100644 index 000000000000..327e3c7c5897 --- /dev/null +++ b/vllm/benchmarks/sweep/sla_sweep.py @@ -0,0 +1,132 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json +import os +from abc import ABC, abstractmethod +from dataclasses import dataclass + +from typing_extensions import override + + +@dataclass +class SLACriterionBase(ABC): + target: float + + @abstractmethod + def validate(self, actual: float) -> bool: + """Return `True` if this criterion is met; otherwise `False`.""" + raise NotImplementedError + + @abstractmethod + def format_cond(self, lhs: str) -> str: + raise NotImplementedError + + def print_and_validate( + self, + metrics: dict[str, float], + metrics_key: str, + ) -> bool: + metric = metrics[metrics_key] + result = self.validate(metric) + + cond = self.format_cond(f"{metrics_key} = {metric:.2f}") + print(f"Validating SLA: {cond} | " + ("PASSED" if result else "FAILED")) + + return result + + +@dataclass +class SLALessThan(SLACriterionBase): + @override + def validate(self, actual: float) -> bool: + return actual < self.target + + @override + def format_cond(self, lhs: str) -> str: + return f"{lhs}<{self.target:.2f}" + + +@dataclass +class SLALessThanOrEqualTo(SLACriterionBase): + @override + def validate(self, actual: float) -> bool: + return actual <= self.target + + @override + def format_cond(self, lhs: str) -> str: + return f"{lhs}<={self.target:.2f}" + + +@dataclass +class SLAGreaterThan(SLACriterionBase): + @override + def validate(self, actual: float) -> bool: + return actual > self.target + + @override + def format_cond(self, lhs: str) -> str: + return f"{lhs}>{self.target:.2f}" + + +@dataclass +class SLAGreaterThanOrEqualTo(SLACriterionBase): + @override + def validate(self, actual: float) -> bool: + return actual >= self.target + + @override + def format_cond(self, lhs: str) -> str: + return f"{lhs}>={self.target:.2f}" + + +# NOTE: The ordering is important! Match longer op_keys first +SLA_CRITERIA: dict[str, type[SLACriterionBase]] = { + "<=": SLALessThanOrEqualTo, + ">=": SLAGreaterThanOrEqualTo, + "<": SLALessThan, + ">": SLAGreaterThan, +} + + +class SLASweep(list["SLASweepItem"]): + @classmethod + def read_json(cls, filepath: os.PathLike): + with open(filepath, "rb") as f: + records = json.load(f) + + return cls.from_records(records) + + @classmethod + def from_records(cls, records: list[dict[str, str]]): + if not isinstance(records, list): + raise TypeError( + f"The SLA sweep should be a list of dictionaries, " + f"but found type: {type(records)}" + ) + + return cls(SLASweepItem.from_record(record) for record in records) + + +class SLASweepItem(dict[str, SLACriterionBase]): + @classmethod + def from_record(cls, record: dict[str, str]): + sla_criteria: dict[str, SLACriterionBase] = {} + + for metric_key, metric_value in record.items(): + for op_key in SLA_CRITERIA: + if metric_value.startswith(op_key): + sla_criteria[metric_key] = SLA_CRITERIA[op_key]( + float(metric_value.removeprefix(op_key)) + ) + break + else: + raise ValueError( + f"Invalid operator for " + f"SLA constraint '{metric_key}={metric_value}'. " + f"Valid operators are: {sorted(SLA_CRITERIA)}", + ) + + return cls(sla_criteria) + + def as_text(self, sep: str = ", ") -> str: + return sep.join(v.format_cond(k) for k, v in self.items()) diff --git a/vllm/benchmarks/sweep/utils.py b/vllm/benchmarks/sweep/utils.py new file mode 100644 index 000000000000..49d7867eaf48 --- /dev/null +++ b/vllm/benchmarks/sweep/utils.py @@ -0,0 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +def sanitize_filename(filename: str) -> str: + return filename.replace("/", "_").replace("..", "__").strip("'").strip('"')