Skip to content

Commit 28af6b0

Browse files
committed
[Models] Apply SharedFusedMoE to all models with shared experts
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 502640c commit 28af6b0

23 files changed

+264
-341
lines changed

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,8 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
630630
layer.w13_input_scale = None
631631
layer.w2_input_scale = None
632632

633+
self.rocm_aiter_moe_enabled = False
634+
633635
def process_weights_after_loading(self, layer: Module) -> None:
634636
# Lazy import to avoid importing triton too early.
635637
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (

vllm/model_executor/layers/shared_fused_moe/shared_fused_moe.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@ class SharedFusedMoE(FusedMoE):
1818

1919
def __init__(
2020
self,
21-
shared_experts: torch.nn.Module,
21+
shared_experts: Optional[torch.nn.Module],
2222
use_overlapped: bool = True,
2323
**kwargs,
2424
):
2525
super().__init__(**kwargs)
2626
self._shared_experts = shared_experts
27-
self.use_overlapped = use_overlapped
27+
self.use_overlapped = use_overlapped and not (
28+
self.use_ep or self.use_flashinfer_cutlass_kernels
29+
) and self.shared_experts is not None
2830

2931
@property
3032
def shared_experts(self) -> Optional[torch.nn.Module]:
@@ -36,13 +38,16 @@ def forward(
3638
router_logits: torch.Tensor,
3739
) -> tuple[torch.Tensor, torch.Tensor]:
3840
if not self.use_overlapped:
39-
shared_out = self._shared_experts(hidden_states)
40-
41-
# Reduce outputs if necessary, since the MLP should
42-
# have been created with reduce_results=False.
43-
if (self.reduce_results and self.tp_size > 1
44-
and self.must_reduce_shared_expert_outputs()):
45-
shared_out = tensor_model_parallel_all_reduce(shared_out)
41+
if self._shared_experts is not None:
42+
shared_out = self._shared_experts(hidden_states)
43+
44+
# Reduce shared expert outputs if necessary, since the MLP
45+
# should have been created with reduce_results=False.
46+
if (self.reduce_results and self.tp_size > 1
47+
and self.must_reduce_shared_expert_outputs()):
48+
shared_out = tensor_model_parallel_all_reduce(shared_out)
49+
else:
50+
shared_out = None
4651

4752
fused_out = super().forward(
4853
hidden_states=hidden_states,

vllm/model_executor/models/aria.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from vllm.distributed import get_tensor_model_parallel_rank
1414
from vllm.model_executor.layers.activation import get_act_fn
1515
from vllm.model_executor.layers.fused_moe import FusedMoE
16+
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
1617
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
1718
RowParallelLinear)
1819
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -199,7 +200,7 @@ def forward(
199200
return out
200201

201202

202-
class AriaFusedMoE(FusedMoE):
203+
class AriaFusedMoE(SharedFusedMoE):
203204

204205
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
205206
shard_id: str) -> None:
@@ -253,22 +254,23 @@ def __init__(
253254
torch.empty(
254255
(self.config.moe_num_experts, self.config.hidden_size)))
255256

257+
self.shared_experts = LlamaMLP(
258+
config.hidden_size,
259+
config.intermediate_size * config.moe_num_shared_experts,
260+
"silu",
261+
quant_config=quant_config,
262+
bias=config.mlp_bias,
263+
)
264+
256265
self.experts = AriaFusedMoE(
266+
shared_experts=self.shared_experts,
257267
num_experts=config.moe_num_experts,
258268
top_k=config.moe_topk,
259269
hidden_size=config.hidden_size,
260270
intermediate_size=config.intermediate_size,
261271
quant_config=quant_config,
262-
reduce_results=True,
263272
prefix=f"{prefix}.experts",
264273
)
265-
self.shared_experts = LlamaMLP(
266-
config.hidden_size,
267-
config.intermediate_size * config.moe_num_shared_experts,
268-
"silu",
269-
quant_config=quant_config,
270-
bias=config.mlp_bias,
271-
)
272274

273275
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
274276
"""
@@ -285,12 +287,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
285287
router_output = torch.nn.functional.linear(hidden_states,
286288
self.router_weight)
287289

288-
hidden_states_copy = hidden_states.clone()
289-
# NOTE: hidden_states will be modified inplace by `FusedMoE`
290-
sparse_expert_output = self.experts(hidden_states, router_output)
291-
shared_expert_output = self.shared_experts(hidden_states_copy)
292-
293-
return sparse_expert_output + shared_expert_output
290+
# NOTE: hidden_states will be modified inplace by `SharedFusedMoE`
291+
return self.experts(hidden_states, router_output)
294292

295293

296294
class AriaTextDecoderLayer(LlamaDecoderLayer):

vllm/model_executor/models/bailing_moe.py

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@
3636
from vllm.compilation.decorators import support_torch_compile
3737
from vllm.config import CacheConfig, VllmConfig
3838
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
39-
get_tensor_model_parallel_world_size,
40-
tensor_model_parallel_all_reduce)
39+
get_tensor_model_parallel_world_size)
4140
from vllm.model_executor.layers.activation import SiluAndMul
4241
from vllm.model_executor.layers.fused_moe import FusedMoE
4342
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -47,6 +46,7 @@
4746
from vllm.model_executor.layers.logits_processor import LogitsProcessor
4847
from vllm.model_executor.layers.quantization import QuantizationConfig
4948
from vllm.model_executor.layers.rotary_embedding import get_rope
49+
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
5050
from vllm.model_executor.layers.vocab_parallel_embedding import (
5151
ParallelLMHead, VocabParallelEmbedding)
5252
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -266,22 +266,6 @@ def __init__(
266266
# default value for scoring_func
267267
self.score_function = "softmax"
268268

269-
self.experts = FusedMoE(
270-
num_experts=self.num_experts,
271-
top_k=self.top_k,
272-
hidden_size=self.hidden_size,
273-
intermediate_size=config.moe_intermediate_size,
274-
reduce_results=False,
275-
renormalize=self.norm_expert_prob,
276-
quant_config=quant_config,
277-
prefix=f"{prefix}.experts",
278-
scoring_func=self.score_function,
279-
e_score_correction_bias=self.gate.expert_bias,
280-
num_expert_group=self.n_group,
281-
topk_group=self.topk_group,
282-
use_grouped_topk=self.use_grouped_topk,
283-
)
284-
285269
if self.num_shared_experts > 0:
286270
if hasattr(config, "moe_shared_expert_intermediate_size"):
287271
intermediate_size = config.moe_shared_expert_intermediate_size
@@ -294,29 +278,52 @@ def __init__(
294278
quant_config=quant_config,
295279
reduce_results=False,
296280
prefix=f"{prefix}.shared_experts")
281+
282+
self.experts = SharedFusedMoE(
283+
shared_experts=self.shared_experts,
284+
fused_output_scaling_factor=self.routed_scaling_factor,
285+
shared_output_scaling_factor=1.0,
286+
num_experts=self.num_experts,
287+
top_k=self.top_k,
288+
hidden_size=self.hidden_size,
289+
intermediate_size=config.moe_intermediate_size,
290+
renormalize=self.norm_expert_prob,
291+
quant_config=quant_config,
292+
prefix=f"{prefix}.experts",
293+
scoring_func=self.score_function,
294+
e_score_correction_bias=self.gate.expert_bias,
295+
num_expert_group=self.n_group,
296+
topk_group=self.topk_group,
297+
use_grouped_topk=self.use_grouped_topk,
298+
)
297299
else:
300+
self.experts = FusedMoE(
301+
num_experts=self.num_experts,
302+
top_k=self.top_k,
303+
hidden_size=self.hidden_size,
304+
intermediate_size=config.moe_intermediate_size,
305+
renormalize=self.norm_expert_prob,
306+
quant_config=quant_config,
307+
prefix=f"{prefix}.experts",
308+
scoring_func=self.score_function,
309+
e_score_correction_bias=self.gate.expert_bias,
310+
num_expert_group=self.n_group,
311+
topk_group=self.topk_group,
312+
use_grouped_topk=self.use_grouped_topk,
313+
)
298314
self.shared_experts = None
299315

300316
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
301317
num_tokens, hidden_size = hidden_states.shape
302318
hidden_states = hidden_states.view(-1, hidden_size)
303-
if self.shared_experts:
304-
shared_output = self.shared_experts(hidden_states)
319+
305320
# router_logits: (num_tokens, n_experts)
306321
router_logits = self.gate(hidden_states.to(self.router_dtype))
307322
router_logits = router_logits.to(hidden_states.dtype)
308323

309324
final_hidden_states = self.experts(hidden_states=hidden_states,
310325
router_logits=router_logits)
311326

312-
final_hidden_states *= self.routed_scaling_factor
313-
314-
if self.shared_experts:
315-
final_hidden_states = final_hidden_states + shared_output
316-
317-
if self.tp_size > 1:
318-
final_hidden_states = tensor_model_parallel_all_reduce(
319-
final_hidden_states)
320327
return final_hidden_states.view(num_tokens, hidden_size)
321328

322329

vllm/model_executor/models/deepseek_v2.py

Lines changed: 41 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from vllm.attention.backends.abstract import AttentionBackend
3737
from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton
3838
from vllm.compilation.decorators import support_torch_compile
39-
from vllm.config import (CacheConfig, ParallelConfig, VllmConfig,
39+
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, VllmConfig,
4040
get_current_vllm_config)
4141
from vllm.distributed import (get_ep_group, get_pp_group,
4242
get_tensor_model_parallel_rank,
@@ -133,6 +133,7 @@ class DeepseekV2MoE(nn.Module):
133133
def __init__(
134134
self,
135135
config: Union[DeepseekV2Config, DeepseekV3Config],
136+
model_config: ModelConfig,
136137
parallel_config: ParallelConfig,
137138
quant_config: Optional[QuantizationConfig] = None,
138139
prefix: str = "",
@@ -182,27 +183,9 @@ def __init__(
182183
self.n_local_physical_experts)
183184

184185
if config.n_shared_experts is None:
185-
self.experts = FusedMoE(
186-
num_experts=config.n_routed_experts,
187-
top_k=config.num_experts_per_tok,
188-
hidden_size=config.hidden_size,
189-
intermediate_size=config.moe_intermediate_size,
190-
reduce_results=False,
191-
renormalize=config.norm_topk_prob,
192-
quant_config=quant_config,
193-
use_grouped_topk=True,
194-
num_expert_group=config.n_group,
195-
topk_group=config.topk_group,
196-
prefix=f"{prefix}.experts",
197-
scoring_func=config.scoring_func,
198-
# we do scaling outside, set factor to 1.0 to avoid double mul
199-
routed_scaling_factor=1.0,
200-
e_score_correction_bias=self.gate.e_score_correction_bias,
201-
enable_eplb=self.enable_eplb,
202-
num_redundant_experts=self.n_redundant_experts,
203-
is_sequence_parallel=self.is_sequence_parallel,
204-
)
205186
self.shared_experts = None
187+
fused_output_scaling_factor = 1.0
188+
shared_output_scaling_factor = 1.0
206189
else:
207190
intermediate_size = (config.moe_intermediate_size *
208191
config.n_shared_experts)
@@ -213,31 +196,42 @@ def __init__(
213196
hidden_act=config.hidden_act,
214197
quant_config=quant_config,
215198
is_sequence_parallel=self.is_sequence_parallel,
216-
reduce_results=False,
199+
reduce_results=False, # XXXXX
217200
prefix=f"{prefix}.shared_experts",
218201
)
219202

220-
self.experts = SharedFusedMoE(
221-
shared_experts=self.shared_experts,
222-
num_experts=config.n_routed_experts,
223-
top_k=config.num_experts_per_tok,
224-
hidden_size=config.hidden_size,
225-
intermediate_size=config.moe_intermediate_size,
226-
reduce_results=False,
227-
renormalize=config.norm_topk_prob,
228-
quant_config=quant_config,
229-
use_grouped_topk=True,
230-
num_expert_group=config.n_group,
231-
topk_group=config.topk_group,
232-
prefix=f"{prefix}.experts",
233-
scoring_func=config.scoring_func,
234-
# we do scaling outside, set factor to 1.0 to avoid double mul
235-
routed_scaling_factor=1.0,
236-
e_score_correction_bias=self.gate.e_score_correction_bias,
237-
enable_eplb=self.enable_eplb,
238-
num_redundant_experts=self.n_redundant_experts,
239-
is_sequence_parallel=self.is_sequence_parallel,
240-
)
203+
# Fix FP16 overflow
204+
# See DeepseekV2DecoderLayer for more details.
205+
if model_config.dtype != torch.float16:
206+
fused_output_scaling_factor = self.routed_scaling_factor
207+
shared_output_scaling_factor = 1.0
208+
else:
209+
fused_output_scaling_factor = 1.0
210+
shared_output_scaling_factor = (1. /
211+
self.routed_scaling_factor)
212+
213+
self.experts = SharedFusedMoE(
214+
shared_experts=self.shared_experts,
215+
fused_output_scaling_factor=fused_output_scaling_factor,
216+
shared_output_scaling_factor=shared_output_scaling_factor,
217+
num_experts=config.n_routed_experts,
218+
top_k=config.num_experts_per_tok,
219+
hidden_size=config.hidden_size,
220+
intermediate_size=config.moe_intermediate_size,
221+
renormalize=config.norm_topk_prob,
222+
quant_config=quant_config,
223+
use_grouped_topk=True,
224+
num_expert_group=config.n_group,
225+
topk_group=config.topk_group,
226+
prefix=f"{prefix}.experts",
227+
scoring_func=config.scoring_func,
228+
# we do scaling outside, set factor to 1.0 to avoid double mul
229+
routed_scaling_factor=1.0,
230+
e_score_correction_bias=self.gate.e_score_correction_bias,
231+
enable_eplb=self.enable_eplb,
232+
num_redundant_experts=self.n_redundant_experts,
233+
is_sequence_parallel=self.is_sequence_parallel,
234+
)
241235

242236
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
243237
num_tokens, hidden_dim = hidden_states.shape
@@ -253,36 +247,15 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
253247
# router_logits: (num_tokens, n_experts)
254248
router_logits, _ = self.gate(hidden_states)
255249

256-
fused_moe_out = self.experts(hidden_states=hidden_states,
257-
router_logits=router_logits)
258-
259-
if self.shared_experts is not None:
260-
shared_output, final_hidden_states = fused_moe_out
261-
else:
262-
shared_output = None
263-
final_hidden_states = fused_moe_out
264-
265-
# Fix FP16 overflow
266-
# See DeepseekV2DecoderLayer for more details.
267-
if hidden_states.dtype != torch.float16:
268-
final_hidden_states *= self.routed_scaling_factor
269-
elif self.shared_experts is not None:
270-
assert shared_output is not None
271-
shared_output *= (1. / self.routed_scaling_factor)
272-
273-
if self.shared_experts is not None:
274-
assert shared_output is not None
275-
final_hidden_states += shared_output
250+
final_hidden_states = self.experts(hidden_states=hidden_states,
251+
router_logits=router_logits)
276252

277253
if self.is_sequence_parallel:
278254
final_hidden_states = tensor_model_parallel_all_gather(
279255
final_hidden_states, 0)
280256
final_hidden_states = final_hidden_states[:num_tokens]
281-
elif self.tp_size > 1:
282-
final_hidden_states = (
283-
self.experts.maybe_all_reduce_tensor_model_parallel(
284-
final_hidden_states))
285257

258+
# TODO(bnell): why is this view needed?
286259
return final_hidden_states.view(num_tokens, hidden_dim)
287260

288261

@@ -1036,6 +1009,7 @@ def __init__(self,
10361009
and layer_idx % config.moe_layer_freq == 0):
10371010
self.mlp = DeepseekV2MoE(
10381011
config=config,
1012+
model_config=model_config,
10391013
parallel_config=parallel_config,
10401014
quant_config=quant_config,
10411015
prefix=f"{prefix}.mlp",

0 commit comments

Comments
 (0)