diff --git a/vllm/config.py b/vllm/config.py index d24082799d00..0c009e93bf9d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -17,7 +17,7 @@ from importlib.util import find_spec from pathlib import Path from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal, - Optional, Protocol, Union) + Optional, Protocol, TypeVar, Union) import torch from pydantic import BaseModel, Field, PrivateAttr @@ -45,6 +45,7 @@ random_uuid, resolve_obj_by_qualname) if TYPE_CHECKING: + from _typeshed import DataclassInstance from ray.util.placement_group import PlacementGroup from vllm.executor.executor_base import ExecutorBase @@ -53,8 +54,11 @@ from vllm.model_executor.model_loader.loader import BaseModelLoader from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( BaseTokenizerGroup) + + Config = TypeVar("Config", bound=DataclassInstance) else: QuantizationConfig = None + Config = TypeVar("Config") logger = init_logger(__name__) @@ -159,7 +163,7 @@ def pairwise(iterable): return out -def config(cls: type[Any]) -> type[Any]: +def config(cls: type[Config]) -> type[Config]: """ A decorator that ensures all fields in a dataclass have default values and that each field has a docstring. @@ -1431,44 +1435,47 @@ class LoadFormat(str, enum.Enum): FASTSAFETENSORS = "fastsafetensors" +@config @dataclass class LoadConfig: - """ - download_dir: Directory to download and load the weights, default to the - default cache directory of huggingface. - load_format: The format of the model weights to load: - "auto" will try to load the weights in the safetensors format and - fall back to the pytorch bin format if safetensors format is - not available. - "pt" will load the weights in the pytorch bin format. - "safetensors" will load the weights in the safetensors format. - "npcache" will load the weights in pytorch format and store - a numpy cache to speed up the loading. - "dummy" will initialize the weights with random values, which is - mainly for profiling. - "tensorizer" will use CoreWeave's tensorizer library for - fast weight loading. - "bitsandbytes" will load nf4 type weights. - "sharded_state" will load weights from pre-sharded checkpoint files, - supporting efficient loading of tensor-parallel models. - "gguf" will load weights from GGUF format files. - "mistral" will load weights from consolidated safetensors files used - by Mistral models. - "runai_streamer" will load weights from RunAI streamer format files. - model_loader_extra_config: The extra config for the model loader. - ignore_patterns: The list of patterns to ignore when loading the model. - Default to "original/**/*" to avoid repeated loading of llama's - checkpoints. - use_tqdm_on_load: Whether to enable tqdm for showing progress bar during - loading. Default to True - """ - - load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO + """Configuration for loading the model weights.""" + + load_format: Union[str, LoadFormat, + "BaseModelLoader"] = LoadFormat.AUTO.value + """The format of the model weights to load:\n + - "auto" will try to load the weights in the safetensors format and fall + back to the pytorch bin format if safetensors format is not available.\n + - "pt" will load the weights in the pytorch bin format.\n + - "safetensors" will load the weights in the safetensors format.\n + - "npcache" will load the weights in pytorch format and store a numpy cache + to speed up the loading.\n + - "dummy" will initialize the weights with random values, which is mainly + for profiling.\n + - "tensorizer" will use CoreWeave's tensorizer library for fast weight + loading. See the Tensorize vLLM Model script in the Examples section for + more information.\n + - "runai_streamer" will load the Safetensors weights using Run:ai Model + Streamer.\n + - "bitsandbytes" will load the weights using bitsandbytes quantization.\n + - "sharded_state" will load weights from pre-sharded checkpoint files, + supporting efficient loading of tensor-parallel models.\n + - "gguf" will load weights from GGUF format files (details specified in + https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n + - "mistral" will load weights from consolidated safetensors files used by + Mistral models.""" download_dir: Optional[str] = None - model_loader_extra_config: Optional[Union[str, dict]] = field( - default_factory=dict) + """Directory to download and load the weights, default to the default + cache directory of Hugging Face.""" + model_loader_extra_config: Optional[Union[str, dict]] = None + """Extra config for model loader. This will be passed to the model loader + corresponding to the chosen load_format. This should be a JSON string that + will be parsed into a dictionary.""" ignore_patterns: Optional[Union[list[str], str]] = None + """The list of patterns to ignore when loading the model. Default to + "original/**/*" to avoid repeated loading of llama's checkpoints.""" use_tqdm_on_load: bool = True + """Whether to enable tqdm for showing progress bar when loading model + weights.""" def compute_hash(self) -> str: """ diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 3eafb6827d49..70e628ed1680 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -101,8 +101,8 @@ class EngineArgs: tokenizer_mode: str = 'auto' trust_remote_code: bool = False allowed_local_media_path: str = "" - download_dir: Optional[str] = None - load_format: str = 'auto' + download_dir: Optional[str] = LoadConfig.download_dir + load_format: str = LoadConfig.load_format config_format: ConfigFormat = ConfigFormat.AUTO dtype: str = 'auto' kv_cache_dtype: str = 'auto' @@ -174,8 +174,10 @@ class EngineArgs: ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight num_gpu_blocks_override: Optional[int] = None num_lookahead_slots: int = 0 - model_loader_extra_config: Optional[dict] = None - ignore_patterns: Optional[Union[str, List[str]]] = None + model_loader_extra_config: Optional[ + dict] = LoadConfig.model_loader_extra_config + ignore_patterns: Optional[Union[str, + List[str]]] = LoadConfig.ignore_patterns preemption_mode: Optional[str] = None scheduler_delay_factor: float = 0.0 @@ -213,7 +215,7 @@ class EngineArgs: additional_config: Optional[Dict[str, Any]] = None enable_reasoning: Optional[bool] = None reasoning_parser: Optional[str] = None - use_tqdm_on_load: bool = True + use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load def __post_init__(self): if not self.tokenizer: @@ -234,9 +236,13 @@ def __post_init__(self): def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: """Shared CLI arguments for vLLM engine.""" + def is_type_in_union(cls: type[Any], type: type[Any]) -> bool: + """Check if the class is a type in a union type.""" + return get_origin(cls) is Union and type in get_args(cls) + def is_optional(cls: type[Any]) -> bool: """Check if the class is an optional type.""" - return get_origin(cls) is Union and type(None) in get_args(cls) + return is_type_in_union(cls, type(None)) def get_kwargs(cls: type[Any]) -> Dict[str, Any]: cls_docs = get_attr_docs(cls) @@ -255,6 +261,10 @@ def get_kwargs(cls: type[Any]) -> Dict[str, Any]: if is_optional(field.type): kwargs[name]["type"] = nullable_str continue + # Handle str in union fields + if is_type_in_union(field.type, str): + kwargs[name]["type"] = str + continue kwargs[name]["type"] = field.type return kwargs @@ -333,38 +343,23 @@ def get_kwargs(cls: type[Any]) -> Dict[str, Any]: "from directories specified by the server file system. " "This is a security risk. " "Should only be enabled in trusted environments.") - parser.add_argument('--download-dir', - type=nullable_str, - default=EngineArgs.download_dir, - help='Directory to download and load the weights.') - parser.add_argument( - '--load-format', - type=str, - default=EngineArgs.load_format, - choices=[f.value for f in LoadFormat], - help='The format of the model weights to load.\n\n' - '* "auto" will try to load the weights in the safetensors format ' - 'and fall back to the pytorch bin format if safetensors format ' - 'is not available.\n' - '* "pt" will load the weights in the pytorch bin format.\n' - '* "safetensors" will load the weights in the safetensors format.\n' - '* "npcache" will load the weights in pytorch format and store ' - 'a numpy cache to speed up the loading.\n' - '* "dummy" will initialize the weights with random values, ' - 'which is mainly for profiling.\n' - '* "tensorizer" will load the weights using tensorizer from ' - 'CoreWeave. See the Tensorize vLLM Model script in the Examples ' - 'section for more information.\n' - '* "runai_streamer" will load the Safetensors weights using Run:ai' - 'Model Streamer.\n' - '* "bitsandbytes" will load the weights using bitsandbytes ' - 'quantization.\n' - '* "sharded_state" will load weights from pre-sharded checkpoint ' - 'files, supporting efficient loading of tensor-parallel models\n' - '* "gguf" will load weights from GGUF format files (details ' - 'specified in https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n' - '* "mistral" will load weights from consolidated safetensors files ' - 'used by Mistral models.\n') + # Model loading arguments + load_kwargs = get_kwargs(LoadConfig) + load_group = parser.add_argument_group( + title="LoadConfig", + description=LoadConfig.__doc__, + ) + load_group.add_argument('--load-format', + choices=[f.value for f in LoadFormat], + **load_kwargs["load_format"]) + load_group.add_argument('--download-dir', + **load_kwargs["download_dir"]) + load_group.add_argument('--model-loader-extra-config', + **load_kwargs["model_loader_extra_config"]) + load_group.add_argument('--use-tqdm-on-load', + action=argparse.BooleanOptionalAction, + **load_kwargs["use_tqdm_on_load"]) + parser.add_argument( '--config-format', default=EngineArgs.config_format, @@ -770,14 +765,6 @@ def get_kwargs(cls: type[Any]) -> Dict[str, Any]: default=1, help=('Maximum number of forward steps per ' 'scheduler call.')) - parser.add_argument( - '--use-tqdm-on-load', - dest='use_tqdm_on_load', - action=argparse.BooleanOptionalAction, - default=EngineArgs.use_tqdm_on_load, - help='Whether to enable/disable progress bar ' - 'when loading model weights.', - ) parser.add_argument( '--multi-step-stream-outputs', @@ -806,15 +793,6 @@ def get_kwargs(cls: type[Any]) -> Dict[str, Any]: default=None, help='The configurations for speculative decoding.' ' Should be a JSON string.') - - parser.add_argument('--model-loader-extra-config', - type=nullable_str, - default=EngineArgs.model_loader_extra_config, - help='Extra config for model loader. ' - 'This will be passed to the model loader ' - 'corresponding to the chosen load_format. ' - 'This should be a JSON string that will be ' - 'parsed into a dictionary.') parser.add_argument( '--ignore-patterns', action="append", diff --git a/vllm/utils.py b/vllm/utils.py index 551f1a4c9d26..0fa3384aa090 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -2,7 +2,6 @@ from __future__ import annotations -import argparse import asyncio import concurrent import contextlib @@ -25,6 +24,7 @@ import subprocess import sys import tempfile +import textwrap import threading import time import traceback @@ -32,6 +32,8 @@ import uuid import warnings import weakref +from argparse import (Action, ArgumentDefaultsHelpFormatter, ArgumentParser, + ArgumentTypeError) from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task from collections import UserDict, defaultdict from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable, @@ -1209,7 +1211,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> None: return wrapper -class StoreBoolean(argparse.Action): +class StoreBoolean(Action): def __call__(self, parser, namespace, values, option_string=None): if values.lower() == "true": @@ -1221,15 +1223,28 @@ def __call__(self, parser, namespace, values, option_string=None): "Expected 'true' or 'false'.") -class SortedHelpFormatter(argparse.ArgumentDefaultsHelpFormatter): +class SortedHelpFormatter(ArgumentDefaultsHelpFormatter): """SortedHelpFormatter that sorts arguments by their option strings.""" + def _split_lines(self, text, width): + """ + 1. Sentences split across lines have their single newlines removed. + 2. Paragraphs and explicit newlines are split into separate lines. + 3. Each line is wrapped to the specified width (width of terminal). + """ + # The patterns also include whitespace after the newline + single_newline = re.compile("(?