Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Add FlexibleArgumentParser to support both underscore and dash in names #5718

Merged
merged 2 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptStrictInputs
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser


def main(args: argparse.Namespace):
Expand Down Expand Up @@ -120,7 +121,7 @@ def run_to_completion(profile_dir: Optional[str] = None):


if __name__ == '__main__':
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description='Benchmark the latency of processing a single batch of '
'requests till completion.')
parser.add_argument('--model', type=str, default='facebook/opt-125m')
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/benchmark_prefix_caching.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import time

from vllm import LLM, SamplingParams
from vllm.utils import FlexibleArgumentParser

PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n" # noqa: E501

Expand Down Expand Up @@ -44,7 +44,7 @@ def main(args):


if __name__ == "__main__":
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description='Benchmark the performance with or without automatic '
'prefix caching.')
parser.add_argument('--model',
Expand Down
7 changes: 6 additions & 1 deletion benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@
except ImportError:
from backend_request_func import get_tokenizer

try:
from vllm.utils import FlexibleArgumentParser
except ImportError:
from argparse import ArgumentParser as FlexibleArgumentParser


@dataclass
class BenchmarkMetrics:
Expand Down Expand Up @@ -511,7 +516,7 @@ def main(args: argparse.Namespace):


if __name__ == "__main__":
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description="Benchmark the online serving throughput.")
parser.add_argument(
"--backend",
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser


def sample_requests(
Expand Down Expand Up @@ -261,7 +262,7 @@ def main(args: argparse.Namespace):


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
parser = FlexibleArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--backend",
type=str,
choices=["vllm", "hf", "mii"],
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from weight_shapes import WEIGHT_SHAPES

from vllm import _custom_ops as ops
from vllm.utils import FlexibleArgumentParser

DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())[1:]
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
Expand Down Expand Up @@ -293,7 +294,7 @@ def to_torch_dtype(dt):
return torch.float8_e4m3fn
raise ValueError("unsupported dtype")

parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description="""
Benchmark Cutlass GEMM.

Expand Down
4 changes: 2 additions & 2 deletions benchmarks/kernels/benchmark_aqlm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import argparse
import os
import sys
from typing import Optional
Expand All @@ -10,6 +9,7 @@
from vllm.model_executor.layers.quantization.aqlm import (
dequantize_weight, generic_dequantize_gemm, get_int_dtype,
optimized_dequantize_gemm)
from vllm.utils import FlexibleArgumentParser

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

Expand Down Expand Up @@ -137,7 +137,7 @@ def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None:

def main():

parser = argparse.ArgumentParser(description="Benchmark aqlm performance.")
parser = FlexibleArgumentParser(description="Benchmark aqlm performance.")

# Add arguments
parser.add_argument("--nbooks",
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/kernels/benchmark_marlin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import argparse
from typing import List

import torch
Expand All @@ -16,6 +15,7 @@
MarlinWorkspace, marlin_24_quantize, marlin_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
gptq_pack, quantize_weights, sort_weights)
from vllm.utils import FlexibleArgumentParser

DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
Expand Down Expand Up @@ -211,7 +211,7 @@ def main(args):
# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 --limit-num-bits 4 --limit-act-order 0 --limit-k-full 1 # noqa E501
#
if __name__ == "__main__":
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description="Benchmark Marlin across specified models/shapes/batches")
parser.add_argument(
"--models",
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/kernels/benchmark_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from transformers import AutoConfig

from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.utils import FlexibleArgumentParser


class BenchmarkConfig(TypedDict):
Expand Down Expand Up @@ -315,7 +316,7 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]:


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = FlexibleArgumentParser()
parser.add_argument("--model",
type=str,
default="mistralai/Mixtral-8x7B-Instruct-v0.1")
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import argparse
import random
import time
from typing import List, Optional

import torch

from vllm import _custom_ops as ops
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
create_kv_caches_with_random)

NUM_BLOCKS = 1024
PARTITION_SIZE = 512
Expand Down Expand Up @@ -161,7 +161,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:


if __name__ == '__main__':
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description="Benchmark the paged attention kernel.")
parser.add_argument("--version",
type=str,
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/kernels/benchmark_rope.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import argparse
from itertools import accumulate
from typing import List, Optional

Expand All @@ -7,6 +6,7 @@

from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding,
get_rope)
from vllm.utils import FlexibleArgumentParser


def benchmark_rope_kernels_multi_lora(
Expand Down Expand Up @@ -86,7 +86,7 @@ def benchmark_rope_kernels_multi_lora(


if __name__ == '__main__':
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description="Benchmark the rotary embedding kernels.")
parser.add_argument("--is-neox-style", type=bool, default=True)
parser.add_argument("--batch-size", type=int, default=16)
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/overheads/benchmark_hashing.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import argparse
import cProfile
import pstats

from vllm import LLM, SamplingParams
from vllm.utils import FlexibleArgumentParser

# A very long prompt, total number of tokens is about 15k.
LONG_PROMPT = ["You are an expert in large language models, aren't you?"
Expand Down Expand Up @@ -47,7 +47,7 @@ def main(args):


if __name__ == "__main__":
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description='Benchmark the performance of hashing function in'
'automatic prefix caching.')
parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k')
Expand Down
5 changes: 2 additions & 3 deletions examples/aqlm_example.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import argparse

from vllm import LLM, SamplingParams
from vllm.utils import FlexibleArgumentParser


def main():

parser = argparse.ArgumentParser(description='AQLM examples')
parser = FlexibleArgumentParser(description='AQLM examples')

parser.add_argument('--model',
'-m',
Expand Down
3 changes: 2 additions & 1 deletion examples/llm_engine_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, Tuple

from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
from vllm.utils import FlexibleArgumentParser


def create_test_prompts() -> List[Tuple[str, SamplingParams]]:
Expand Down Expand Up @@ -55,7 +56,7 @@ def main(args: argparse.Namespace):


if __name__ == '__main__':
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description='Demo on using the LLMEngine class directly')
parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args()
Expand Down
4 changes: 2 additions & 2 deletions examples/save_sharded_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@
tensor_parallel_size=8,
)
"""
import argparse
import dataclasses
import os
import shutil
from pathlib import Path

from vllm import LLM, EngineArgs
from vllm.utils import FlexibleArgumentParser

parser = argparse.ArgumentParser()
parser = FlexibleArgumentParser()
EngineArgs.add_cli_args(parser)
parser.add_argument("--output",
"-o",
Expand Down
3 changes: 2 additions & 1 deletion examples/tensorize_vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vllm.model_executor.model_loader.tensorizer import (TensorizerArgs,
TensorizerConfig,
tensorize_vllm_model)
from vllm.utils import FlexibleArgumentParser

# yapf conflicts with isort for this docstring
# yapf: disable
Expand Down Expand Up @@ -96,7 +97,7 @@


def parse_args():
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description="An example script that can be used to serialize and "
"deserialize vLLM models. These models "
"can be loaded using tensorizer directly to the GPU "
Expand Down
4 changes: 2 additions & 2 deletions tests/async_engine/api_server_async_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""vllm.entrypoints.api_server with some extra logging for testing."""
import argparse
from typing import Any, Dict

import uvicorn
Expand All @@ -8,6 +7,7 @@
import vllm.entrypoints.api_server
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.utils import FlexibleArgumentParser

app = vllm.entrypoints.api_server.app

Expand All @@ -33,7 +33,7 @@ def stats() -> Response:


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = FlexibleArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser = AsyncEngineArgs.add_cli_args(parser)
Expand Down
17 changes: 8 additions & 9 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
SpeculativeConfig, TokenizerPoolConfig,
VisionLanguageConfig)
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import str_to_int_tuple
from vllm.utils import FlexibleArgumentParser, str_to_int_tuple


def nullable_str(val: str):
Expand Down Expand Up @@ -110,7 +110,7 @@ def __post_init__(self):

@staticmethod
def add_cli_args_for_vlm(
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument('--image-input-type',
type=nullable_str,
default=None,
Expand Down Expand Up @@ -156,8 +156,7 @@ def add_cli_args_for_vlm(
return parser

@staticmethod
def add_cli_args(
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"""Shared CLI arguments for vLLM engine."""

# Model arguments
Expand Down Expand Up @@ -800,8 +799,8 @@ class AsyncEngineArgs(EngineArgs):
max_log_len: Optional[int] = None

@staticmethod
def add_cli_args(parser: argparse.ArgumentParser,
async_args_only: bool = False) -> argparse.ArgumentParser:
def add_cli_args(parser: FlexibleArgumentParser,
async_args_only: bool = False) -> FlexibleArgumentParser:
if not async_args_only:
parser = EngineArgs.add_cli_args(parser)
parser.add_argument('--engine-use-ray',
Expand All @@ -822,13 +821,13 @@ def add_cli_args(parser: argparse.ArgumentParser,

# These functions are used by sphinx to build the documentation
def _engine_args_parser():
return EngineArgs.add_cli_args(argparse.ArgumentParser())
return EngineArgs.add_cli_args(FlexibleArgumentParser())


def _async_engine_args_parser():
return AsyncEngineArgs.add_cli_args(argparse.ArgumentParser(),
return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
async_args_only=True)


def _vlm_engine_args_parser():
return EngineArgs.add_cli_args_for_vlm(argparse.ArgumentParser())
return EngineArgs.add_cli_args_for_vlm(FlexibleArgumentParser())
5 changes: 2 additions & 3 deletions vllm/entrypoints/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
change `vllm/entrypoints/openai/api_server.py` instead.
"""

import argparse
import json
import ssl
from typing import AsyncGenerator
Expand All @@ -19,7 +18,7 @@
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.usage.usage_lib import UsageContext
from vllm.utils import random_uuid
from vllm.utils import FlexibleArgumentParser, random_uuid

TIMEOUT_KEEP_ALIVE = 5 # seconds.
app = FastAPI()
Expand Down Expand Up @@ -80,7 +79,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]:


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = FlexibleArgumentParser()
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--ssl-keyfile", type=str, default=None)
Expand Down
Loading
Loading