Skip to content

Commit d51ea18

Browse files
authored
Merge pull request vllm-project#25 from jianan-gu/jianan/enable_moe
Add IPEX MOE CPU support
2 parents b16b78d + 3aacc6c commit d51ea18

File tree

2 files changed

+96
-17
lines changed

2 files changed

+96
-17
lines changed

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import triton
99
import triton.language as tl
1010

11-
import vllm._moe_C as moe_kernels
11+
# import vllm._moe_C as moe_kernels
1212
from vllm import _custom_ops as ops
1313
from vllm.logger import init_logger
1414

vllm/model_executor/models/mixtral.py

Lines changed: 95 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,93 @@
5151
from vllm.model_executor.utils import set_weight_attrs
5252
from vllm.sequence import SamplerOutput
5353
from vllm.utils import print_warning_once
54-
54+
import intel_extension_for_pytorch as ipex
55+
from intel_extension_for_pytorch.cpu._auto_kernel_selection import (
56+
_enable_tpp,
57+
_disable_tpp,
58+
)
59+
class _IPEXlinearMOECPU(nn.Module):
60+
def __init__(self, W13, W2, W3=None, tpp=False, woq=False):
61+
super().__init__()
62+
self.tpp = tpp
63+
self.woq = woq
64+
self.num_experts = W2.shape[0]
65+
self.hidden_size = W2.shape[1]
66+
self.intermediate_size = W2.shape[2]
67+
68+
linear_list = []
69+
for i in range(W2.shape[0]):
70+
if W3 is not None:
71+
_W1 = W13[i]
72+
else:
73+
_W1 = W13[i][0 : self.intermediate_size, :]
74+
_W3 = W13[i][self.intermediate_size : 2 * self.intermediate_size, :]
75+
linear1 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
76+
linear1.weight = nn.Parameter(_W1)
77+
linear2 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
78+
linear2.weight = nn.Parameter(W2[i])
79+
linear3 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
80+
linear3.weight = nn.Parameter(_W3)
81+
linear_per_expert = nn.ModuleList([linear1, linear2, linear3])
82+
linear_list.append(linear_per_expert)
83+
self.linear_module_list = nn.ModuleList([linear_list[i] for i in range(W2.shape[0])])
84+
85+
def forward(self, hidden_states, score, topk):
86+
batch_size, head_dim = hidden_states.shape
87+
routing_weights = torch.nn.functional.softmax(score, dim=1, dtype=torch.float32)
88+
routing_weights, selected_experts = torch.topk(routing_weights, topk, dim=-1)
89+
routing_weights = routing_weights.to(hidden_states.dtype)
90+
final_hidden_states = torch.zeros(
91+
(batch_size, head_dim),
92+
dtype=hidden_states.dtype,
93+
device=hidden_states.device,
94+
)
95+
expert_mask = torch.nn.functional.one_hot(
96+
selected_experts, num_classes=self.num_experts
97+
).permute(2, 1, 0)
98+
for expert_idx in range(self.num_experts):
99+
idx, top_x = torch.where(expert_mask[expert_idx])
100+
if (
101+
hasattr(self.linear_module_list[expert_idx][0], "use_dnnl")
102+
and self.linear_module_list[expert_idx][0].use_dnnl
103+
):
104+
final_hidden_states = torch.ops.torch_ipex.mixtral_moe(
105+
hidden_states,
106+
top_x,
107+
idx,
108+
self.linear_module_list[expert_idx][0]._get_forward_weight(),
109+
self.linear_module_list[expert_idx][0].ctx.get_data_handle(),
110+
self.linear_module_list[expert_idx][2]._get_forward_weight(),
111+
self.linear_module_list[expert_idx][2].ctx.get_data_handle(),
112+
self.linear_module_list[expert_idx][1]._get_forward_weight(),
113+
self.linear_module_list[expert_idx][1].ctx.get_data_handle(),
114+
hasattr(self.linear_module_list[expert_idx][0], "use_dnnl")
115+
and self.linear_module_list[expert_idx][0].use_dnnl,
116+
routing_weights,
117+
final_hidden_states,
118+
False,
119+
)
120+
else:
121+
final_hidden_states = torch.ops.torch_ipex.mixtral_moe_tpp(
122+
hidden_states,
123+
top_x,
124+
idx,
125+
self.linear_module_list[expert_idx][0].weight.detach(),
126+
self.linear_module_list[expert_idx][2].weight.detach(),
127+
self.linear_module_list[expert_idx][1].weight.detach(),
128+
(
129+
self.linear_module_list[expert_idx][0].tpp_fallback
130+
if hasattr(
131+
self.linear_module_list[expert_idx][0], "tpp_fallback"
132+
)
133+
else True
134+
),
135+
routing_weights,
136+
final_hidden_states,
137+
False,
138+
)
139+
140+
return final_hidden_states.view(-1, head_dim)
55141

56142
class MixtralMoE(nn.Module):
57143
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
@@ -108,14 +194,12 @@ def __init__(
108194
self.hidden_size,
109195
self.intermediate_size,
110196
dtype=params_dtype))
111-
112197
set_weight_attrs(self.w13_weight, {
113198
"weight_loader": self.weight_loader,
114199
})
115200
set_weight_attrs(self.w2_weight, {
116201
"weight_loader": self.weight_loader,
117202
})
118-
119203
# Used for fp8.
120204
self.w13_scale = None
121205
self.w2_scale = None
@@ -221,22 +305,17 @@ def process_weights_after_loading(self):
221305

222306
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
223307
num_tokens, hidden_size = hidden_states.shape
308+
224309
hidden_states = hidden_states.view(-1, self.hidden_size)
225310
# router_logits: (num_tokens, n_experts)
226311
router_logits, _ = self.gate(hidden_states)
227-
final_hidden_states = fused_moe(hidden_states,
228-
self.w13_weight,
229-
self.w2_weight,
230-
router_logits,
231-
self.top_k,
232-
renormalize=True,
233-
inplace=True,
234-
use_fp8=self.use_fp8,
235-
w1_scale=self.w13_scale,
236-
w2_scale=self.w2_scale,
237-
a1_scale=self.a13_scale,
238-
a2_scale=self.a2_scale)
239-
312+
if not hasattr(self, "ipex_moe"):
313+
self.ipex_moe = _IPEXlinearMOECPU(self.w13_weight, self.w2_weight)
314+
_disable_tpp()
315+
if hidden_states.dtype is torch.bfloat16:
316+
_enable_tpp()
317+
self.ipex_moe = ipex.optimize(self.ipex_moe.eval(), dtype=hidden_states.dtype, inplace=True)
318+
final_hidden_states = self.ipex_moe(hidden_states, router_logits, self.top_k)
240319
if self.tp_size > 1:
241320
final_hidden_states = tensor_model_parallel_all_reduce(
242321
final_hidden_states)

0 commit comments

Comments
 (0)