Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
10d543b
Enable Pydantic mypy plugin
hmellor May 2, 2025
29fc238
Fix Pydantic errors in `protocol.py`
hmellor May 2, 2025
0ecc5d2
Fix other mypy errors
hmellor May 2, 2025
ad9535a
Convert dataclasses to pydantic dataclasses
hmellor May 2, 2025
2ec9ba6
Merge branch 'main' into enable-pydantic-mypy
hmellor May 4, 2025
0426733
Merge branch 'main' into enable-pydantic-mypy
hmellor May 12, 2025
f6b1be7
Fix missing imports
hmellor May 12, 2025
b1f0fdc
Make mypy pass
hmellor May 12, 2025
d2ff0da
Assert no longer needed
hmellor May 12, 2025
2ed29c8
Pydantic base models correctnyl handle mutable defaults
hmellor May 12, 2025
bb47b84
remove parenthesis
hmellor May 12, 2025
30fcc16
remove comment
hmellor May 12, 2025
73dba34
Remove model validator only used for tokenizer
hmellor May 12, 2025
eca37a2
Fix Pydantic runtime error
hmellor May 12, 2025
a72251f
Add Pydantic validation to dataclass instantiation from CLI
hmellor May 12, 2025
712b312
Skip validation for defaults which are deferred
hmellor May 12, 2025
4397de9
Fix docs build
hmellor May 12, 2025
caa1dc5
Fix docs build 2
hmellor May 15, 2025
b267203
Merge branch 'main' into enable-pydantic-mypy
hmellor May 15, 2025
96876a6
Merge branch 'main' into enable-pydantic-mypy
hmellor May 16, 2025
18ea0ac
`VllmConfig.compilation_config` should never be `None`
hmellor May 16, 2025
905ab3e
Type adapter works for non-Pydantic dataclasses
hmellor May 16, 2025
80d03a6
Using stdlib dataclass when not type checking breaks pydantic validation
hmellor May 16, 2025
92e2b75
Fix `compilation_config_instance` being `None`
hmellor May 16, 2025
e995cc0
Use stdlib dataclasses when not in docs build
hmellor May 16, 2025
fdea28a
Undo whitespace change
hmellor May 16, 2025
ead89d7
Make docs build pass
hmellor May 16, 2025
0f82baf
Merge branch 'main' into enable-pydantic-mypy
hmellor May 26, 2025
ae07688
Fix now unused import
hmellor May 26, 2025
4c9ac3d
Remove sphinx stuff now that we use mkdocs
hmellor May 26, 2025
d900098
`hf_overrides` is not optional in `EngineArgs`
hmellor May 27, 2025
934050a
Fix validation errors
hmellor May 27, 2025
2aaa91f
Fix LoRA test
hmellor May 28, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ repos:
entry: tools/mypy.sh 0 "local"
language: python
types: [python]
additional_dependencies: &mypy_deps [mypy==1.11.1, types-cachetools, types-setuptools, types-PyYAML, types-requests]
additional_dependencies: &mypy_deps [mypy==1.11.1, types-cachetools, types-setuptools, types-PyYAML, types-requests, pydantic]
stages: [pre-commit] # Don't run in CI
- id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.9
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ ignore = [
]

[tool.mypy]
plugins = ['pydantic.mypy']
ignore_missing_imports = true
check_untyped_defs = true
follow_imports = "silent"
Expand Down
14 changes: 7 additions & 7 deletions tests/lora/test_quant_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,16 @@ class ModelWithQuantization:
MODELS = [
ModelWithQuantization(
model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
quantization="GPTQ"),
quantization="gptq"),
]
else:
MODELS = [
ModelWithQuantization(
model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ",
quantization="AWQ"),
quantization="awq"),
ModelWithQuantization(
model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
quantization="GPTQ"),
quantization="gptq"),
]


