| 
51 | 51 | from vllm.model_executor.utils import set_weight_attrs  | 
52 | 52 | from vllm.sequence import SamplerOutput  | 
53 | 53 | from vllm.utils import print_warning_once  | 
54 |  | - | 
 | 54 | +import intel_extension_for_pytorch as ipex  | 
 | 55 | +from intel_extension_for_pytorch.cpu._auto_kernel_selection import (  | 
 | 56 | +    _enable_tpp,  | 
 | 57 | +    _disable_tpp,  | 
 | 58 | +)  | 
 | 59 | +class _IPEXlinearMOECPU(nn.Module):  | 
 | 60 | +    def __init__(self, W13, W2, W3=None, tpp=False, woq=False):  | 
 | 61 | +        super().__init__()  | 
 | 62 | +        self.tpp = tpp  | 
 | 63 | +        self.woq = woq  | 
 | 64 | +        self.num_experts = W2.shape[0]  | 
 | 65 | +        self.hidden_size = W2.shape[1]  | 
 | 66 | +        self.intermediate_size = W2.shape[2]  | 
 | 67 | + | 
 | 68 | +        linear_list = []  | 
 | 69 | +        for i in range(W2.shape[0]):  | 
 | 70 | +            if W3 is not None:  | 
 | 71 | +                _W1 = W13[i]  | 
 | 72 | +            else:  | 
 | 73 | +                _W1 = W13[i][0 : self.intermediate_size, :]  | 
 | 74 | +                _W3 = W13[i][self.intermediate_size : 2 * self.intermediate_size, :]  | 
 | 75 | +            linear1 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)  | 
 | 76 | +            linear1.weight = nn.Parameter(_W1)  | 
 | 77 | +            linear2 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)  | 
 | 78 | +            linear2.weight = nn.Parameter(W2[i])  | 
 | 79 | +            linear3 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)  | 
 | 80 | +            linear3.weight = nn.Parameter(_W3)  | 
 | 81 | +            linear_per_expert = nn.ModuleList([linear1, linear2, linear3])  | 
 | 82 | +            linear_list.append(linear_per_expert)  | 
 | 83 | +        self.linear_module_list = nn.ModuleList([linear_list[i] for i in range(W2.shape[0])])  | 
 | 84 | + | 
 | 85 | +    def forward(self, hidden_states, score, topk):  | 
 | 86 | +        batch_size, head_dim = hidden_states.shape  | 
 | 87 | +        routing_weights = torch.nn.functional.softmax(score, dim=1, dtype=torch.float32)  | 
 | 88 | +        routing_weights, selected_experts = torch.topk(routing_weights, topk, dim=-1)  | 
 | 89 | +        routing_weights = routing_weights.to(hidden_states.dtype)  | 
 | 90 | +        final_hidden_states = torch.zeros(  | 
 | 91 | +            (batch_size, head_dim),  | 
 | 92 | +            dtype=hidden_states.dtype,  | 
 | 93 | +            device=hidden_states.device,  | 
 | 94 | +        )  | 
 | 95 | +        expert_mask = torch.nn.functional.one_hot(  | 
 | 96 | +            selected_experts, num_classes=self.num_experts  | 
 | 97 | +        ).permute(2, 1, 0)  | 
 | 98 | +        for expert_idx in range(self.num_experts):  | 
 | 99 | +            idx, top_x = torch.where(expert_mask[expert_idx])  | 
 | 100 | +            if (  | 
 | 101 | +                hasattr(self.linear_module_list[expert_idx][0], "use_dnnl")  | 
 | 102 | +                and self.linear_module_list[expert_idx][0].use_dnnl  | 
 | 103 | +            ):  | 
 | 104 | +                final_hidden_states = torch.ops.torch_ipex.mixtral_moe(  | 
 | 105 | +                    hidden_states,  | 
 | 106 | +                    top_x,  | 
 | 107 | +                    idx,  | 
 | 108 | +                    self.linear_module_list[expert_idx][0]._get_forward_weight(),  | 
 | 109 | +                    self.linear_module_list[expert_idx][0].ctx.get_data_handle(),  | 
 | 110 | +                    self.linear_module_list[expert_idx][2]._get_forward_weight(),  | 
 | 111 | +                    self.linear_module_list[expert_idx][2].ctx.get_data_handle(),  | 
 | 112 | +                    self.linear_module_list[expert_idx][1]._get_forward_weight(),  | 
 | 113 | +                    self.linear_module_list[expert_idx][1].ctx.get_data_handle(),  | 
 | 114 | +                    hasattr(self.linear_module_list[expert_idx][0], "use_dnnl")  | 
 | 115 | +                    and self.linear_module_list[expert_idx][0].use_dnnl,  | 
 | 116 | +                    routing_weights,  | 
 | 117 | +                    final_hidden_states,  | 
 | 118 | +                    False,  | 
 | 119 | +                )  | 
 | 120 | +            else:  | 
 | 121 | +                final_hidden_states = torch.ops.torch_ipex.mixtral_moe_tpp(  | 
 | 122 | +                    hidden_states,  | 
 | 123 | +                    top_x,  | 
 | 124 | +                    idx,  | 
 | 125 | +                    self.linear_module_list[expert_idx][0].weight.detach(),  | 
 | 126 | +                    self.linear_module_list[expert_idx][2].weight.detach(),  | 
 | 127 | +                    self.linear_module_list[expert_idx][1].weight.detach(),  | 
 | 128 | +                    (  | 
 | 129 | +                        self.linear_module_list[expert_idx][0].tpp_fallback  | 
 | 130 | +                        if hasattr(  | 
 | 131 | +                            self.linear_module_list[expert_idx][0], "tpp_fallback"  | 
 | 132 | +                        )  | 
 | 133 | +                        else True  | 
 | 134 | +                    ),  | 
 | 135 | +                    routing_weights,  | 
 | 136 | +                    final_hidden_states,  | 
 | 137 | +                    False,  | 
 | 138 | +                )  | 
 | 139 | + | 
 | 140 | +        return final_hidden_states.view(-1, head_dim)  | 
