Skip to content

Commit abc60fd

Browse files
committed
Handle weight-only embeddings in torchao_convert
1 parent af5fb97 commit abc60fd

File tree

2 files changed

+38
-15
lines changed

2 files changed

+38
-15
lines changed

test/prototype/test_parq.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def build_param_groups(
161161
model,
162162
b: int = 2,
163163
group_size: Optional[int] = None,
164+
embed_b: int = 4,
164165
):
165166
params_quant, params_embed, params_no_quant = split_param_groups(model)
166167
quant_kwargs = {}
@@ -171,14 +172,27 @@ def build_param_groups(
171172
{"params": params_no_quant},
172173
]
173174
if params_embed:
174-
param_groups.append({"params": params_embed, "quant_bits": 4})
175+
param_groups.append({"params": params_embed, "quant_bits": embed_b})
175176
return param_groups
176177

177178

178-
def get_optim_kwargs(base_optimizer, embedding=True, quant_cls=UnifTorchaoQuantizer):
179+
def get_optim_kwargs(
180+
model, base_optimizer, embedding=True, quant_cls=UnifTorchaoQuantizer
181+
):
179182
optim_kwargs = {}
180183
if embedding:
181-
group_idx = len(base_optimizer.param_groups) - 2
184+
embed_data_ptrs = set(
185+
(
186+
m.weight.data_ptr()
187+
for m in model.modules()
188+
if isinstance(m, nn.Embedding)
189+
)
190+
)
191+
group_idx = -1
192+
for i, group in enumerate(base_optimizer.param_groups):
193+
if all(p.data_ptr() in embed_data_ptrs for p in group["params"]):
194+
group_idx = i
195+
break
182196
assert group_idx > -1
183197
optim_kwargs["group_quantizer_map"] = {group_idx: quant_cls()}
184198
return optim_kwargs
@@ -221,7 +235,7 @@ def compare_parq_convert(
221235
orig_model = copy.deepcopy(model) # save copy of PARQ quantized model
222236

223237
# equivalent to torchao's convert step
224-
optimizer.torchao_convert(model, weight_only=weight_only)
238+
optimizer.torchao_convert(model, weight_only=weight_only, embed_weight_only=True)
225239

226240
inputs = model.example_inputs(device=_DEVICE)
227241
torch.testing.assert_close(model(inputs), orig_model(inputs))
@@ -289,13 +303,15 @@ def test_parq_train_loop(
289303
quantizer = TernaryUnifQuantizer() if b == 0 else UnifQuantizer()
290304
else:
291305
quantizer = LSBQuantizer()
292-
param_groups = build_param_groups(model, b)
306+
param_groups = build_param_groups(model, b, embed_b=b)
293307
base_optimizer = torch.optim.AdamW(param_groups)
294308

295309
prox_map = (
296310
ProxHardQuant() if hard_prox else ProxPARQ(anneal_start=0, anneal_end=2)
297311
)
298-
optim_kwargs = get_optim_kwargs(base_optimizer)
312+
optim_kwargs = get_optim_kwargs(
313+
model, base_optimizer, quant_cls=type(quantizer), embedding=False
314+
)
299315
optimizer = QuantOptimizer(base_optimizer, quantizer, prox_map, **optim_kwargs)
300316
for _ in range(3):
301317
x = model.example_inputs(device=_DEVICE)
@@ -365,7 +381,7 @@ def test_int4_weight_only_e2e(self, group_size: int = 32):
365381

366382
b = 4
367383
base_optimizer = torch.optim.AdamW(build_param_groups(model, b, group_size))
368-
optim_kwargs = get_optim_kwargs(base_optimizer, embedding=False)
384+
optim_kwargs = get_optim_kwargs(model, base_optimizer, embedding=False)
369385
optimizer = QuantOptimizer(
370386
base_optimizer,
371387
Int4UnifTorchaoQuantizer(),
@@ -387,7 +403,7 @@ def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32):
387403
quantize_(m_ref, config)
388404

389405
base_optimizer = torch.optim.AdamW(build_param_groups(model, b, group_size))
390-
optim_kwargs = get_optim_kwargs(base_optimizer, embedding=False)
406+
optim_kwargs = get_optim_kwargs(model, base_optimizer, embedding=False)
391407
optimizer = QuantOptimizer(
392408
base_optimizer,
393409
UnifTorchaoQuantizer(),
@@ -464,7 +480,7 @@ def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32):
464480
quantize_(m_ref, config, filter_fn=_is_linear)
465481

466482
base_optimizer = torch.optim.AdamW(build_param_groups(model, b, group_size))
467-
optim_kwargs = get_optim_kwargs(base_optimizer, embedding=False)
483+
optim_kwargs = get_optim_kwargs(model, base_optimizer, embedding=False)
468484
optimizer = QuantOptimizer(
469485
base_optimizer,
470486
quantizer,
@@ -486,7 +502,7 @@ def test_intx_weight_only_tied_embed_linear(
486502

487503
quantizer = StretchedUnifTorchaoQuantizer(b)
488504
base_optimizer = torch.optim.SGD(build_param_groups(model, b))
489-
optim_kwargs = get_optim_kwargs(base_optimizer)
505+
optim_kwargs = get_optim_kwargs(model, base_optimizer)
490506
optimizer = QuantOptimizer(
491507
base_optimizer,
492508
quantizer,
@@ -498,7 +514,7 @@ def test_intx_weight_only_tied_embed_linear(
498514
optimizer.step()
499515

500516
apply_activation_quantization(model, optimizer, model_dtype)
501-
optimizer.torchao_convert(model)
517+
optimizer.torchao_convert(model, embed_weight_only=True)
502518
check_torchao_tensor_subclass(self, model)
503519
self.assertTrue(
504520
torch.equal(model.embed_tokens.weight.qdata, model.linear2.weight.qdata)
@@ -540,7 +556,7 @@ def test_int8_dynamic_activation_intx_e2e(
540556

541557
# quantize weights with PARQ
542558
base_optimizer = torch.optim.SGD(build_param_groups(model, b, group_size))
543-
optim_kwargs = get_optim_kwargs(base_optimizer, embedding=False)
559+
optim_kwargs = get_optim_kwargs(model, base_optimizer, embedding=False)
544560
optimizer = QuantOptimizer(
545561
base_optimizer,
546562
quantizer,

torchao/prototype/parq/optim/quantopt.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,15 @@ def _get_quantizer(self, group_idx: int) -> Optional[Quantizer]:
160160
return self.group_quantizer_map[group_idx]
161161
return self.quantizer
162162

163-
def torchao_convert(self, model: nn.Module, weight_only: bool = False) -> None:
163+
def torchao_convert(
164+
self,
165+
model: nn.Module,
166+
weight_only: bool = False,
167+
embed_weight_only: bool = False,
168+
) -> None:
164169
"""Converts model parameters to torchao quantized tensor subclasses."""
165170
model.eval()
171+
self.restore_latent_params()
166172

167173
# TODO(lvj): find more robust way to identify embedding layers
168174
embed_data_ptrs = set()
@@ -175,9 +181,10 @@ def torchao_convert(self, model: nn.Module, weight_only: bool = False) -> None:
175181
elif _is_linear(module) and module.weight.data_ptr() not in embed_data_ptrs:
176182
linear_data_ptrs.add(module.weight.data_ptr())
177183

178-
tied_embeddings = getattr(model, "_tied_weights_keys", None) is not None
179-
if tied_embeddings:
184+
tied_embeddings = False
185+
if not embed_weight_only and getattr(model, "_tied_weights_keys", None):
180186
# Workaround for dynamic activations on tied embeddings
187+
tied_embeddings = True
181188
for module in embed_modules:
182189
setattr(module, "bias", None)
183190

0 commit comments

Comments
 (0)