@@ -142,6 +142,9 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
142142 else :
143143 logger .info_once ("Using Triton backend" )
144144 return Mxfp4Backend .TRITON
145+ elif current_platform .is_xpu ():
146+ logger .info_once ("Using ipex marlin backend on XPU" )
147+ return Mxfp4Backend .MARLIN
145148 elif current_platform .is_rocm () and has_triton_kernels ():
146149 logger .info_once ("Using Triton backend" )
147150 return Mxfp4Backend .TRITON
@@ -188,7 +191,10 @@ def get_quant_method(
188191 return UnquantizedLinearMethod ()
189192 raise NotImplementedError ("Mxfp4 linear layer is not implemented" )
190193 elif isinstance (layer , FusedMoE ):
191- return Mxfp4MoEMethod (layer .moe_config )
194+ if current_platform .is_xpu ():
195+ return IpexMxfp4MoEMethod (layer .moe_config )
196+ else :
197+ return Mxfp4MoEMethod (layer .moe_config )
192198 elif isinstance (layer , Attention ):
193199 raise NotImplementedError ("Mxfp4 attention layer is not implemented" )
194200 return None
@@ -245,7 +251,10 @@ def create_weights(
245251 intermediate_size_per_partition_after_pad = round_up (
246252 intermediate_size_per_partition , 128
247253 )
248- hidden_size = round_up (hidden_size , 256 )
254+ if current_platform .is_xpu ():
255+ hidden_size = round_up (hidden_size , 128 )
256+ else :
257+ hidden_size = round_up (hidden_size , 256 )
249258
250259 layer .params_dtype = params_dtype
251260 layer .num_experts = num_experts
@@ -1071,3 +1080,84 @@ def apply(
10711080 )
10721081 else :
10731082 raise ValueError (f"Unsupported backend: { self .mxfp4_backend } " )
1083+
1084+
1085+ class IpexMxfp4MoEMethod (Mxfp4MoEMethod ):
1086+ def __init__ (self , moe_config : FusedMoEConfig ):
1087+ super ().__init__ (moe_config )
1088+ self .moe_config = moe_config
1089+
1090+ def create_weights (
1091+ self ,
1092+ layer : torch .nn .Module ,
1093+ num_experts : int ,
1094+ hidden_size : int ,
1095+ intermediate_size_per_partition : int ,
1096+ params_dtype : torch .dtype ,
1097+ ** extra_weight_attrs ,
1098+ ):
1099+ super ().create_weights (
1100+ layer ,
1101+ num_experts ,
1102+ hidden_size ,
1103+ intermediate_size_per_partition ,
1104+ params_dtype ,
1105+ ** extra_weight_attrs ,
1106+ )
1107+ self .original_hidden_size = hidden_size
1108+
1109+ def process_weights_after_loading (self , layer : torch .nn .Module ) -> None :
1110+ import intel_extension_for_pytorch as ipex
1111+
1112+ layer .w13_weight .data = layer .w13_weight .data .view (torch .int32 )
1113+ layer .w2_weight .data = layer .w2_weight .data .view (torch .int32 )
1114+ layer .ipex_fusion = ipex .llm .modules .GatedMLPMOE (
1115+ layer .w13_weight ,
1116+ layer .w2_weight ,
1117+ w1_scale_inv = layer .w13_weight_scale ,
1118+ w2_scale_inv = layer .w2_weight_scale ,
1119+ w13_bias = layer .w13_bias ,
1120+ w2_bias = layer .w2_bias ,
1121+ is_mxfp4 = True ,
1122+ )
1123+
1124+ def apply (
1125+ self ,
1126+ layer : torch .nn .Module ,
1127+ x : torch .Tensor ,
1128+ router_logits : torch .Tensor ,
1129+ top_k : int ,
1130+ renormalize : bool ,
1131+ use_grouped_topk : bool = False ,
1132+ topk_group : int | None = None ,
1133+ num_expert_group : int | None = None ,
1134+ global_num_experts : int = - 1 ,
1135+ expert_map : torch .Tensor | None = None ,
1136+ custom_routing_function : Callable | None = None ,
1137+ scoring_func : str = "softmax" ,
1138+ routed_scaling_factor : float = 1.0 ,
1139+ e_score_correction_bias : torch .Tensor | None = None ,
1140+ apply_router_weight_on_input : bool = False ,
1141+ activation : str = "silu" ,
1142+ enable_eplb : bool = False ,
1143+ expert_load_view : torch .Tensor | None = None ,
1144+ logical_to_physical_map : torch .Tensor | None = None ,
1145+ logical_replica_count : torch .Tensor | None = None ,
1146+ ) -> torch .Tensor :
1147+ assert activation == "swigluoai" , (
1148+ "Only swiglu_oai activation is supported for IPEX MXFP4 MoE"
1149+ ) # noqa:
1150+ hidden_size_pad = round_up (self .original_hidden_size , 128 )
1151+ x_pad = torch .nn .functional .pad (x , (0 , hidden_size_pad - x .size (- 1 )))
1152+ hidden_states = layer .ipex_fusion (
1153+ x_pad ,
1154+ use_grouped_topk ,
1155+ top_k ,
1156+ router_logits ,
1157+ renormalize ,
1158+ topk_group ,
1159+ num_expert_group ,
1160+ activation = "swiglu_oai" ,
1161+ )
1162+ hidden_states = hidden_states [..., : self .original_hidden_size ].contiguous ()
1163+ return hidden_states
0 commit comments