2727 UnifQuantizer ,
2828 UnifTorchaoQuantizer ,
2929)
30- from torchao .prototype .parq .quant .config_torchao import TRANSFORMERS_AVAIL , _is_hf_model
30+ from torchao .prototype .parq .quant .config_torchao import (
31+ TRANSFORMERS_AVAIL ,
32+ _attach_hf_quantization_config ,
33+ _is_hf_model ,
34+ )
3135from torchao .prototype .parq .quant .uniform_torchao import _BIT_WIDTH_TO_DTYPE
32- from torchao .quantization .granularity import PerGroup
36+ from torchao .quantization .granularity import PerAxis , PerGroup
3337from torchao .quantization .qat import IntxFakeQuantizeConfig , QATConfig
3438from torchao .quantization .quant_api import (
3539 Int4WeightOnlyConfig ,
4953
5054_DEVICE = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
5155
56+ if TRANSFORMERS_AVAIL :
57+ from transformers import PretrainedConfig , TorchAoConfig
58+ from transformers .quantizers .quantizer_torchao import TorchAoHfQuantizer
59+
5260
5361def split_param_groups (model ) -> tuple [list , list , list ]:
5462 params_quant , params_embed , params_no_quant = [], [], []
@@ -206,9 +214,12 @@ def __init__(
206214
207215 if embedding and tied_weights :
208216 assert self .embedding .weight .shape == self .linear2 .weight .shape
209- self .linear2 . weight = self . embedding . weight
217+ self .tie_weights ()
210218 self ._tied_weights_keys .append ("linear2.weight" )
211219
220+ def tie_weights (self ):
221+ self .linear2 .weight = self .embedding .weight
222+
212223 def reset_parameters (self ):
213224 for module in (self .linear1 , self .linear2 ):
214225 nn .init .xavier_uniform_ (module .weight )
@@ -223,6 +234,9 @@ def example_inputs(self, device=None):
223234 inputs = torch .randint (1 , k , (1 , self .linear1 .in_features ), device = device )
224235 return inputs
225236
237+ def get_input_embeddings (self ) -> nn .Module :
238+ return self .embedding
239+
226240 def forward (self , x ):
227241 x = self .embedding (x )
228242 x = self .relu (self .linear1 (x ))
@@ -506,8 +520,6 @@ def test_int8_dynamic_activation_intx_e2e(
506520
507521 attach_hf_config = False
508522 if TRANSFORMERS_AVAIL :
509- from transformers import PretrainedConfig
510-
511523 model .config = PretrainedConfig () # pretend this is a HF model
512524 attach_hf_config = _is_hf_model (model )
513525 self .assertTrue (attach_hf_config )
@@ -530,6 +542,55 @@ def test_int8_dynamic_activation_intx_e2e(
530542 self .assertTrue (isinstance (torchao_config , config .__class__ ))
531543
532544
545+ class TestTorchAoConfigIntegration (common_utils .TestCase ):
546+ @unittest .skipIf (not TRANSFORMERS_AVAIL , "Need transformers" )
547+ def test_tied_weights_quantization (self , b : int = 4 ):
548+ model = M (m = 128 , n = 128 , tied_weights = True ).to (_DEVICE )
549+ model .config = PretrainedConfig () # pretend this is a HF model
550+
551+ quantizer = StretchedUnifTorchaoQuantizer (b )
552+ linear_config = StretchedIntxWeightConfig (
553+ b = b ,
554+ quant_min = quantizer .quant_min ,
555+ quant_max = quantizer .quant_max ,
556+ granularity = PerAxis (0 ),
557+ )
558+ embed_config = IntxWeightOnlyConfig (
559+ weight_dtype = _BIT_WIDTH_TO_DTYPE [b ], granularity = PerGroup (32 )
560+ )
561+ module_to_config = {"_default" : linear_config }
562+ configs = [embed_config ]
563+ filter_fns = [lambda m : isinstance (m , nn .Embedding )]
564+ _attach_hf_quantization_config (model , filter_fns , configs , module_to_config )
565+
566+ quantization_config = getattr (model .config , "quantization_config" , None )
567+ self .assertTrue (isinstance (quantization_config , TorchAoConfig ))
568+ self .assertTrue (quantization_config .modules_to_not_convert == ["linear2" ])
569+
570+ # Simulate transformers.PreTrainedModel.from_pretrained
571+ hf_quantizer = TorchAoHfQuantizer (
572+ quantization_config ,
573+ pre_quantized = False ,
574+ modules_to_not_convert = quantization_config .modules_to_not_convert ,
575+ )
576+ state_dict = model .state_dict ()
577+ unexpected_keys = []
578+ for n , p in state_dict .items ():
579+ if hf_quantizer .check_quantized_param (model , p , n , state_dict ):
580+ hf_quantizer .create_quantized_param (
581+ model , p , n , _DEVICE , state_dict , unexpected_keys
582+ )
583+ model .tie_weights ()
584+
585+ check_torchao_tensor_subclass (self , model .linear1 )
586+ check_torchao_tensor_subclass (self , model .linear2 , weight_only = True )
587+ check_torchao_tensor_subclass (self , model .embedding , weight_only = True )
588+
589+ self .assertTrue (
590+ model .linear2 .weight .data_ptr () == model .embedding .weight .data_ptr ()
591+ )
592+
593+
533594common_utils .instantiate_parametrized_tests (TestPARQuantization )
534595common_utils .instantiate_parametrized_tests (TestUnifTorchaoQuantizer )
535596common_utils .instantiate_parametrized_tests (TestInt8DynamicActivationTorchaoQuantizer )
0 commit comments