Skip to content

Commit b81bcd1

Browse files
committed
[Models] Apply SharedFusedMoE to all models with shared experts
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent fc67969 commit b81bcd1

File tree

15 files changed

+315
-335
lines changed

15 files changed

+315
-335
lines changed

vllm/model_executor/layers/fused_moe/__init__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
FusedMoeWeightScaleSupported,
1212
)
1313
from vllm.model_executor.layers.fused_moe.modular_kernel import (
14-
FusedMoEActivationFormat,
15-
FusedMoEPermuteExpertsUnpermute,
16-
FusedMoEPrepareAndFinalize,
17-
)
14+
FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute,
15+
FusedMoEPrepareAndFinalize)
16+
from vllm.model_executor.layers.fused_moe.shared_fused_moe import (
17+
SharedFusedMoE)
1818
from vllm.model_executor.layers.fused_moe.utils import activation_without_mul
1919
from vllm.triton_utils import HAS_TRITON
2020

@@ -42,6 +42,7 @@ def get_config() -> Optional[dict[str, Any]]:
4242
"FusedMoEPermuteExpertsUnpermute",
4343
"FusedMoEActivationFormat",
4444
"FusedMoEPrepareAndFinalize",
45+
"SharedFusedMoE",
4546
"activation_without_mul",
4647
"override_config",
4748
"get_config",

vllm/model_executor/layers/shared_fused_moe/shared_fused_moe.py renamed to vllm/model_executor/layers/fused_moe/shared_fused_moe.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,19 @@ class SharedFusedMoE(FusedMoE):
1818

1919
def __init__(
2020
self,
21-
shared_experts: torch.nn.Module,
21+
shared_experts: Optional[torch.nn.Module],
2222
use_overlapped: bool = True,
2323
**kwargs,
2424
):
2525
super().__init__(**kwargs)
2626
self._shared_experts = shared_experts
27-
self.use_overlapped = use_overlapped
27+
# Disable shared expert overlap if EP is disabled or we are not using
28+
# flashinfer + DP since there is nothing to be gained in this case.
29+
# Disabling the overlap optimization also prevents the shared experts
30+
# from being hidden from torch.compile.
31+
self.use_overlapped = use_overlapped and not (
32+
self.use_ep or self.use_flashinfer_cutlass_kernels
33+
) and self._shared_experts is not None
2834

2935
@property
3036
def shared_experts(self) -> Optional[torch.nn.Module]:
@@ -36,16 +42,16 @@ def forward(
3642
router_logits: torch.Tensor,
3743
) -> tuple[torch.Tensor, torch.Tensor]:
3844
if not self.use_overlapped:
39-
shared_out = self._shared_experts(hidden_states)
40-
41-
# Reduce outputs if necessary, since the MLP should
42-
# have been created with reduce_results=False.
43-
if (
44-
self.reduce_results
45-
and self.tp_size > 1
46-
and self.must_reduce_shared_expert_outputs()
47-
):
48-
shared_out = tensor_model_parallel_all_reduce(shared_out)
45+
if self._shared_experts is not None:
46+
shared_out = self._shared_experts(hidden_states)
47+
48+
# Reduce shared expert outputs if necessary, since the MLP
49+
# should have been created with reduce_results=False.
50+
if (self.reduce_results and self.tp_size > 1
51+
and self.must_reduce_shared_expert_outputs()):
52+
shared_out = tensor_model_parallel_all_reduce(shared_out)
53+
else:
54+
shared_out = None
4955

