@@ -98,6 +98,41 @@ def forward_oot(self, x: Union[torch.Tensor, Tuple[torch.Tensor,
9898 return super ().forward_oot (x )
9999
100100
101+ class CustomDeepseekV2MergedReplicatedLinear (ReplicatedLinear ):
102+
103+ def __init__ (
104+ self ,
105+ input_size : int ,
106+ output_sizes : list [int ],
107+ bias : bool = True ,
108+ quant_config : Optional [QuantizationConfig ] = None ,
109+ prefix : str = "" ,
110+ ):
111+ self .output_sizes = output_sizes
112+ super ().__init__ (input_size ,
113+ sum (output_sizes ),
114+ bias = bias ,
115+ quant_config = quant_config ,
116+ prefix = prefix )
117+
118+ def weight_loader (self , param : torch .nn .Parameter ,
119+ loaded_weight : torch .Tensor , loaded_shard_id : int ):
120+ # With no support for GGUF format yet.
121+ assert not getattr (param , "is_gguf_weight" , False )
122+ assert not getattr (param , "is_gguf_weight_type" , False )
123+
124+ assert loaded_shard_id < len (self .output_sizes )
125+ shard_offset = sum (self .output_sizes [:loaded_shard_id ])
126+ shard_size = self .output_sizes [loaded_shard_id ]
127+ shard = param .data .narrow (param .output_dim , shard_offset , shard_size )
128+
129+ assert shard .size () == loaded_weight .size (), (
130+ f"Tried to load weights of size { loaded_weight .size ()} "
131+ f"to a parameter shard of id { loaded_shard_id } size { shard .size ()} "
132+ )
133+ shard .copy_ (loaded_weight )
134+
135+
101136class CustomDeepseekV2MLP (nn .Module ):
102137
103138 def __init__ (
@@ -107,20 +142,33 @@ def __init__(
107142 hidden_act : str ,
108143 quant_config : Optional [QuantizationConfig ] = None ,
109144 reduce_results : bool = True ,
145+ force_replicate : bool = False ,
110146 prefix : str = "" ,
111147 ) -> None :
112148 super ().__init__ ()
113- self .gate_up_proj = MergedColumnParallelLinear (
114- hidden_size , [intermediate_size ] * 2 ,
115- bias = False ,
116- quant_config = quant_config ,
117- prefix = f"{ prefix } .gate_up_proj" )
118- self .down_proj = RowParallelLinear (intermediate_size ,
119- hidden_size ,
120- bias = False ,
121- quant_config = quant_config ,
122- reduce_results = reduce_results ,
123- prefix = f"{ prefix } .down_proj" )
149+ if not force_replicate :
150+ self .gate_up_proj = MergedColumnParallelLinear (
151+ hidden_size , [intermediate_size ] * 2 ,
152+ bias = False ,
153+ quant_config = quant_config ,
154+ prefix = f"{ prefix } .gate_up_proj" )
155+ self .down_proj = RowParallelLinear (intermediate_size ,
156+ hidden_size ,
157+ bias = False ,
158+ quant_config = quant_config ,
159+ reduce_results = reduce_results ,
160+ prefix = f"{ prefix } .down_proj" )
161+ else :
162+ self .gate_up_proj = CustomDeepseekV2MergedReplicatedLinear (
163+ hidden_size , [intermediate_size ] * 2 ,
164+ bias = False ,
165+ quant_config = quant_config ,
166+ prefix = f"{ prefix } .gate_up_proj" )
167+ self .down_proj = ReplicatedLinear (intermediate_size ,
168+ hidden_size ,
169+ bias = False ,
170+ quant_config = quant_config ,
171+ prefix = f"{ prefix } .down_proj" )
124172 if hidden_act != "silu" :
125173 raise ValueError (f"Unsupported activation: { hidden_act } . "
126174 "Only silu is supported for now." )
@@ -181,6 +229,12 @@ def __init__(
181229 raise ValueError (f"Unsupported activation: { config .hidden_act } . "
182230 "Only silu is supported for now." )
183231
232+ ascend_config = get_ascend_config ()
233+ self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
234+ # NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on
235+ self .enable_multistream_moe = \
236+ ascend_config .torchair_graph_config .enable_multistream_moe and VLLM_ENABLE_MC2
237+
184238 self .gate = ReplicatedLinear (config .hidden_size ,
185239 config .n_routed_experts ,
186240 bias = False ,
@@ -216,6 +270,7 @@ def __init__(
216270 hidden_act = config .hidden_act ,
217271 quant_config = quant_config ,
218272 reduce_results = True ,
273+ force_replicate = self .enable_multistream_moe ,
219274 prefix = f"{ prefix } .shared_experts" ,
220275 )
221276 else :
@@ -230,12 +285,6 @@ def __init__(
230285
231286 self .params_dtype = torch .get_default_dtype ()
232287
233- ascend_config = get_ascend_config ()
234- self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
235- # NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on
236- self .enable_multistream_moe = \
237- ascend_config .torchair_graph_config .enable_multistream_moe and VLLM_ENABLE_MC2
238-
239288 def forward (
240289 self ,
241290 hidden_states : torch .Tensor ,
@@ -274,27 +323,22 @@ def forward(
274323 # router_logits: (num_tokens, n_experts)
275324 router_logits , _ = self .gate (hidden_states )
276325
277- kwargs = {}
278- if not use_separated_shared_experts :
279- kwargs .update ({
280- "shared_experts" : self .shared_experts ,
281- "shared_experts_input" : old_hidden_states
282- })
283-
284326 experts_hidden_states = self .experts (
285327 hidden_states = hidden_states ,
286328 router_logits = router_logits ,
287329 is_prefill = is_prefill ,
288330 top_k = CustomDeepseekV2MoE .top_k ,
289331 enable_force_load_balance = enable_force_load_balance ,
290- ** kwargs )
332+ shared_experts = (self .shared_experts
333+ if not use_separated_shared_experts else None ),
334+ )
291335
292336 if not isinstance (experts_hidden_states , tuple ):
293337 hidden_states = experts_hidden_states * self .routed_scaling_factor
294338 else :
295- hidden_states = experts_hidden_states [
296- 0 ] * self .routed_scaling_factor
297- shared_hidden_states = experts_hidden_states [1 ]
339+ hidden_states = (
340+ experts_hidden_states [ 0 ] * self .routed_scaling_factor +
341+ experts_hidden_states [1 ])
298342
299343 if self .tp_size > 1 :
300344 if (VLLM_ENABLE_MC2
@@ -309,10 +353,8 @@ def forward(
309353 hidden_states = tensor_model_parallel_all_reduce (hidden_states )
310354
311355 if use_separated_shared_experts :
312- shared_hidden_states = self .shared_experts (old_hidden_states )
313-
314- if self .shared_experts is not None :
315- hidden_states = hidden_states + shared_hidden_states
356+ hidden_states = hidden_states + self .shared_experts (
357+ old_hidden_states )
316358
317359 return hidden_states .view (num_tokens , hidden_size )
318360
0 commit comments