@@ -327,7 +327,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
327327 layer .w13_weight .data = shuffled_w13
328328 layer .w2_weight .data = shuffled_w2
329329
330- if current_platform .is_cpu ():
330+ if current_platform .is_xpu ():
331+ import intel_extension_for_pytorch as ipex
332+ layer .ipex_fusion = ipex .llm .modules .GatedMLPMOE (
333+ layer .w13_weight ,
334+ layer .w2_weight ,
335+ use_prepack = True ,
336+ )
337+ elif current_platform .is_cpu ():
331338 if current_platform .get_cpu_architecture () == CpuArchEnum .X86 :
332339 from vllm .model_executor .layers .fused_moe import cpu_fused_moe
333340 dtype = layer .w13_weight .dtype
@@ -509,6 +516,44 @@ def forward_cpu(
509516 activation ,
510517 )
511518
519+ def forward_xpu (
520+ self ,
521+ layer : torch .nn .Module ,
522+ x : torch .Tensor ,
523+ use_grouped_topk : bool ,
524+ top_k : int ,
525+ router_logits : torch .Tensor ,
526+ renormalize : bool ,
527+ topk_group : Optional [int ] = None ,
528+ num_expert_group : Optional [int ] = None ,
529+ global_num_experts : int = - 1 ,
530+ expert_map : Optional [torch .Tensor ] = None ,
531+ custom_routing_function : Optional [Callable ] = None ,
532+ scoring_func : str = "softmax" ,
533+ e_score_correction_bias : Optional [torch .Tensor ] = None ,
534+ apply_router_weight_on_input : bool = False ,
535+ activation : str = "silu" ,
536+ enable_eplb : bool = False ,
537+ expert_load_view : Optional [torch .Tensor ] = None ,
538+ logical_to_physical_map : Optional [torch .Tensor ] = None ,
539+ logical_replica_count : Optional [torch .Tensor ] = None ,
540+ ):
541+ if enable_eplb is not False or expert_load_view is not None or \
542+ logical_to_physical_map is not None or \
543+ logical_replica_count is not None :
544+ raise NotImplementedError ("Expert load balancing is not supported "
545+ "for XPU." )
546+ assert custom_routing_function is None
547+ return layer .ipex_fusion (
548+ x ,
549+ use_grouped_topk ,
550+ top_k ,
551+ router_logits ,
552+ renormalize ,
553+ topk_group ,
554+ num_expert_group ,
555+ )
556+
512557 def forward_tpu (
513558 self ,
514559 layer : torch .nn .Module ,
0 commit comments