Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[Frontend] Add FlexibleArgumentParser to support both underscore and …
Browse files Browse the repository at this point in the history
…dash in names (vllm-project#5718)
  • Loading branch information
mgoin authored and robertgshaw2-neuralmagic committed Jun 23, 2024
1 parent 51dfab0 commit c477239
Show file tree
Hide file tree
Showing 22 changed files with 72 additions and 45 deletions.
3 changes: 2 additions & 1 deletion benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptStrictInputs
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser


def main(args: argparse.Namespace):
Expand Down Expand Up @@ -120,7 +121,7 @@ def run_to_completion(profile_dir: Optional[str] = None):


if __name__ == '__main__':
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description='Benchmark the latency of processing a single batch of '
'requests till completion.')
parser.add_argument('--model', type=str, default='facebook/opt-125m')
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/benchmark_prefix_caching.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import time

from vllm import LLM, SamplingParams
from vllm.utils import FlexibleArgumentParser

PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n" # noqa: E501

Expand Down Expand Up @@ -44,7 +44,7 @@ def main(args):


if __name__ == "__main__":
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description='Benchmark the performance with or without automatic '
'prefix caching.')
parser.add_argument('--model',
Expand Down
7 changes: 6 additions & 1 deletion benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@
except ImportError:
from backend_request_func import get_tokenizer

try:
from vllm.utils import FlexibleArgumentParser
except ImportError:
from argparse import ArgumentParser as FlexibleArgumentParser


@dataclass
class BenchmarkMetrics:
Expand Down Expand Up @@ -511,7 +516,7 @@ def main(args: argparse.Namespace):


if __name__ == "__main__":
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description="Benchmark the online serving throughput.")
parser.add_argument(
"--backend",
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser


def sample_requests(
Expand Down Expand Up @@ -261,7 +262,7 @@ def main(args: argparse.Namespace):


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
parser = FlexibleArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--backend",
type=str,
choices=["vllm", "hf", "mii"],
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from weight_shapes import WEIGHT_SHAPES

from vllm import _custom_ops as ops
from vllm.utils import FlexibleArgumentParser

DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())[1:]
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
Expand Down Expand Up @@ -293,7 +294,7 @@ def to_torch_dtype(dt):
return torch.float8_e4m3fn
raise ValueError("unsupported dtype")

parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description="""
Benchmark Cutlass GEMM.
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/kernels/benchmark_aqlm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import argparse
import os
import sys
from typing import Optional
Expand All @@ -10,6 +9,7 @@
from vllm.model_executor.layers.quantization.aqlm import (
dequantize_weight, generic_dequantize_gemm, get_int_dtype,
optimized_dequantize_gemm)
from vllm.utils import FlexibleArgumentParser

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

Expand Down Expand Up @@ -137,7 +137,7 @@ def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None:

def main():

parser = argparse.ArgumentParser(description="Benchmark aqlm performance.")
parser = FlexibleArgumentParser(description="Benchmark aqlm performance.")

# Add arguments
parser.add_argument("--nbooks",
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/kernels/benchmark_marlin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import argparse
from typing import List

import torch
Expand All @@ -16,6 +15,7 @@
MarlinWorkspace, marlin_24_quantize, marlin_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
gptq_pack, quantize_weights, sort_weights)
from vllm.utils import FlexibleArgumentParser

DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
Expand Down Expand Up @@ -211,7 +211,7 @@ def main(args):
# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 --limit-num-bits 4 --limit-act-order 0 --limit-k-full 1 # noqa E501
#
if __name__ == "__main__":
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description="Benchmark Marlin across specified models/shapes/batches")
parser.add_argument(
"--models",
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/kernels/benchmark_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from transformers import AutoConfig

from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.utils import FlexibleArgumentParser


class BenchmarkConfig(TypedDict):
Expand Down Expand Up @@ -315,7 +316,7 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]:


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = FlexibleArgumentParser()
parser.add_argument("--model",
type=str,
default="mistralai/Mixtral-8x7B-Instruct-v0.1")
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import argparse
import random
import time
from typing import List, Optional

import torch

from vllm import _custom_ops as ops
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
create_kv_caches_with_random)

NUM_BLOCKS = 1024
PARTITION_SIZE = 512
Expand Down Expand Up @@ -161,7 +161,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:


if __name__ == '__main__':
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description="Benchmark the paged attention kernel.")
parser.add_argument("--version",
type=str,
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/kernels/benchmark_rope.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import argparse
from itertools import accumulate
from typing import List, Optional

Expand All @@ -7,6 +6,7 @@

from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding,
get_rope)
from vllm.utils import FlexibleArgumentParser


def benchmark_rope_kernels_multi_lora(
Expand Down Expand Up @@ -86,7 +86,7 @@ def benchmark_rope_kernels_multi_lora(


if __name__ == '__main__':
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description="Benchmark the rotary embedding kernels.")
parser.add_argument("--is-neox-style", type=bool, default=True)
parser.add_argument("--batch-size", type=int, default=16)
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/overheads/benchmark_hashing.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import argparse
import cProfile
import pstats

from vllm import LLM, SamplingParams
from vllm.utils import FlexibleArgumentParser

# A very long prompt, total number of tokens is about 15k.
LONG_PROMPT = ["You are an expert in large language models, aren't you?"
Expand Down Expand Up @@ -47,7 +47,7 @@ def main(args):


if __name__ == "__main__":
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description='Benchmark the performance of hashing function in'
'automatic prefix caching.')
parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k')
Expand Down
5 changes: 2 additions & 3 deletions examples/aqlm_example.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import argparse

from vllm import LLM, SamplingParams
from vllm.utils import FlexibleArgumentParser


def main():

parser = argparse.ArgumentParser(description='AQLM examples')
parser = FlexibleArgumentParser(description='AQLM examples')

parser.add_argument('--model',
'-m',
Expand Down
3 changes: 2 additions & 1 deletion examples/llm_engine_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, Tuple

from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
from vllm.utils import FlexibleArgumentParser


def create_test_prompts() -> List[Tuple[str, SamplingParams]]:
Expand Down Expand Up @@ -55,7 +56,7 @@ def main(args: argparse.Namespace):


if __name__ == '__main__':
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description='Demo on using the LLMEngine class directly')
parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args()
Expand Down
4 changes: 2 additions & 2 deletions examples/save_sharded_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@
tensor_parallel_size=8,
)
"""
import argparse
import dataclasses
import os
import shutil
from pathlib import Path

from vllm import LLM, EngineArgs
from vllm.utils import FlexibleArgumentParser

parser = argparse.ArgumentParser()
parser = FlexibleArgumentParser()
EngineArgs.add_cli_args(parser)
parser.add_argument("--output",
"-o",
Expand Down
3 changes: 2 additions & 1 deletion examples/tensorize_vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vllm.model_executor.model_loader.tensorizer import (TensorizerArgs,
TensorizerConfig,
tensorize_vllm_model)
from vllm.utils import FlexibleArgumentParser

# yapf conflicts with isort for this docstring
# yapf: disable
Expand Down Expand Up @@ -96,7 +97,7 @@


def parse_args():
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description="An example script that can be used to serialize and "
"deserialize vLLM models. These models "
"can be loaded using tensorizer directly to the GPU "
Expand Down
4 changes: 2 additions & 2 deletions tests/async_engine/api_server_async_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""vllm.entrypoints.api_server with some extra logging for testing."""
import argparse
from typing import Any, Dict

import uvicorn
Expand All @@ -8,6 +7,7 @@
import vllm.entrypoints.api_server
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.utils import FlexibleArgumentParser

app = vllm.entrypoints.api_server.app

Expand All @@ -33,7 +33,7 @@ def stats() -> Response:


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = FlexibleArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser = AsyncEngineArgs.add_cli_args(parser)
Expand Down
17 changes: 8 additions & 9 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
SpeculativeConfig, TokenizerPoolConfig,
VisionLanguageConfig)
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import str_to_int_tuple
from vllm.utils import FlexibleArgumentParser, str_to_int_tuple


def nullable_str(val: str):
Expand Down Expand Up @@ -114,7 +114,7 @@ def __post_init__(self):

@staticmethod
def add_cli_args_for_vlm(
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument('--image-input-type',
type=nullable_str,
default=None,
Expand Down Expand Up @@ -160,8 +160,7 @@ def add_cli_args_for_vlm(
return parser

@staticmethod
def add_cli_args(
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"""Shared CLI arguments for vLLM engine."""

# Model arguments
Expand Down Expand Up @@ -816,8 +815,8 @@ class AsyncEngineArgs(EngineArgs):
max_log_len: Optional[int] = None

@staticmethod
def add_cli_args(parser: argparse.ArgumentParser,
async_args_only: bool = False) -> argparse.ArgumentParser:
def add_cli_args(parser: FlexibleArgumentParser,
async_args_only: bool = False) -> FlexibleArgumentParser:
if not async_args_only:
parser = EngineArgs.add_cli_args(parser)
parser.add_argument('--engine-use-ray',
Expand All @@ -838,13 +837,13 @@ def add_cli_args(parser: argparse.ArgumentParser,

# These functions are used by sphinx to build the documentation
def _engine_args_parser():
return EngineArgs.add_cli_args(argparse.ArgumentParser())
return EngineArgs.add_cli_args(FlexibleArgumentParser())


def _async_engine_args_parser():
return AsyncEngineArgs.add_cli_args(argparse.ArgumentParser(),
return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
async_args_only=True)


def _vlm_engine_args_parser():
return EngineArgs.add_cli_args_for_vlm(argparse.ArgumentParser())
return EngineArgs.add_cli_args_for_vlm(FlexibleArgumentParser())
5 changes: 2 additions & 3 deletions vllm/entrypoints/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
change `vllm/entrypoints/openai/api_server.py` instead.
"""

import argparse
import json
import ssl
from typing import AsyncGenerator
Expand All @@ -19,7 +18,7 @@
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.usage.usage_lib import UsageContext
from vllm.utils import random_uuid
from vllm.utils import FlexibleArgumentParser, random_uuid

TIMEOUT_KEEP_ALIVE = 5 # seconds.
app = FastAPI()
Expand Down Expand Up @@ -80,7 +79,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]:


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = FlexibleArgumentParser()
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--ssl-keyfile", type=str, default=None)
Expand Down
Loading

0 comments on commit c477239

Please sign in to comment.