@@ -82,20 +82,35 @@ def __init__(
8282 hidden_act : str ,
8383 quant_config : Optional [QuantizationConfig ] = None ,
8484 reduce_results : bool = True ,
85+ force_replicate : bool = False ,
8586 prefix : str = "" ,
8687 ) -> None :
8788 super ().__init__ ()
88- self .gate_up_proj = MergedColumnParallelLinear (
89- hidden_size , [intermediate_size ] * 2 ,
90- bias = False ,
91- quant_config = quant_config ,
92- prefix = f"{ prefix } .gate_up_proj" )
93- self .down_proj = RowParallelLinear (intermediate_size ,
94- hidden_size ,
95- bias = False ,
96- quant_config = quant_config ,
97- reduce_results = reduce_results ,
98- prefix = f"{ prefix } .down_proj" )
89+ if not force_replicate :
90+ self .gate_up_proj = MergedColumnParallelLinear (
91+ hidden_size , [intermediate_size ] * 2 ,
92+ bias = False ,
93+ quant_config = quant_config ,
94+ prefix = f"{ prefix } .gate_up_proj" )
95+ self .down_proj = RowParallelLinear (intermediate_size ,
96+ hidden_size ,
97+ bias = False ,
98+ quant_config = quant_config ,
99+ reduce_results = reduce_results ,
100+ prefix = f"{ prefix } .down_proj" )
101+ else :
102+ self .gate_up_proj = ReplicatedLinear (
103+ hidden_size ,
104+ intermediate_size * 2 ,
105+ bias = False ,
106+ quant_config = quant_config ,
107+ prefix = f"{ prefix } .gate_up_proj" )
108+ self .down_proj = ReplicatedLinear (intermediate_size ,
109+ hidden_size ,
110+ bias = False ,
111+ quant_config = quant_config ,
112+ prefix = f"{ prefix } .down_proj" )
113+
99114 if hidden_act != "silu" :
100115 raise ValueError (f"Unsupported activation: { hidden_act } . "
101116 "Only silu is supported for now." )
@@ -202,8 +217,12 @@ def __init__(
202217 hidden_act = config .hidden_act ,
203218 quant_config = quant_config ,
204219 reduce_results = True ,
220+ force_replicate = envs_ascend .
221+ VLLM_ENABLE_MULTISTREAM_SHARED_EXPERT ,
205222 prefix = f"{ prefix } .shared_experts" ,
206223 )
224+ else :
225+ self .shared_experts = None # type: ignore
207226 CustomDeepseekV2MoE .top_k = config .num_experts_per_tok
208227
209228 self .dp_size = get_dp_group ().world_size
@@ -224,8 +243,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
224243 is_prefill = attn_metadata .num_prefills > 0
225244 enable_force_load_balance = False
226245 num_tokens , hidden_dim = hidden_states .shape
246+ use_separated_shared_expert = (
247+ self .n_shared_experts is not None
248+ and (is_prefill
249+ or not envs_ascend .VLLM_ENABLE_MULTISTREAM_SHARED_EXPERT ))
227250
228- if self . n_shared_experts is not None :
251+ if use_separated_shared_expert :
229252 shared_output = self .shared_experts (hidden_states )
230253
231254 if self .tp_size > 1 :
@@ -248,13 +271,22 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
248271 # router_logits: (num_tokens, n_experts)
249272 router_logits , _ = self .gate (local_hidden_states )
250273
251- router_hidden_states = self .experts (
274+ experts_hidden_states = self .experts (
252275 hidden_states = local_hidden_states ,
253276 router_logits = router_logits ,
254277 is_prefill = is_prefill ,
255278 top_k = CustomDeepseekV2MoE .top_k ,
256279 enable_force_load_balance = enable_force_load_balance ,
257- ) * self .routed_scaling_factor
280+ shared_experts = (self .shared_experts
281+ if not use_separated_shared_expert else None ),
282+ )
283+
284+ if not isinstance (experts_hidden_states , tuple ):
285+ router_hidden_states = experts_hidden_states * self .routed_scaling_factor
286+ else :
287+ router_hidden_states = (
288+ experts_hidden_states [0 ] * self .routed_scaling_factor +
289+ experts_hidden_states [1 ])
258290
259291 if self .tp_size > 1 :
260292 dist .all_gather (list (chunk_hidden_states ), router_hidden_states ,
@@ -265,7 +297,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
265297 else :
266298 final_hidden_states = router_hidden_states
267299
268- if shared_output is not None :
300+ if use_separated_shared_expert :
269301 final_hidden_states = final_hidden_states + shared_output
270302
271303 return final_hidden_states .view (num_tokens , hidden_dim )
0 commit comments