Skip to content

Commit a3ac858

Browse files
authored
Merge pull request vllm-project#3 from ri938/merge_linear_layers
Merge linear layers
2 parents c39ec2a + 033e8c1 commit a3ac858

File tree

1 file changed

+61
-16
lines changed

1 file changed

+61
-16
lines changed

vllm/model_executor/models/llama.py

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
import torch
3131
from torch import nn
32-
from transformers import LlamaConfig, activations
32+
from transformers import LlamaConfig
3333

3434
from vllm.config import QuantizationConfig
3535
from vllm.model_executor.input_metadata import InputMetadata
@@ -172,13 +172,21 @@ def __init__(
172172
assert self.total_num_kv_heads % tp_size == 0
173173
self.num_kv_heads = self.total_num_kv_heads // tp_size
174174
self.head_dim = hidden_size // self.total_num_heads
175+
self.q_size = self.num_heads * self.head_dim
176+
self.kv_size = self.num_kv_heads * self.head_dim
175177
self.scaling = self.head_dim**-0.5
176178

177-
intermediate_size = self.total_num_heads * self.head_dim
178-
self.q_proj = get_quantized_layer(hidden_size, intermediate_size, quant_config)
179-
self.k_proj = get_quantized_layer(hidden_size, intermediate_size, quant_config)
180-
self.v_proj = get_quantized_layer(hidden_size, intermediate_size, quant_config)
181-
self.o_proj = get_quantized_layer(intermediate_size, hidden_size, quant_config)
179+
self.qkv_proj = get_quantized_layer(
180+
hidden_size,
181+
(self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim,
182+
quant_config
183+
)
184+
185+
self.o_proj = get_quantized_layer(
186+
self.total_num_heads * self.head_dim,
187+
hidden_size,
188+
quant_config
189+
)
182190

183191
self.attn = PagedAttentionWithRoPE(self.num_heads,
184192
self.head_dim,
@@ -194,9 +202,8 @@ def forward(
194202
input_metadata: InputMetadata,
195203
cache_event: Optional[torch.cuda.Event],
196204
) -> torch.Tensor:
197-
q = self.q_proj(hidden_states)
198-
k = self.k_proj(hidden_states)
199-
v = self.v_proj(hidden_states)
205+
qkv = self.qkv_proj(hidden_states)
206+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
200207
k_cache, v_cache = kv_cache
201208
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
202209
input_metadata, cache_event)
@@ -213,20 +220,19 @@ def __init__(
213220
quant_config: QuantizationConfig
214221
):
215222
super().__init__()
216-
self.gate_proj = get_quantized_layer(hidden_size, intermediate_size, quant_config)
217-
self.up_proj = get_quantized_layer(hidden_size, intermediate_size, quant_config)
223+
self.gate_up_proj = get_quantized_layer(hidden_size, 2 * intermediate_size, quant_config)
218224
self.down_proj = get_quantized_layer(intermediate_size, hidden_size, quant_config)
219225

220226
if hidden_act != "silu":
221227
raise ValueError(f"Unsupported activation: {hidden_act}. "
222228
"Only silu is supported for now.")
223-
self.act_fn = activations.SiLUActivation()
229+
self.act_fn = SiluAndMul()
224230

225231
def forward(self, x):
226-
gate_proj = self.act_fn(self.gate_proj(x))
227-
gate_up_proj = gate_proj * self.up_proj(x)
228-
down_proj = self.down_proj(gate_up_proj)
229-
return down_proj
232+
gate_up = self.gate_up_proj(x)
233+
x = self.act_fn(gate_up)
234+
x = self.down_proj(x)
235+
return x
230236

231237

232238
class LlamaDecoderLayer(nn.Module):
@@ -383,6 +389,7 @@ def load_weights(self,
383389
kv_proj_shard_size = (self.config.hidden_size //
384390
self.config.num_attention_heads *
385391
self.config.num_key_value_heads // tp_size)
392+
386393
attention_weight_specs = [
387394
# (weight_name, shard_size, offset)
388395
("q_proj", q_proj_shard_size, 0),
@@ -409,6 +416,7 @@ def load_weights(self,
409416

410417
is_quantized = self.quant_config is not None and self.quant_config.method is not None
411418

419+
# merge linear layers
412420
if not is_quantized:
413421
is_attention_weight = False
414422
for weight_name, shard_size, offset in attention_weight_specs:
@@ -445,6 +453,43 @@ def load_weights(self,
445453
break
446454
if is_gate_up_weight:
447455
continue
456+
else:
457+
# TODO: improve this block of code (not DRY, hacky, specific to AWQ)
458+
is_attention_weight = False
459+
for stride_id, (weight_name, shard_size, offset) in enumerate(attention_weight_specs):
460+
if weight_name not in name:
461+
continue
462+
param = state_dict[name.replace(weight_name, "qkv_proj")]
463+
464+
# TODO: this is specific to AWQ (should be more general)
465+
if 'qweight' in name or 'qzeros' in name:
466+
shard_size = int(shard_size // (32 / self.quant_config.bits))
467+
offset = int(offset // (32 / self.quant_config.bits))
468+
469+
param_slice = param.data[:, offset:offset + shard_size]
470+
assert param_slice.shape == loaded_weight.shape
471+
472+
param_slice.copy_(loaded_weight)
473+
is_attention_weight = True
474+
break
475+
if is_attention_weight:
476+
continue
477+
478+
is_gate_up_weight = False
479+
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
480+
if weight_name not in name:
481+
continue
482+
param = state_dict[name.replace(weight_name, "gate_up_proj")]
483+
shard_size = param.shape[1] // 2
484+
485+
start, end = shard_size * stride_id, shard_size * (stride_id + 1)
486+
param_slice = param.data[:, start:end]
487+
assert param_slice.shape == loaded_weight.shape
488+
param_slice.copy_(loaded_weight)
489+
is_gate_up_weight = True
490+
break
491+
if is_gate_up_weight:
492+
continue
448493

449494
param = state_dict[name]
450495
load_tensor_parallel_weights(param, loaded_weight, name,

0 commit comments

Comments
 (0)