Skip to content

Commit dddd40f

Browse files
authored
Merge pull request vllm-project#26 from intel-sandbox/jianan/enable_linear_fusion_and_prepack
Enable linear fusion/prepack and MOE AWQ fusion
2 parents 92e7866 + 70af56f commit dddd40f

File tree

5 files changed

+111
-216
lines changed

5 files changed

+111
-216
lines changed

vllm/model_executor/layers/linear.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
from vllm.model_executor.layers.quantization.base_config import (
1515
QuantizationConfig, QuantizeMethodBase)
1616
from vllm.model_executor.utils import set_weight_attrs
17-
17+
import intel_extension_for_pytorch as ipex
18+
from intel_extension_for_pytorch.cpu._auto_kernel_selection import (
19+
_enable_tpp,
20+
_disable_tpp,
21+
)
1822
logger = init_logger(__name__)
1923

2024

@@ -103,6 +107,20 @@ def apply(self,
103107
layer: torch.nn.Module,
104108
x: torch.Tensor,
105109
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
110+
if not hasattr(layer, "ipex_linear"):
111+
linear = torch.nn.Linear(layer.weight.shape[1], layer.weight.shape[0], bias=True if bias is not None else False)
112+
linear.weight = layer.weight
113+
if bias is not None:
114+
linear.bias = bias
115+
_disable_tpp()
116+
if layer.weight.dtype is torch.bfloat16:
117+
_enable_tpp()
118+
layer.ipex_linear = ipex.llm.optimize(linear.eval(), dtype=layer.weight.dtype, inplace=True)
119+
120+
if hasattr(layer, "ipex_linear"):
121+
res = layer.ipex_linear(x)
122+
return res
123+
106124
weight = layer.weight
107125
if self.separate_bias_add:
108126
if bias is not None:

vllm/model_executor/layers/quantization/awq.py

Lines changed: 4 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44
from torch.nn.parameter import Parameter
5-
import intel_extension_for_pytorch
5+
import intel_extension_for_pytorch as ipex
66
from vllm import _custom_ops as ops
77
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
88
from vllm.model_executor.layers.quantization.base_config import (
@@ -150,78 +150,6 @@ def create_weights(self, layer: torch.nn.Module,
150150
layer.register_parameter("scales", scales)
151151
set_weight_attrs(scales, extra_weight_attrs)
152152

153-
def awq_reverse_reorder_int_tensor(self,int_tensor, bits: int):
154-
assert bits == 4
155-
156-
int_tensor = int_tensor.T.contiguous()
157-
compress_ratio = (32 // bits)
158-
assert int_tensor.shape[-1] % compress_ratio == 0
159-
160-
order_map = [0, 2, 4, 6, 1, 3, 5, 7]
161-
order_tensor = torch.tensor(
162-
order_map, dtype=torch.int32, device=int_tensor.device).reshape(1, -1)
163-
order_tensor = order_tensor.repeat(
164-
int_tensor.shape[1]//compress_ratio, 1)
165-
order_tensor = order_tensor + torch.arange(0, int_tensor.shape[1],
166-
compress_ratio, dtype=torch.int32, device=int_tensor.device).reshape(-1, 1)
167-
order_tensor = order_tensor.reshape(-1)
168-
169-
reverse_order_tensor = torch.arange(order_tensor.shape[0])[order_tensor]
170-
reverse_order_tensor = reverse_order_tensor[order_tensor]
171-
int_tensor = int_tensor[:, reverse_order_tensor]
172-
return int_tensor
173-
def unpack_awq(self, awq_qweight: torch.Tensor, awq_qzeros: torch.Tensor, awq_scales: torch.Tensor, bits: int, group_size: int):
174-
"""
175-
Args:
176-
awq_qweight (`torch.LongTensor`):
177-
Expected shape: (in_features, out_features // (32 // bits))
178-
awq_qzeros (`torch.LongTensor`):
179-
Expected shape: (in_features // group_size, out_features // (32 // bits))
180-
awq_scales (`torch.LongTensor`):
181-
Expected shape: (in_features // group_size, out_features)
182-
183-
Returns:
184-
fp16_weight (`torch.LongTensor`):
185-
With shape (in_features, out_features).
186-
zeros (`torch.LongTensor`):
187-
With shape (in_features // group_size, out_features).
188-
"""
189-
assert bits == 4
190-
191-
qzeros = awq_qzeros
192-
qweight = awq_qweight
193-
qweight = qweight.T.contiguous()
194-
195-
scales = awq_scales
196-
scales = scales.reshape(-1, 1, scales.shape[-1])
197-
198-
infeatures = awq_qweight.shape[0]
199-
200-
wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32, device=qzeros.device).unsqueeze(0)
201-
zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2), wf.unsqueeze(0)).to(
202-
torch.int16 if bits == 8 else torch.int8)
203-
204-
#zeros = zeros + 1
205-
206-
torch.bitwise_and(zeros, (2 ** bits) - 1, out=zeros)
207-
208-
zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])
209-
210-
weight = torch.bitwise_right_shift(torch.unsqueeze(
211-
qweight, 1), wf.unsqueeze(-1)).to(torch.int16 if bits == 8 else torch.int8)
212-
torch.bitwise_and(weight, (2 ** bits) - 1, out=weight)
213-
weight = weight.reshape(-1, group_size, weight.shape[2])
214-
215-
weight = weight.view(-1, weight.shape[-1])
216-
zeros = zeros.view(-1, zeros.shape[-1])
217-
218-
zeros = zeros.T.contiguous()
219-
zeros = self.awq_reverse_reorder_int_tensor(zeros, bits)
220-
weight = self.awq_reverse_reorder_int_tensor(weight, bits)
221-
222-
return weight.contiguous(), zeros.contiguous()
223-
224-
225153
def apply(self,
226154
layer: torch.nn.Module,
227155
x: torch.Tensor,
@@ -232,27 +160,8 @@ def apply(self,
232160
pack_factor = self.quant_config.pack_factor
233161
out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
234162
reshaped_x = x.reshape(-1, x.shape[-1])
163+
if not hasattr(layer,"ipex_qlinear") :
164+
layer.ipex_qlinear = ipex.nn.modules.weight_only_quantization.WeightOnlyQuantizedLinear.from_int4_weight(qweight, scales, qzeros, x.shape[-1], out_shape[-1], bias=bias, group_size=self.quant_config.group_size)
165+
out = layer.ipex_qlinear(reshaped_x)
235166

236-
if not hasattr(self,"_op_context") :
237-
t, zp_x = self.unpack_awq(qweight, qzeros, scales, 4, 128)
238-
# # transpose -> [N, K]
239-
t = t.T.contiguous()
240-
qweight_ = t[:, 1::2].bitwise_left_shift(4).bitwise_or_(t[:, ::2]).to(torch.uint8)
241-
scales_ = scales.t().contiguous()
242-
self._op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack_int4(
243-
qweight_,
244-
scales_,
245-
zp_x.t_().contiguous(),
246-
bias,
247-
None,
248-
None,
249-
128,
250-
2, # 2 for bf16 compute, 3 for int8 compute
251-
1,
252-
)
253-
254-
out = torch.ops.torch_ipex.ipex_woq_linear(reshaped_x, self._op_context.get_data_handle())
255-
256-
if bias is not None:
257-
out.add_(bias)
258167
return out.reshape(out_shape)

vllm/model_executor/models/gpt_j.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import torch
2222
from torch import nn
23+
import intel_extension_for_pytorch as ipex
2324
from transformers import GPTJConfig
2425

2526
from vllm.attention import Attention, AttentionMetadata
@@ -130,9 +131,18 @@ def __init__(
130131
intermediate_size)
131132

132133
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
133-
hidden_states, _ = self.fc_in(hidden_states)
134-
hidden_states = self.act(hidden_states)
135-
hidden_states, _ = self.fc_out(hidden_states)
134+
if not hasattr(self, "ipex_fusion"):
135+
if hasattr(self.fc_in, "ipex_linear"):
136+
self.ipex_fusion = ipex.llm.modules.LinearNewGelu(self.fc_in.ipex_linear)
137+
elif hasattr(self.fc_in, "ipex_qlinear"):
138+
self.ipex_fusion = ipex.llm.modules.LinearNewGelu(self.fc_in.ipex_qlinear)
139+
if hasattr(self, "ipex_fusion"):
140+
hidden_states = self.ipex_fusion(hidden_states)
141+
else:
142+
hidden_states, _ = self.fc_in(hidden_states)
143+
hidden_states = self.act(hidden_states)
144+
# move self.fc_out to GPTJBlock to enable linear+add+add fusion when tp_size <=1c
145+
# hidden_states, _ = self.fc_out(hidden_states)
136146
return hidden_states
137147

138148

@@ -167,7 +177,20 @@ def forward(
167177
attn_metadata=attn_metadata,
168178
)
169179
mlp_output = self.mlp(hidden_states)
170-
hidden_states = attn_output + mlp_output + residual
180+
if self.mlp.fc_out.tp_size <=1 and not hasattr(self, "ipex_fusion"):
181+
if hasattr(self.mlp.fc_out, "ipex_linear"):
182+
self.ipex_fusion = ipex.llm.modules.LinearAddAdd(self.mlp.fc_out.ipex_linear)
183+
elif hasattr(self.mlp.fc_out, "ipex_qlinear"):
184+
self.ipex_fusion = ipex.llm.modules.LinearAddAdd(self.mlp.fc_out.ipex_qlinear)
185+
if hasattr(self, "ipex_fusion"):
186+
hidden_states = self.ipex_fusion(
187+
mlp_output, attn_output, residual
188+
)
189+
if not self.mlp.fc_out.skip_bias_add and self.mlp.fc_out.bias is not None:
190+
hidden_states = hidden_states + self.mlp.fc_out.bias
191+
else:
192+
mlp_output, _ = self.mlp.fc_out(mlp_output)
193+
hidden_states = attn_output + mlp_output + residual
171194
return hidden_states
172195

173196

vllm/model_executor/models/mixtral.py

Lines changed: 19 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -56,88 +56,7 @@
5656
_enable_tpp,
5757
_disable_tpp,
5858
)
59-
class _IPEXlinearMOECPU(nn.Module):
60-
def __init__(self, W13, W2, W3=None, tpp=False, woq=False):
61-
super().__init__()
62-
self.tpp = tpp
63-
self.woq = woq
64-
self.num_experts = W2.shape[0]
65-
self.hidden_size = W2.shape[1]
66-
self.intermediate_size = W2.shape[2]
67-
68-
linear_list = []
69-
for i in range(W2.shape[0]):
70-
if W3 is not None:
71-
_W1 = W13[i]
72-
else:
73-
_W1 = W13[i][0 : self.intermediate_size, :]
74-
_W3 = W13[i][self.intermediate_size : 2 * self.intermediate_size, :]
75-
linear1 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
76-
linear1.weight = nn.Parameter(_W1)
77-
linear2 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
78-
linear2.weight = nn.Parameter(W2[i])
79-
linear3 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
80-
linear3.weight = nn.Parameter(_W3)
81-
linear_per_expert = nn.ModuleList([linear1, linear2, linear3])
82-
linear_list.append(linear_per_expert)
83-
self.linear_module_list = nn.ModuleList([linear_list[i] for i in range(W2.shape[0])])
84-
85-
def forward(self, hidden_states, score, topk):
86-
batch_size, head_dim = hidden_states.shape
87-
routing_weights = torch.nn.functional.softmax(score, dim=1, dtype=torch.float32)
88-
routing_weights, selected_experts = torch.topk(routing_weights, topk, dim=-1)
89-
routing_weights = routing_weights.to(hidden_states.dtype)
90-
final_hidden_states = torch.zeros(
91-
(batch_size, head_dim),
92-
dtype=hidden_states.dtype,
93-
device=hidden_states.device,
94-
)
95-
expert_mask = torch.nn.functional.one_hot(
96-
selected_experts, num_classes=self.num_experts
97-
).permute(2, 1, 0)
98-
for expert_idx in range(self.num_experts):
99-
idx, top_x = torch.where(expert_mask[expert_idx])
100-
if (
101-
hasattr(self.linear_module_list[expert_idx][0], "use_dnnl")
102-
and self.linear_module_list[expert_idx][0].use_dnnl
103-
):
104-
final_hidden_states = torch.ops.torch_ipex.mixtral_moe(
105-
hidden_states,
106-
top_x,
107-
idx,
108-
self.linear_module_list[expert_idx][0]._get_forward_weight(),
109-
self.linear_module_list[expert_idx][0].ctx.get_data_handle(),
110-
self.linear_module_list[expert_idx][2]._get_forward_weight(),
111-
self.linear_module_list[expert_idx][2].ctx.get_data_handle(),
112-
self.linear_module_list[expert_idx][1]._get_forward_weight(),
113-
self.linear_module_list[expert_idx][1].ctx.get_data_handle(),
114-
hasattr(self.linear_module_list[expert_idx][0], "use_dnnl")
115-
and self.linear_module_list[expert_idx][0].use_dnnl,
116-
routing_weights,
117-
final_hidden_states,
118-
False,
119-
)
120-
else:
121-
final_hidden_states = torch.ops.torch_ipex.mixtral_moe_tpp(
122-
hidden_states,
123-
top_x,
124-
idx,
125-
self.linear_module_list[expert_idx][0].weight.detach(),
126-
self.linear_module_list[expert_idx][2].weight.detach(),
127-
self.linear_module_list[expert_idx][1].weight.detach(),
128-
(
129-
self.linear_module_list[expert_idx][0].tpp_fallback
130-
if hasattr(
131-
self.linear_module_list[expert_idx][0], "tpp_fallback"
132-
)
133-
else True
134-
),
135-
routing_weights,
136-
final_hidden_states,
137-
False,
138-
)
139-
140-
return final_hidden_states.view(-1, head_dim)
59+
14160

14261
class MixtralMoE(nn.Module):
14362
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
@@ -310,11 +229,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
310229
# router_logits: (num_tokens, n_experts)
311230
router_logits, _ = self.gate(hidden_states)
312231
if not hasattr(self, "ipex_moe"):
313-
self.ipex_moe = _IPEXlinearMOECPU(self.w13_weight, self.w2_weight)
232+
self.ipex_moe = ipex.llm.modules.LinearMOE(W13=self.w13_weight, W2=self.w2_weight)
314233
_disable_tpp()
315234
if hidden_states.dtype is torch.bfloat16:
316235
_enable_tpp()
317-
self.ipex_moe = ipex.optimize(self.ipex_moe.eval(), dtype=hidden_states.dtype, inplace=True)
236+
self.ipex_moe = ipex.llm.optimize(self.ipex_moe.eval(), dtype=hidden_states.dtype, inplace=True)
318237
final_hidden_states = self.ipex_moe(hidden_states, router_logits, self.top_k)
319238
if self.tp_size > 1:
320239
final_hidden_states = tensor_model_parallel_all_reduce(
@@ -396,8 +315,9 @@ def forward(
396315
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
397316
q, k = self.rotary_emb(positions, q, k)
398317
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
399-
output, _ = self.o_proj(attn_output)
400-
return output
318+
# move self.o_proj to MixtralDecoderLayer to enable linear+add fusion when tp_size <=1
319+
# output, _ = self.o_proj(attn_output)
320+
return attn_output
401321

402322

403323
class MixtralDecoderLayer(nn.Module):
@@ -452,10 +372,19 @@ def forward(
452372
kv_cache=kv_cache,
453373
attn_metadata=attn_metadata,
454374
)
455-
456-
# Fully Connected
457-
hidden_states, residual = self.post_attention_layernorm(
458-
hidden_states, residual)
375+
if self.self_attn.o_proj.tp_size <=1 and not hasattr(self, "ipex_fusion") and hasattr(self.self_attn.o_proj, "ipex_linear"):
376+
self.ipex_fusion = ipex.llm.modules.LinearAdd(self.self_attn.o_proj.ipex_linear)
377+
if hasattr(self, "ipex_fusion"):
378+
hidden_states = self.ipex_fusion(hidden_states, residual)
379+
if not self.self_attn.o_proj.skip_bias_add and self.self_attn.o_proj.bias is not None:
380+
hidden_states = hidden_states + self.self_attn.o_proj.bias
381+
residual = hidden_states
382+
hidden_states = self.post_attention_layernorm(
383+
hidden_states)
384+
else:
385+
hidden_states, _ = self.self_attn.o_proj(hidden_states)
386+
hidden_states, residual = self.post_attention_layernorm(
387+
hidden_states, residual)
459388
hidden_states = self.block_sparse_moe(hidden_states)
460389
return hidden_states, residual
461390

0 commit comments

Comments
 (0)