Expand Down Expand Up @@ -100,7 +100,7 @@ def test_quant_model_lora(tinyllama_lora_files, model):
"#ff8050",
"#ff8080",
]
elif model.quantization == "AWQ":
elif model.quantization == "awq":
expected_no_lora_output = [
"I'm sorry, I don't understand",
"I'm sorry, I don't understand",
Expand All @@ -109,7 +109,7 @@ def test_quant_model_lora(tinyllama_lora_files, model):
"#f07700: A v",
"#f00000: A v",
]
elif model.quantization == "GPTQ":
elif model.quantization == "gptq":
expected_no_lora_output = [
"I'm sorry, I don't have",
"I'm sorry, I don't have",
Expand All @@ -122,7 +122,7 @@ def test_quant_model_lora(tinyllama_lora_files, model):
def expect_match(output, expected_output):
# HACK: GPTQ lora outputs are just incredibly unstable.
# Assert that the outputs changed.
if (model.quantization == "GPTQ"
if (model.quantization == "gptq"
and expected_output is expected_lora_output):
assert output != expected_no_lora_output
for i, o in enumerate(output):
Expand Down Expand Up @@ -172,7 +172,7 @@ def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available,
model):
if num_gpus_available < 2:
pytest.skip(f"Not enough GPUs for tensor parallelism {2}")
if model.quantization == "GPTQ":
if model.quantization == "gptq":
pytest.skip("GPTQ lora outputs are just incredibly unstable")
llm_tp1 = vllm.LLM(
model=model.model_path,
Expand Down
2 changes: 1 addition & 1 deletion tests/tracing/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def test_traces_with_detailed_steps(
llm = LLM(
model=model,
otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS,
collect_detailed_traces="all",
collect_detailed_traces=["all"],
)
prompts = ["This is a short prompt"]
outputs = llm.generate(prompts, sampling_params=sampling_params)
Expand Down
108 changes: 60 additions & 48 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import warnings
from collections import Counter
from contextlib import contextmanager
from dataclasses import (MISSING, Field, asdict, dataclass, field, fields,
is_dataclass, replace)
from dataclasses import (MISSING, Field, asdict, field, fields, is_dataclass,
replace)
from functools import cached_property
from importlib.util import find_spec
from pathlib import Path
Expand All @@ -21,9 +21,12 @@

import regex as re
import torch
from pydantic import (ConfigDict, SkipValidation, TypeAdapter, field_validator,
model_validator)
from pydantic.dataclasses import dataclass
from torch.distributed import ProcessGroup, ReduceOp
from transformers import PretrainedConfig
from typing_extensions import deprecated
from typing_extensions import deprecated, runtime_checkable

import vllm.envs as envs
from vllm import version
Expand Down Expand Up @@ -57,10 +60,15 @@
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.model_loader import BaseModelLoader
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig

ConfigType = type[DataclassInstance]
else:
PlacementGroup = Any
ExecutorBase = Any
QuantizationConfig = Any
BaseModelLoader = Any
TensorizerConfig = Any
ConfigType = type

logger = init_logger(__name__)
Expand Down Expand Up @@ -92,6 +100,7 @@
PretrainedConfig]]


@runtime_checkable
class SupportsHash(Protocol):

def compute_hash(self) -> str:
Expand Down Expand Up @@ -223,7 +232,7 @@ def is_init_field(cls: ConfigType, name: str) -> bool:


@config
@dataclass
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class ModelConfig:
"""Configuration for the model."""

Expand All @@ -236,7 +245,7 @@ class ModelConfig:
task, even if the same model can be used for multiple tasks. When the model
only supports one task, "auto" can be used to select it; otherwise, you
must specify explicitly which task to use."""
tokenizer: str = None # type: ignore
tokenizer: SkipValidation[str] = None # type: ignore
"""Name or path of the Hugging Face tokenizer to use. If unspecified, model
name or path will be used."""
tokenizer_mode: TokenizerMode = "auto"
Expand Down Expand Up @@ -284,7 +293,7 @@ class ModelConfig:
"""The specific revision to use for the tokenizer on the Hugging Face Hub.
It can be a branch name, a tag name, or a commit id. If unspecified, will
use the default version."""
max_model_len: int = None # type: ignore
max_model_len: SkipValidation[int] = None # type: ignore
"""Model context length (prompt and output). If unspecified, will be
automatically derived from the model config.

Expand Down Expand Up @@ -607,6 +616,22 @@ def __post_init__(self) -> None:
self._verify_cuda_graph()
self._verify_bnb_config()

@field_validator("quantization", mode="before")
@classmethod
def validate_quantization_before(cls, value: Any) -> Any:
if isinstance(value, str):
return value.lower()
return value

@model_validator(mode="after")
def validate_model_config_after(self: "ModelConfig") -> "ModelConfig":
if not isinstance(self.tokenizer, str):
raise ValueError("tokenizer must be a string after __post_init__.")
if not isinstance(self.max_model_len, int):
raise ValueError(
"max_model_len must be an integer after __post_init__.")
return self

@property
def registry(self):
return ModelRegistry
Expand Down Expand Up @@ -833,8 +858,7 @@ def _verify_quantization(self) -> None:
"quark", "modelopt_fp4", "bitblas", "gptq_bitblas"
]
if self.quantization is not None:
self.quantization = cast(QuantizationMethods,
self.quantization.lower())
self.quantization = cast(QuantizationMethods, self.quantization)

