2929
3030import torch
3131from torch import nn
32- from transformers import LlamaConfig , activations
32+ from transformers import LlamaConfig
3333
3434from vllm .config import QuantizationConfig
3535from vllm .model_executor .input_metadata import InputMetadata
@@ -172,13 +172,21 @@ def __init__(
172172 assert self .total_num_kv_heads % tp_size == 0
173173 self .num_kv_heads = self .total_num_kv_heads // tp_size
174174 self .head_dim = hidden_size // self .total_num_heads
175+ self .q_size = self .num_heads * self .head_dim
176+ self .kv_size = self .num_kv_heads * self .head_dim
175177 self .scaling = self .head_dim ** - 0.5
176178
177- intermediate_size = self .total_num_heads * self .head_dim
178- self .q_proj = get_quantized_layer (hidden_size , intermediate_size , quant_config )
179- self .k_proj = get_quantized_layer (hidden_size , intermediate_size , quant_config )
180- self .v_proj = get_quantized_layer (hidden_size , intermediate_size , quant_config )
181- self .o_proj = get_quantized_layer (intermediate_size , hidden_size , quant_config )
179+ self .qkv_proj = get_quantized_layer (
180+ hidden_size ,
181+ (self .total_num_heads + 2 * self .total_num_kv_heads ) * self .head_dim ,
182+ quant_config
183+ )
184+
185+ self .o_proj = get_quantized_layer (
186+ self .total_num_heads * self .head_dim ,
187+ hidden_size ,
188+ quant_config
189+ )
182190
183191 self .attn = PagedAttentionWithRoPE (self .num_heads ,
184192 self .head_dim ,
@@ -194,9 +202,8 @@ def forward(
194202 input_metadata : InputMetadata ,
195203 cache_event : Optional [torch .cuda .Event ],
196204 ) -> torch .Tensor :
197- q = self .q_proj (hidden_states )
198- k = self .k_proj (hidden_states )
199- v = self .v_proj (hidden_states )
205+ qkv = self .qkv_proj (hidden_states )
206+ q , k , v = qkv .split ([self .q_size , self .kv_size , self .kv_size ], dim = - 1 )
200207 k_cache , v_cache = kv_cache
201208 attn_output = self .attn (positions , q , k , v , k_cache , v_cache ,
202209 input_metadata , cache_event )
@@ -213,20 +220,19 @@ def __init__(
213220 quant_config : QuantizationConfig
214221 ):
215222 super ().__init__ ()
216- self .gate_proj = get_quantized_layer (hidden_size , intermediate_size , quant_config )
217- self .up_proj = get_quantized_layer (hidden_size , intermediate_size , quant_config )
223+ self .gate_up_proj = get_quantized_layer (hidden_size , 2 * intermediate_size , quant_config )
218224 self .down_proj = get_quantized_layer (intermediate_size , hidden_size , quant_config )
219225
220226 if hidden_act != "silu" :
221227 raise ValueError (f"Unsupported activation: { hidden_act } . "
222228 "Only silu is supported for now." )
223- self .act_fn = activations . SiLUActivation ()
229+ self .act_fn = SiluAndMul ()
224230
225231 def forward (self , x ):
226- gate_proj = self .act_fn ( self . gate_proj ( x ) )
227- gate_up_proj = gate_proj * self .up_proj ( x )
228- down_proj = self .down_proj (gate_up_proj )
229- return down_proj
232+ gate_up = self .gate_up_proj ( x )
233+ x = self .act_fn ( gate_up )
234+ x = self .down_proj (x )
235+ return x
230236
231237
232238class LlamaDecoderLayer (nn .Module ):
@@ -383,6 +389,7 @@ def load_weights(self,
383389 kv_proj_shard_size = (self .config .hidden_size //
384390 self .config .num_attention_heads *
385391 self .config .num_key_value_heads // tp_size )
392+
386393 attention_weight_specs = [
387394 # (weight_name, shard_size, offset)
388395 ("q_proj" , q_proj_shard_size , 0 ),
@@ -409,6 +416,7 @@ def load_weights(self,
409416
410417 is_quantized = self .quant_config is not None and self .quant_config .method is not None
411418
419+ # merge linear layers
412420 if not is_quantized :
413421 is_attention_weight = False
414422 for weight_name , shard_size , offset in attention_weight_specs :
@@ -445,6 +453,43 @@ def load_weights(self,
445453 break
446454 if is_gate_up_weight :
447455 continue
456+ else :
457+ # TODO: improve this block of code (not DRY, hacky, specific to AWQ)
458+ is_attention_weight = False
459+ for stride_id , (weight_name , shard_size , offset ) in enumerate (attention_weight_specs ):
460+ if weight_name not in name :
461+ continue
462+ param = state_dict [name .replace (weight_name , "qkv_proj" )]
463+
464+ # TODO: this is specific to AWQ (should be more general)
465+ if 'qweight' in name or 'qzeros' in name :
466+ shard_size = int (shard_size // (32 / self .quant_config .bits ))
467+ offset = int (offset // (32 / self .quant_config .bits ))
468+
469+ param_slice = param .data [:, offset :offset + shard_size ]
470+ assert param_slice .shape == loaded_weight .shape
471+
472+ param_slice .copy_ (loaded_weight )
473+ is_attention_weight = True
474+ break
475+ if is_attention_weight :
476+ continue
477+
478+ is_gate_up_weight = False
479+ for stride_id , weight_name in enumerate (["gate_proj" , "up_proj" ]):
480+ if weight_name not in name :
481+ continue
482+ param = state_dict [name .replace (weight_name , "gate_up_proj" )]
483+ shard_size = param .shape [1 ] // 2
484+
485+ start , end = shard_size * stride_id , shard_size * (stride_id + 1 )
486+ param_slice = param .data [:, start :end ]
487+ assert param_slice .shape == loaded_weight .shape
488+ param_slice .copy_ (loaded_weight )
489+ is_gate_up_weight = True
490+ break
491+ if is_gate_up_weight :
492+ continue
448493
449494 param = state_dict [name ]
450495 load_tensor_parallel_weights (param , loaded_weight , name ,
0 commit comments