|
25 | 25 | from vllm.model_executor.layers.vocab_parallel_embedding import ( |
26 | 26 | DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) |
27 | 27 | from vllm.model_executor.model_loader.weight_utils import default_weight_loader |
| 28 | +from vllm.model_executor.models.llama import LlamaMLP as JambaMLP |
28 | 29 | from vllm.model_executor.models.mamba_cache import (MambaCacheManager, |
29 | 30 | MambaCacheParams) |
30 | 31 | from vllm.model_executor.sampling_metadata import SamplingMetadata |
|
33 | 34 |
|
34 | 35 | from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, |
35 | 36 | SupportsV0Only) |
36 | | -from .utils import (is_pp_missing_parameter, |
| 37 | +from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, |
37 | 38 | make_empty_intermediate_tensors_factory, make_layers, |
38 | 39 | maybe_prefix) |
39 | 40 |
|
@@ -87,23 +88,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
87 | 88 | return hidden_states.view(orig_shape) |
88 | 89 |
|
89 | 90 |
|
90 | | -class JambaMLP(JambaMoE): |
91 | | - |
92 | | - def __init__(self, |
93 | | - config: JambaConfig, |
94 | | - params_dtype: Optional[torch.dtype] = None, |
95 | | - tp_size: Optional[int] = None, |
96 | | - quant_config: Optional[QuantizationConfig] = None, |
97 | | - prefix: str = ""): |
98 | | - super().__init__(config, |
99 | | - num_experts=1, |
100 | | - top_k=1, |
101 | | - params_dtype=params_dtype, |
102 | | - tp_size=tp_size, |
103 | | - quant_config=quant_config, |
104 | | - prefix=prefix) |
105 | | - |
106 | | - |
107 | 91 | class JambaMambaDecoderLayer(nn.Module): |
108 | 92 |
|
109 | 93 | def __init__(self, |
@@ -132,10 +116,20 @@ def __init__(self, |
132 | 116 | ) |
133 | 117 |
|
134 | 118 | num_experts = config.layers_num_experts[layer_idx] |
135 | | - ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP |
136 | | - self.feed_forward = ffn_layer_class(config, |
137 | | - quant_config=quant_config, |
138 | | - prefix=f"{prefix}.feed_forward") |
| 119 | + if num_experts > 1: |
| 120 | + self.feed_forward = JambaMoE( |
| 121 | + config, |
| 122 | + quant_config=quant_config, |
| 123 | + prefix=f"{prefix}.feed_forward", |
| 124 | + ) |
| 125 | + else: |
| 126 | + self.feed_forward = JambaMLP( |
| 127 | + config.hidden_size, |
| 128 | + config.intermediate_size, |
| 129 | + config.hidden_act, |
| 130 | + quant_config=quant_config, |
| 131 | + prefix=f"{prefix}.feed_forward", |
| 132 | + ) |
139 | 133 | self.input_layernorm = RMSNorm(config.hidden_size, |
140 | 134 | eps=config.rms_norm_eps) |
141 | 135 | self.pre_ff_layernorm = RMSNorm(config.hidden_size, |
@@ -216,10 +210,20 @@ def __init__(self, |
216 | 210 | ) |
217 | 211 |
|
218 | 212 | num_experts = config.layers_num_experts[layer_idx] |
219 | | - ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP |
220 | | - self.feed_forward = ffn_layer_class(config, |
221 | | - quant_config=quant_config, |
222 | | - prefix=f"{prefix}.feed_forward") |
| 213 | + if num_experts > 1: |
| 214 | + self.feed_forward = JambaMoE( |
| 215 | + config, |
| 216 | + quant_config=quant_config, |
| 217 | + prefix=f"{prefix}.feed_forward", |
| 218 | + ) |
| 219 | + else: |
| 220 | + self.feed_forward = JambaMLP( |
| 221 | + config.hidden_size, |
| 222 | + config.intermediate_size, |
| 223 | + config.hidden_act, |
| 224 | + quant_config=quant_config, |
| 225 | + prefix=f"{prefix}.feed_forward", |
| 226 | + ) |
223 | 227 | self.input_layernorm = RMSNorm(config.hidden_size, |
224 | 228 | eps=config.rms_norm_eps) |
225 | 229 | self.pre_ff_layernorm = RMSNorm(config.hidden_size, |
@@ -359,15 +363,97 @@ def forward( |
359 | 363 | hidden_states, _ = self.final_layernorm(hidden_states, residual) |
360 | 364 | return hidden_states |
361 | 365 |
|
| 366 | + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: |
| 367 | + # Params for weights, fp8 weight scales, fp8 activation scales |
| 368 | + # (param_name, weight_name, expert_id, shard_id) |
| 369 | + return FusedMoE.make_expert_params_mapping( |
| 370 | + ckpt_gate_proj_name="gate_proj", |
| 371 | + ckpt_down_proj_name="down_proj", |
| 372 | + ckpt_up_proj_name="up_proj", |
| 373 | + num_experts=self.config.num_experts) |
| 374 | + |
| 375 | + def load_weights(self, weights: Iterable[tuple[str, |
| 376 | + torch.Tensor]]) -> set[str]: |
| 377 | + stacked_params_mapping = [ |
| 378 | + # (param_name, shard_name, shard_id) |
| 379 | + ("qkv_proj", "q_proj", "q"), |
| 380 | + ("qkv_proj", "k_proj", "k"), |
| 381 | + ("qkv_proj", "v_proj", "v"), |
| 382 | + (".gate_up_proj", ".gate_proj", 0), |
| 383 | + (".gate_up_proj", ".up_proj", 1), |
| 384 | + ] |
| 385 | + |
| 386 | + params_dict = dict(self.named_parameters()) |
| 387 | + loaded_params: set[str] = set() |
| 388 | + expert_params_mapping = self.get_expert_mapping() |
| 389 | + for name, loaded_weight in weights: |
| 390 | + if "rotary_emb.inv_freq" in name: |
| 391 | + continue |
| 392 | + for param_name, weight_name, shard_id in stacked_params_mapping: |
| 393 | + if weight_name not in name: |
| 394 | + continue |
| 395 | + if 'experts' in name: |
| 396 | + continue |
| 397 | + name = name.replace(weight_name, param_name) |
| 398 | + # Skip loading extra bias for GPTQ models. |
| 399 | + if name.endswith(".bias") and name not in params_dict: |
| 400 | + continue |
| 401 | + # Skip layers on other devices. |
| 402 | + if is_pp_missing_parameter(name, self): |
| 403 | + continue |
| 404 | + param = params_dict[name] |
| 405 | + weight_loader = param.weight_loader |
| 406 | + weight_loader(param, loaded_weight, shard_id) |
| 407 | + break |
| 408 | + else: |
| 409 | + for ( |
| 410 | + param_name, |
| 411 | + weight_name, |
| 412 | + expert_id, |
| 413 | + shard_id, |
| 414 | + ) in expert_params_mapping: |
| 415 | + if weight_name not in name: |
| 416 | + continue |
| 417 | + |
| 418 | + if is_pp_missing_parameter(name, self): |
| 419 | + continue |
| 420 | + name = name.replace(weight_name, param_name) |
| 421 | + param = params_dict[name] |
| 422 | + weight_loader = param.weight_loader |
| 423 | + weight_loader(param, |
| 424 | + loaded_weight, |
| 425 | + name, |
| 426 | + shard_id=shard_id, |
| 427 | + expert_id=expert_id) |
| 428 | + break |
| 429 | + else: |
| 430 | + # Skip loading extra bias for GPTQ models. |
| 431 | + if name.endswith(".bias") and name not in params_dict: |
| 432 | + continue |
| 433 | + if is_pp_missing_parameter(name, self): |
| 434 | + continue |
| 435 | + |
| 436 | + param = params_dict[name] |
| 437 | + weight_loader = getattr(param, "weight_loader", |
| 438 | + default_weight_loader) |
| 439 | + weight_loader(param, loaded_weight) |
| 440 | + loaded_params.add(name) |
| 441 | + return loaded_params |
| 442 | + |
362 | 443 |
|
363 | 444 | class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, |
364 | 445 | IsHybrid, SupportsV0Only): |
| 446 | + hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={ |
| 447 | + ".self_attn.": ".", |
| 448 | + ".A_log": ".A" |
| 449 | + }, ) |
365 | 450 | packed_modules_mapping = { |
366 | 451 | "qkv_proj": [ |
367 | 452 | "q_proj", |
368 | 453 | "k_proj", |
369 | 454 | "v_proj", |
370 | 455 | ], |
| 456 | + "gate_up_proj": ["gate_proj", "up_proj"], |
371 | 457 | "in_proj": ["in_proj"], |
372 | 458 | } |
373 | 459 |
|
@@ -468,96 +554,11 @@ def compute_logits( |
468 | 554 |
|
469 | 555 | def load_weights(self, weights: Iterable[tuple[str, |
470 | 556 | torch.Tensor]]) -> set[str]: |
471 | | - stacked_params_mapping = [ |
472 | | - # (param_name, shard_name, shard_id) |
473 | | - ("qkv_proj", "q_proj", "q"), |
474 | | - ("qkv_proj", "k_proj", "k"), |
475 | | - ("qkv_proj", "v_proj", "v"), |
476 | | - ] |
477 | | - |
478 | | - # Params for weights, fp8 weight scales, fp8 activation scales |
479 | | - # (param_name, weight_name, expert_id, shard_id) |
480 | | - expert_params_mapping = FusedMoE.make_expert_params_mapping( |
481 | | - ckpt_gate_proj_name="gate_proj", |
482 | | - ckpt_down_proj_name="down_proj", |
483 | | - ckpt_up_proj_name="up_proj", |
484 | | - num_experts=self.config.num_experts) |
485 | | - |
486 | | - params_dict = dict(self.named_parameters()) |
487 | | - loaded_params: set[str] = set() |
488 | | - for name, loaded_weight in weights: |
489 | | - if "rotary_emb.inv_freq" in name: |
490 | | - continue |
491 | | - |
492 | | - if "A_log" in name: |
493 | | - name = name.replace("A_log", "A") |
494 | | - |
495 | | - if ".self_attn." in name: |
496 | | - name = name.replace(".self_attn", "") |
497 | | - |
498 | | - if "feed_forward" in name and not _is_moe_layer(name): |
499 | | - ## map MLP layers to expert with ID=0 |
500 | | - name = name.replace("feed_forward", "feed_forward.experts.0") |
501 | | - |
502 | | - for param_name, weight_name, shard_id in stacked_params_mapping: |
503 | | - if weight_name not in name: |
504 | | - continue |
505 | | - if 'experts' in name: |
506 | | - continue |
507 | | - name = name.replace(weight_name, param_name) |
508 | | - # Skip loading extra bias for GPTQ models. |
509 | | - |
510 | | - if name.endswith(".bias") and name not in params_dict: |
511 | | - continue |
512 | | - # Skip layers on other devices. |
513 | | - if is_pp_missing_parameter(name, self): |
514 | | - continue |
515 | | - param = params_dict[name] |
516 | | - weight_loader = param.weight_loader |
517 | | - weight_loader(param, loaded_weight, shard_id) |
518 | | - break |
519 | | - else: |
520 | | - for ( |
521 | | - param_name, |
522 | | - weight_name, |
523 | | - expert_id, |
524 | | - shard_id, |
525 | | - ) in expert_params_mapping: |
526 | | - if weight_name not in name: |
527 | | - continue |
528 | | - |
529 | | - if is_pp_missing_parameter(name, self): |
530 | | - continue |
531 | | - name = name.replace(weight_name, param_name) |
532 | | - param = params_dict[name] |
533 | | - weight_loader = param.weight_loader |
534 | | - weight_loader(param, |
535 | | - loaded_weight, |
536 | | - name, |
537 | | - shard_id=shard_id, |
538 | | - expert_id=expert_id) |
539 | | - break |
540 | | - else: |
541 | | - # Skip loading extra bias for GPTQ models. |
542 | | - if name.endswith(".bias") and name not in params_dict: |
543 | | - continue |
544 | | - if is_pp_missing_parameter(name, self): |
545 | | - continue |
546 | | - |
547 | | - param = params_dict[name] |
548 | | - weight_loader = getattr(param, "weight_loader", |
549 | | - default_weight_loader) |
550 | | - weight_loader(param, loaded_weight) |
551 | | - loaded_params.add(name) |
552 | | - return loaded_params |
553 | | - |
| 557 | + loader = AutoWeightsLoader(self) |
| 558 | + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) |
554 | 559 |
|
555 | | -def _is_moe_layer(name: str): |
556 | | - return any( |
557 | | - [experts_name in name for experts_name in [ |
558 | | - "experts", |
559 | | - "router", |
560 | | - ]]) |
| 560 | + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: |
| 561 | + return self.model.get_expert_mapping() |
561 | 562 |
|
562 | 563 |
|
563 | 564 | class JambaForSequenceClassification(JambaForCausalLM): |
|
0 commit comments