5252from vllm .sequence import SamplerOutput
5353from vllm .utils import print_warning_once
5454import intel_extension_for_pytorch as ipex
55+ from intel_extension_for_pytorch .cpu ._auto_kernel_selection import (
56+ _enable_tpp ,
57+ _disable_tpp ,
58+ )
5559class _IPEXlinearMOECPU (nn .Module ):
5660 def __init__ (self , W13 , W2 , W3 = None , tpp = False , woq = False ):
5761 super ().__init__ ()
@@ -64,16 +68,16 @@ def __init__(self, W13, W2, W3=None, tpp=False, woq=False):
6468 linear_list = []
6569 for i in range (W2 .shape [0 ]):
6670 if W3 is not None :
67- W1 = W13 [i ]
71+ _W1 = W13 [i ]
6872 else :
69- W1 = W13 [i ][0 : self .intermediate_size , :]
70- W3 = W13 [i ][self .intermediate_size : 2 * self .intermediate_size , :]
71- linear1 = nn .Linear (self .intermediate_size , self .hidden_size )
72- linear1 .weight = nn .Parameter (W1 )
73- linear2 = nn .Linear (self .intermediate_size , self .hidden_size )
73+ _W1 = W13 [i ][0 : self .intermediate_size , :]
74+ _W3 = W13 [i ][self .intermediate_size : 2 * self .intermediate_size , :]
75+ linear1 = nn .Linear (self .intermediate_size , self .hidden_size , bias = False )
76+ linear1 .weight = nn .Parameter (_W1 )
77+ linear2 = nn .Linear (self .hidden_size , self .intermediate_size , bias = False )
7478 linear2 .weight = nn .Parameter (W2 [i ])
75- linear3 = nn .Linear (self .hidden_size , self .intermediate_size )
76- linear3 .weight = nn .Parameter (W3 )
79+ linear3 = nn .Linear (self .intermediate_size , self .hidden_size , bias = False )
80+ linear3 .weight = nn .Parameter (_W3 )
7781 linear_per_expert = nn .ModuleList ([linear1 , linear2 , linear3 ])
7882 linear_list .append (linear_per_expert )
7983 self .linear_module_list = nn .ModuleList ([linear_list [i ] for i in range (W2 .shape [0 ])])
@@ -118,9 +122,9 @@ def forward(self, hidden_states, score, topk):
118122 hidden_states ,
119123 top_x ,
120124 idx ,
121- self .linear_module_list [expert_idx ][0 ].weight ,
122- self .linear_module_list [expert_idx ][2 ].weight ,
123- self .linear_module_list [expert_idx ][1 ].weight ,
125+ self .linear_module_list [expert_idx ][0 ].weight . detach () ,
126+ self .linear_module_list [expert_idx ][2 ].weight . detach () ,
127+ self .linear_module_list [expert_idx ][1 ].weight . detach () ,
124128 (
125129 self .linear_module_list [expert_idx ][0 ].tpp_fallback
126130 if hasattr (
@@ -307,8 +311,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
307311 router_logits , _ = self .gate (hidden_states )
308312 if not hasattr (self , "ipex_moe" ):
309313 self .ipex_moe = _IPEXlinearMOECPU (self .w13_weight , self .w2_weight )
314+ _disable_tpp ()
315+ if hidden_states .dtype is torch .bfloat16 :
316+ _enable_tpp ()
310317 self .ipex_moe = ipex .optimize (self .ipex_moe .eval (), dtype = hidden_states .dtype , inplace = True )
311- breakpoint ()
312318 final_hidden_states = self .ipex_moe (hidden_states , router_logits , self .top_k )
313319 if self .tp_size > 1 :
314320 final_hidden_states = tensor_model_parallel_all_reduce (
0 commit comments