Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.w13_weight.data = shuffled_w13
layer.w2_weight.data = shuffled_w2

if current_platform.is_cpu():
if current_platform.is_xpu():
import intel_extension_for_pytorch as ipex
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
layer.w13_weight,
layer.w2_weight,
use_prepack=True,
)
elif current_platform.is_cpu():
if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
from vllm.model_executor.layers.fused_moe import cpu_fused_moe
dtype = layer.w13_weight.dtype
Expand Down Expand Up @@ -509,6 +516,44 @@ def forward_cpu(
activation,
)

def forward_xpu(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
):
if enable_eplb is not False or expert_load_view is not None or \
logical_to_physical_map is not None or \
logical_replica_count is not None:
raise NotImplementedError("Expert load balancing is not supported "
"for XPU.")
assert custom_routing_function is None
return layer.ipex_fusion(
x,
use_grouped_topk,
top_k,
router_logits,
renormalize,
topk_group,
num_expert_group,
)

def forward_tpu(
self,
layer: torch.nn.Module,
Expand Down