Skip to content

Commit a8cd0fe

Browse files
bnellnm0xrushi
authored andcommitted
[Model] Apply shared experts overlap optimization to all models with shared experts (vllm-project#26145)
Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
1 parent c1167e7 commit a8cd0fe

File tree

15 files changed

+271
-283
lines changed

15 files changed

+271
-283
lines changed

vllm/model_executor/layers/fused_moe/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
FusedMoEPermuteExpertsUnpermute,
1616
FusedMoEPrepareAndFinalize,
1717
)
18+
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
1819
from vllm.model_executor.layers.fused_moe.utils import activation_without_mul
1920
from vllm.triton_utils import HAS_TRITON
2021

@@ -42,6 +43,7 @@ def get_config() -> Optional[dict[str, Any]]:
4243
"FusedMoEPermuteExpertsUnpermute",
4344
"FusedMoEActivationFormat",
4445
"FusedMoEPrepareAndFinalize",
46+
"SharedFusedMoE",
4547
"activation_without_mul",
4648
"override_config",
4749
"get_config",

vllm/model_executor/layers/shared_fused_moe/shared_fused_moe.py renamed to vllm/model_executor/layers/fused_moe/shared_fused_moe.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,21 @@ class SharedFusedMoE(FusedMoE):
1818

1919
def __init__(
2020
self,
21-
shared_experts: torch.nn.Module,
21+
shared_experts: Optional[torch.nn.Module],
2222
use_overlapped: bool = True,
2323
**kwargs,
2424
):
2525
super().__init__(**kwargs)
2626
self._shared_experts = shared_experts
27-
self.use_overlapped = use_overlapped
27+
# Disable shared expert overlap if EP is disabled or we are not using
28+
# flashinfer + DP since there is nothing to be gained in this case.
29+
# Disabling the overlap optimization also prevents the shared experts
30+
# from being hidden from torch.compile.
31+
self.use_overlapped = (
32+
use_overlapped
33+
and not (self.use_ep or self.use_flashinfer_cutlass_kernels)
34+
and self._shared_experts is not None
35+
)
2836

2937
@property
3038
def shared_experts(self) -> Optional[torch.nn.Module]:
@@ -36,16 +44,19 @@ def forward(
3644
router_logits: torch.Tensor,
3745
) -> tuple[torch.Tensor, torch.Tensor]:
3846
if not self.use_overlapped:
39-
shared_out = self._shared_experts(hidden_states)
40-
41-
# Reduce outputs if necessary, since the MLP should
42-
# have been created with reduce_results=False.
43-
if (
44-
self.reduce_results
45-
and self.tp_size > 1
46-
and self.must_reduce_shared_expert_outputs()
47-
):
48-
shared_out = tensor_model_parallel_all_reduce(shared_out)
47+
if self._shared_experts is not None:
48+
shared_out = self._shared_experts(hidden_states)
49+
50+
# Reduce shared expert outputs if necessary, since the MLP
51+
# should have been created with reduce_results=False.
52+
if (
53+
self.reduce_results
54+
and self.tp_size > 1
55+
and self.must_reduce_shared_expert_outputs()
56+
):
57+
shared_out = tensor_model_parallel_all_reduce(shared_out)
58+
else:
59+
shared_out = None
4960

