@@ -89,6 +89,7 @@ def __init__(
8989        self ,
9090        config : CohereConfig ,
9191        quant_config : Optional [QuantizationConfig ] =  None ,
92+         prefix : str  =  "" ,
9293    ):
9394        super ().__init__ ()
9495        self .config  =  config 
@@ -99,12 +100,14 @@ def __init__(
99100            [self .intermediate_size ] *  2 ,
100101            bias = False ,
101102            quant_config = quant_config ,
103+             prefix = f"{ prefix }  ,
102104        )
103105        self .down_proj  =  RowParallelLinear (
104106            self .intermediate_size ,
105107            self .hidden_size ,
106108            bias = False ,
107109            quant_config = quant_config ,
110+             prefix = f"{ prefix }  ,
108111        )
109112        self .act_fn  =  SiluAndMul ()
110113
@@ -158,12 +161,14 @@ def __init__(
158161            self .total_num_kv_heads ,
159162            bias = False ,
160163            quant_config = quant_config ,
164+             prefix = f"{ prefix }  ,
161165        )
162166        self .o_proj  =  RowParallelLinear (
163167            self .total_num_heads  *  self .head_dim ,
164168            self .hidden_size ,
165169            bias = False ,
166170            quant_config = quant_config ,
171+             prefix = f"{ prefix }  ,
167172        )
168173        self .rotary_emb  =  get_rope (
169174            self .head_dim ,
@@ -244,7 +249,9 @@ def __init__(self,
244249                                         quant_config = quant_config ,
245250                                         prefix = f"{ prefix }  )
246251
247-         self .mlp  =  CohereMLP (config , quant_config = quant_config )
252+         self .mlp  =  CohereMLP (config ,
253+                              quant_config = quant_config ,
254+                              prefix = f"{ prefix }  )
248255        self .input_layernorm  =  LayerNorm (param_shape = (config .hidden_size ),
249256                                         eps = config .layer_norm_eps )
250257
0 commit comments