Skip to content

Commit 18b3982

Browse files
authored
[XPU] Add gpt-oss model support for Intel GPU (#27786)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
1 parent 4ea62b7 commit 18b3982

File tree

4 files changed

+101
-6
lines changed

4 files changed

+101
-6
lines changed

vllm/attention/utils/fa_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,13 @@ def flash_attn_supports_fp8() -> bool:
8080
)
8181

8282

83+
def flash_attn_supports_sinks() -> bool:
84+
if current_platform.is_xpu():
85+
return True
86+
else:
87+
return get_flash_attn_version() == 3
88+
89+
8390
def flash_attn_supports_mla():
8491
from vllm.platforms import current_platform
8592

vllm/model_executor/layers/quantization/mxfp4.py

Lines changed: 92 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,9 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
142142
else:
143143
logger.info_once("Using Triton backend")
144144
return Mxfp4Backend.TRITON
145+
elif current_platform.is_xpu():
146+
logger.info_once("Using ipex marlin backend on XPU")
147+
return Mxfp4Backend.MARLIN
145148
elif current_platform.is_rocm() and has_triton_kernels():
146149
logger.info_once("Using Triton backend")
147150
return Mxfp4Backend.TRITON
@@ -188,7 +191,10 @@ def get_quant_method(
188191
return UnquantizedLinearMethod()
189192
raise NotImplementedError("Mxfp4 linear layer is not implemented")
190193
elif isinstance(layer, FusedMoE):
191-
return Mxfp4MoEMethod(layer.moe_config)
194+
if current_platform.is_xpu():
195+
return IpexMxfp4MoEMethod(layer.moe_config)
196+
else:
197+
return Mxfp4MoEMethod(layer.moe_config)
192198
elif isinstance(layer, Attention):
193199
raise NotImplementedError("Mxfp4 attention layer is not implemented")
194200
return None
@@ -245,7 +251,10 @@ def create_weights(
245251
intermediate_size_per_partition_after_pad = round_up(
246252
intermediate_size_per_partition, 128
247253
)
248-
hidden_size = round_up(hidden_size, 256)
254+
if current_platform.is_xpu():
255+
hidden_size = round_up(hidden_size, 128)
256+
else:
257+
hidden_size = round_up(hidden_size, 256)
249258

250259
layer.params_dtype = params_dtype
251260
layer.num_experts = num_experts
@@ -1071,3 +1080,84 @@ def apply(
10711080
)
10721081
else:
10731082
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
1083+
1084+
1085+
class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
1086+
def __init__(self, moe_config: FusedMoEConfig):
1087+
super().__init__(moe_config)
1088+
self.moe_config = moe_config
1089+
1090+
def create_weights(
1091+
self,
1092+
layer: torch.nn.Module,
1093+
num_experts: int,
1094+
hidden_size: int,
1095+
intermediate_size_per_partition: int,
1096+
params_dtype: torch.dtype,
1097+
**extra_weight_attrs,
1098+
):
1099+
super().create_weights(
1100+
layer,
1101+
num_experts,
1102+
hidden_size,
1103+
intermediate_size_per_partition,
1104+
params_dtype,
1105+
**extra_weight_attrs,
1106+
)
1107+
self.original_hidden_size = hidden_size
1108+
1109+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
1110+
import intel_extension_for_pytorch as ipex
1111+
1112+
layer.w13_weight.data = layer.w13_weight.data.view(torch.int32)
1113+
layer.w2_weight.data = layer.w2_weight.data.view(torch.int32)
1114+
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
1115+
layer.w13_weight,
1116+
layer.w2_weight,
1117+
w1_scale_inv=layer.w13_weight_scale,
1118+
w2_scale_inv=layer.w2_weight_scale,
1119+
w13_bias=layer.w13_bias,
1120+
w2_bias=layer.w2_bias,
1121+
is_mxfp4=True,
1122+
)
1123+
1124+
def apply(
1125+
self,
1126+
layer: torch.nn.Module,
1127+
x: torch.Tensor,
1128+
router_logits: torch.Tensor,
1129+
top_k: int,
1130+
renormalize: bool,
1131+
use_grouped_topk: bool = False,
1132+
topk_group: int | None = None,
1133+
num_expert_group: int | None = None,
1134+
global_num_experts: int = -1,
1135+
expert_map: torch.Tensor | None = None,
1136+
custom_routing_function: Callable | None = None,
1137+
scoring_func: str = "softmax",
1138+
routed_scaling_factor: float = 1.0,
1139+
e_score_correction_bias: torch.Tensor | None = None,
1140+
apply_router_weight_on_input: bool = False,
1141+
activation: str = "silu",
1142+
enable_eplb: bool = False,
1143+
expert_load_view: torch.Tensor | None = None,
1144+
logical_to_physical_map: torch.Tensor | None = None,
1145+
logical_replica_count: torch.Tensor | None = None,
1146+
) -> torch.Tensor:
1147+
assert activation == "swigluoai", (
1148+
"Only swiglu_oai activation is supported for IPEX MXFP4 MoE"
1149+
) # noqa:
1150+
hidden_size_pad = round_up(self.original_hidden_size, 128)
1151+
x_pad = torch.nn.functional.pad(x, (0, hidden_size_pad - x.size(-1)))
1152+
hidden_states = layer.ipex_fusion(
1153+
x_pad,
1154+
use_grouped_topk,
1155+
top_k,
1156+
router_logits,
1157+
renormalize,
1158+
topk_group,
1159+
num_expert_group,
1160+
activation="swiglu_oai",
1161+
)
1162+
hidden_states = hidden_states[..., : self.original_hidden_size].contiguous()
1163+
return hidden_states

vllm/model_executor/models/gpt_oss.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -337,9 +337,6 @@ def _load_weights_mxfp4(
337337
if is_pp_missing_parameter(name, self):
338338
continue
339339

340-
# FIXME(woosuk): Remove this after testing.
341-
weight = weight.cuda()
342-
343340
if ".w13_weight_scale" in name:
344341
# Handle MLP gate and up projection weights scale
345342
if use_ep:

vllm/v1/attention/backends/flash_attn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
if is_flash_attn_varlen_func_available():
2929
from vllm.attention.utils.fa_utils import (
30+
flash_attn_supports_sinks,
3031
flash_attn_varlen_func,
3132
get_scheduler_metadata,
3233
reshape_and_cache_flash,
@@ -497,7 +498,7 @@ def __init__(
497498

498499
self.sinks = sinks
499500
if self.sinks is not None:
500-
assert self.vllm_flash_attn_version == 3, (
501+
assert flash_attn_supports_sinks(), (
501502
"Sinks are only supported in FlashAttention 3"
502503
)
503504
assert self.sinks.shape[0] == num_heads, (

0 commit comments

Comments
 (0)