5454from vllm .model_executor .sampling_metadata import SamplingMetadata
5555from vllm .sequence import IntermediateTensors
5656
57- from .interfaces import SupportsPP
58- from .utils import (PPMissingLayer , is_pp_missing_parameter ,
57+ from .interfaces import SupportsLoRA , SupportsPP
58+ from .utils import (AutoWeightsLoader , PPMissingLayer , is_pp_missing_parameter ,
5959 make_empty_intermediate_tensors_factory , make_layers ,
6060 maybe_prefix )
6161
@@ -327,6 +327,7 @@ def forward(
327327 return hidden_states , residual
328328
329329
330+ @support_torch_compile
330331class Dots1Model (nn .Module ):
331332
332333 fall_back_to_pt_during_load = False
@@ -404,68 +405,12 @@ def forward(
404405 hidden_states , _ = self .norm (hidden_states , residual )
405406 return hidden_states
406407
407-
408- @support_torch_compile
409- class Dots1ForCausalLM (nn .Module , SupportsPP ):
410-
411- def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
412- super ().__init__ ()
413- config = vllm_config .model_config .hf_config
414- quant_config = vllm_config .quant_config
415- self .config = config
416- self .quant_config = quant_config
417- self .model = Dots1Model (vllm_config = vllm_config ,
418- prefix = maybe_prefix (prefix , "model" ))
419- if get_pp_group ().is_last_rank :
420- self .lm_head = ParallelLMHead (config .vocab_size ,
421- config .hidden_size ,
422- quant_config = quant_config )
423- else :
424- self .lm_head = PPMissingLayer ()
425- self .logits_processor = LogitsProcessor (config .vocab_size )
426- self .make_empty_intermediate_tensors = (
427- self .model .make_empty_intermediate_tensors )
428-
429- def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
430- return self .model .get_input_embeddings (input_ids )
431-
432- def forward (
433- self ,
434- input_ids : torch .Tensor ,
435- positions : torch .Tensor ,
436- intermediate_tensors : Optional [IntermediateTensors ] = None ,
437- inputs_embeds : Optional [torch .Tensor ] = None ,
438- ) -> Union [torch .Tensor , IntermediateTensors ]:
439- hidden_states = self .model (
440- input_ids ,
441- positions ,
442- intermediate_tensors ,
443- inputs_embeds ,
444- )
445- return hidden_states
446-
447- def compute_logits (
448- self ,
449- hidden_states : torch .Tensor ,
450- sampling_metadata : SamplingMetadata ,
451- ) -> Optional [torch .Tensor ]:
452- logits = self .logits_processor (self .lm_head , hidden_states ,
453- sampling_metadata )
454- return logits
455-
456- def make_empty_intermediate_tensors (
457- self , batch_size : int , dtype : torch .dtype ,
458- device : torch .device ) -> IntermediateTensors :
459- return IntermediateTensors ({
460- "hidden_states" :
461- torch .zeros ((batch_size , self .config .hidden_size ),
462- dtype = dtype ,
463- device = device ),
464- "residual" :
465- torch .zeros ((batch_size , self .config .hidden_size ),
466- dtype = dtype ,
467- device = device ),
468- })
408+ def get_expert_mapping (self ) -> list [tuple [str , str , int , str ]]:
409+ return FusedMoE .make_expert_params_mapping (
410+ ckpt_gate_proj_name = "gate_proj" ,
411+ ckpt_down_proj_name = "down_proj" ,
412+ ckpt_up_proj_name = "up_proj" ,
413+ num_experts = self .config .n_routed_experts )
469414
470415 def load_weights (self , weights : Iterable [tuple [str ,
471416 torch .Tensor ]]) -> set [str ]:
@@ -477,14 +422,9 @@ def load_weights(self, weights: Iterable[tuple[str,
477422 ("gate_up_proj" , "up_proj" , 1 ),
478423 ]
479424
480- expert_params_mapping = FusedMoE .make_expert_params_mapping (
481- ckpt_gate_proj_name = "gate_proj" ,
482- ckpt_down_proj_name = "down_proj" ,
483- ckpt_up_proj_name = "up_proj" ,
484- num_experts = self .config .n_routed_experts )
485-
486425 params_dict = dict (self .named_parameters ())
487426 loaded_params : set [str ] = set ()
427+ expert_params_mapping = self .get_expert_mapping ()
488428 for name , loaded_weight in weights :
489429 if "rotary_emb.inv_freq" in name :
490430 continue
@@ -534,3 +474,71 @@ def load_weights(self, weights: Iterable[tuple[str,
534474 weight_loader (param , loaded_weight )
535475 loaded_params .add (name )
536476 return loaded_params
477+
478+
479+ class Dots1ForCausalLM (nn .Module , SupportsPP , SupportsLoRA ):
480+
481+ packed_modules_mapping = {
482+ "qkv_proj" : [
483+ "q_proj" ,
484+ "k_proj" ,
485+ "v_proj" ,
486+ ],
487+ "gate_up_proj" : [
488+ "gate_proj" ,
489+ "up_proj" ,
490+ ],
491+ }
492+
493+ def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
494+ super ().__init__ ()
495+ config = vllm_config .model_config .hf_config
496+ quant_config = vllm_config .quant_config
497+ self .config = config
498+ self .quant_config = quant_config
499+ self .model = Dots1Model (vllm_config = vllm_config ,
500+ prefix = maybe_prefix (prefix , "model" ))
501+ if get_pp_group ().is_last_rank :
502+ self .lm_head = ParallelLMHead (config .vocab_size ,
503+ config .hidden_size ,
504+ quant_config = quant_config )
505+ else :
506+ self .lm_head = PPMissingLayer ()
507+ self .logits_processor = LogitsProcessor (config .vocab_size )
508+ self .make_empty_intermediate_tensors = (
509+ self .model .make_empty_intermediate_tensors )
510+
511+ def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
512+ return self .model .get_input_embeddings (input_ids )
513+
514+ def forward (
515+ self ,
516+ input_ids : torch .Tensor ,
517+ positions : torch .Tensor ,
518+ intermediate_tensors : Optional [IntermediateTensors ] = None ,
519+ inputs_embeds : Optional [torch .Tensor ] = None ,
520+ ) -> Union [torch .Tensor , IntermediateTensors ]:
521+ hidden_states = self .model (
522+ input_ids ,
523+ positions ,
524+ intermediate_tensors ,
525+ inputs_embeds ,
526+ )
527+ return hidden_states
528+
529+ def compute_logits (
530+ self ,
531+ hidden_states : torch .Tensor ,
532+ sampling_metadata : SamplingMetadata ,
533+ ) -> Optional [torch .Tensor ]:
534+ logits = self .logits_processor (self .lm_head , hidden_states ,
535+ sampling_metadata )
536+ return logits
537+
538+ def load_weights (self , weights : Iterable [tuple [str ,
539+ torch .Tensor ]]) -> set [str ]:
540+ loader = AutoWeightsLoader (self )
541+ return loader .load_weights (weights )
542+
543+ def get_expert_mapping (self ) -> list [tuple [str , str , int , str ]]:
544+ return self .model .get_expert_mapping ()
0 commit comments