Skip to content

Commit cd77382

Browse files
authored
Improve configs - LoadConfig (#16422)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
1 parent 71b9cde commit cd77382

File tree

3 files changed

+96
-97
lines changed

3 files changed

+96
-97
lines changed

vllm/config.py

Lines changed: 42 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from importlib.util import find_spec
1818
from pathlib import Path
1919
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
20-
Optional, Protocol, Union)
20+
Optional, Protocol, TypeVar, Union)
2121

2222
import torch
2323
from pydantic import BaseModel, Field, PrivateAttr
@@ -45,6 +45,7 @@
4545
random_uuid, resolve_obj_by_qualname)
4646

4747
if TYPE_CHECKING:
48+
from _typeshed import DataclassInstance
4849
from ray.util.placement_group import PlacementGroup
4950

5051
from vllm.executor.executor_base import ExecutorBase
@@ -53,8 +54,11 @@
5354
from vllm.model_executor.model_loader.loader import BaseModelLoader
5455
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
5556
BaseTokenizerGroup)
57+
58+
Config = TypeVar("Config", bound=DataclassInstance)
5659
else:
5760
QuantizationConfig = None
61+
Config = TypeVar("Config")
5862

5963
logger = init_logger(__name__)
6064

@@ -159,7 +163,7 @@ def pairwise(iterable):
159163
return out
160164

161165

