@@ -161,6 +161,7 @@ def build_param_groups(
161161 model ,
162162 b : int = 2 ,
163163 group_size : Optional [int ] = None ,
164+ embed_b : int = 4 ,
164165):
165166 params_quant , params_embed , params_no_quant = split_param_groups (model )
166167 quant_kwargs = {}
@@ -171,14 +172,27 @@ def build_param_groups(
171172 {"params" : params_no_quant },
172173 ]
173174 if params_embed :
174- param_groups .append ({"params" : params_embed , "quant_bits" : 4 })
175+ param_groups .append ({"params" : params_embed , "quant_bits" : embed_b })
175176 return param_groups
176177
177178
178- def get_optim_kwargs (base_optimizer , embedding = True , quant_cls = UnifTorchaoQuantizer ):
179+ def get_optim_kwargs (
180+ model , base_optimizer , embedding = True , quant_cls = UnifTorchaoQuantizer
181+ ):
179182 optim_kwargs = {}
180183 if embedding :
181- group_idx = len (base_optimizer .param_groups ) - 2
184+ embed_data_ptrs = set (
185+ (
186+ m .weight .data_ptr ()
187+ for m in model .modules ()
188+ if isinstance (m , nn .Embedding )
189+ )
190+ )
191+ group_idx = - 1
192+ for i , group in enumerate (base_optimizer .param_groups ):
193+ if all (p .data_ptr () in embed_data_ptrs for p in group ["params" ]):
194+ group_idx = i
195+ break
182196 assert group_idx > - 1
183197 optim_kwargs ["group_quantizer_map" ] = {group_idx : quant_cls ()}
184198 return optim_kwargs
@@ -221,7 +235,7 @@ def compare_parq_convert(
221235 orig_model = copy .deepcopy (model ) # save copy of PARQ quantized model
222236
223237 # equivalent to torchao's convert step
224- optimizer .torchao_convert (model , weight_only = weight_only )
238+ optimizer .torchao_convert (model , weight_only = weight_only , embed_weight_only = True )
225239
226240 inputs = model .example_inputs (device = _DEVICE )
227241 torch .testing .assert_close (model (inputs ), orig_model (inputs ))
@@ -289,13 +303,15 @@ def test_parq_train_loop(
289303 quantizer = TernaryUnifQuantizer () if b == 0 else UnifQuantizer ()
290304 else :
291305 quantizer = LSBQuantizer ()
292- param_groups = build_param_groups (model , b )
306+ param_groups = build_param_groups (model , b , embed_b = b )
293307 base_optimizer = torch .optim .AdamW (param_groups )
294308
295309 prox_map = (
296310 ProxHardQuant () if hard_prox else ProxPARQ (anneal_start = 0 , anneal_end = 2 )
297311 )
298- optim_kwargs = get_optim_kwargs (base_optimizer )
312+ optim_kwargs = get_optim_kwargs (
313+ model , base_optimizer , quant_cls = type (quantizer ), embedding = False
314+ )
299315 optimizer = QuantOptimizer (base_optimizer , quantizer , prox_map , ** optim_kwargs )
300316 for _ in range (3 ):
301317 x = model .example_inputs (device = _DEVICE )
@@ -365,7 +381,7 @@ def test_int4_weight_only_e2e(self, group_size: int = 32):
365381
366382 b = 4
367383 base_optimizer = torch .optim .AdamW (build_param_groups (model , b , group_size ))
368- optim_kwargs = get_optim_kwargs (base_optimizer , embedding = False )
384+ optim_kwargs = get_optim_kwargs (model , base_optimizer , embedding = False )
369385 optimizer = QuantOptimizer (
370386 base_optimizer ,
371387 Int4UnifTorchaoQuantizer (),
@@ -387,7 +403,7 @@ def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32):
387403 quantize_ (m_ref , config )
388404
389405 base_optimizer = torch .optim .AdamW (build_param_groups (model , b , group_size ))
390- optim_kwargs = get_optim_kwargs (base_optimizer , embedding = False )
406+ optim_kwargs = get_optim_kwargs (model , base_optimizer , embedding = False )
391407 optimizer = QuantOptimizer (
392408 base_optimizer ,
393409 UnifTorchaoQuantizer (),
@@ -464,7 +480,7 @@ def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32):
464480 quantize_ (m_ref , config , filter_fn = _is_linear )
465481
466482 base_optimizer = torch .optim .AdamW (build_param_groups (model , b , group_size ))
467- optim_kwargs = get_optim_kwargs (base_optimizer , embedding = False )
483+ optim_kwargs = get_optim_kwargs (model , base_optimizer , embedding = False )
468484 optimizer = QuantOptimizer (
469485 base_optimizer ,
470486 quantizer ,
@@ -486,7 +502,7 @@ def test_intx_weight_only_tied_embed_linear(
486502
487503 quantizer = StretchedUnifTorchaoQuantizer (b )
488504 base_optimizer = torch .optim .SGD (build_param_groups (model , b ))
489- optim_kwargs = get_optim_kwargs (base_optimizer )
505+ optim_kwargs = get_optim_kwargs (model , base_optimizer )
490506 optimizer = QuantOptimizer (
491507 base_optimizer ,
492508 quantizer ,
@@ -498,7 +514,7 @@ def test_intx_weight_only_tied_embed_linear(
498514 optimizer .step ()
499515
500516 apply_activation_quantization (model , optimizer , model_dtype )
501- optimizer .torchao_convert (model )
517+ optimizer .torchao_convert (model , embed_weight_only = True )
502518 check_torchao_tensor_subclass (self , model )
503519 self .assertTrue (
504520 torch .equal (model .embed_tokens .weight .qdata , model .linear2 .weight .qdata )
@@ -540,7 +556,7 @@ def test_int8_dynamic_activation_intx_e2e(
540556
541557 # quantize weights with PARQ
542558 base_optimizer = torch .optim .SGD (build_param_groups (model , b , group_size ))
543- optim_kwargs = get_optim_kwargs (base_optimizer , embedding = False )
559+ optim_kwargs = get_optim_kwargs (model , base_optimizer , embedding = False )
544560 optimizer = QuantOptimizer (
545561 base_optimizer ,
546562 quantizer ,
0 commit comments