5353from vllm .sequence import IntermediateTensors
5454
5555from .interfaces import SupportsPP
56- from .utils import (AutoWeightsLoader , PPMissingLayer , is_pp_missing_parameter ,
56+ from .utils import (PPMissingLayer , is_pp_missing_parameter ,
5757 make_empty_intermediate_tensors_factory , make_layers ,
5858 maybe_prefix )
5959
@@ -668,6 +668,73 @@ def forward(
668668 hidden_states , _ = self .norm (hidden_states , residual )
669669 return hidden_states
670670
671+
672+ class DeepseekV2ForCausalLM (nn .Module , SupportsPP ):
673+
674+ def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
675+ super ().__init__ ()
676+ config = vllm_config .model_config .hf_config
677+ quant_config = vllm_config .quant_config
678+ self .config = config
679+ self .quant_config = quant_config
680+ self .model = DeepseekV2Model (vllm_config = vllm_config ,
681+ prefix = maybe_prefix (prefix , "model" ))
682+ if get_pp_group ().is_last_rank :
683+ self .lm_head = ParallelLMHead (config .vocab_size ,
684+ config .hidden_size ,
685+ quant_config = quant_config )
686+ else :
687+ self .lm_head = PPMissingLayer ()
688+ self .logits_processor = LogitsProcessor (config .vocab_size )
689+ self .sampler = get_sampler ()
690+ self .make_empty_intermediate_tensors = (
691+ self .model .make_empty_intermediate_tensors )
692+
693+ def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
694+ return self .model .get_input_embeddings (input_ids )
695+
696+ def forward (
697+ self ,
698+ input_ids : torch .Tensor ,
699+ positions : torch .Tensor ,
700+ intermediate_tensors : Optional [IntermediateTensors ] = None ,
701+ inputs_embeds : Optional [torch .Tensor ] = None ,
702+ ) -> Union [torch .Tensor , IntermediateTensors ]:
703+ hidden_states = self .model (input_ids , positions , intermediate_tensors ,
704+ inputs_embeds )
705+ return hidden_states
706+
707+ def compute_logits (
708+ self ,
709+ hidden_states : torch .Tensor ,
710+ sampling_metadata : SamplingMetadata ,
711+ ) -> Optional [torch .Tensor ]:
712+ logits = self .logits_processor (self .lm_head , hidden_states ,
713+ sampling_metadata )
714+ return logits
715+
716+ def sample (
717+ self ,
718+ logits : Optional [torch .Tensor ],
719+ sampling_metadata : SamplingMetadata ,
720+ ) -> Optional [SamplerOutput ]:
721+ next_tokens = self .sampler (logits , sampling_metadata )
722+ return next_tokens
723+
724+ def make_empty_intermediate_tensors (
725+ self , batch_size : int , dtype : torch .dtype ,
726+ device : torch .device ) -> IntermediateTensors :
727+ return IntermediateTensors ({
728+ "hidden_states" :
729+ torch .zeros ((batch_size , self .config .hidden_size ),
730+ dtype = dtype ,
731+ device = device ),
732+ "residual" :
733+ torch .zeros ((batch_size , self .config .hidden_size ),
734+ dtype = dtype ,
735+ device = device ),
736+ })
737+
671738 def load_weights (self , weights : Iterable [Tuple [str ,
672739 torch .Tensor ]]) -> Set [str ]:
673740 stacked_params_mapping = [
@@ -687,6 +754,9 @@ def load_weights(self, weights: Iterable[Tuple[str,
687754 params_dict = dict (self .named_parameters ())
688755 loaded_params : Set [str ] = set ()
689756 for name , loaded_weight in weights :
757+ if "rotary_emb.inv_freq" in name :
758+ continue
759+
690760 spec_layer = get_spec_layer_idx_from_weight_name (self .config , name )
691761 if spec_layer is not None :
692762 continue # skip spec decode layers for main model
@@ -754,78 +824,6 @@ def load_weights(self, weights: Iterable[Tuple[str,
754824 return loaded_params
755825
756826
757- class DeepseekV2ForCausalLM (nn .Module , SupportsPP ):
758-
759- def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
760- super ().__init__ ()
761- config = vllm_config .model_config .hf_config
762- quant_config = vllm_config .quant_config
763- self .config = config
764- self .quant_config = quant_config
765- self .model = DeepseekV2Model (vllm_config = vllm_config ,
766- prefix = maybe_prefix (prefix , "model" ))
767- if get_pp_group ().is_last_rank :
768- self .lm_head = ParallelLMHead (config .vocab_size ,
769- config .hidden_size ,
770- quant_config = quant_config )
771- else :
772- self .lm_head = PPMissingLayer ()
773- self .logits_processor = LogitsProcessor (config .vocab_size )
774- self .sampler = get_sampler ()
775- self .make_empty_intermediate_tensors = (
776- self .model .make_empty_intermediate_tensors )
777-
778- def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
779- return self .model .get_input_embeddings (input_ids )
780-
781- def forward (
782- self ,
783- input_ids : torch .Tensor ,
784- positions : torch .Tensor ,
785- intermediate_tensors : Optional [IntermediateTensors ] = None ,
786- inputs_embeds : Optional [torch .Tensor ] = None ,
787- ) -> Union [torch .Tensor , IntermediateTensors ]:
788- hidden_states = self .model (input_ids , positions , intermediate_tensors ,
789- inputs_embeds )
790- return hidden_states
791-
792- def compute_logits (
793- self ,
794- hidden_states : torch .Tensor ,
795- sampling_metadata : SamplingMetadata ,
796- ) -> Optional [torch .Tensor ]:
797- logits = self .logits_processor (self .lm_head , hidden_states ,
798- sampling_metadata )
799- return logits
800-
801- def sample (
802- self ,
803- logits : Optional [torch .Tensor ],
804- sampling_metadata : SamplingMetadata ,
805- ) -> Optional [SamplerOutput ]:
806- next_tokens = self .sampler (logits , sampling_metadata )
807- return next_tokens
808-
809- def make_empty_intermediate_tensors (
810- self , batch_size : int , dtype : torch .dtype ,
811- device : torch .device ) -> IntermediateTensors :
812- return IntermediateTensors ({
813- "hidden_states" :
814- torch .zeros ((batch_size , self .config .hidden_size ),
815- dtype = dtype ,
816- device = device ),
817- "residual" :
818- torch .zeros ((batch_size , self .config .hidden_size ),
819- dtype = dtype ,
820- device = device ),
821- })
822-
823- def load_weights (self , weights : Iterable [Tuple [str ,
824- torch .Tensor ]]) -> Set [str ]:
825- loader = AutoWeightsLoader (self , skip_prefixes = ["rotary_emb.inv_freq" ])
826- return loader .load_weights (weights )
827-
828-
829827class DeepseekV3ForCausalLM (DeepseekV2ForCausalLM ):
830828 pass
831829
0 commit comments