@@ -89,6 +89,7 @@ def __init__(
8989 self ,
9090 config : CohereConfig ,
9191 quant_config : Optional [QuantizationConfig ] = None ,
92+ prefix : str = "" ,
9293 ):
9394 super ().__init__ ()
9495 self .config = config
@@ -99,12 +100,14 @@ def __init__(
99100 [self .intermediate_size ] * 2 ,
100101 bias = False ,
101102 quant_config = quant_config ,
103+ prefix = f"{ prefix } .gate_up_proj" ,
102104 )
103105 self .down_proj = RowParallelLinear (
104106 self .intermediate_size ,
105107 self .hidden_size ,
106108 bias = False ,
107109 quant_config = quant_config ,
110+ prefix = f"{ prefix } .down_proj" ,
108111 )
109112 self .act_fn = SiluAndMul ()
110113
@@ -158,12 +161,14 @@ def __init__(
158161 self .total_num_kv_heads ,
159162 bias = False ,
160163 quant_config = quant_config ,
164+ prefix = f"{ prefix } .qkv_proj" ,
161165 )
162166 self .o_proj = RowParallelLinear (
163167 self .total_num_heads * self .head_dim ,
164168 self .hidden_size ,
165169 bias = False ,
166170 quant_config = quant_config ,
171+ prefix = f"{ prefix } .o_proj" ,
167172 )
168173 self .rotary_emb = get_rope (
169174 self .head_dim ,
@@ -244,7 +249,9 @@ def __init__(self,
244249 quant_config = quant_config ,
245250 prefix = f"{ prefix } .self_attn" )
246251
247- self .mlp = CohereMLP (config , quant_config = quant_config )
252+ self .mlp = CohereMLP (config ,
253+ quant_config = quant_config ,
254+ prefix = f"{ prefix } .mlp" )
248255 self .input_layernorm = LayerNorm (param_shape = (config .hidden_size ),
249256 eps = config .layer_norm_eps )
250257
0 commit comments