Skip to content

Commit 61a6905

Browse files
authored
[Model] Refactor JambaForCausalLM (#21394)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
1 parent 37efc63 commit 61a6905

File tree

1 file changed

+116
-115
lines changed

1 file changed

+116
-115
lines changed

vllm/model_executor/models/jamba.py

Lines changed: 116 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from vllm.model_executor.layers.vocab_parallel_embedding import (
2626
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
2727
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
28+
from vllm.model_executor.models.llama import LlamaMLP as JambaMLP
2829
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
2930
MambaCacheParams)
3031
from vllm.model_executor.sampling_metadata import SamplingMetadata
@@ -33,7 +34,7 @@
3334

3435
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
3536
SupportsV0Only)
36-
from .utils import (is_pp_missing_parameter,
37+
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
3738
make_empty_intermediate_tensors_factory, make_layers,
3839
maybe_prefix)
3940

@@ -87,23 +88,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
8788
return hidden_states.view(orig_shape)
8889

8990

90-
class JambaMLP(JambaMoE):
91-
92-
def __init__(self,
93-
config: JambaConfig,
94-
params_dtype: Optional[torch.dtype] = None,
95-
tp_size: Optional[int] = None,
96-
quant_config: Optional[QuantizationConfig] = None,
97-
prefix: str = ""):
98-
super().__init__(config,
99-
num_experts=1,
100-
top_k=1,
101-
params_dtype=params_dtype,
102-
tp_size=tp_size,
103-
quant_config=quant_config,
104-
prefix=prefix)
105-
106-
10791
class JambaMambaDecoderLayer(nn.Module):
10892

10993
def __init__(self,
@@ -132,10 +116,20 @@ def __init__(self,
132116
)
133117

134118
num_experts = config.layers_num_experts[layer_idx]
135-
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
136-
self.feed_forward = ffn_layer_class(config,
137-
quant_config=quant_config,
138-
prefix=f"{prefix}.feed_forward")
119+
if num_experts > 1:
120+
self.feed_forward = JambaMoE(
121+
config,
122+
quant_config=quant_config,
123+
prefix=f"{prefix}.feed_forward",
124+
)
125+
else:
126+
self.feed_forward = JambaMLP(
127+
config.hidden_size,
128+
config.intermediate_size,
129+
config.hidden_act,
130+
quant_config=quant_config,
131+
prefix=f"{prefix}.feed_forward",
132+
)
139133
self.input_layernorm = RMSNorm(config.hidden_size,
140134
eps=config.rms_norm_eps)
141135
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
@@ -216,10 +210,20 @@ def __init__(self,
216210
)
217211

218212
num_experts = config.layers_num_experts[layer_idx]
219-
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
220-
self.feed_forward = ffn_layer_class(config,
221-
quant_config=quant_config,
222-
prefix=f"{prefix}.feed_forward")
213+
if num_experts > 1:
214+
self.feed_forward = JambaMoE(
215+
config,
216+
quant_config=quant_config,
217+
prefix=f"{prefix}.feed_forward",
218+
)
219+
else:
220+
self.feed_forward = JambaMLP(
221+
config.hidden_size,
222+
config.intermediate_size,
223+
config.hidden_act,
224+
quant_config=quant_config,
225+
prefix=f"{prefix}.feed_forward",
226+
)
223227
self.input_layernorm = RMSNorm(config.hidden_size,
224228
eps=config.rms_norm_eps)
225229
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
@@ -359,15 +363,97 @@ def forward(
359363
hidden_states, _ = self.final_layernorm(hidden_states, residual)
360364
return hidden_states
361365

366+
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
367+
# Params for weights, fp8 weight scales, fp8 activation scales
368+
# (param_name, weight_name, expert_id, shard_id)
369+
return FusedMoE.make_expert_params_mapping(
370+
ckpt_gate_proj_name="gate_proj",
371+
ckpt_down_proj_name="down_proj",
372+
ckpt_up_proj_name="up_proj",
373+
num_experts=self.config.num_experts)
374+
375+
def load_weights(self, weights: Iterable[tuple[str,
376+
torch.Tensor]]) -> set[str]:
377+
stacked_params_mapping = [
378+
# (param_name, shard_name, shard_id)
379+
("qkv_proj", "q_proj", "q"),
380+
("qkv_proj", "k_proj", "k"),
381+
("qkv_proj", "v_proj", "v"),
382+
(".gate_up_proj", ".gate_proj", 0),
383+
(".gate_up_proj", ".up_proj", 1),
384+
]
385+
386+
params_dict = dict(self.named_parameters())
387+
loaded_params: set[str] = set()
388+
expert_params_mapping = self.get_expert_mapping()
389+
for name, loaded_weight in weights:
390+
if "rotary_emb.inv_freq" in name:
391+
continue
392+
for param_name, weight_name, shard_id in stacked_params_mapping:
393+
if weight_name not in name:
394+
continue
395+
if 'experts' in name:
396+
continue
397+
name = name.replace(weight_name, param_name)
398+
# Skip loading extra bias for GPTQ models.
399+
if name.endswith(".bias") and name not in params_dict:
400+
continue
401+
# Skip layers on other devices.
402+
if is_pp_missing_parameter(name, self):
403+
continue
404+
param = params_dict[name]
405+
weight_loader = param.weight_loader
406+
weight_loader(param, loaded_weight, shard_id)
407+
break
408+
else:
409+
for (
410+
param_name,
411+
weight_name,
412+
expert_id,
413+
shard_id,
414+
) in expert_params_mapping:
415+
if weight_name not in name:
416+
continue
417+
418+
if is_pp_missing_parameter(name, self):
419+
continue
420+
name = name.replace(weight_name, param_name)
421+
param = params_dict[name]
422+
weight_loader = param.weight_loader
423+
weight_loader(param,
424+
loaded_weight,
425+
name,
426+
shard_id=shard_id,
427+
expert_id=expert_id)
428+
break
429+
else:
430+
# Skip loading extra bias for GPTQ models.
431+
if name.endswith(".bias") and name not in params_dict:
432+
continue
433+
if is_pp_missing_parameter(name, self):
434+
continue
435+
436+
param = params_dict[name]
437+
weight_loader = getattr(param, "weight_loader",
438+
default_weight_loader)
439+
weight_loader(param, loaded_weight)
440+
loaded_params.add(name)
441+
return loaded_params
442+
362443

363444
class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
364445
IsHybrid, SupportsV0Only):
446+
hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={
447+
".self_attn.": ".",
448+
".A_log": ".A"
449+
}, )
365450
packed_modules_mapping = {
366451
"qkv_proj": [
367452
"q_proj",
368453
"k_proj",
369454
"v_proj",
370455
],
456+
"gate_up_proj": ["gate_proj", "up_proj"],
371457
"in_proj": ["in_proj"],
372458
}
373459

