Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/fused_moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe.utils import activation_without_mul
from vllm.triton_utils import HAS_TRITON

Expand Down Expand Up @@ -42,6 +43,7 @@ def get_config() -> Optional[dict[str, Any]]:
"FusedMoEPermuteExpertsUnpermute",
"FusedMoEActivationFormat",
"FusedMoEPrepareAndFinalize",
"SharedFusedMoE",
"activation_without_mul",
"override_config",
"get_config",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,21 @@ class SharedFusedMoE(FusedMoE):

def __init__(
self,
shared_experts: torch.nn.Module,
shared_experts: Optional[torch.nn.Module],
use_overlapped: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self._shared_experts = shared_experts
self.use_overlapped = use_overlapped
# Disable shared expert overlap if EP is disabled or we are not using
# flashinfer + DP since there is nothing to be gained in this case.
# Disabling the overlap optimization also prevents the shared experts
# from being hidden from torch.compile.
self.use_overlapped = (
use_overlapped
and not (self.use_ep or self.use_flashinfer_cutlass_kernels)
and self._shared_experts is not None
)

@property
def shared_experts(self) -> Optional[torch.nn.Module]:
Expand All @@ -36,16 +44,19 @@ def forward(
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
if not self.use_overlapped:
shared_out = self._shared_experts(hidden_states)

# Reduce outputs if necessary, since the MLP should
# have been created with reduce_results=False.
if (
self.reduce_results
and self.tp_size > 1
and self.must_reduce_shared_expert_outputs()
):
shared_out = tensor_model_parallel_all_reduce(shared_out)
if self._shared_experts is not None:
shared_out = self._shared_experts(hidden_states)

# Reduce shared expert outputs if necessary, since the MLP
# should have been created with reduce_results=False.
if (
self.reduce_results
and self.tp_size > 1
and self.must_reduce_shared_expert_outputs()
):
shared_out = tensor_model_parallel_all_reduce(shared_out)
else:
shared_out = None

fused_out = super().forward(
hidden_states=hidden_states,
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,8 @@ def create_weights(
layer.w13_input_scale = None
layer.w2_input_scale = None

self.rocm_aiter_moe_enabled = False

def process_weights_after_loading(self, layer: Module) -> None:
# Lazy import to avoid importing triton too early.
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
Expand Down
5 changes: 0 additions & 5 deletions vllm/model_executor/layers/shared_fused_moe/__init__.py

This file was deleted.

28 changes: 15 additions & 13 deletions vllm/model_executor/models/aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
Expand Down Expand Up @@ -206,7 +206,7 @@ def forward(
return out


class AriaFusedMoE(FusedMoE):
class AriaFusedMoE(SharedFusedMoE):
def weight_loader(
self, param: nn.Parameter, loaded_weight: torch.Tensor, shard_id: str
) -> None:
Expand Down Expand Up @@ -260,7 +260,16 @@ def __init__(
torch.empty((self.config.moe_num_experts, self.config.hidden_size))
)

self.shared_experts = LlamaMLP(
config.hidden_size,
config.intermediate_size * config.moe_num_shared_experts,
"silu",
quant_config=quant_config,
bias=config.mlp_bias,
)

self.experts = AriaFusedMoE(
shared_experts=self.shared_experts,
num_experts=config.moe_num_experts,
top_k=config.moe_topk,
hidden_size=config.hidden_size,
Expand All @@ -269,13 +278,6 @@ def __init__(
reduce_results=True,
prefix=f"{prefix}.experts",
)
self.shared_experts = LlamaMLP(
config.hidden_size,
config.intermediate_size * config.moe_num_shared_experts,
"silu",
quant_config=quant_config,
bias=config.mlp_bias,
)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Expand All @@ -291,12 +293,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

router_output = torch.nn.functional.linear(hidden_states, self.router_weight)

hidden_states_copy = hidden_states.clone()
# NOTE: hidden_states will be modified inplace by `FusedMoE`
sparse_expert_output = self.experts(hidden_states, router_output)
shared_expert_output = self.shared_experts(hidden_states_copy)

return sparse_expert_output + shared_expert_output
if self.shared_experts is not None:
return sparse_expert_output[0] + sparse_expert_output[1]
else:
return sparse_expert_output


class AriaTextDecoderLayer(LlamaDecoderLayer):
Expand Down
47 changes: 26 additions & 21 deletions vllm/model_executor/models/bailing_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
Expand Down Expand Up @@ -276,22 +276,6 @@ def __init__(
# default value for scoring_func
self.score_function = "softmax"

self.experts = FusedMoE(
num_experts=self.num_experts,
top_k=self.top_k,
hidden_size=self.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=self.norm_expert_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
scoring_func=self.score_function,
e_score_correction_bias=self.gate.expert_bias,
num_expert_group=self.n_group,
topk_group=self.topk_group,
use_grouped_topk=self.use_grouped_topk,
)

if self.num_shared_experts > 0:
if hasattr(config, "moe_shared_expert_intermediate_size"):
intermediate_size = config.moe_shared_expert_intermediate_size
Expand All @@ -308,11 +292,27 @@ def __init__(
else:
self.shared_experts = None

self.experts = SharedFusedMoE(
shared_experts=self.shared_experts,
num_experts=self.num_experts,
top_k=self.top_k,
hidden_size=self.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=self.norm_expert_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
scoring_func=self.score_function,
e_score_correction_bias=self.gate.expert_bias,
num_expert_group=self.n_group,
topk_group=self.topk_group,
use_grouped_topk=self.use_grouped_topk,
)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_size)
if self.shared_experts:
shared_output = self.shared_experts(hidden_states)

# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states.to(self.router_dtype))
router_logits = router_logits.to(hidden_states.dtype)
Expand All @@ -321,9 +321,14 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states=hidden_states, router_logits=router_logits
)

if self.shared_experts is not None:
shared_output, final_hidden_states = final_hidden_states
else:
shared_output = None

final_hidden_states *= self.routed_scaling_factor

if self.shared_experts:
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output

if self.tp_size > 1:
Expand Down Expand Up @@ -475,7 +480,7 @@ def forward(
return hidden_states

def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return FusedMoE.make_expert_params_mapping(
return SharedFusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
Expand Down
69 changes: 24 additions & 45 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
Expand All @@ -64,7 +64,6 @@
per_token_group_quant_fp8,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
Expand Down Expand Up @@ -205,26 +204,6 @@ def __init__(
)

if config.n_shared_experts is None:
self.experts = FusedMoE(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
# we do scaling outside, set factor to 1.0 to avoid double mul
routed_scaling_factor=1.0,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
)
self.shared_experts = None
else:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
Expand All @@ -239,27 +218,27 @@ def __init__(
prefix=f"{prefix}.shared_experts",
)

self.experts = SharedFusedMoE(
shared_experts=self.shared_experts,
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
# we do scaling outside, set factor to 1.0 to avoid double mul
routed_scaling_factor=1.0,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
)
self.experts = SharedFusedMoE(
shared_experts=self.shared_experts,
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
# we do scaling outside, set factor to 1.0 to avoid double mul
routed_scaling_factor=1.0,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
Expand Down Expand Up @@ -1293,7 +1272,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace
self.num_expert_groups = config.n_group

self.moe_layers: list[FusedMoE] = []
self.moe_layers: list[SharedFusedMoE] = []
example_moe = None
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
Expand Down Expand Up @@ -1381,7 +1360,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:

# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
Expand Down
Loading