Skip to content

Commit 7ffd56e

Browse files
jeejeeleeepwalsh
authored andcommitted
[Quantization] Enable BNB support for more MoE models (vllm-project#21370)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
1 parent 2173f40 commit 7ffd56e

File tree

2 files changed

+93
-80
lines changed

2 files changed

+93
-80
lines changed

vllm/model_executor/models/dots1.py

Lines changed: 78 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@
5454
from vllm.model_executor.sampling_metadata import SamplingMetadata
5555
from vllm.sequence import IntermediateTensors
5656

57-
from .interfaces import SupportsPP
58-
from .utils import (PPMissingLayer, is_pp_missing_parameter,
57+
from .interfaces import SupportsLoRA, SupportsPP
58+
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
5959
make_empty_intermediate_tensors_factory, make_layers,
6060
maybe_prefix)
6161

@@ -327,6 +327,7 @@ def forward(
327327
return hidden_states, residual
328328

329329

330+
@support_torch_compile
330331
class Dots1Model(nn.Module):
331332

332333
fall_back_to_pt_during_load = False
@@ -404,68 +405,12 @@ def forward(
404405
hidden_states, _ = self.norm(hidden_states, residual)
405406
return hidden_states
406407

407-
408-
@support_torch_compile
409-
class Dots1ForCausalLM(nn.Module, SupportsPP):
410-
411-
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
412-
super().__init__()
413-
config = vllm_config.model_config.hf_config
414-
quant_config = vllm_config.quant_config
415-
self.config = config
416-
self.quant_config = quant_config
417-
self.model = Dots1Model(vllm_config=vllm_config,
418-
prefix=maybe_prefix(prefix, "model"))
419-
if get_pp_group().is_last_rank:
420-
self.lm_head = ParallelLMHead(config.vocab_size,
421-
config.hidden_size,
422-
quant_config=quant_config)
423-
else:
424-
self.lm_head = PPMissingLayer()
425-
self.logits_processor = LogitsProcessor(config.vocab_size)
426-
self.make_empty_intermediate_tensors = (
427-
self.model.make_empty_intermediate_tensors)
428-
429-
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
430-
return self.model.get_input_embeddings(input_ids)
431-
432-
def forward(
433-
self,
434-
input_ids: torch.Tensor,
435-
positions: torch.Tensor,
436-
intermediate_tensors: Optional[IntermediateTensors] = None,
437-
inputs_embeds: Optional[torch.Tensor] = None,
438-
) -> Union[torch.Tensor, IntermediateTensors]:
439-
hidden_states = self.model(
440-
input_ids,
441-
positions,
442-
intermediate_tensors,
443-
inputs_embeds,
444-
)
445-
return hidden_states
446-
447-
def compute_logits(
448-
self,
449-
hidden_states: torch.Tensor,
450-
sampling_metadata: SamplingMetadata,
451-
) -> Optional[torch.Tensor]:
452-
logits = self.logits_processor(self.lm_head, hidden_states,
453-
sampling_metadata)
454-
return logits
455-
456-
def make_empty_intermediate_tensors(
457-
self, batch_size: int, dtype: torch.dtype,
458-
device: torch.device) -> IntermediateTensors:
459-
return IntermediateTensors({
460-
"hidden_states":
461-
torch.zeros((batch_size, self.config.hidden_size),
462-
dtype=dtype,
463-
device=device),
464-
"residual":
465-
torch.zeros((batch_size, self.config.hidden_size),
466-
dtype=dtype,
467-
device=device),
468-
})
408+
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
409+
return FusedMoE.make_expert_params_mapping(
410+
ckpt_gate_proj_name="gate_proj",
411+
ckpt_down_proj_name="down_proj",
412+
ckpt_up_proj_name="up_proj",
413+
num_experts=self.config.n_routed_experts)
469414

470415
def load_weights(self, weights: Iterable[tuple[str,
471416
torch.Tensor]]) -> set[str]:
@@ -477,14 +422,9 @@ def load_weights(self, weights: Iterable[tuple[str,
477422
("gate_up_proj", "up_proj", 1),
478423
]
479424

480-
expert_params_mapping = FusedMoE.make_expert_params_mapping(
481-
ckpt_gate_proj_name="gate_proj",
482-
ckpt_down_proj_name="down_proj",
483-
ckpt_up_proj_name="up_proj",
484-
num_experts=self.config.n_routed_experts)
485-
486425
params_dict = dict(self.named_parameters())
487426
loaded_params: set[str] = set()
427+
expert_params_mapping = self.get_expert_mapping()
488428
for name, loaded_weight in weights:
489429
if "rotary_emb.inv_freq" in name:
490430
continue
@@ -534,3 +474,71 @@ def load_weights(self, weights: Iterable[tuple[str,
534474
weight_loader(param, loaded_weight)
535475
loaded_params.add(name)
536476
return loaded_params
477+
478+
479+
class Dots1ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
480+
481+
packed_modules_mapping = {
482+
"qkv_proj": [
483+
"q_proj",
484+
"k_proj",
485+
"v_proj",
486+
],
487+
"gate_up_proj": [
488+
"gate_proj",
489+
"up_proj",
490+
],
491+
}
492+
493+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
494+
super().__init__()
495+
config = vllm_config.model_config.hf_config
496+
quant_config = vllm_config.quant_config
497+
self.config = config
498+
self.quant_config = quant_config
499+
self.model = Dots1Model(vllm_config=vllm_config,
500+
prefix=maybe_prefix(prefix, "model"))
501+
if get_pp_group().is_last_rank:
502+
self.lm_head = ParallelLMHead(config.vocab_size,
503+
config.hidden_size,
504+
quant_config=quant_config)
505+
else:
506+
self.lm_head = PPMissingLayer()
507+
self.logits_processor = LogitsProcessor(config.vocab_size)
508+
self.make_empty_intermediate_tensors = (
509+
self.model.make_empty_intermediate_tensors)
510+
511+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
512+
return self.model.get_input_embeddings(input_ids)
513+
514+
def forward(
515+
self,
516+
input_ids: torch.Tensor,
517+
positions: torch.Tensor,
518+
intermediate_tensors: Optional[IntermediateTensors] = None,
519+
inputs_embeds: Optional[torch.Tensor] = None,
520+
) -> Union[torch.Tensor, IntermediateTensors]:
521+
hidden_states = self.model(
522+
input_ids,
523+
positions,
524+
intermediate_tensors,
525+
inputs_embeds,
526+
)
527+
return hidden_states
528+
529+
def compute_logits(
530+
self,
531+
hidden_states: torch.Tensor,
532+
sampling_metadata: SamplingMetadata,
533+
) -> Optional[torch.Tensor]:
534+
logits = self.logits_processor(self.lm_head, hidden_states,
535+
sampling_metadata)
536+
return logits
537+
538+
def load_weights(self, weights: Iterable[tuple[str,
539+
torch.Tensor]]) -> set[str]:
540+
loader = AutoWeightsLoader(self)
541+
return loader.load_weights(weights)
542+
543+
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
544+
return self.model.get_expert_mapping()

vllm/model_executor/models/glm4_moe.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
from vllm.model_executor.sampling_metadata import SamplingMetadata
5454
from vllm.sequence import IntermediateTensors
5555

56-
from .interfaces import SupportsPP
56+
from .interfaces import SupportsLoRA, SupportsPP
5757
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
5858
make_empty_intermediate_tensors_factory, make_layers,
5959
maybe_prefix)
@@ -461,6 +461,15 @@ def make_empty_intermediate_tensors(
461461
device=device),
462462
})
463463

464+
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
465+
# Params for weights, fp8 weight scales, fp8 activation scales
466+
# (param_name, weight_name, expert_id, shard_id)
467+
return FusedMoE.make_expert_params_mapping(
468+
ckpt_gate_proj_name="gate_proj",
469+
ckpt_down_proj_name="down_proj",
470+
ckpt_up_proj_name="up_proj",
471+
num_experts=self.config.n_routed_experts)
472+
464473
def load_weights(self, weights: Iterable[tuple[str,
465474
torch.Tensor]]) -> set[str]:
466475
stacked_params_mapping = [
@@ -472,16 +481,9 @@ def load_weights(self, weights: Iterable[tuple[str,
472481
("gate_up_proj", "up_proj", 1),
473482
]
474483

475-
# Params for weights, fp8 weight scales, fp8 activation scales
476-
# (param_name, weight_name, expert_id, shard_id)
477-
expert_params_mapping = FusedMoE.make_expert_params_mapping(
478-
ckpt_gate_proj_name="gate_proj",
479-
ckpt_down_proj_name="down_proj",
480-
ckpt_up_proj_name="up_proj",
481-
num_experts=self.config.n_routed_experts)
482-
483484
params_dict = dict(self.named_parameters())
484485
loaded_params: set[str] = set()
486+
expert_params_mapping = self.get_expert_mapping()
485487
for name, loaded_weight in weights:
486488
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
487489
if spec_layer is not None:
@@ -570,7 +572,7 @@ def load_weights(self, weights: Iterable[tuple[str,
570572
return loaded_params
571573

572574

573-
class Glm4MoeForCausalLM(nn.Module, SupportsPP):
575+
class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
574576
packed_modules_mapping = {
575577
"qkv_proj": [
576578
"q_proj",
@@ -677,6 +679,9 @@ def load_weights(self, weights: Iterable[tuple[str,
677679
loader = AutoWeightsLoader(self)
678680
return loader.load_weights(weights)
679681

682+
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
683+
return self.model.get_expert_mapping()
684+
680685

681686
def get_spec_layer_idx_from_weight_name(config: PretrainedConfig,
682687
weight_name: str) -> Optional[int]:

0 commit comments

Comments
 (0)