@@ -52,11 +52,6 @@ def __init__(
5252        quant_config : Optional [QuantizationConfig ] =  None ,
5353    ) ->  None :
5454        super ().__init__ ()
55-         self .embed_tokens  =  VocabParallelEmbedding (
56-             config .vocab_size ,
57-             config .hidden_size ,
58-         )
59- 
6055        self .enorm  =  RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
6156        self .hnorm  =  RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
6257        self .eh_proj  =  nn .Linear (config .hidden_size  *  2 ,
@@ -74,8 +69,6 @@ def forward(
7469        inputs_embeds : Optional [torch .Tensor ] =  None ,
7570        spec_step_index : int  =  0 ,
7671    ) ->  torch .Tensor :
77-         if  inputs_embeds  is  None :
78-             inputs_embeds  =  self .embed_tokens (input_ids )
7972        assert  inputs_embeds  is  not   None 
8073        # masking inputs at position 0, as not needed by MTP 
8174        inputs_embeds [positions  ==  0 ] =  0 
@@ -112,7 +105,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
112105            for  idx  in  range (self .mtp_start_layer_idx ,
113106                             self .mtp_start_layer_idx  +  self .num_mtp_layers )
114107        })
115- 
108+         self .embed_tokens  =  VocabParallelEmbedding (
109+             config .vocab_size ,
110+             config .hidden_size ,
111+         )
116112        self .logits_processor  =  LogitsProcessor (config .vocab_size )
117113
118114    def  forward (
@@ -123,6 +119,8 @@ def forward(
123119        inputs_embeds : Optional [torch .Tensor ] =  None ,
124120        spec_step_idx : int  =  0 ,
125121    ) ->  torch .Tensor :
122+         if  inputs_embeds  is  None :
123+             inputs_embeds  =  self .embed_tokens (input_ids )
126124        current_step_idx  =  (spec_step_idx  %  self .num_mtp_layers )
127125        return  self .layers [str (self .mtp_start_layer_idx  +  current_step_idx )](
128126            input_ids ,
@@ -242,6 +240,12 @@ def load_weights(self, weights: Iterable[tuple[str,
242240                    if  name .endswith (".bias" ) and  name  not  in   params_dict :
243241                        continue 
244242
243+                     # According to DeepSeek-V3 Technical Report, MTP modules 
244+                     # shares embedding layer. We only load the first weights. 
245+                     if  (spec_layer  !=  self .model .mtp_start_layer_idx 
246+                             and  ".layers"  not  in   name ):
247+                         continue 
248+ 
245249                    param  =  params_dict [name ]
246250                    weight_loader  =  getattr (param , "weight_loader" ,
247251                                            default_weight_loader )
@@ -253,17 +257,25 @@ def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
253257        """ 
254258        Rewrite the weight name to match the format of the original model. 
255259        Add .mtp_block for modules in transformer layer block for spec layer 
260+         and rename shared layer weights to be top level. 
256261        """ 
257262        spec_layer_weight_names  =  [
258263            "embed_tokens" , "enorm" , "hnorm" , "eh_proj" , "shared_head" 
259264        ]
265+         shared_weight_names  =  ["embed_tokens" ]
260266        spec_layer_weight  =  False 
267+         shared_weight  =  False 
261268        for  weight_name  in  spec_layer_weight_names :
262269            if  weight_name  in  name :
263270                spec_layer_weight  =  True 
271+                 if  weight_name  in  shared_weight_names :
272+                     shared_weight  =  True 
264273                break 
265274        if  not  spec_layer_weight :
266275            # treat rest weights as weights for transformer layer block 
267276            name  =  name .replace (f"model.layers.{ spec_layer }  ." ,
268277                                f"model.layers.{ spec_layer }  .mtp_block." )
278+         elif  shared_weight :
279+             # treat shared weights as top level weights 
280+             name  =  name .replace (f"model.layers.{ spec_layer }  ." , "model." )
269281        return  name 
0 commit comments