|
56 | 56 | _enable_tpp, |
57 | 57 | _disable_tpp, |
58 | 58 | ) |
59 | | -class _IPEXlinearMOECPU(nn.Module): |
60 | | - def __init__(self, W13, W2, W3=None, tpp=False, woq=False): |
61 | | - super().__init__() |
62 | | - self.tpp = tpp |
63 | | - self.woq = woq |
64 | | - self.num_experts = W2.shape[0] |
65 | | - self.hidden_size = W2.shape[1] |
66 | | - self.intermediate_size = W2.shape[2] |
67 | | - |
68 | | - linear_list = [] |
69 | | - for i in range(W2.shape[0]): |
70 | | - if W3 is not None: |
71 | | - _W1 = W13[i] |
72 | | - else: |
73 | | - _W1 = W13[i][0 : self.intermediate_size, :] |
74 | | - _W3 = W13[i][self.intermediate_size : 2 * self.intermediate_size, :] |
75 | | - linear1 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
76 | | - linear1.weight = nn.Parameter(_W1) |
77 | | - linear2 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
78 | | - linear2.weight = nn.Parameter(W2[i]) |
79 | | - linear3 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
80 | | - linear3.weight = nn.Parameter(_W3) |
81 | | - linear_per_expert = nn.ModuleList([linear1, linear2, linear3]) |
82 | | - linear_list.append(linear_per_expert) |
83 | | - self.linear_module_list = nn.ModuleList([linear_list[i] for i in range(W2.shape[0])]) |
84 | | - |
85 | | - def forward(self, hidden_states, score, topk): |
86 | | - batch_size, head_dim = hidden_states.shape |
87 | | - routing_weights = torch.nn.functional.softmax(score, dim=1, dtype=torch.float32) |
88 | | - routing_weights, selected_experts = torch.topk(routing_weights, topk, dim=-1) |
89 | | - routing_weights = routing_weights.to(hidden_states.dtype) |
90 | | - final_hidden_states = torch.zeros( |
91 | | - (batch_size, head_dim), |
92 | | - dtype=hidden_states.dtype, |
93 | | - device=hidden_states.device, |
94 | | - ) |
95 | | - expert_mask = torch.nn.functional.one_hot( |
96 | | - selected_experts, num_classes=self.num_experts |
97 | | - ).permute(2, 1, 0) |
98 | | - for expert_idx in range(self.num_experts): |
99 | | - idx, top_x = torch.where(expert_mask[expert_idx]) |
100 | | - if ( |
101 | | - hasattr(self.linear_module_list[expert_idx][0], "use_dnnl") |
102 | | - and self.linear_module_list[expert_idx][0].use_dnnl |
103 | | - ): |
104 | | - final_hidden_states = torch.ops.torch_ipex.mixtral_moe( |
105 | | - hidden_states, |
106 | | - top_x, |
107 | | - idx, |
108 | | - self.linear_module_list[expert_idx][0]._get_forward_weight(), |
109 | | - self.linear_module_list[expert_idx][0].ctx.get_data_handle(), |
110 | | - self.linear_module_list[expert_idx][2]._get_forward_weight(), |
111 | | - self.linear_module_list[expert_idx][2].ctx.get_data_handle(), |
112 | | - self.linear_module_list[expert_idx][1]._get_forward_weight(), |
113 | | - self.linear_module_list[expert_idx][1].ctx.get_data_handle(), |
114 | | - hasattr(self.linear_module_list[expert_idx][0], "use_dnnl") |
115 | | - and self.linear_module_list[expert_idx][0].use_dnnl, |
116 | | - routing_weights, |
117 | | - final_hidden_states, |
118 | | - False, |
119 | | - ) |
120 | | - else: |
121 | | - final_hidden_states = torch.ops.torch_ipex.mixtral_moe_tpp( |
122 | | - hidden_states, |
123 | | - top_x, |
124 | | - idx, |
125 | | - self.linear_module_list[expert_idx][0].weight.detach(), |
126 | | - self.linear_module_list[expert_idx][2].weight.detach(), |
127 | | - self.linear_module_list[expert_idx][1].weight.detach(), |
128 | | - ( |
129 | | - self.linear_module_list[expert_idx][0].tpp_fallback |
130 | | - if hasattr( |
131 | | - self.linear_module_list[expert_idx][0], "tpp_fallback" |
132 | | - ) |
133 | | - else True |
134 | | - ), |
135 | | - routing_weights, |
136 | | - final_hidden_states, |
137 | | - False, |
138 | | - ) |
139 | | - |
140 | | - return final_hidden_states.view(-1, head_dim) |
| 59 | + |
141 | 60 |
|
142 | 61 | class MixtralMoE(nn.Module): |
143 | 62 | """A tensor-parallel MoE implementation for Mixtral that shards each expert |
@@ -310,11 +229,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
310 | 229 | # router_logits: (num_tokens, n_experts) |
311 | 230 | router_logits, _ = self.gate(hidden_states) |
312 | 231 | if not hasattr(self, "ipex_moe"): |
313 | | - self.ipex_moe = _IPEXlinearMOECPU(self.w13_weight, self.w2_weight) |
| 232 | + self.ipex_moe = ipex.llm.modules.LinearMOE(W13=self.w13_weight, W2=self.w2_weight) |
314 | 233 | _disable_tpp() |
315 | 234 | if hidden_states.dtype is torch.bfloat16: |
316 | 235 | _enable_tpp() |
317 | | - self.ipex_moe = ipex.optimize(self.ipex_moe.eval(), dtype=hidden_states.dtype, inplace=True) |
| 236 | + self.ipex_moe = ipex.llm.optimize(self.ipex_moe.eval(), dtype=hidden_states.dtype, inplace=True) |
318 | 237 | final_hidden_states = self.ipex_moe(hidden_states, router_logits, self.top_k) |
319 | 238 | if self.tp_size > 1: |
320 | 239 | final_hidden_states = tensor_model_parallel_all_reduce( |
@@ -396,8 +315,9 @@ def forward( |
396 | 315 | q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) |
397 | 316 | q, k = self.rotary_emb(positions, q, k) |
398 | 317 | attn_output = self.attn(q, k, v, kv_cache, attn_metadata) |
399 | | - output, _ = self.o_proj(attn_output) |
400 | | - return output |
| 318 | + # move self.o_proj to MixtralDecoderLayer to enable linear+add fusion when tp_size <=1 |
| 319 | + # output, _ = self.o_proj(attn_output) |
| 320 | + return attn_output |
401 | 321 |
|
402 | 322 |
|
403 | 323 | class MixtralDecoderLayer(nn.Module): |
@@ -452,10 +372,19 @@ def forward( |
452 | 372 | kv_cache=kv_cache, |
453 | 373 | attn_metadata=attn_metadata, |
454 | 374 | ) |
455 | | - |
456 | | - # Fully Connected |
457 | | - hidden_states, residual = self.post_attention_layernorm( |
458 | | - hidden_states, residual) |
| 375 | + if self.self_attn.o_proj.tp_size <=1 and not hasattr(self, "ipex_fusion") and hasattr(self.self_attn.o_proj, "ipex_linear"): |
| 376 | + self.ipex_fusion = ipex.llm.modules.LinearAdd(self.self_attn.o_proj.ipex_linear) |
| 377 | + if hasattr(self, "ipex_fusion"): |
| 378 | + hidden_states = self.ipex_fusion(hidden_states, residual) |
| 379 | + if not self.self_attn.o_proj.skip_bias_add and self.self_attn.o_proj.bias is not None: |
| 380 | + hidden_states = hidden_states + self.self_attn.o_proj.bias |
| 381 | + residual = hidden_states |
| 382 | + hidden_states = self.post_attention_layernorm( |
| 383 | + hidden_states) |
| 384 | + else: |
| 385 | + hidden_states, _ = self.self_attn.o_proj(hidden_states) |
| 386 | + hidden_states, residual = self.post_attention_layernorm( |
| 387 | + hidden_states, residual) |
459 | 388 | hidden_states = self.block_sparse_moe(hidden_states) |
460 | 389 | return hidden_states, residual |
461 | 390 |
|
|
0 commit comments