5056
fused_out = super().forward(
5157
hidden_states=hidden_states,

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,8 @@ def create_weights(
741741
layer.w13_input_scale = None
742742
layer.w2_input_scale = None
743743

744+
self.rocm_aiter_moe_enabled = False
745+
744746
def process_weights_after_loading(self, layer: Module) -> None:
745747
# Lazy import to avoid importing triton too early.
746748
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (

vllm/model_executor/layers/shared_fused_moe/__init__.py

Lines changed: 0 additions & 5 deletions
This file was deleted.

vllm/model_executor/models/aria.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
from vllm.config.multimodal import BaseDummyOptions
1414
from vllm.distributed import get_tensor_model_parallel_rank
1515
from vllm.model_executor.layers.activation import get_act_fn
16-
from vllm.model_executor.layers.fused_moe import FusedMoE
17-
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
16+
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
17+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
18+
RowParallelLinear)
1819
from vllm.model_executor.layers.logits_processor import LogitsProcessor
1920
from vllm.model_executor.layers.quantization import QuantizationConfig
2021
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
@@ -206,10 +207,10 @@ def forward(
206207
return out
207208

208209

209-
class AriaFusedMoE(FusedMoE):
210-
def weight_loader(
211-
self, param: nn.Parameter, loaded_weight: torch.Tensor, shard_id: str
212-
) -> None:
210+
class AriaFusedMoE(SharedFusedMoE):
211+
212+
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
213+
shard_id: str) -> None:
213214
# Override the weight_loader to handle the expert weights in the Aria
214215
# model, which are already packed with experts, and merge the gate and
215216
# up weights for each expert.
@@ -260,7 +261,16 @@ def __init__(
260261
torch.empty((self.config.moe_num_experts, self.config.hidden_size))
261262
)
262263

264+
self.shared_experts = LlamaMLP(
265+
config.hidden_size,
266+
config.intermediate_size * config.moe_num_shared_experts,
267+
"silu",
268+
quant_config=quant_config,
269+
bias=config.mlp_bias,
270+
)
271+
263272
self.experts = AriaFusedMoE(
273+
shared_experts=self.shared_experts,
264274
num_experts=config.moe_num_experts,
265275
top_k=config.moe_topk,
266276
hidden_size=config.hidden_size,
@@ -269,13 +279,6 @@ def __init__(
269279
reduce_results=True,
270280
prefix=f"{prefix}.experts",
271281
)
272-
self.shared_experts = LlamaMLP(
273-
config.hidden_size,
274-
config.intermediate_size * config.moe_num_shared_experts,
275-
"silu",
276-
quant_config=quant_config,
277-
bias=config.mlp_bias,
278-
)
279282

280283
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
281284
"""
@@ -291,12 +294,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
291294

292295
router_output = torch.nn.functional.linear(hidden_states, self.router_weight)
293296

294-
hidden_states_copy = hidden_states.clone()
295-
# NOTE: hidden_states will be modified inplace by `FusedMoE`
296297
sparse_expert_output = self.experts(hidden_states, router_output)
297-
shared_expert_output = self.shared_experts(hidden_states_copy)
298298

299-
return sparse_expert_output + shared_expert_output
299+
if self.shared_experts is not None:
300+
return sparse_expert_output[0] + sparse_expert_output[1]
301+
else:
302+
return sparse_expert_output
300303

301304

302305
class AriaTextDecoderLayer(LlamaDecoderLayer):

vllm/model_executor/models/bailing_moe.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
tensor_model_parallel_all_reduce,
4444
)
4545
from vllm.model_executor.layers.activation import SiluAndMul
46-
from vllm.model_executor.layers.fused_moe import FusedMoE
46+
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
4747
from vllm.model_executor.layers.layernorm import RMSNorm
4848
from vllm.model_executor.layers.linear import (
4949
MergedColumnParallelLinear,
@@ -276,22 +276,6 @@ def __init__(
276276
# default value for scoring_func
277277
self.score_function = "softmax"
278278

279-
self.experts = FusedMoE(
280-
num_experts=self.num_experts,
281-
top_k=self.top_k,
282-
hidden_size=self.hidden_size,
283-
intermediate_size=config.moe_intermediate_size,
284-
reduce_results=False,
285-
renormalize=self.norm_expert_prob,
286-
quant_config=quant_config,
287-
prefix=f"{prefix}.experts",
288-
scoring_func=self.score_function,
289-
e_score_correction_bias=self.gate.expert_bias,
290-
num_expert_group=self.n_group,
291-
topk_group=self.topk_group,
292-
use_grouped_topk=self.use_grouped_topk,
293-
)
294-
295279
if self.num_shared_experts > 0:
296280
if hasattr(config, "moe_shared_expert_intermediate_size"):
297281
intermediate_size = config.moe_shared_expert_intermediate_size
@@ -308,11 +292,29 @@ def __init__(
308292
else:
309293
self.shared_experts = None
310294

295+
self.experts = SharedFusedMoE(
296+
shared_experts=self.shared_experts,
297+
fused_output_scaling_factor=self.routed_scaling_factor,
298+
shared_output_scaling_factor=1.0,
299+
num_experts=self.num_experts,
300+
top_k=self.top_k,
301+
hidden_size=self.hidden_size,
302+
intermediate_size=config.moe_intermediate_size,
303+
reduce_results=False,
304+
renormalize=self.norm_expert_prob,
305+
quant_config=quant_config,
306+
prefix=f"{prefix}.experts",
307+
scoring_func=self.score_function,
308+
e_score_correction_bias=self.gate.expert_bias,
309+
num_expert_group=self.n_group,
310+
topk_group=self.topk_group,
311+
use_grouped_topk=self.use_grouped_topk,
312+
)
313+
311314
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
312315
num_tokens, hidden_size = hidden_states.shape
313316
hidden_states = hidden_states.view(-1, hidden_size)
314-
if self.shared_experts:
315-
shared_output = self.shared_experts(hidden_states)
317+
316318
# router_logits: (num_tokens, n_experts)
317319
router_logits = self.gate(hidden_states.to(self.router_dtype))
318320
router_logits = router_logits.to(hidden_states.dtype)
@@ -321,9 +323,14 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
321323
hidden_states=hidden_states, router_logits=router_logits
322324
)
323325

326+
if self.shared_experts is not None:
327+
shared_output, final_hidden_states = final_hidden_states
328+
else:
329+
shared_output = None
330+
324331
final_hidden_states *= self.routed_scaling_factor
325332

326-
if self.shared_experts:
333+
if shared_output is not None:
327334
final_hidden_states = final_hidden_states + shared_output
328335

329336
if self.tp_size > 1:
@@ -475,7 +482,7 @@ def forward(
475482
return hidden_states
476483

477484
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
478-
return FusedMoE.make_expert_params_mapping(
485+
return SharedFusedMoE.make_expert_params_mapping(
479486
ckpt_gate_proj_name="gate_proj",
480487
ckpt_down_proj_name="down_proj",
481488
ckpt_up_proj_name="up_proj",

vllm/model_executor/models/deepseek_v2.py

Lines changed: 24 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from vllm.logger import init_logger
5050
from vllm.model_executor.layers.activation import SiluAndMul
5151
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
52-
from vllm.model_executor.layers.fused_moe import FusedMoE
52+
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
5353
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
5454
from vllm.model_executor.layers.linear import (
5555
ColumnParallelLinear,
@@ -64,7 +64,6 @@
6464
per_token_group_quant_fp8,
6565
)
6666
from vllm.model_executor.layers.rotary_embedding import get_rope
67-
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
6867
from vllm.model_executor.layers.vocab_parallel_embedding import (
6968
ParallelLMHead,
7069
VocabParallelEmbedding,
@@ -205,26 +204,6 @@ def __init__(
205204
)
206205

207206
if config.n_shared_experts is None:
208-
self.experts = FusedMoE(
209-
num_experts=config.n_routed_experts,
210-
top_k=config.num_experts_per_tok,
211-
hidden_size=config.hidden_size,
212-
intermediate_size=config.moe_intermediate_size,
213-
reduce_results=False,
214-
renormalize=config.norm_topk_prob,
215-
quant_config=quant_config,
216-
use_grouped_topk=True,
217-
num_expert_group=config.n_group,
218-
topk_group=config.topk_group,
219-
prefix=f"{prefix}.experts",
220-
scoring_func=config.scoring_func,
221-
# we do scaling outside, set factor to 1.0 to avoid double mul
222-
routed_scaling_factor=1.0,
223-
e_score_correction_bias=self.gate.e_score_correction_bias,
224-
enable_eplb=self.enable_eplb,
225-
num_redundant_experts=self.n_redundant_experts,
226-
is_sequence_parallel=self.is_sequence_parallel,
227-
)
228207
self.shared_experts = None
229208
else:
230209
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
@@ -239,27 +218,27 @@ def __init__(
239218
prefix=f"{prefix}.shared_experts",
240219
)
241220

242-
self.experts = SharedFusedMoE(
243-
shared_experts=self.shared_experts,
244-
num_experts=config.n_routed_experts,
245-
top_k=config.num_experts_per_tok,
246-
hidden_size=config.hidden_size,
247-
intermediate_size=config.moe_intermediate_size,
248-
reduce_results=False,
249-
renormalize=config.norm_topk_prob,
250-
quant_config=quant_config,
251-
use_grouped_topk=True,
252-
num_expert_group=config.n_group,
253-
topk_group=config.topk_group,
254-
prefix=f"{prefix}.experts",
255-
scoring_func=config.scoring_func,
256-
# we do scaling outside, set factor to 1.0 to avoid double mul
257-
routed_scaling_factor=1.0,
258-
e_score_correction_bias=self.gate.e_score_correction_bias,
259-
enable_eplb=self.enable_eplb,
260-
num_redundant_experts=self.n_redundant_experts,
261-
is_sequence_parallel=self.is_sequence_parallel,
262-
)
221+
self.experts = SharedFusedMoE(
222+
shared_experts=self.shared_experts,
223+
num_experts=config.n_routed_experts,
224+
top_k=config.num_experts_per_tok,
225+
hidden_size=config.hidden_size,
226+
intermediate_size=config.moe_intermediate_size,
227+
reduce_results=False,
228+
renormalize=config.norm_topk_prob,
229+
quant_config=quant_config,
230+
use_grouped_topk=True,
231+
num_expert_group=config.n_group,
232+
topk_group=config.topk_group,
233+
prefix=f"{prefix}.experts",
234+
scoring_func=config.scoring_func,
235+
# we do scaling outside, set factor to 1.0 to avoid double mul
236+
routed_scaling_factor=1.0,
237+
e_score_correction_bias=self.gate.e_score_correction_bias,
238+
enable_eplb=self.enable_eplb,
239+
num_redundant_experts=self.n_redundant_experts,
240+
is_sequence_parallel=self.is_sequence_parallel,
241+
)
263242

264243
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
265244
num_tokens, hidden_dim = hidden_states.shape
@@ -1293,7 +1272,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
12931272
self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace
12941273
self.num_expert_groups = config.n_group
12951274

1296-
self.moe_layers: list[FusedMoE] = []
1275+
self.moe_layers: list[SharedFusedMoE] = []
12971276
example_moe = None
12981277
for layer in self.model.layers:
12991278
if isinstance(layer, PPMissingLayer):
@@ -1381,7 +1360,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
13811360

13821361
# Params for weights, fp8 weight scales, fp8 activation scales
13831362
# (param_name, weight_name, expert_id, shard_id)
1384-
expert_params_mapping = FusedMoE.make_expert_params_mapping(
1363+
expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
13851364
ckpt_gate_proj_name="gate_proj",
13861365
ckpt_down_proj_name="down_proj",
13871366
ckpt_up_proj_name="up_proj",

0 commit comments

Comments
 (0)