Skip to content

Commit af5fb97

Browse files
committed
Update torchao_convert to match notebook
1 parent f07488e commit af5fb97

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

torchao/prototype/parq/optim/quantopt.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,17 +163,24 @@ def _get_quantizer(self, group_idx: int) -> Optional[Quantizer]:
163163
def torchao_convert(self, model: nn.Module, weight_only: bool = False) -> None:
164164
"""Converts model parameters to torchao quantized tensor subclasses."""
165165
model.eval()
166-
self.restore_latent_params()
167166

168167
# TODO(lvj): find more robust way to identify embedding layers
169168
embed_data_ptrs = set()
170169
linear_data_ptrs = set()
170+
embed_modules = []
171171
for module in model.modules():
172172
if isinstance(module, nn.Embedding):
173+
embed_modules.append(module)
173174
embed_data_ptrs.add(module.weight.data_ptr())
174175
elif _is_linear(module) and module.weight.data_ptr() not in embed_data_ptrs:
175176
linear_data_ptrs.add(module.weight.data_ptr())
176177

178+
tied_embeddings = getattr(model, "_tied_weights_keys", None) is not None
179+
if tied_embeddings:
180+
# Workaround for dynamic activations on tied embeddings
181+
for module in embed_modules:
182+
setattr(module, "bias", None)
183+
177184
filter_fns = []
178185
configs = []
179186
attach_hf_config = _is_hf_model(model)
@@ -194,7 +201,7 @@ def torchao_convert(self, model: nn.Module, weight_only: bool = False) -> None:
194201
any_embed = any(p.data_ptr() in embed_data_ptrs for p in group["params"])
195202
config = _get_config_from_quantizer(
196203
quantizer,
197-
weight_only or any_embed,
204+
weight_only or (any_embed and not tied_embeddings),
198205
device,
199206
group["quant_bits"],
200207
group.get("quant_block_size"),

0 commit comments

Comments
 (0)