@@ -163,17 +163,24 @@ def _get_quantizer(self, group_idx: int) -> Optional[Quantizer]:
163163 def torchao_convert (self , model : nn .Module , weight_only : bool = False ) -> None :
164164 """Converts model parameters to torchao quantized tensor subclasses."""
165165 model .eval ()
166- self .restore_latent_params ()
167166
168167 # TODO(lvj): find more robust way to identify embedding layers
169168 embed_data_ptrs = set ()
170169 linear_data_ptrs = set ()
170+ embed_modules = []
171171 for module in model .modules ():
172172 if isinstance (module , nn .Embedding ):
173+ embed_modules .append (module )
173174 embed_data_ptrs .add (module .weight .data_ptr ())
174175 elif _is_linear (module ) and module .weight .data_ptr () not in embed_data_ptrs :
175176 linear_data_ptrs .add (module .weight .data_ptr ())
176177
178+ tied_embeddings = getattr (model , "_tied_weights_keys" , None ) is not None
179+ if tied_embeddings :
180+ # Workaround for dynamic activations on tied embeddings
181+ for module in embed_modules :
182+ setattr (module , "bias" , None )
183+
177184 filter_fns = []
178185 configs = []
179186 attach_hf_config = _is_hf_model (model )
@@ -194,7 +201,7 @@ def torchao_convert(self, model: nn.Module, weight_only: bool = False) -> None:
194201 any_embed = any (p .data_ptr () in embed_data_ptrs for p in group ["params" ])
195202 config = _get_config_from_quantizer (
196203 quantizer ,
197- weight_only or any_embed ,
204+ weight_only or ( any_embed and not tied_embeddings ) ,
198205 device ,
199206 group ["quant_bits" ],
200207 group .get ("quant_block_size" ),
0 commit comments