From 9fd25421fe3c8dbb8eabcaaf5972a4f8c3eeb73e Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 25 Sep 2025 15:20:28 +0000 Subject: [PATCH 1/3] [Optimization] Cache the hash of the model config Signed-off-by: DarkLight1337 --- tests/test_config.py | 10 ++++++++++ vllm/config/model.py | 41 ++++++++++++++++++++++++++++++++++++++--- 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/tests/test_config.py b/tests/test_config.py index 90d0c78c451f..caa866fa5028 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -77,6 +77,16 @@ def test_update_config(): new_config3 = update_config(config3, {"a": "new_value"}) +def test_model_config_recompute_hash(): + config = ModelConfig("Qwen/Qwen2.5-1.5B-Instruct", task="auto") + orig_hash = config.compute_hash() + + config.model = "Qwen/Qwen2.5-0.5B-Instruct" + new_hash = config.compute_hash() + + assert orig_hash != new_hash + + # Can remove once --task option is fully deprecated @pytest.mark.parametrize( ("model_id", "expected_runner_type", "expected_convert_type", diff --git a/vllm/config/model.py b/vllm/config/model.py index 0ded70388b8a..fb26b3aba05d 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -6,8 +6,8 @@ import warnings from dataclasses import InitVar, field from importlib.util import find_spec -from typing import (TYPE_CHECKING, Any, Callable, Literal, Optional, Union, - cast, get_args) +from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional, + Union, cast, get_args) import torch from pydantic import (ConfigDict, SkipValidation, field_validator, @@ -285,6 +285,28 @@ class ModelConfig: interleave_mm_strings: InitVar[Optional[bool]] = None skip_mm_profiling: InitVar[Optional[bool]] = None + _hash: Optional[str] = None + """The cached hash value, or `None` if not cached.""" + + _HASH_FACTORS: ClassVar[tuple[str, ...]] = ( + "model", + "dtype", + "quantization", + "revision", + "code_revision", + "max_model_len", + "max_logprobs", + "disable_sliding_window", + "trust_remote_code", + "generation_config", + "model_impl", + "override_generation_config", + "rope_scaling", + "rope_theta", + "hf_config", + ) + """Used to determine whether the hash needs to be recomputed.""" + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -297,6 +319,9 @@ def compute_hash(self) -> str: excluding anything before input ids/embeddings and after the final hidden states. """ + if self._hash is not None: + return self._hash + factors: list[Any] = [] factors.append(self.model) factors.append(self.dtype) @@ -336,7 +361,17 @@ def compute_hash(self) -> str: str_factors = str(factors) assert_hashable(str_factors) - return hashlib.sha256(str(factors).encode()).hexdigest() + + out_hash = hashlib.sha256(str_factors.encode()).hexdigest() + self._hash = out_hash + return out_hash + + def __setattr__(self, name: str, value: Any, /) -> None: + super().__setattr__(name, value) + + # Trigger recomputation next time compute_hash is called + if name in self._HASH_FACTORS: + self._hash = None def __post_init__( self, From 45371cec31e7da5b4b27374f516f8b67600fd73d Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 25 Sep 2025 18:14:05 +0000 Subject: [PATCH 2/3] Revert "[Optimization] Cache the hash of the model config" This reverts commit 9fd25421fe3c8dbb8eabcaaf5972a4f8c3eeb73e. --- tests/test_config.py | 10 ---------- vllm/config/model.py | 41 +++-------------------------------------- 2 files changed, 3 insertions(+), 48 deletions(-) diff --git a/tests/test_config.py b/tests/test_config.py index caa866fa5028..90d0c78c451f 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -77,16 +77,6 @@ def test_update_config(): new_config3 = update_config(config3, {"a": "new_value"}) -def test_model_config_recompute_hash(): - config = ModelConfig("Qwen/Qwen2.5-1.5B-Instruct", task="auto") - orig_hash = config.compute_hash() - - config.model = "Qwen/Qwen2.5-0.5B-Instruct" - new_hash = config.compute_hash() - - assert orig_hash != new_hash - - # Can remove once --task option is fully deprecated @pytest.mark.parametrize( ("model_id", "expected_runner_type", "expected_convert_type", diff --git a/vllm/config/model.py b/vllm/config/model.py index fb26b3aba05d..0ded70388b8a 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -6,8 +6,8 @@ import warnings from dataclasses import InitVar, field from importlib.util import find_spec -from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional, - Union, cast, get_args) +from typing import (TYPE_CHECKING, Any, Callable, Literal, Optional, Union, + cast, get_args) import torch from pydantic import (ConfigDict, SkipValidation, field_validator, @@ -285,28 +285,6 @@ class ModelConfig: interleave_mm_strings: InitVar[Optional[bool]] = None skip_mm_profiling: InitVar[Optional[bool]] = None - _hash: Optional[str] = None - """The cached hash value, or `None` if not cached.""" - - _HASH_FACTORS: ClassVar[tuple[str, ...]] = ( - "model", - "dtype", - "quantization", - "revision", - "code_revision", - "max_model_len", - "max_logprobs", - "disable_sliding_window", - "trust_remote_code", - "generation_config", - "model_impl", - "override_generation_config", - "rope_scaling", - "rope_theta", - "hf_config", - ) - """Used to determine whether the hash needs to be recomputed.""" - def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -319,9 +297,6 @@ def compute_hash(self) -> str: excluding anything before input ids/embeddings and after the final hidden states. """ - if self._hash is not None: - return self._hash - factors: list[Any] = [] factors.append(self.model) factors.append(self.dtype) @@ -361,17 +336,7 @@ def compute_hash(self) -> str: str_factors = str(factors) assert_hashable(str_factors) - - out_hash = hashlib.sha256(str_factors.encode()).hexdigest() - self._hash = out_hash - return out_hash - - def __setattr__(self, name: str, value: Any, /) -> None: - super().__setattr__(name, value) - - # Trigger recomputation next time compute_hash is called - if name in self._HASH_FACTORS: - self._hash = None + return hashlib.sha256(str(factors).encode()).hexdigest() def __post_init__( self, From fa7579a3c2ff70a8e1b32f375592749fe82353f3 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 25 Sep 2025 18:19:37 +0000 Subject: [PATCH 3/3] Use a cheaper cache key Signed-off-by: DarkLight1337 --- vllm/model_executor/model_loader/utils.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index e007d431880e..03202e13c280 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -165,7 +165,7 @@ def device_loading_context(module: torch.nn.Module, # New parameters or parameters already on target device are untouched -_MODEL_ARCH_BY_HASH = dict[str, tuple[type[nn.Module], str]]() +_MODEL_ARCH_BY_HASH = dict[int, tuple[type[nn.Module], str]]() """Caches the outputs of `_get_model_architecture`.""" @@ -215,7 +215,14 @@ def _get_model_architecture( def get_model_architecture( model_config: ModelConfig) -> tuple[type[nn.Module], str]: - key = model_config.compute_hash() + key = hash(( + model_config.model, + model_config.convert_type, + model_config.runner_type, + model_config.trust_remote_code, + model_config.model_impl, + tuple(getattr(model_config.hf_config, "architectures", [])), + )) if key in _MODEL_ARCH_BY_HASH: return _MODEL_ARCH_BY_HASH[key]