5061
fused_out = super().forward(
5162
hidden_states=hidden_states,

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,8 @@ def create_weights(
741741
layer.w13_input_scale = None
742742
layer.w2_input_scale = None
743743

744+
self.rocm_aiter_moe_enabled = False
745+
744746
def process_weights_after_loading(self, layer: Module) -> None:
745747
# Lazy import to avoid importing triton too early.
746748
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (

vllm/model_executor/layers/shared_fused_moe/__init__.py

Lines changed: 0 additions & 5 deletions
This file was deleted.

vllm/model_executor/models/aria.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from vllm.config.multimodal import BaseDummyOptions
1414
from vllm.distributed import get_tensor_model_parallel_rank
1515
from vllm.model_executor.layers.activation import get_act_fn
16-
from vllm.model_executor.layers.fused_moe import FusedMoE
16+
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
1717
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
1818
from vllm.model_executor.layers.logits_processor import LogitsProcessor
1919
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -206,7 +206,7 @@ def forward(
206206
return out
207207

208208

209-
class AriaFusedMoE(FusedMoE):
209+
class AriaFusedMoE(SharedFusedMoE):
210210
def weight_loader(
211211
self, param: nn.Parameter, loaded_weight: torch.Tensor, shard_id: str
212212
) -> None:
@@ -260,7 +260,16 @@ def __init__(
260260
torch.empty((self.config.moe_num_experts, self.config.hidden_size))
261261
)
262262

263+
self.shared_experts = LlamaMLP(
264+
config.hidden_size,
265+
config.intermediate_size * config.moe_num_shared_experts,
266+
"silu",
267+
quant_config=quant_config,
268+
bias=config.mlp_bias,
269+
)
270+
263271
self.experts = AriaFusedMoE(
272+
shared_experts=self.shared_experts,
264273
num_experts=config.moe_num_experts,
265274
top_k=config.moe_topk,
266275
hidden_size=config.hidden_size,
@@ -269,13 +278,6 @@ def __init__(
269278
reduce_results=True,
270279
prefix=f"{prefix}.experts",
271280
)
272-
self.shared_experts = LlamaMLP(
273-
config.hidden_size,
274-
config.intermediate_size * config.moe_num_shared_experts,
275-
"silu",
276-
quant_config=quant_config,
277-
bias=config.mlp_bias,
278-
)
279281

280282
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
281283
"""
@@ -291,12 +293,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
291293

292294
router_output = torch.nn.functional.linear(hidden_states, self.router_weight)
293295

294-
hidden_states_copy = hidden_states.clone()
295-
# NOTE: hidden_states will be modified inplace by `FusedMoE`
296296
sparse_expert_output = self.experts(hidden_states, router_output)
297-
shared_expert_output = self.shared_experts(hidden_states_copy)
298297

299-
return sparse_expert_output + shared_expert_output
298+
if self.shared_experts is not None:
299+
return sparse_expert_output[0] + sparse_expert_output[1]
300+
else:
301+
return sparse_expert_output
300302

301303

302304
class AriaTextDecoderLayer(LlamaDecoderLayer):

vllm/model_executor/models/bailing_moe.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
tensor_model_parallel_all_reduce,
4444
)
4545
from vllm.model_executor.layers.activation import SiluAndMul
46-
from vllm.model_executor.layers.fused_moe import FusedMoE
46+
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
4747
from vllm.model_executor.layers.layernorm import RMSNorm
4848
from vllm.model_executor.layers.linear import (
4949
MergedColumnParallelLinear,
@@ -276,22 +276,6 @@ def __init__(
276276
# default value for scoring_func
277277
self.score_function = "softmax"
278278

279-
self.experts = FusedMoE(
280-
num_experts=self.num_experts,
281-
top_k=self.top_k,
282-
hidden_size=self.hidden_size,
283-
intermediate_size=config.moe_intermediate_size,
284-
reduce_results=False,
285-
renormalize=self.norm_expert_prob,
286-
quant_config=quant_config,
287-
prefix=f"{prefix}.experts",
288-
scoring_func=self.score_function,
289-
e_score_correction_bias=self.gate.expert_bias,
290-
num_expert_group=self.n_group,
291-
topk_group=self.topk_group,
292-
use_grouped_topk=self.use_grouped_topk,
293-
)
294-
295279
if self.num_shared_experts > 0:
296280
if hasattr(config, "moe_shared_expert_intermediate_size"):
297281
intermediate_size = config.moe_shared_expert_intermediate_size
@@ -308,11 +292,27 @@ def __init__(
308292
else:
309293
self.shared_experts = None
310294

295+
self.experts = SharedFusedMoE(
296+
shared_experts=self.shared_experts,
297+
num_experts=self.num_experts,
298+
top_k=self.top_k,
299+
hidden_size=self.hidden_size,
300+
intermediate_size=config.moe_intermediate_size,
301+
reduce_results=False,
302+
renormalize=self.norm_expert_prob,
303+
quant_config=quant_config,
304+
prefix=f"{prefix}.experts",
305+
scoring_func=self.score_function,
306+
e_score_correction_bias=self.gate.expert_bias,
307+
num_expert_group=self.n_group,
308+
topk_group=self.topk_group,
309+
use_grouped_topk=self.use_grouped_topk,
310+
)
311+
311312
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
312313
num_tokens, hidden_size = hidden_states.shape
313314
hidden_states = hidden_states.view(-1, hidden_size)
314-
if self.shared_experts:
315-
shared_output = self.shared_experts(hidden_states)
315+
316316
# router_logits: (num_tokens, n_experts)
317317
router_logits = self.gate(hidden_states.to(self.router_dtype))
318318
router_logits = router_logits.to(hidden_states.dtype)
@@ -321,9 +321,14 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
321321
hidden_states=hidden_states, router_logits=router_logits
322322
)
323323

324+
if self.shared_experts is not None:
325+
shared_output, final_hidden_states = final_hidden_states
326+
else:
327+
shared_output = None
328+
324329
final_hidden_states *= self.routed_scaling_factor
325330

326-
if self.shared_experts:
331+
if shared_output is not None:
327332
final_hidden_states = final_hidden_states + shared_output
328333

329334
if self.tp_size > 1:
@@ -475,7 +480,7 @@ def forward(
475480
return hidden_states
476481

477482
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
478-
return FusedMoE.make_expert_params_mapping(
483+
return SharedFusedMoE.make_expert_params_mapping(
479484
ckpt_gate_proj_name="gate_proj",
480485
ckpt_down_proj_name="down_proj",
481486
ckpt_up_proj_name="up_proj",

vllm/model_executor/models/deepseek_v2.py

Lines changed: 24 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from vllm.logger import init_logger
5050
from vllm.model_executor.layers.activation import SiluAndMul
5151
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
52-
from vllm.model_executor.layers.fused_moe import FusedMoE
52+
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
5353
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
5454
from vllm.model_executor.layers.linear import (
5555
ColumnParallelLinear,
@@ -64,7 +64,6 @@
6464
per_token_group_quant_fp8,
6565
)
6666
from vllm.model_executor.layers.rotary_embedding import get_rope
67-
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
6867
from vllm.model_executor.layers.vocab_parallel_embedding import (
6968
ParallelLMHead,
7069
VocabParallelEmbedding,
@@ -205,26 +204,6 @@ def __init__(
205204
)
206205

207206
if config.n_shared_experts is None:
208-
self.experts = FusedMoE(
209-
num_experts=config.n_routed_experts,
210-
top_k=config.num_experts_per_tok,
211-
hidden_size=config.hidden_size,
212-
intermediate_size=config.moe_intermediate_size,
213-
reduce_results=False,
214-
renormalize=config.norm_topk_prob,
215-
quant_config=quant_config,
216-
use_grouped_topk=True,
217-
num_expert_group=config.n_group,
218-
topk_group=config.topk_group,
219-
prefix=f"{prefix}.experts",
220-
scoring_func=config.scoring_func,
221-
# we do scaling outside, set factor to 1.0 to avoid double mul
222-
routed_scaling_factor=1.0,
223-
e_score_correction_bias=self.gate.e_score_correction_bias,
224-
enable_eplb=self.enable_eplb,
225-
num_redundant_experts=self.n_redundant_experts,
226-
is_sequence_parallel=self.is_sequence_parallel,
227-
)
228207
self.shared_experts = None
229208
else:
230209
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
@@ -239,27 +218,27 @@ def __init__(
239218
prefix=f"{prefix}.shared_experts",
240219
)
241220

242-
self.experts = SharedFusedMoE(
243-
shared_experts=self.shared_experts,
244-
num_experts=config.n_routed_experts,
245-
top_k=config.num_experts_per_tok,
246-
hidden_size=config.hidden_size,
247-
intermediate_size=config.moe_intermediate_size,
248-
reduce_results=False,
249-
renormalize=config.norm_topk_prob,
250-
quant_config=quant_config,
251-
use_grouped_topk=True,
252-
num_expert_group=config.n_group,
253-
topk_group=config.topk_group,
254-
prefix=f"{prefix}.experts",
255-
scoring_func=config.scoring_func,
256-
# we do scaling outside, set factor to 1.0 to avoid double mul
257-
routed_scaling_factor=1.0,
258-
e_score_correction_bias=self.gate.e_score_correction_bias,
259-
enable_eplb=self.enable_eplb,
260-
num_redundant_experts=self.n_redundant_experts,
261-
is_sequence_parallel=self.is_sequence_parallel,
262-
)
221+
self.experts = SharedFusedMoE(
222+
shared_experts=self.shared_experts,
223+
num_experts=config.n_routed_experts,
224+
top_k=config.num_experts_per_tok,
225+
hidden_size=config.hidden_size,
226+
intermediate_size=config.moe_intermediate_size,
227+
reduce_results=False,
228+
renormalize=config.norm_topk_prob,
229+
quant_config=quant_config,
230+
use_grouped_topk=True,
231+
num_expert_group=config.n_group,
232+
topk_group=config.topk_group,
233+
prefix=f"{prefix}.experts",
234+
scoring_func=config.scoring_func,
235+
# we do scaling outside, set factor to 1.0 to avoid double mul
236+
routed_scaling_factor=1.0,
237+
e_score_correction_bias=self.gate.e_score_correction_bias,
238+
enable_eplb=self.enable_eplb,
239+
num_redundant_experts=self.n_redundant_experts,
240+
is_sequence_parallel=self.is_sequence_parallel,
241+
)
263242

264243
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
265244
num_tokens, hidden_dim = hidden_states.shape
@@ -1306,7 +1285,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
13061285
self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace
13071286
self.num_expert_groups = config.n_group
13081287

1309-
self.moe_layers: list[FusedMoE] = []
1288+
self.moe_layers: list[SharedFusedMoE] = []
13101289
example_moe = None
13111290
for layer in self.model.layers:
13121291
if isinstance(layer, PPMissingLayer):
@@ -1394,7 +1373,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
13941373

13951374
# Params for weights, fp8 weight scales, fp8 activation scales
13961375
# (param_name, weight_name, expert_id, shard_id)
1397-
expert_params_mapping = FusedMoE.make_expert_params_mapping(
1376+
expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
13981377
ckpt_gate_proj_name="gate_proj",
13991378
ckpt_down_proj_name="down_proj",
14001379
ckpt_up_proj_name="up_proj",

0 commit comments

Comments
 (0)