Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,28 @@ def compute_hash(self) -> str:
factors.append(self.override_generation_config)
factors.append(self.rope_scaling)
factors.append(self.rope_theta)

# hf_config can control how the model looks!
factors.append(self.hf_config.to_json_string())
try:
hf_config_json = self.hf_config.to_json_string(use_diff=False)
except TypeError:
from transformers import PretrainedConfig

from vllm.utils.jsontree import json_map_leaves

# Handle nested HF configs with unserializable values gracefully
hf_config_json = json.dumps(
json_map_leaves(
lambda v: v.to_dict()
if isinstance(v, PretrainedConfig) else str(v),
self.hf_config.to_dict(),
),
indent=2,
sort_keys=True,
) + "\n"

factors.append(hf_config_json)

str_factors = str(factors)
assert_hashable(str_factors)
return hashlib.sha256(str(factors).encode()).hexdigest()
Expand Down
17 changes: 16 additions & 1 deletion vllm/model_executor/model_loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,11 @@ def device_loading_context(module: torch.nn.Module,
# New parameters or parameters already on target device are untouched


def get_model_architecture(
_MODEL_ARCH_BY_HASH = dict[str, tuple[type[nn.Module], str]]()
"""Caches the outputs of `_get_model_architecture`."""


def _get_model_architecture(
model_config: ModelConfig) -> tuple[type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", [])

Expand Down Expand Up @@ -209,6 +213,17 @@ def get_model_architecture(
return model_cls, arch


def get_model_architecture(
model_config: ModelConfig) -> tuple[type[nn.Module], str]:
key = model_config.compute_hash()
if key in _MODEL_ARCH_BY_HASH:
return _MODEL_ARCH_BY_HASH[key]

model_arch = _get_model_architecture(model_config)
_MODEL_ARCH_BY_HASH[key] = model_arch
return model_arch


def get_model_cls(model_config: ModelConfig) -> type[nn.Module]:
return get_model_architecture(model_config)[0]

Expand Down