Skip to content

Commit 3aacc6c

Browse files
committed
add tpp as bf16 moe default
1 parent d3a8427 commit 3aacc6c

File tree

1 file changed

+18
-12
lines changed

1 file changed

+18
-12
lines changed

vllm/model_executor/models/mixtral.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@
5252
from vllm.sequence import SamplerOutput
5353
from vllm.utils import print_warning_once
5454
import intel_extension_for_pytorch as ipex
55+
from intel_extension_for_pytorch.cpu._auto_kernel_selection import (
56+
_enable_tpp,
57+
_disable_tpp,
58+
)
5559
class _IPEXlinearMOECPU(nn.Module):
5660
def __init__(self, W13, W2, W3=None, tpp=False, woq=False):
5761
super().__init__()
@@ -64,16 +68,16 @@ def __init__(self, W13, W2, W3=None, tpp=False, woq=False):
6468
linear_list = []
6569
for i in range(W2.shape[0]):
6670
if W3 is not None:
67-
W1 = W13[i]
71+
_W1 = W13[i]
6872
else:
69-
W1 = W13[i][0 : self.intermediate_size, :]
70-
W3 = W13[i][self.intermediate_size : 2 * self.intermediate_size, :]
71-
linear1 = nn.Linear(self.intermediate_size, self.hidden_size)
72-
linear1.weight = nn.Parameter(W1)
73-
linear2 = nn.Linear(self.intermediate_size, self.hidden_size)
73+
_W1 = W13[i][0 : self.intermediate_size, :]
74+
_W3 = W13[i][self.intermediate_size : 2 * self.intermediate_size, :]
75+
linear1 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
76+
linear1.weight = nn.Parameter(_W1)
77+
linear2 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
7478
linear2.weight = nn.Parameter(W2[i])
75-
linear3 = nn.Linear(self.hidden_size, self.intermediate_size)
76-
linear3.weight = nn.Parameter(W3)
79+
linear3 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
80+
linear3.weight = nn.Parameter(_W3)
7781
linear_per_expert = nn.ModuleList([linear1, linear2, linear3])
7882
linear_list.append(linear_per_expert)
7983
self.linear_module_list = nn.ModuleList([linear_list[i] for i in range(W2.shape[0])])
@@ -118,9 +122,9 @@ def forward(self, hidden_states, score, topk):
118122
hidden_states,
119123
top_x,
120124
idx,
121-
self.linear_module_list[expert_idx][0].weight,
122-
self.linear_module_list[expert_idx][2].weight,
123-
self.linear_module_list[expert_idx][1].weight,
125+
self.linear_module_list[expert_idx][0].weight.detach(),
126+
self.linear_module_list[expert_idx][2].weight.detach(),
127+
self.linear_module_list[expert_idx][1].weight.detach(),
124128
(
125129
self.linear_module_list[expert_idx][0].tpp_fallback
126130
if hasattr(
@@ -307,8 +311,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
307311
router_logits, _ = self.gate(hidden_states)
308312
if not hasattr(self, "ipex_moe"):
309313
self.ipex_moe = _IPEXlinearMOECPU(self.w13_weight, self.w2_weight)
314+
_disable_tpp()
315+
if hidden_states.dtype is torch.bfloat16:
316+
_enable_tpp()
310317
self.ipex_moe = ipex.optimize(self.ipex_moe.eval(), dtype=hidden_states.dtype, inplace=True)
311-
breakpoint()
312318
final_hidden_states = self.ipex_moe(hidden_states, router_logits, self.top_k)
313319
if self.tp_size > 1:
314320
final_hidden_states = tensor_model_parallel_all_reduce(

0 commit comments

Comments
 (0)