# Parse quantization method from the HF model config, if available.
quant_cfg = self._parse_quant_hf_config()
Expand Down Expand Up @@ -1397,7 +1421,7 @@ def matryoshka_dimensions(self):
class CacheConfig:
"""Configuration for the KV cache."""

block_size: BlockSize = None # type: ignore
block_size: SkipValidation[BlockSize] = None # type: ignore
"""Size of a contiguous cache block in number of tokens. This is ignored on
neuron devices and set to `--max-model-len`. On CUDA devices, only block
sizes up to 32 are supported. On HPU devices, block size defaults to 128.
Expand Down Expand Up @@ -1619,7 +1643,8 @@ class LoadConfig:
download_dir: Optional[str] = None
"""Directory to download and load the weights, default to the default
cache directory of Hugging Face."""
model_loader_extra_config: dict = field(default_factory=dict)
model_loader_extra_config: Union[dict, TensorizerConfig] = field(
default_factory=dict)
"""Extra config for model loader. This will be passed to the model loader
corresponding to the chosen load_format."""
ignore_patterns: Optional[Union[list[str], str]] = None
Expand Down Expand Up @@ -1929,19 +1954,19 @@ class SchedulerConfig:
runner_type: RunnerType = "generate"
"""The runner type to launch for the model."""

max_num_batched_tokens: int = None # type: ignore
max_num_batched_tokens: SkipValidation[int] = None # type: ignore
"""Maximum number of tokens to be processed in a single iteration.

This config has no static default. If left unspecified by the user, it will
be set in `EngineArgs.create_engine_config` based on the usage context."""

max_num_seqs: int = None # type: ignore
max_num_seqs: SkipValidation[int] = None # type: ignore
"""Maximum number of sequences to be processed in a single iteration.

This config has no static default. If left unspecified by the user, it will
be set in `EngineArgs.create_engine_config` based on the usage context."""

max_model_len: int = None # type: ignore
max_model_len: SkipValidation[int] = None # type: ignore
"""Maximum length of a sequence (including prompt and generated text). This
is primarily set in `ModelConfig` and that value should be manually
duplicated here."""
Expand Down Expand Up @@ -1980,7 +2005,7 @@ class SchedulerConfig:
"""Apply a delay (of delay factor multiplied by previous
prompt latency) before scheduling next prompt."""

enable_chunked_prefill: bool = None # type: ignore
enable_chunked_prefill: SkipValidation[bool] = None # type: ignore
"""If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens."""

Expand Down Expand Up @@ -2202,7 +2227,7 @@ def is_multi_step(self) -> bool:


@config
@dataclass
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class DeviceConfig:
"""Configuration for the device to use for vLLM execution."""

Expand Down Expand Up @@ -2260,8 +2285,8 @@ def __post_init__(self):
self.device = torch.device(self.device_type)


SpeculativeMethod = Literal["ngram", "eagle", "medusa", "mlp_speculator",
"draft_model", "deepseek_mtp"]
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
"mlp_speculator", "draft_model", "deepseek_mtp"]
SpeculativeAcceptanceMethod = Literal["rejection_sampler",
"typical_acceptance_sampler"]

Expand All @@ -2272,8 +2297,7 @@ class SpeculativeConfig:
"""Configuration for speculative decoding."""

# General speculative decoding control
num_speculative_tokens: int = field(default=None,
init=True) # type: ignore
num_speculative_tokens: SkipValidation[int] = None # type: ignore
"""The number of speculative tokens, if provided. It will default to the
number in the draft model config if present, otherwise, it is required."""
model: Optional[str] = None
Expand Down Expand Up @@ -2349,26 +2373,23 @@ class SpeculativeConfig:
"""Specifies the tree structure for speculative token generation.
"""
# required configuration params passed from engine
target_model_config: ModelConfig = field(default=None,
init=True) # type: ignore
target_model_config: SkipValidation[ModelConfig] = None # type: ignore
"""The configuration of the target model."""
target_parallel_config: ParallelConfig = field(default=None,
init=True) # type: ignore
target_parallel_config: SkipValidation[
ParallelConfig] = None # type: ignore
"""The parallel configuration for the target model."""
enable_chunked_prefill: bool = field(default=None,
init=True) # type: ignore
enable_chunked_prefill: SkipValidation[bool] = None # type: ignore
"""Whether vLLM is configured to use chunked prefill or not. Used for
raising an error since it's not yet compatible with speculative decode."""
disable_log_stats: bool = field(default=None, init=True) # type: ignore
disable_log_stats: SkipValidation[bool] = None # type: ignore
"""Whether to disable the periodic printing of stage times in speculative
decoding."""

