|
48 | 48 | from vllm.transformers_utils.configs import NemotronConfig |
49 | 49 |
|
50 | 50 | from .interfaces import SupportsLoRA, SupportsPP |
51 | | -from .utils import (PPMissingLayer, is_pp_missing_parameter, |
| 51 | +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, |
52 | 52 | make_empty_intermediate_tensors_factory, make_layers, |
53 | 53 | maybe_prefix) |
54 | 54 |
|
@@ -300,6 +300,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
300 | 300 | lora_config = vllm_config.lora_config |
301 | 301 |
|
302 | 302 | self.config = config |
| 303 | + self.quant_config = quant_config |
303 | 304 | lora_vocab = (lora_config.lora_extra_vocab_size * |
304 | 305 | (lora_config.max_loras or 1)) if lora_config else 0 |
305 | 306 | self.vocab_size = config.vocab_size + lora_vocab |
@@ -362,6 +363,63 @@ def forward( |
362 | 363 | hidden_states, _ = self.norm(hidden_states, residual) |
363 | 364 | return hidden_states |
364 | 365 |
|
| 366 | + def load_weights(self, weights: Iterable[tuple[str, |
| 367 | + torch.Tensor]]) -> set[str]: |
| 368 | + stacked_params_mapping = [ |
| 369 | + # (param_name, shard_name, shard_id) |
| 370 | + (".qkv_proj", ".q_proj", "q"), |
| 371 | + (".qkv_proj", ".k_proj", "k"), |
| 372 | + (".qkv_proj", ".v_proj", "v"), |
| 373 | + ] |
| 374 | + params_dict = dict(self.named_parameters()) |
| 375 | + loaded_params: set[str] = set() |
| 376 | + for name, loaded_weight in weights: |
| 377 | + if (self.quant_config is not None and |
| 378 | + (scale_name := self.quant_config.get_cache_scale(name))): |
| 379 | + # Loading kv cache quantization scales |
| 380 | + param = params_dict[scale_name] |
| 381 | + weight_loader = getattr(param, "weight_loader", |
| 382 | + default_weight_loader) |
| 383 | + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else |
| 384 | + loaded_weight[0]) |
| 385 | + weight_loader(param, loaded_weight) |
| 386 | + loaded_params.add(scale_name) |
| 387 | + continue |
| 388 | + for (param_name, weight_name, shard_id) in stacked_params_mapping: |
| 389 | + if weight_name not in name: |
| 390 | + continue |
| 391 | + name = name.replace(weight_name, param_name) |
| 392 | + # Skip loading extra bias for GPTQ models. |
| 393 | + if name.endswith(".bias") and name not in params_dict: |
| 394 | + continue |
| 395 | + |
| 396 | + if is_pp_missing_parameter(name, self): |
| 397 | + continue |
| 398 | + |
| 399 | + param = params_dict[name] |
| 400 | + weight_loader = param.weight_loader |
| 401 | + weight_loader(param, loaded_weight, shard_id) |
| 402 | + |
| 403 | + break |
| 404 | + else: |
| 405 | + # Skip loading extra bias for GPTQ models. |
| 406 | + if name.endswith(".bias") and name not in params_dict: |
| 407 | + continue |
| 408 | + # Remapping the name of FP8 kv-scale. |
| 409 | + name = maybe_remap_kv_scale_name(name, params_dict) |
| 410 | + if name is None: |
| 411 | + continue |
| 412 | + |
| 413 | + if is_pp_missing_parameter(name, self): |
| 414 | + continue |
| 415 | + |
| 416 | + param = params_dict[name] |
| 417 | + weight_loader = getattr(param, "weight_loader", |
| 418 | + default_weight_loader) |
| 419 | + weight_loader(param, loaded_weight) |
| 420 | + loaded_params.add(name) |
| 421 | + return loaded_params |
| 422 | + |
365 | 423 |
|
366 | 424 | class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): |
367 | 425 | packed_modules_mapping = { |
@@ -444,64 +502,14 @@ def compute_logits( |
444 | 502 |
|
445 | 503 | def load_weights(self, weights: Iterable[tuple[str, |
446 | 504 | torch.Tensor]]) -> set[str]: |
447 | | - stacked_params_mapping = [ |
448 | | - # (param_name, shard_name, shard_id) |
449 | | - (".qkv_proj", ".q_proj", "q"), |
450 | | - (".qkv_proj", ".k_proj", "k"), |
451 | | - (".qkv_proj", ".v_proj", "v"), |
452 | | - ] |
453 | | - params_dict = dict(self.named_parameters()) |
454 | | - loaded_params: set[str] = set() |
455 | | - for name, loaded_weight in weights: |
456 | | - if "rotary_emb.inv_freq" in name: |
457 | | - continue |
458 | | - if ("rotary_emb.cos_cached" in name |
459 | | - or "rotary_emb.sin_cached" in name): |
| 505 | + loader = AutoWeightsLoader( |
| 506 | + self, |
| 507 | + skip_prefixes=([ |
| 508 | + "rotary_emb.inv_freq", |
460 | 509 | # Models trained using ColossalAI may include these tensors in |
461 | 510 | # the checkpoint. Skip them. |
462 | | - continue |
463 | | - if (self.quant_config is not None and |
464 | | - (scale_name := self.quant_config.get_cache_scale(name))): |
465 | | - # Loading kv cache quantization scales |
466 | | - param = params_dict[scale_name] |
467 | | - weight_loader = getattr(param, "weight_loader", |
468 | | - default_weight_loader) |
469 | | - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else |
470 | | - loaded_weight[0]) |
471 | | - weight_loader(param, loaded_weight) |
472 | | - loaded_params.add(scale_name) |
473 | | - continue |
474 | | - for (param_name, weight_name, shard_id) in stacked_params_mapping: |
475 | | - if weight_name not in name: |
476 | | - continue |
477 | | - name = name.replace(weight_name, param_name) |
478 | | - # Skip loading extra bias for GPTQ models. |
479 | | - if name.endswith(".bias") and name not in params_dict: |
480 | | - continue |
481 | | - |
482 | | - if is_pp_missing_parameter(name, self): |
483 | | - continue |
484 | | - |
485 | | - param = params_dict[name] |
486 | | - weight_loader = param.weight_loader |
487 | | - weight_loader(param, loaded_weight, shard_id) |
488 | | - |
489 | | - break |
490 | | - else: |
491 | | - # Skip loading extra bias for GPTQ models. |
492 | | - if name.endswith(".bias") and name not in params_dict: |
493 | | - continue |
494 | | - # Remapping the name of FP8 kv-scale. |
495 | | - name = maybe_remap_kv_scale_name(name, params_dict) |
496 | | - if name is None: |
497 | | - continue |
498 | | - |
499 | | - if is_pp_missing_parameter(name, self): |
500 | | - continue |
501 | | - |
502 | | - param = params_dict[name] |
503 | | - weight_loader = getattr(param, "weight_loader", |
504 | | - default_weight_loader) |
505 | | - weight_loader(param, loaded_weight) |
506 | | - loaded_params.add(name) |
507 | | - return loaded_params |
| 511 | + "rotary_emb.cos_cached", |
| 512 | + "rotary_emb.sin_cached" |
| 513 | + ]), |
| 514 | + ) |
| 515 | + return loader.load_weights(weights) |
0 commit comments