Skip to content

Commit df0d15d

Browse files
DarkLight1337charlifu
authored andcommitted
[Optimization] Avoid repeated model architecture conversion for pooling models (vllm-project#25261)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: charlifu <charlifu@amd.com>
1 parent 3e213d4 commit df0d15d

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

vllm/config/model.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,8 +322,28 @@ def compute_hash(self) -> str:
322322
factors.append(self.override_generation_config)
323323
factors.append(self.rope_scaling)
324324
factors.append(self.rope_theta)
325+
325326
# hf_config can control how the model looks!
326-
factors.append(self.hf_config.to_json_string())
327+
try:
328+
hf_config_json = self.hf_config.to_json_string(use_diff=False)
329+
except TypeError:
330+
from transformers import PretrainedConfig
331+
332+
from vllm.utils.jsontree import json_map_leaves
333+
334+
# Handle nested HF configs with unserializable values gracefully
335+
hf_config_json = json.dumps(
336+
json_map_leaves(
337+
lambda v: v.to_dict()
338+
if isinstance(v, PretrainedConfig) else str(v),
339+
self.hf_config.to_dict(),
340+
),
341+
indent=2,
342+
sort_keys=True,
343+
) + "\n"
344+
345+
factors.append(hf_config_json)
346+
327347
str_factors = str(factors)
328348
assert_hashable(str_factors)
329349
return hashlib.sha256(str(factors).encode()).hexdigest()

vllm/model_executor/model_loader/utils.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,11 @@ def device_loading_context(module: torch.nn.Module,
165165
# New parameters or parameters already on target device are untouched
166166

167167

168-
def get_model_architecture(
168+
_MODEL_ARCH_BY_HASH = dict[str, tuple[type[nn.Module], str]]()
169+
"""Caches the outputs of `_get_model_architecture`."""
170+
171+
172+
def _get_model_architecture(
169173
model_config: ModelConfig) -> tuple[type[nn.Module], str]:
170174
architectures = getattr(model_config.hf_config, "architectures", [])
171175

@@ -209,6 +213,17 @@ def get_model_architecture(
209213
return model_cls, arch
210214

211215

216+
def get_model_architecture(
217+
model_config: ModelConfig) -> tuple[type[nn.Module], str]:
218+
key = model_config.compute_hash()
219+
if key in _MODEL_ARCH_BY_HASH:
220+
return _MODEL_ARCH_BY_HASH[key]
221+
222+
model_arch = _get_model_architecture(model_config)
223+
_MODEL_ARCH_BY_HASH[key] = model_arch
224+
return model_arch
225+
226+
212227
def get_model_cls(model_config: ModelConfig) -> type[nn.Module]:
213228
return get_model_architecture(model_config)[0]
214229

0 commit comments

Comments
 (0)