From e5434efe11ad9ac5b12e3e45f9aa998d1a6321a7 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 15 Aug 2023 18:09:27 +0000 Subject: [PATCH 1/3] working prototype --- vllm/model_executor/models/llama.py | 73 +++++++++++++++++++++++------ 1 file changed, 59 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index fd6979e657b64..dae07f5339803 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -29,7 +29,7 @@ import torch from torch import nn -from transformers import LlamaConfig, activations +from transformers import LlamaConfig from vllm.config import QuantizationConfig from vllm.model_executor.input_metadata import InputMetadata @@ -172,12 +172,12 @@ def __init__( assert self.total_num_kv_heads % tp_size == 0 self.num_kv_heads = self.total_num_kv_heads // tp_size self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 intermediate_size = self.total_num_heads * self.head_dim - self.q_proj = get_quantized_layer(hidden_size, intermediate_size, quant_config) - self.k_proj = get_quantized_layer(hidden_size, intermediate_size, quant_config) - self.v_proj = get_quantized_layer(hidden_size, intermediate_size, quant_config) + self.qkv_proj = get_quantized_layer(hidden_size, self.q_size + 2 * self.kv_size, quant_config) self.o_proj = get_quantized_layer(intermediate_size, hidden_size, quant_config) self.attn = PagedAttentionWithRoPE(self.num_heads, @@ -194,9 +194,8 @@ def forward( input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: - q = self.q_proj(hidden_states) - k = self.k_proj(hidden_states) - v = self.v_proj(hidden_states) + qkv = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) k_cache, v_cache = kv_cache attn_output = self.attn(positions, q, k, v, k_cache, v_cache, input_metadata, cache_event) @@ -213,20 +212,19 @@ def __init__( quant_config: QuantizationConfig ): super().__init__() - self.gate_proj = get_quantized_layer(hidden_size, intermediate_size, quant_config) - self.up_proj = get_quantized_layer(hidden_size, intermediate_size, quant_config) + self.gate_up_proj = get_quantized_layer(hidden_size, 2 * intermediate_size, quant_config) self.down_proj = get_quantized_layer(intermediate_size, hidden_size, quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") - self.act_fn = activations.SiLUActivation() + self.act_fn = SiluAndMul() def forward(self, x): - gate_proj = self.act_fn(self.gate_proj(x)) - gate_up_proj = gate_proj * self.up_proj(x) - down_proj = self.down_proj(gate_up_proj) - return down_proj + gate_up = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x = self.down_proj(x) + return x class LlamaDecoderLayer(nn.Module): @@ -445,9 +443,56 @@ def load_weights(self, break if is_gate_up_weight: continue + else: + is_attention_weight = False + for weight_name, shard_size, offset in attention_weight_specs: + if weight_name not in name: + continue + param = state_dict[name.replace(weight_name, "qkv_proj")] + + if 'q_proj' in weight_name: + pos = 0 + elif 'k_proj' in weight_name: + pos = 1 + else: + assert 'v_proj' in weight_name + pos = 2 + + if 'scale' in name: + size = 5120 + else: + size = 640 + + #param_slice = param.data[:, offset:offset + shard_size] + param_slice = param.data[:, (pos * size):((pos+1) * size)] + #print(weight_name, param_slice.shape, loaded_weight.shape, pos, param.data.shape) + assert param_slice.shape == loaded_weight.shape + + param_slice.copy_(loaded_weight) + is_attention_weight = True + break + if is_attention_weight: + continue + + is_gate_up_weight = False + for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): + if weight_name not in name: + continue + param = state_dict[name.replace(weight_name, "gate_up_proj")] + shard_size = param.shape[1] // 2 + + param_slice = param.data[:, shard_size * stride_id:shard_size * + (stride_id + 1)] + assert param_slice.shape == loaded_weight.shape + param_slice.copy_(loaded_weight) + is_gate_up_weight = True + break + if is_gate_up_weight: + continue param = state_dict[name] load_tensor_parallel_weights(param, loaded_weight, name, self._column_parallel_weights, self._row_parallel_weights, tensor_model_parallel_rank) + print(self.model) From ff4d69394f4c02d4bb95a121ec8944980afc6ad8 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 15 Aug 2023 18:42:41 +0000 Subject: [PATCH 2/3] merge linear layers --- vllm/model_executor/models/llama.py | 45 +++++++++++++++-------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index dae07f5339803..011a17d070aa2 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -176,9 +176,17 @@ def __init__( self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 - intermediate_size = self.total_num_heads * self.head_dim - self.qkv_proj = get_quantized_layer(hidden_size, self.q_size + 2 * self.kv_size, quant_config) - self.o_proj = get_quantized_layer(intermediate_size, hidden_size, quant_config) + self.qkv_proj = get_quantized_layer( + hidden_size, + (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim, + quant_config + ) + + self.o_proj = get_quantized_layer( + self.total_num_heads * self.head_dim, + hidden_size, + quant_config + ) self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_dim, @@ -381,6 +389,7 @@ def load_weights(self, kv_proj_shard_size = (self.config.hidden_size // self.config.num_attention_heads * self.config.num_key_value_heads // tp_size) + attention_weight_specs = [ # (weight_name, shard_size, offset) ("q_proj", q_proj_shard_size, 0), @@ -407,6 +416,7 @@ def load_weights(self, is_quantized = self.quant_config is not None and self.quant_config.method is not None + # merge linear layers if not is_quantized: is_attention_weight = False for weight_name, shard_size, offset in attention_weight_specs: @@ -444,28 +454,19 @@ def load_weights(self, if is_gate_up_weight: continue else: + # TODO: improve this block of code (not DRY, hacky, specific to AWQ) is_attention_weight = False - for weight_name, shard_size, offset in attention_weight_specs: + for stride_id, (weight_name, shard_size, offset) in enumerate(attention_weight_specs): if weight_name not in name: continue param = state_dict[name.replace(weight_name, "qkv_proj")] + + # TODO: this is specific to AWQ (should be more general) + if 'qweight' in name or 'qzeros' in name: + shard_size = int(shard_size // (32 / self.quant_config.bits)) + offset = int(offset // (32 / self.quant_config.bits)) - if 'q_proj' in weight_name: - pos = 0 - elif 'k_proj' in weight_name: - pos = 1 - else: - assert 'v_proj' in weight_name - pos = 2 - - if 'scale' in name: - size = 5120 - else: - size = 640 - - #param_slice = param.data[:, offset:offset + shard_size] - param_slice = param.data[:, (pos * size):((pos+1) * size)] - #print(weight_name, param_slice.shape, loaded_weight.shape, pos, param.data.shape) + param_slice = param.data[:, offset:offset + shard_size] assert param_slice.shape == loaded_weight.shape param_slice.copy_(loaded_weight) @@ -481,8 +482,8 @@ def load_weights(self, param = state_dict[name.replace(weight_name, "gate_up_proj")] shard_size = param.shape[1] // 2 - param_slice = param.data[:, shard_size * stride_id:shard_size * - (stride_id + 1)] + start, end = shard_size * stride_id, shard_size * (stride_id + 1) + param_slice = param.data[:, start:end] assert param_slice.shape == loaded_weight.shape param_slice.copy_(loaded_weight) is_gate_up_weight = True From 033e8c1d81330eb246f2db9c8886c1cad16d2ae7 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 15 Aug 2023 18:45:54 +0000 Subject: [PATCH 3/3] update --- vllm/model_executor/models/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 011a17d070aa2..f4aea9366f13f 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -496,4 +496,3 @@ def load_weights(self, self._column_parallel_weights, self._row_parallel_weights, tensor_model_parallel_rank) - print(self.model)