Skip to content

Commit

Permalink
Merge pull request vllm-project#3 from ri938/merge_linear_layers
Browse files Browse the repository at this point in the history
Merge linear layers
  • Loading branch information
ri938 authored Aug 15, 2023
2 parents c39ec2a + 033e8c1 commit a3ac858
Showing 1 changed file with 61 additions and 16 deletions.
77 changes: 61 additions & 16 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

import torch
from torch import nn
from transformers import LlamaConfig, activations
from transformers import LlamaConfig

from vllm.config import QuantizationConfig
from vllm.model_executor.input_metadata import InputMetadata
Expand Down Expand Up @@ -172,13 +172,21 @@ def __init__(
assert self.total_num_kv_heads % tp_size == 0
self.num_kv_heads = self.total_num_kv_heads // tp_size
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5

intermediate_size = self.total_num_heads * self.head_dim
self.q_proj = get_quantized_layer(hidden_size, intermediate_size, quant_config)
self.k_proj = get_quantized_layer(hidden_size, intermediate_size, quant_config)
self.v_proj = get_quantized_layer(hidden_size, intermediate_size, quant_config)
self.o_proj = get_quantized_layer(intermediate_size, hidden_size, quant_config)
self.qkv_proj = get_quantized_layer(
hidden_size,
(self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim,
quant_config
)

self.o_proj = get_quantized_layer(
self.total_num_heads * self.head_dim,
hidden_size,
quant_config
)

self.attn = PagedAttentionWithRoPE(self.num_heads,
self.head_dim,
Expand All @@ -194,9 +202,8 @@ def forward(
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
qkv = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
k_cache, v_cache = kv_cache
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
input_metadata, cache_event)
Expand All @@ -213,20 +220,19 @@ def __init__(
quant_config: QuantizationConfig
):
super().__init__()
self.gate_proj = get_quantized_layer(hidden_size, intermediate_size, quant_config)
self.up_proj = get_quantized_layer(hidden_size, intermediate_size, quant_config)
self.gate_up_proj = get_quantized_layer(hidden_size, 2 * intermediate_size, quant_config)
self.down_proj = get_quantized_layer(intermediate_size, hidden_size, quant_config)

if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = activations.SiLUActivation()
self.act_fn = SiluAndMul()

def forward(self, x):
gate_proj = self.act_fn(self.gate_proj(x))
gate_up_proj = gate_proj * self.up_proj(x)
down_proj = self.down_proj(gate_up_proj)
return down_proj
gate_up = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x = self.down_proj(x)
return x


class LlamaDecoderLayer(nn.Module):
Expand Down Expand Up @@ -383,6 +389,7 @@ def load_weights(self,
kv_proj_shard_size = (self.config.hidden_size //
self.config.num_attention_heads *
self.config.num_key_value_heads // tp_size)

attention_weight_specs = [
# (weight_name, shard_size, offset)
("q_proj", q_proj_shard_size, 0),
Expand All @@ -409,6 +416,7 @@ def load_weights(self,

is_quantized = self.quant_config is not None and self.quant_config.method is not None

# merge linear layers
if not is_quantized:
is_attention_weight = False
for weight_name, shard_size, offset in attention_weight_specs:
Expand Down Expand Up @@ -445,6 +453,43 @@ def load_weights(self,
break
if is_gate_up_weight:
continue
else:
# TODO: improve this block of code (not DRY, hacky, specific to AWQ)
is_attention_weight = False
for stride_id, (weight_name, shard_size, offset) in enumerate(attention_weight_specs):
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "qkv_proj")]

# TODO: this is specific to AWQ (should be more general)
if 'qweight' in name or 'qzeros' in name:
shard_size = int(shard_size // (32 / self.quant_config.bits))
offset = int(offset // (32 / self.quant_config.bits))

param_slice = param.data[:, offset:offset + shard_size]
assert param_slice.shape == loaded_weight.shape

param_slice.copy_(loaded_weight)
is_attention_weight = True
break
if is_attention_weight:
continue

is_gate_up_weight = False
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "gate_up_proj")]
shard_size = param.shape[1] // 2

start, end = shard_size * stride_id, shard_size * (stride_id + 1)
param_slice = param.data[:, start:end]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True
break
if is_gate_up_weight:
continue

param = state_dict[name]
load_tensor_parallel_weights(param, loaded_weight, name,
Expand Down

0 comments on commit a3ac858

Please sign in to comment.