@@ -99,6 +99,41 @@ def forward_oot(self, x: Union[torch.Tensor, Tuple[torch.Tensor,
9999            return  super ().forward_oot (x )
100100
101101
102+ class  CustomDeepseekV2MergedReplicatedLinear (ReplicatedLinear ):
103+ 
104+     def  __init__ (
105+         self ,
106+         input_size : int ,
107+         output_sizes : list [int ],
108+         bias : bool  =  True ,
109+         quant_config : Optional [QuantizationConfig ] =  None ,
110+         prefix : str  =  "" ,
111+     ):
112+         self .output_sizes  =  output_sizes 
113+         super ().__init__ (input_size ,
114+                          sum (output_sizes ),
115+                          bias = bias ,
116+                          quant_config = quant_config ,
117+                          prefix = prefix )
118+ 
119+     def  weight_loader (self , param : torch .nn .Parameter ,
120+                       loaded_weight : torch .Tensor , loaded_shard_id : int ):
121+         # With no support for GGUF format yet. 
122+         assert  not  getattr (param , "is_gguf_weight" , False )
123+         assert  not  getattr (param , "is_gguf_weight_type" , False )
124+ 
125+         assert  loaded_shard_id  <  len (self .output_sizes )
126+         shard_offset  =  sum (self .output_sizes [:loaded_shard_id ])
127+         shard_size  =  self .output_sizes [loaded_shard_id ]
128+         shard  =  param .data .narrow (param .output_dim , shard_offset , shard_size )
129+ 
130+         assert  shard .size () ==  loaded_weight .size (), (
131+             f"Tried to load weights of size { loaded_weight .size ()}  " 
132+             f"to a parameter shard of id { loaded_shard_id }   size { shard .size ()}  " 
133+         )
134+         shard .copy_ (loaded_weight )
135+ 
136+ 
102137class  CustomDeepseekV2MLP (nn .Module ):
103138
104139    def  __init__ (
@@ -108,20 +143,33 @@ def __init__(
108143        hidden_act : str ,
109144        quant_config : Optional [QuantizationConfig ] =  None ,
110145        reduce_results : bool  =  True ,
146+         force_replicate : bool  =  False ,
111147        prefix : str  =  "" ,
112148    ) ->  None :
113149        super ().__init__ ()
114-         self .gate_up_proj  =  MergedColumnParallelLinear (
115-             hidden_size , [intermediate_size ] *  2 ,
116-             bias = False ,
117-             quant_config = quant_config ,
118-             prefix = f"{ prefix }  .gate_up_proj" )
119-         self .down_proj  =  RowParallelLinear (intermediate_size ,
120-                                            hidden_size ,
121-                                            bias = False ,
122-                                            quant_config = quant_config ,
123-                                            reduce_results = reduce_results ,
124-                                            prefix = f"{ prefix }  .down_proj" )
150+         if  not  force_replicate :
151+             self .gate_up_proj  =  MergedColumnParallelLinear (
152+                 hidden_size , [intermediate_size ] *  2 ,
153+                 bias = False ,
154+                 quant_config = quant_config ,
155+                 prefix = f"{ prefix }  .gate_up_proj" )
156+             self .down_proj  =  RowParallelLinear (intermediate_size ,
157+                                                hidden_size ,
158+                                                bias = False ,
159+                                                quant_config = quant_config ,
160+                                                reduce_results = reduce_results ,
161+                                                prefix = f"{ prefix }  .down_proj" )
162+         else :
163+             self .gate_up_proj  =  CustomDeepseekV2MergedReplicatedLinear (
164+                 hidden_size , [intermediate_size ] *  2 ,
165+                 bias = False ,
166+                 quant_config = quant_config ,
167+                 prefix = f"{ prefix }  .gate_up_proj" )
168+             self .down_proj  =  ReplicatedLinear (intermediate_size ,
169+                                               hidden_size ,
170+                                               bias = False ,
171+                                               quant_config = quant_config ,
172+                                               prefix = f"{ prefix }  .down_proj" )
125173        if  hidden_act  !=  "silu" :
126174            raise  ValueError (f"Unsupported activation: { hidden_act }  . " 
127175                             "Only silu is supported for now." )
@@ -183,6 +231,12 @@ def __init__(
183231            raise  ValueError (f"Unsupported activation: { config .hidden_act }  . " 
184232                             "Only silu is supported for now." )
185233
234+         ascend_config  =  get_ascend_config ()
235+         self .torchair_graph_enabled  =  ascend_config .torchair_graph_config .enabled 
236+         # NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on 
237+         self .enable_multistream_moe  =  \
238+             ascend_config .torchair_graph_config .enable_multistream_moe  and  VLLM_ENABLE_MC2 
239+ 
186240        self .gate  =  ReplicatedLinear (config .hidden_size ,
187241                                     config .n_routed_experts ,
188242                                     bias = False ,
@@ -218,6 +272,7 @@ def __init__(
218272                hidden_act = config .hidden_act ,
219273                quant_config = quant_config ,
220274                reduce_results = True ,
275+                 force_replicate = self .enable_multistream_moe ,
221276                prefix = f"{ prefix }  .shared_experts" ,
222277            )
223278        else :
@@ -232,12 +287,6 @@ def __init__(
232287
233288        self .params_dtype  =  torch .get_default_dtype ()
234289
235-         ascend_config  =  get_ascend_config ()
236-         self .torchair_graph_enabled  =  ascend_config .torchair_graph_config .enabled 
237-         # NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on 
238-         self .enable_multistream_moe  =  \
239-             ascend_config .torchair_graph_config .enable_multistream_moe  and  VLLM_ENABLE_MC2 
240- 
241290    def  forward (
242291            self ,
243292            hidden_states : torch .Tensor ,
@@ -276,27 +325,22 @@ def forward(
276325        # router_logits: (num_tokens, n_experts) 
277326        router_logits , _  =  self .gate (hidden_states )
278327
279-         kwargs  =  {}
280-         if  not  use_separated_shared_experts :
281-             kwargs .update ({
282-                 "shared_experts" : self .shared_experts ,
283-                 "shared_experts_input" : old_hidden_states 
284-             })
285- 
286328        experts_hidden_states  =  self .experts (
287329            hidden_states = hidden_states ,
288330            router_logits = router_logits ,
289331            is_prefill = is_prefill ,
290332            top_k = CustomDeepseekV2MoE .top_k ,
291333            enable_force_load_balance = enable_force_load_balance ,
292-             ** kwargs )
334+             shared_experts = (self .shared_experts 
335+                             if  not  use_separated_shared_experts  else  None ),
336+         )
293337
294338        if  not  isinstance (experts_hidden_states , tuple ):
295339            hidden_states  =  experts_hidden_states  *  self .routed_scaling_factor 
296340        else :
297-             hidden_states  =  experts_hidden_states [ 
298-                 0 ] *  self .routed_scaling_factor 
299-             shared_hidden_states   =   experts_hidden_states [1 ]
341+             hidden_states  =  ( 
342+                 experts_hidden_states [ 0 ] *  self .routed_scaling_factor   + 
343+                  experts_hidden_states [1 ]) 
300344
301345        if  self .tp_size  >  1 :
302346            if  (VLLM_ENABLE_MC2 
@@ -311,10 +355,8 @@ def forward(
311355                hidden_states  =  tensor_model_parallel_all_reduce (hidden_states )
312356
313357        if  use_separated_shared_experts :
314-             shared_hidden_states  =  self .shared_experts (old_hidden_states )
315- 
316-         if  self .shared_experts  is  not   None :
317-             hidden_states  =  hidden_states  +  shared_hidden_states 
358+             hidden_states  =  hidden_states  +  self .shared_experts (
359+                 old_hidden_states )
318360
319361        return  hidden_states .view (num_tokens , hidden_size )
320362
0 commit comments