162-
def config(cls: type[Any]) -> type[Any]:
166+
def config(cls: type[Config]) -> type[Config]:
163167
"""
164168
A decorator that ensures all fields in a dataclass have default values
165169
and that each field has a docstring.
@@ -1431,44 +1435,47 @@ class LoadFormat(str, enum.Enum):
14311435
FASTSAFETENSORS = "fastsafetensors"
14321436

14331437

1438+
@config
14341439
@dataclass
14351440
class LoadConfig:
1436-
"""
1437-
download_dir: Directory to download and load the weights, default to the
1438-
default cache directory of huggingface.
1439-
load_format: The format of the model weights to load:
1440-
"auto" will try to load the weights in the safetensors format and
1441-
fall back to the pytorch bin format if safetensors format is
1442-
not available.
1443-
"pt" will load the weights in the pytorch bin format.
1444-
"safetensors" will load the weights in the safetensors format.
1445-
"npcache" will load the weights in pytorch format and store
1446-
a numpy cache to speed up the loading.
1447-
"dummy" will initialize the weights with random values, which is
1448-
mainly for profiling.
1449-
"tensorizer" will use CoreWeave's tensorizer library for
1450-
fast weight loading.
1451-
"bitsandbytes" will load nf4 type weights.
1452-
"sharded_state" will load weights from pre-sharded checkpoint files,
1453-
supporting efficient loading of tensor-parallel models.
1454-
"gguf" will load weights from GGUF format files.
1455-
"mistral" will load weights from consolidated safetensors files used
1456-
by Mistral models.
1457-
"runai_streamer" will load weights from RunAI streamer format files.
1458-
model_loader_extra_config: The extra config for the model loader.
1459-
ignore_patterns: The list of patterns to ignore when loading the model.
1460-
Default to "original/**/*" to avoid repeated loading of llama's
1461-
checkpoints.
1462-
use_tqdm_on_load: Whether to enable tqdm for showing progress bar during
1463-
loading. Default to True
1464-
"""
1465-
1466-
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
1441+
"""Configuration for loading the model weights."""
1442+
1443+
load_format: Union[str, LoadFormat,
1444+
"BaseModelLoader"] = LoadFormat.AUTO.value
1445+
"""The format of the model weights to load:\n
1446+
- "auto" will try to load the weights in the safetensors format and fall
1447+
back to the pytorch bin format if safetensors format is not available.\n
1448+
- "pt" will load the weights in the pytorch bin format.\n
1449+
- "safetensors" will load the weights in the safetensors format.\n
1450+
- "npcache" will load the weights in pytorch format and store a numpy cache
1451+
to speed up the loading.\n
1452+
- "dummy" will initialize the weights with random values, which is mainly
1453+
for profiling.\n
1454+
- "tensorizer" will use CoreWeave's tensorizer library for fast weight
1455+
loading. See the Tensorize vLLM Model script in the Examples section for
1456+
more information.\n
1457+
- "runai_streamer" will load the Safetensors weights using Run:ai Model
1458+
Streamer.\n
1459+
- "bitsandbytes" will load the weights using bitsandbytes quantization.\n
1460+
- "sharded_state" will load weights from pre-sharded checkpoint files,
1461+
supporting efficient loading of tensor-parallel models.\n
1462+
- "gguf" will load weights from GGUF format files (details specified in
1463+
https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n
1464+
- "mistral" will load weights from consolidated safetensors files used by
1465+
Mistral models."""
14671466
download_dir: Optional[str] = None
1468-
model_loader_extra_config: Optional[Union[str, dict]] = field(
1469-
default_factory=dict)
1467+
"""Directory to download and load the weights, default to the default
1468+
cache directory of Hugging Face."""
1469+
model_loader_extra_config: Optional[Union[str, dict]] = None
1470+
"""Extra config for model loader. This will be passed to the model loader
1471+
corresponding to the chosen load_format. This should be a JSON string that
1472+
will be parsed into a dictionary."""
14701473
ignore_patterns: Optional[Union[list[str], str]] = None
1474+
"""The list of patterns to ignore when loading the model. Default to
1475+
"original/**/*" to avoid repeated loading of llama's checkpoints."""
14711476
use_tqdm_on_load: bool = True
1477+
"""Whether to enable tqdm for showing progress bar when loading model
1478+
weights."""
14721479

14731480
def compute_hash(self) -> str:
14741481
"""

vllm/engine/arg_utils.py

Lines changed: 33 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ class EngineArgs:
101101
tokenizer_mode: str = 'auto'
102102
trust_remote_code: bool = False
103103
allowed_local_media_path: str = ""
104-
download_dir: Optional[str] = None
105-
load_format: str = 'auto'
104+
download_dir: Optional[str] = LoadConfig.download_dir
105+
load_format: str = LoadConfig.load_format
106106
config_format: ConfigFormat = ConfigFormat.AUTO
107107
dtype: str = 'auto'
108108
kv_cache_dtype: str = 'auto'
@@ -174,8 +174,10 @@ class EngineArgs:
174174
ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
175175
num_gpu_blocks_override: Optional[int] = None
176176
num_lookahead_slots: int = 0
177-
model_loader_extra_config: Optional[dict] = None
178-
ignore_patterns: Optional[Union[str, List[str]]] = None
177+
model_loader_extra_config: Optional[
178+
dict] = LoadConfig.model_loader_extra_config
179+
ignore_patterns: Optional[Union[str,
180+
List[str]]] = LoadConfig.ignore_patterns
179181
preemption_mode: Optional[str] = None
180182

181183
scheduler_delay_factor: float = 0.0
@@ -213,7 +215,7 @@ class EngineArgs:
213215
additional_config: Optional[Dict[str, Any]] = None
214216
enable_reasoning: Optional[bool] = None
215217
reasoning_parser: Optional[str] = None
216-
use_tqdm_on_load: bool = True
218+
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
217219

218220
def __post_init__(self):
219221
if not self.tokenizer:
@@ -234,9 +236,13 @@ def __post_init__(self):
234236
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
235237
"""Shared CLI arguments for vLLM engine."""
236238

239+
def is_type_in_union(cls: type[Any], type: type[Any]) -> bool:
240+
"""Check if the class is a type in a union type."""
241+
return get_origin(cls) is Union and type in get_args(cls)
242+
237243
def is_optional(cls: type[Any]) -> bool:
238244
"""Check if the class is an optional type."""
239-
return get_origin(cls) is Union and type(None) in get_args(cls)
245+
return is_type_in_union(cls, type(None))
240246

241247
def get_kwargs(cls: type[Any]) -> Dict[str, Any]:
242248
cls_docs = get_attr_docs(cls)
@@ -255,6 +261,10 @@ def get_kwargs(cls: type[Any]) -> Dict[str, Any]:
255261
if is_optional(field.type):
256262
kwargs[name]["type"] = nullable_str
257263
continue
264+
# Handle str in union fields
265+
if is_type_in_union(field.type, str):
266+
kwargs[name]["type"] = str
267+
continue
258268
kwargs[name]["type"] = field.type
259269
return kwargs
260270

@@ -333,38 +343,23 @@ def get_kwargs(cls: type[Any]) -> Dict[str, Any]:
333343
"from directories specified by the server file system. "
334344
"This is a security risk. "
335345
"Should only be enabled in trusted environments.")
336-
parser.add_argument('--download-dir',
337-
type=nullable_str,
338-
default=EngineArgs.download_dir,
339-
help='Directory to download and load the weights.')
340-
parser.add_argument(
341-
'--load-format',
342-
type=str,
343-
default=EngineArgs.load_format,
344-
choices=[f.value for f in LoadFormat],
345-
help='The format of the model weights to load.\n\n'
346-
'* "auto" will try to load the weights in the safetensors format '
347-
'and fall back to the pytorch bin format if safetensors format '
348-
'is not available.\n'
349-
'* "pt" will load the weights in the pytorch bin format.\n'
350-
'* "safetensors" will load the weights in the safetensors format.\n'
351-
'* "npcache" will load the weights in pytorch format and store '
352-
'a numpy cache to speed up the loading.\n'
353-
'* "dummy" will initialize the weights with random values, '
354-
'which is mainly for profiling.\n'
355-
'* "tensorizer" will load the weights using tensorizer from '
356-
'CoreWeave. See the Tensorize vLLM Model script in the Examples '
357-
'section for more information.\n'
358-
'* "runai_streamer" will load the Safetensors weights using Run:ai'
359-
'Model Streamer.\n'
360-
'* "bitsandbytes" will load the weights using bitsandbytes '
361-
'quantization.\n'
362-
'* "sharded_state" will load weights from pre-sharded checkpoint '
363-
'files, supporting efficient loading of tensor-parallel models\n'
364-
'* "gguf" will load weights from GGUF format files (details '
365-
'specified in https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n'
366-
'* "mistral" will load weights from consolidated safetensors files '
367-
'used by Mistral models.\n')
346+
# Model loading arguments
347+
load_kwargs = get_kwargs(LoadConfig)
348+
load_group = parser.add_argument_group(
349+
title="LoadConfig",
350+
description=LoadConfig.__doc__,
351+
)
352+
load_group.add_argument('--load-format',
353+
choices=[f.value for f in LoadFormat],
354+
**load_kwargs["load_format"])
355+
load_group.add_argument('--download-dir',
356+
**load_kwargs["download_dir"])
357+
load_group.add_argument('--model-loader-extra-config',
358+
**load_kwargs["model_loader_extra_config"])
359+
load_group.add_argument('--use-tqdm-on-load',
360+
action=argparse.BooleanOptionalAction,
361+
**load_kwargs["use_tqdm_on_load"])
362+
368363
parser.add_argument(
369364
'--config-format',
370365
default=EngineArgs.config_format,
@@ -770,14 +765,6 @@ def get_kwargs(cls: type[Any]) -> Dict[str, Any]:
770765
default=1,
771766
help=('Maximum number of forward steps per '
772767
'scheduler call.'))
773-
parser.add_argument(
774-
'--use-tqdm-on-load',
775-
dest='use_tqdm_on_load',
776-
action=argparse.BooleanOptionalAction,
777-
default=EngineArgs.use_tqdm_on_load,
778-
help='Whether to enable/disable progress bar '
779-
'when loading model weights.',
780-
)
781768

782769
parser.add_argument(
783770
'--multi-step-stream-outputs',
@@ -806,15 +793,6 @@ def get_kwargs(cls: type[Any]) -> Dict[str, Any]:
806793
default=None,
807794
help='The configurations for speculative decoding.'
808795
' Should be a JSON string.')
809-
810-
parser.add_argument('--model-loader-extra-config',
811-
type=nullable_str,
812-
default=EngineArgs.model_loader_extra_config,
813-
help='Extra config for model loader. '
814-
'This will be passed to the model loader '
815-
'corresponding to the chosen load_format. '
816-
'This should be a JSON string that will be '
817-
'parsed into a dictionary.')
818796
parser.add_argument(
819797
'--ignore-patterns',
820798
action="append",

vllm/utils.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from __future__ import annotations
44

5-
import argparse
65
import asyncio
76
import concurrent
87
import contextlib
@@ -25,13 +24,16 @@
2524
import subprocess
2625
import sys
2726
import tempfile
27+
import textwrap
2828
import threading
2929
import time
3030
import traceback
3131
import types
3232
import uuid
3333
import warnings
3434
import weakref
35+
from argparse import (Action, ArgumentDefaultsHelpFormatter, ArgumentParser,
36+
ArgumentTypeError)
3537
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
3638
from collections import UserDict, defaultdict
3739
from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable,
@@ -1209,7 +1211,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
12091211
return wrapper
12101212

12111213

1212-
class StoreBoolean(argparse.Action):
1214+
class StoreBoolean(Action):
12131215

12141216
def __call__(self, parser, namespace, values, option_string=None):
12151217
if values.lower() == "true":
@@ -1221,15 +1223,28 @@ def __call__(self, parser, namespace, values, option_string=None):
12211223
"Expected 'true' or 'false'.")
12221224

12231225

1224-
class SortedHelpFormatter(argparse.ArgumentDefaultsHelpFormatter):
1226+
class SortedHelpFormatter(ArgumentDefaultsHelpFormatter):
12251227
"""SortedHelpFormatter that sorts arguments by their option strings."""
12261228

1229+
def _split_lines(self, text, width):
1230+
"""
1231+
1. Sentences split across lines have their single newlines removed.
1232+
2. Paragraphs and explicit newlines are split into separate lines.
1233+
3. Each line is wrapped to the specified width (width of terminal).
1234+
"""
1235+
# The patterns also include whitespace after the newline
1236+
single_newline = re.compile("(?<!\n)\n(?!\n)\s*")
1237+
multiple_newlines = re.compile("\n{2,}\s*")
1238+
text = single_newline.sub(' ', text)
1239+
lines = re.split(multiple_newlines, text)
1240+
return sum([textwrap.wrap(line, width) for line in lines], [])
1241+
12271242
def add_arguments(self, actions):
12281243
actions = sorted(actions, key=lambda x: x.option_strings)
12291244
super().add_arguments(actions)
12301245

12311246

1232-
class FlexibleArgumentParser(argparse.ArgumentParser):
1247+
class FlexibleArgumentParser(ArgumentParser):
12331248
"""ArgumentParser that allows both underscore and dash in names."""
12341249

12351250
def __init__(self, *args, **kwargs):
@@ -1280,11 +1295,10 @@ def check_port(self, value):
12801295
value = int(value)
12811296
except ValueError:
12821297
msg = "Port must be an integer"
1283-
raise argparse.ArgumentTypeError(msg) from None
1298+
raise ArgumentTypeError(msg) from None
12841299

12851300
if not (1024 <= value <= 65535):
1286-
raise argparse.ArgumentTypeError(
1287-
"Port must be between 1024 and 65535")
1301+
raise ArgumentTypeError("Port must be between 1024 and 65535")
12881302

12891303
return value
12901304

0 commit comments

Comments
 (0)