55 | 141 | 
 
  | 
56 | 142 | class MixtralMoE(nn.Module):  | 
57 | 143 |     """A tensor-parallel MoE implementation for Mixtral that shards each expert  | 
@@ -108,14 +194,12 @@ def __init__(  | 
108 | 194 |                         self.hidden_size,  | 
109 | 195 |                         self.intermediate_size,  | 
110 | 196 |                         dtype=params_dtype))  | 
111 |  | - | 
112 | 197 |         set_weight_attrs(self.w13_weight, {  | 
113 | 198 |             "weight_loader": self.weight_loader,  | 
114 | 199 |         })  | 
115 | 200 |         set_weight_attrs(self.w2_weight, {  | 
116 | 201 |             "weight_loader": self.weight_loader,  | 
117 | 202 |         })  | 
118 |  | - | 
119 | 203 |         # Used for fp8.  | 
120 | 204 |         self.w13_scale = None  | 
121 | 205 |         self.w2_scale = None  | 
@@ -221,22 +305,17 @@ def process_weights_after_loading(self):  | 
221 | 305 | 
 
  | 
222 | 306 |     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:  | 
223 | 307 |         num_tokens, hidden_size = hidden_states.shape  | 
 | 308 | + | 
224 | 309 |         hidden_states = hidden_states.view(-1, self.hidden_size)  | 
225 | 310 |         # router_logits: (num_tokens, n_experts)  | 
226 | 311 |         router_logits, _ = self.gate(hidden_states)  | 
227 |  | -        final_hidden_states = fused_moe(hidden_states,  | 
228 |  | -                                        self.w13_weight,  | 
229 |  | -                                        self.w2_weight,  | 
230 |  | -                                        router_logits,  | 
231 |  | -                                        self.top_k,  | 
232 |  | -                                        renormalize=True,  | 
233 |  | -                                        inplace=True,  | 
234 |  | -                                        use_fp8=self.use_fp8,  | 
235 |  | -                                        w1_scale=self.w13_scale,  | 
236 |  | -                                        w2_scale=self.w2_scale,  | 
237 |  | -                                        a1_scale=self.a13_scale,  | 
238 |  | -                                        a2_scale=self.a2_scale)  | 
239 |  | - | 
 | 312 | +        if not hasattr(self, "ipex_moe"):  | 
 | 313 | +            self.ipex_moe = _IPEXlinearMOECPU(self.w13_weight, self.w2_weight)  | 
 | 314 | +            _disable_tpp()  | 
 | 315 | +            if hidden_states.dtype is torch.bfloat16:  | 
 | 316 | +                _enable_tpp()  | 
 | 317 | +            self.ipex_moe = ipex.optimize(self.ipex_moe.eval(), dtype=hidden_states.dtype, inplace=True)  | 
 | 318 | +        final_hidden_states = self.ipex_moe(hidden_states, router_logits, self.top_k)  | 
240 | 319 |         if self.tp_size > 1:  | 
241 | 320 |             final_hidden_states = tensor_model_parallel_all_reduce(  | 
242 | 321 |                 final_hidden_states)  | 
 | 
0 commit comments