File tree Expand file tree Collapse file tree 2 files changed +37
-2
lines changed
model_executor/model_loader Expand file tree Collapse file tree 2 files changed +37
-2
lines changed Original file line number Diff line number Diff line change @@ -322,8 +322,28 @@ def compute_hash(self) -> str:
322322 factors .append (self .override_generation_config )
323323 factors .append (self .rope_scaling )
324324 factors .append (self .rope_theta )
325+
325326 # hf_config can control how the model looks!
326- factors .append (self .hf_config .to_json_string ())
327+ try :
328+ hf_config_json = self .hf_config .to_json_string (use_diff = False )
329+ except TypeError :
330+ from transformers import PretrainedConfig
331+
332+ from vllm .utils .jsontree import json_map_leaves
333+
334+ # Handle nested HF configs with unserializable values gracefully
335+ hf_config_json = json .dumps (
336+ json_map_leaves (
337+ lambda v : v .to_dict ()
338+ if isinstance (v , PretrainedConfig ) else str (v ),
339+ self .hf_config .to_dict (),
340+ ),
341+ indent = 2 ,
342+ sort_keys = True ,
343+ ) + "\n "
344+
345+ factors .append (hf_config_json )
346+
327347 str_factors = str (factors )
328348 assert_hashable (str_factors )
329349 return hashlib .sha256 (str (factors ).encode ()).hexdigest ()
Original file line number Diff line number Diff line change @@ -165,7 +165,11 @@ def device_loading_context(module: torch.nn.Module,
165165 # New parameters or parameters already on target device are untouched
166166
167167
168- def get_model_architecture (
168+ _MODEL_ARCH_BY_HASH = dict [str , tuple [type [nn .Module ], str ]]()
169+ """Caches the outputs of `_get_model_architecture`."""
170+
171+
172+ def _get_model_architecture (
169173 model_config : ModelConfig ) -> tuple [type [nn .Module ], str ]:
170174 architectures = getattr (model_config .hf_config , "architectures" , [])
171175
@@ -209,6 +213,17 @@ def get_model_architecture(
209213 return model_cls , arch
210214
211215
216+ def get_model_architecture (
217+ model_config : ModelConfig ) -> tuple [type [nn .Module ], str ]:
218+ key = model_config .compute_hash ()
219+ if key in _MODEL_ARCH_BY_HASH :
220+ return _MODEL_ARCH_BY_HASH [key ]
221+
222+ model_arch = _get_model_architecture (model_config )
223+ _MODEL_ARCH_BY_HASH [key ] = model_arch
224+ return model_arch
225+
226+
212227def get_model_cls (model_config : ModelConfig ) -> type [nn .Module ]:
213228 return get_model_architecture (model_config )[0 ]
214229
You can’t perform that action at this time.
0 commit comments