Skip to content

Commit b58c30d

Browse files
vrdn-23hmellor
authored andcommitted
[Feature] Change cache.py with pydantic validation (vllm-project#26390)
Signed-off-by: Vinay Damodaran <vrdn@hey.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent a9f479f commit b58c30d

File tree

2 files changed

+25
-60
lines changed

2 files changed

+25
-60
lines changed

vllm/config/cache.py

Lines changed: 17 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,11 @@
33

44
import hashlib
55
from dataclasses import field
6-
from typing import TYPE_CHECKING, Any, Literal, Optional, get_args
6+
from typing import TYPE_CHECKING, Any, Literal, Optional
77

8-
from pydantic import SkipValidation, model_validator
8+
from pydantic import Field, SkipValidation, field_validator
99
from pydantic.dataclasses import dataclass
10-
from typing_extensions import Self
1110

12-
import vllm.envs as envs
1311
from vllm.config.utils import config
1412
from vllm.logger import init_logger
1513
from vllm.utils import GiB_bytes, get_cpu_memory
@@ -39,15 +37,15 @@ class CacheConfig:
3937
This config has no static default. If left unspecified by the user, it will
4038
be set in `Platform.check_and_update_config()` based on the current
4139
platform."""
42-
gpu_memory_utilization: float = 0.9
40+
gpu_memory_utilization: float = Field(default=0.9, gt=0, le=1)
4341
"""The fraction of GPU memory to be used for the model executor, which can
4442
range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory
4543
utilization. If unspecified, will use the default value of 0.9. This is a
4644
per-instance limit, and only applies to the current vLLM instance. It does
4745
not matter if you have another vLLM instance running on the same GPU. For
4846
example, if you have two vLLM instances running on the same GPU, you can
4947
set the GPU memory utilization to 0.5 for each instance."""
50-
swap_space: float = 4
48+
swap_space: float = Field(default=4, ge=0)
5149
"""Size of the CPU swap space per GPU (in GiB)."""
5250
cache_dtype: CacheDType = "auto"
5351
"""Data type for kv cache storage. If "auto", will use model data type.
@@ -73,7 +71,7 @@ class CacheConfig:
7371
- "sha256" uses Pickle for object serialization before hashing.\n
7472
- "sha256_cbor" provides a reproducible, cross-language compatible hash. It
7573
serializes objects using canonical CBOR and hashes them with SHA-256."""
76-
cpu_offload_gb: float = 0
74+
cpu_offload_gb: float = Field(default=0, ge=0)
7775
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means
7876
no offloading. Intuitively, this argument can be seen as a virtual way to
7977
increase the GPU memory size. For example, if you have one 24 GB GPU and
@@ -147,74 +145,33 @@ def compute_hash(self) -> str:
147145
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
148146
return hash_str
149147

150-
def __post_init__(self) -> None:
151-
self.swap_space_bytes = self.swap_space * GiB_bytes
152-
153-
self._verify_cache_dtype()
154-
self._verify_prefix_caching()
155-
156148
def metrics_info(self):
157149
# convert cache_config to dict(key: str, value: str) for prometheus
158150
# metrics info
159151
return {key: str(value) for key, value in self.__dict__.items()}
160152

161-
@model_validator(mode="after")
162-
def _verify_args(self) -> Self:
163-
if self.cpu_offload_gb < 0:
164-
raise ValueError(
165-
f"CPU offload space must be non-negative, but got {self.cpu_offload_gb}"
166-
)
167-
168-
if self.gpu_memory_utilization > 1.0:
169-
raise ValueError(
170-
"GPU memory utilization must be less than 1.0. Got "
171-
f"{self.gpu_memory_utilization}."
172-
)
173-
174-
return self
175-
176-
def _verify_cache_dtype(self) -> None:
177-
if self.cache_dtype == "auto":
178-
pass
179-
elif self.cache_dtype in get_args(CacheDType):
180-
if self.cache_dtype.startswith("fp8"):
181-
logger.info(
182-
"Using fp8 data type to store kv cache. It reduces the GPU "
183-
"memory footprint and boosts the performance. "
184-
"Meanwhile, it may cause accuracy drop without a proper "
185-
"scaling factor."
186-
)
187-
else:
188-
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
189-
190-
def _verify_prefix_caching(self) -> None:
191-
if not self.enable_prefix_caching:
192-
return
193-
194-
if self.sliding_window is not None and not envs.VLLM_USE_V1:
195-
raise NotImplementedError(
196-
"Prefix caching is not supported with sliding window. "
197-
"Run with --disable-sliding-window to use prefix caching."
198-
)
199-
200-
if self.enable_prefix_caching and self.prefix_caching_hash_algo not in get_args(
201-
PrefixCachingHashAlgo
202-
):
203-
raise ValueError(
204-
"Unknown prefix caching hash algorithm: "
205-
f"{self.prefix_caching_hash_algo}. Must be one of "
206-
f"{get_args(PrefixCachingHashAlgo)}."
153+
@field_validator("cache_dtype", mode="after")
154+
@classmethod
155+
def _validate_cache_dtype(cls, cache_dtype: CacheDType) -> CacheDType:
156+
if cache_dtype.startswith("fp8"):
157+
logger.info(
158+
"Using fp8 data type to store kv cache. It reduces the GPU "
159+
"memory footprint and boosts the performance. "
160+
"Meanwhile, it may cause accuracy drop without a proper "
161+
"scaling factor."
207162
)
163+
return cache_dtype
208164

209165
def verify_with_parallel_config(
210166
self,
211167
parallel_config: ParallelConfig,
212168
) -> None:
169+
swap_space_bytes = self.swap_space * GiB_bytes
213170
total_cpu_memory = get_cpu_memory()
214171
# FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
215172
# group are in the same node. However, the GPUs may span multiple nodes.
216173
num_gpus_per_node = parallel_config.tensor_parallel_size
217-
cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node
174+
cpu_memory_usage = swap_space_bytes * num_gpus_per_node
218175

219176
msg = (
220177
f"{cpu_memory_usage / GiB_bytes:.2f} GiB out of the "

vllm/engine/arg_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import regex as re
2828
import torch
2929
from pydantic import TypeAdapter, ValidationError
30+
from pydantic.fields import FieldInfo
3031
from typing_extensions import TypeIs, deprecated
3132

3233
import vllm.envs as envs
@@ -209,6 +210,13 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
209210
# Get the default value of the field
210211
if field.default is not MISSING:
211212
default = field.default
213+
# Handle pydantic.Field defaults
214+
if isinstance(default, FieldInfo):
215+
default = (
216+
default.default
217+
if default.default_factory is None
218+
else default.default_factory()
219+
)
212220
elif field.default_factory is not MISSING:
213221
default = field.default_factory()
214222

0 commit comments

Comments
 (0)