diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index e16fc13c945c..c2039adad99c 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -327,7 +327,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w13_weight.data = shuffled_w13 layer.w2_weight.data = shuffled_w2 - if current_platform.is_cpu(): + if current_platform.is_xpu(): + import intel_extension_for_pytorch as ipex + layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( + layer.w13_weight, + layer.w2_weight, + use_prepack=True, + ) + elif current_platform.is_cpu(): if current_platform.get_cpu_architecture() == CpuArchEnum.X86: from vllm.model_executor.layers.fused_moe import cpu_fused_moe dtype = layer.w13_weight.dtype @@ -509,6 +516,44 @@ def forward_cpu( activation, ) + def forward_xpu( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ): + if enable_eplb is not False or expert_load_view is not None or \ + logical_to_physical_map is not None or \ + logical_replica_count is not None: + raise NotImplementedError("Expert load balancing is not supported " + "for XPU.") + assert custom_routing_function is None + return layer.ipex_fusion( + x, + use_grouped_topk, + top_k, + router_logits, + renormalize, + topk_group, + num_expert_group, + ) + def forward_tpu( self, layer: torch.nn.Module,