# params generated in the post-init stage
draft_model_config: ModelConfig = field(default=None,
init=True) # type: ignore
draft_model_config: SkipValidation[ModelConfig] = None # type: ignore
"""The configuration of the draft model initialized internal."""
draft_parallel_config: ParallelConfig = field(default=None,
init=True) # type: ignore
draft_parallel_config: SkipValidation[
ParallelConfig] = None # type: ignore
"""The parallel configuration for the draft model initialized internal."""

def compute_hash(self) -> str:
Expand Down Expand Up @@ -2766,7 +2787,7 @@ def __repr__(self) -> str:


@config
@dataclass
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class LoRAConfig:
"""Configuration for LoRA."""

Expand Down Expand Up @@ -2863,7 +2884,7 @@ def verify_lora_support(self):


@config
@dataclass
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class PromptAdapterConfig:
"""Configuration for PromptAdapters."""

Expand Down Expand Up @@ -3888,17 +3909,11 @@ def __repr__(self) -> str:
"pass_config",
"traced_files",
}
include = dict()
for k, v in asdict(self).items():
if k in exclude:
continue
f = get_field(CompilationConfig, k)
if (d := f.default) is not MISSING and d == v:
continue
if (df := f.default_factory) is not MISSING and df() == v:
continue
include[k] = v
return json.dumps(include)
# The cast to string is necessary because Pydantic is mocked in docs
# builds and sphinx-argparse doesn't know the return type of decode()
return str(
TypeAdapter(CompilationConfig).dump_json(
self, exclude=exclude, exclude_unset=True).decode())

__str__ = __repr__

Expand All @@ -3907,7 +3922,7 @@ def from_cli(cls, cli_value: str) -> "CompilationConfig":
"""Parse the CLI value for the compilation config."""
if cli_value in ["0", "1", "2", "3"]:
return cls(level=int(cli_value))
return cls(**json.loads(cli_value))
return TypeAdapter(CompilationConfig).validate_json(cli_value)

def __post_init__(self) -> None:
count_none = self.custom_ops.count("none")
Expand Down Expand Up @@ -4033,7 +4048,7 @@ def set_splitting_ops_for_v1(self):


@config
@dataclass
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class VllmConfig:
"""Dataclass which contains all vllm-related configuration. This
simplifies passing around the distinct configurations in the codebase.
Expand Down Expand Up @@ -4290,9 +4305,6 @@ def __post_init__(self):
"To workaround this limitation, vLLM will set 'ieee' input "
"precision for chunked prefill triton kernels.")

if self.compilation_config is None:
self.compilation_config = CompilationConfig()

# async tp is built on top of sequence parallelism
# and requires it to be enabled.
if self.compilation_config.pass_config.enable_async_tp:
Expand Down
Loading