|
43 | 43 | from vllm.sequence import IntermediateTensors |
44 | 44 |
|
45 | 45 | from .interfaces import SupportsPP, SupportsQuant, SupportsV0Only |
46 | | -from .utils import (is_pp_missing_parameter, |
| 46 | +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, |
47 | 47 | make_empty_intermediate_tensors_factory, make_layers, |
48 | 48 | maybe_prefix) |
49 | 49 |
|
@@ -229,6 +229,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
229 | 229 | config = vllm_config.model_config.hf_config |
230 | 230 | cache_config = vllm_config.cache_config |
231 | 231 | quant_config = vllm_config.quant_config |
| 232 | + self.config = config |
232 | 233 |
|
233 | 234 | self.embed_dim = config.hidden_size |
234 | 235 |
|
@@ -278,6 +279,38 @@ def forward( |
278 | 279 | hidden_states = self.ln_f(hidden_states) |
279 | 280 | return hidden_states |
280 | 281 |
|
| 282 | + def load_weights(self, weights: Iterable[tuple[str, |
| 283 | + torch.Tensor]]) -> set[str]: |
| 284 | + params_dict = dict(self.named_parameters(remove_duplicate=False)) |
| 285 | + loaded_params: set[str] = set() |
| 286 | + for name, loaded_weight in weights: |
| 287 | + if is_pp_missing_parameter(name, self): |
| 288 | + continue |
| 289 | + param = params_dict[name] |
| 290 | + |
| 291 | + if "query_key_value" in name: |
| 292 | + # NOTE: BLOOM's fused QKV's output_dim has the shape of |
| 293 | + # (num_heads * 3 * head_size), while the |
| 294 | + # required shape is (3 * num_heads * head_size). |
| 295 | + # Thus, we need weight conversion. |
| 296 | + output_dim = getattr(param, "output_dim", None) |
| 297 | + num_heads = self.config.num_attention_heads |
| 298 | + if output_dim is not None: |
| 299 | + loaded_weight_shape = loaded_weight.shape |
| 300 | + loaded_weight = loaded_weight.view( |
| 301 | + loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + |
| 302 | + loaded_weight_shape[output_dim + 1:]) |
| 303 | + loaded_weight = loaded_weight.transpose( |
| 304 | + output_dim, output_dim + 1) |
| 305 | + loaded_weight = loaded_weight.reshape(loaded_weight_shape) |
| 306 | + |
| 307 | + weight_loader = getattr(param, "weight_loader", |
| 308 | + default_weight_loader) |
| 309 | + weight_loader(param, loaded_weight) |
| 310 | + loaded_params.add(name) |
| 311 | + |
| 312 | + return loaded_params |
| 313 | + |
281 | 314 |
|
282 | 315 | class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only, SupportsQuant): |
283 | 316 |
|
@@ -325,35 +358,15 @@ def compute_logits( |
325 | 358 |
|
326 | 359 | def load_weights(self, weights: Iterable[tuple[str, |
327 | 360 | torch.Tensor]]) -> set[str]: |
328 | | - params_dict = dict(self.named_parameters(remove_duplicate=False)) |
329 | | - loaded_params: set[str] = set() |
330 | | - for name, loaded_weight in weights: |
331 | | - if name == "lm_head.weight": |
332 | | - continue |
333 | | - if not name.startswith("transformer."): |
334 | | - name = "transformer." + name |
335 | | - if is_pp_missing_parameter(name, self): |
336 | | - continue |
337 | | - param = params_dict[name] |
338 | | - |
339 | | - if "query_key_value" in name: |
340 | | - # NOTE: BLOOM's fused QKV's output_dim has the shape of |
341 | | - # (num_heads * 3 * head_size), while the |
342 | | - # required shape is (3 * num_heads * head_size). |
343 | | - # Thus, we need weight conversion. |
344 | | - output_dim = getattr(param, "output_dim", None) |
345 | | - num_heads = self.config.num_attention_heads |
346 | | - if output_dim is not None: |
347 | | - loaded_weight_shape = loaded_weight.shape |
348 | | - loaded_weight = loaded_weight.view( |
349 | | - loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + |
350 | | - loaded_weight_shape[output_dim + 1:]) |
351 | | - loaded_weight = loaded_weight.transpose( |
352 | | - output_dim, output_dim + 1) |
353 | | - loaded_weight = loaded_weight.reshape(loaded_weight_shape) |
354 | | - |
355 | | - weight_loader = getattr(param, "weight_loader", |
356 | | - default_weight_loader) |
357 | | - weight_loader(param, loaded_weight) |
358 | | - loaded_params.add(name) |
359 | | - return loaded_params |
| 361 | + loader = AutoWeightsLoader(self, skip_prefixes=["lm_head.weight"]) |
| 362 | + weights = _add_transformer_prefix(weights) |
| 363 | + return loader.load_weights(weights) |
| 364 | + |
| 365 | + |
| 366 | +def _add_transformer_prefix( |
| 367 | + weights: Iterable[tuple[str, torch.Tensor]] |
| 368 | +) -> Iterable[tuple[str, torch.Tensor]]: |
| 369 | + for name, tensor in weights: |
| 370 | + if not name.startswith('transformer.'): |
| 371 | + name = 'transformer.' + name |
| 372 | + yield name, tensor |
0 commit comments