@@ -468,96 +554,11 @@ def compute_logits(
468554

469555
def load_weights(self, weights: Iterable[tuple[str,
470556
torch.Tensor]]) -> set[str]:
471-
stacked_params_mapping = [
472-
# (param_name, shard_name, shard_id)
473-
("qkv_proj", "q_proj", "q"),
474-
("qkv_proj", "k_proj", "k"),
475-
("qkv_proj", "v_proj", "v"),
476-
]
477-
478-
# Params for weights, fp8 weight scales, fp8 activation scales
479-
# (param_name, weight_name, expert_id, shard_id)
480-
expert_params_mapping = FusedMoE.make_expert_params_mapping(
481-
ckpt_gate_proj_name="gate_proj",
482-
ckpt_down_proj_name="down_proj",
483-
ckpt_up_proj_name="up_proj",
484-
num_experts=self.config.num_experts)
485-
486-
params_dict = dict(self.named_parameters())
487-
loaded_params: set[str] = set()
488-
for name, loaded_weight in weights:
489-
if "rotary_emb.inv_freq" in name:
490-
continue
491-
492-
if "A_log" in name:
493-
name = name.replace("A_log", "A")
494-
495-
if ".self_attn." in name:
496-
name = name.replace(".self_attn", "")
497-
498-
if "feed_forward" in name and not _is_moe_layer(name):
499-
## map MLP layers to expert with ID=0
500-
name = name.replace("feed_forward", "feed_forward.experts.0")
501-
502-
for param_name, weight_name, shard_id in stacked_params_mapping:
503-
if weight_name not in name:
504-
continue
505-
if 'experts' in name:
506-
continue
507-
name = name.replace(weight_name, param_name)
508-
# Skip loading extra bias for GPTQ models.
509-
510-
if name.endswith(".bias") and name not in params_dict:
511-
continue
512-
# Skip layers on other devices.
513-
if is_pp_missing_parameter(name, self):
514-
continue
515-
param = params_dict[name]
516-
weight_loader = param.weight_loader
517-
weight_loader(param, loaded_weight, shard_id)
518-
break
519-
else:
520-
for (
521-
param_name,
522-
weight_name,
523-
expert_id,
524-
shard_id,
525-
) in expert_params_mapping:
526-
if weight_name not in name:
527-
continue
528-
529-
if is_pp_missing_parameter(name, self):
530-
continue
531-
name = name.replace(weight_name, param_name)
532-
param = params_dict[name]
533-
weight_loader = param.weight_loader
534-
weight_loader(param,
535-
loaded_weight,
536-
name,
537-
shard_id=shard_id,
538-
expert_id=expert_id)
539-
break
540-
else:
541-
# Skip loading extra bias for GPTQ models.
542-
if name.endswith(".bias") and name not in params_dict:
543-
continue
544-
if is_pp_missing_parameter(name, self):
545-
continue
546-
547-
param = params_dict[name]
548-
weight_loader = getattr(param, "weight_loader",
549-
default_weight_loader)
550-
weight_loader(param, loaded_weight)
551-
loaded_params.add(name)
552-
return loaded_params
553-
557+
loader = AutoWeightsLoader(self)
558+
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
554559

555-
def _is_moe_layer(name: str):
556-
return any(
557-
[experts_name in name for experts_name in [
558-
"experts",
559-
"router",
560-
]])
560+
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
561+
return self.model.get_expert_mapping()
561562

562563

563564
class JambaForSequenceClassification(JambaForCausalLM):

0 commit comments

Comments
 (0)