diff --git a/vllm/config/model.py b/vllm/config/model.py index 21457d3660a2..5548ce83d617 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -322,8 +322,28 @@ def compute_hash(self) -> str: factors.append(self.override_generation_config) factors.append(self.rope_scaling) factors.append(self.rope_theta) + # hf_config can control how the model looks! - factors.append(self.hf_config.to_json_string()) + try: + hf_config_json = self.hf_config.to_json_string(use_diff=False) + except TypeError: + from transformers import PretrainedConfig + + from vllm.utils.jsontree import json_map_leaves + + # Handle nested HF configs with unserializable values gracefully + hf_config_json = json.dumps( + json_map_leaves( + lambda v: v.to_dict() + if isinstance(v, PretrainedConfig) else str(v), + self.hf_config.to_dict(), + ), + indent=2, + sort_keys=True, + ) + "\n" + + factors.append(hf_config_json) + str_factors = str(factors) assert_hashable(str_factors) return hashlib.sha256(str(factors).encode()).hexdigest() diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index bd1773c753a9..e007d431880e 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -165,7 +165,11 @@ def device_loading_context(module: torch.nn.Module, # New parameters or parameters already on target device are untouched -def get_model_architecture( +_MODEL_ARCH_BY_HASH = dict[str, tuple[type[nn.Module], str]]() +"""Caches the outputs of `_get_model_architecture`.""" + + +def _get_model_architecture( model_config: ModelConfig) -> tuple[type[nn.Module], str]: architectures = getattr(model_config.hf_config, "architectures", []) @@ -209,6 +213,17 @@ def get_model_architecture( return model_cls, arch +def get_model_architecture( + model_config: ModelConfig) -> tuple[type[nn.Module], str]: + key = model_config.compute_hash() + if key in _MODEL_ARCH_BY_HASH: + return _MODEL_ARCH_BY_HASH[key] + + model_arch = _get_model_architecture(model_config) + _MODEL_ARCH_BY_HASH[key] = model_arch + return model_arch + + def get_model_cls(model_config: ModelConfig) -> type[nn.Module]: return get_model_architecture(model_config)[0]