44# This source code is licensed under the BSD 3-Clause license found in the
55# LICENSE file in the root directory of this source tree.
66import copy
7+ import tempfile
78import unittest
89from typing import Optional
910
2728 UnifQuantizer ,
2829 UnifTorchaoQuantizer ,
2930)
30- from torchao .prototype .parq .quant .config_torchao import TRANSFORMERS_AVAIL , _is_hf_model
31+ from torchao .prototype .parq .quant .config_torchao import (
32+ TRANSFORMERS_AVAIL ,
33+ _attach_hf_quantization_config ,
34+ _is_hf_model ,
35+ )
3136from torchao .prototype .parq .quant .uniform_torchao import _BIT_WIDTH_TO_DTYPE
32- from torchao .quantization .granularity import PerGroup
37+ from torchao .quantization .granularity import PerAxis , PerGroup
3338from torchao .quantization .qat import IntxFakeQuantizeConfig , QATConfig
3439from torchao .quantization .quant_api import (
3540 Int4WeightOnlyConfig ,
5055_DEVICE = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
5156
5257
58+ class M (nn .Module ):
59+ _tied_weights_keys : list [str ] = []
60+
61+ def __init__ (
62+ self , m = 256 , n = 128 , k = 16 , bias = False , embedding = True , tied_weights = False
63+ ):
64+ nn .Module .__init__ (self )
65+ self .embed_tokens = nn .Embedding (k , m ) if embedding else nn .Identity ()
66+ self .linear1 = nn .Linear (m , n , bias = bias )
67+ self .linear2 = nn .Linear (n , k , bias = bias )
68+ self .relu = nn .ReLU ()
69+ self .sigmoid = nn .Sigmoid ()
70+
71+ if embedding and tied_weights :
72+ assert self .embed_tokens .weight .shape == self .linear2 .weight .shape
73+ self .tie_weights ()
74+ self ._tied_weights_keys .append ("linear2.weight" )
75+
76+ def tie_weights (self ):
77+ self .linear2 .weight = self .embed_tokens .weight
78+
79+ def example_inputs (self , device = None ):
80+ if isinstance (self .embed_tokens , nn .Identity ):
81+ inputs = torch .randn (1 , self .linear1 .in_features , device = device )
82+ else :
83+ k = self .embed_tokens .num_embeddings
84+ inputs = torch .randint (1 , k , (1 , self .linear1 .in_features ), device = device )
85+ return inputs
86+
87+ def forward (self , x ):
88+ x = self .embed_tokens (x )
89+ x = self .relu (self .linear1 (x ))
90+ x = self .sigmoid (self .linear2 (x ))
91+ return x
92+
93+
94+ if TRANSFORMERS_AVAIL :
95+ from transformers import PretrainedConfig , PreTrainedModel , TorchAoConfig
96+
97+ class MConfig (PretrainedConfig ):
98+ def __init__ (
99+ self ,
100+ m = 256 ,
101+ n = 128 ,
102+ k = 16 ,
103+ bias = False ,
104+ embedding = True ,
105+ tied_weights = False ,
106+ ** kwargs ,
107+ ):
108+ super ().__init__ (** kwargs )
109+ self .m = m
110+ self .n = n
111+ self .k = k
112+ self .bias = bias
113+ self .embedding = embedding
114+ self .tied_weights = tied_weights
115+
116+ class PreTrainedM (M , PreTrainedModel ):
117+ base_model_prefix = "base"
118+ config_class = MConfig
119+
120+ def __init__ (self , config : MConfig ):
121+ PreTrainedModel .__init__ (self , config )
122+ M .__init__ (
123+ self ,
124+ m = config .m ,
125+ n = config .n ,
126+ k = config .k ,
127+ bias = config .bias ,
128+ embedding = config .embedding ,
129+ tied_weights = config .tied_weights ,
130+ )
131+
132+ def get_input_embeddings (self ) -> nn .Module :
133+ return self .embed_tokens
134+
135+
53136def split_param_groups (model ) -> tuple [list , list , list ]:
54137 params_quant , params_embed , params_no_quant = [], [], []
55138
@@ -191,49 +274,9 @@ def apply_activation_quantization(
191274 pass
192275
193276
194- class M (nn .Module ):
195- _tied_weights_keys : list [str ] = []
196-
197- def __init__ (
198- self , m = 256 , n = 128 , k = 16 , bias = False , embedding = True , tied_weights = False
199- ):
200- super ().__init__ ()
201- self .embedding = nn .Embedding (k , m ) if embedding else nn .Identity ()
202- self .linear1 = nn .Linear (m , n , bias = bias )
203- self .linear2 = nn .Linear (n , k , bias = bias )
204- self .relu = nn .ReLU ()
205- self .sigmoid = nn .Sigmoid ()
206-
207- if embedding and tied_weights :
208- assert self .embedding .weight .shape == self .linear2 .weight .shape
209- self .linear2 .weight = self .embedding .weight
210- self ._tied_weights_keys .append ("linear2.weight" )
211-
212- def reset_parameters (self ):
213- for module in (self .linear1 , self .linear2 ):
214- nn .init .xavier_uniform_ (module .weight )
215- if module .bias is not None :
216- nn .init .zeros_ (module .bias )
217-
218- def example_inputs (self , device = None ):
219- if isinstance (self .embedding , nn .Identity ):
220- inputs = torch .randn (1 , self .linear1 .in_features , device = device )
221- else :
222- k = self .embedding .num_embeddings
223- inputs = torch .randint (1 , k , (1 , self .linear1 .in_features ), device = device )
224- return inputs
225-
226- def forward (self , x ):
227- x = self .embedding (x )
228- x = self .relu (self .linear1 (x ))
229- x = self .sigmoid (self .linear2 (x ))
230- return x
231-
232-
233277class TestPARQuantization (common_utils .TestCase ):
234278 def setUp (self ):
235279 torch .manual_seed (123 )
236- self .model = M (bias = True ).to (_DEVICE )
237280
238281 @common_utils .parametrize ("b" , [0 , 1 , 2 , 4 ])
239282 @common_utils .parametrize ("unif_quant" , [True , False ])
@@ -242,13 +285,13 @@ def setUp(self):
242285 def test_parq_train_loop (
243286 self , b : int = 2 , unif_quant = True , hard_prox = True , per_group_quantizer = False
244287 ):
245- self . model . reset_parameters ( )
288+ model = M ( bias = True ). to ( _DEVICE )
246289 if unif_quant :
247290 quantizer = TernaryUnifQuantizer () if b == 0 else UnifQuantizer ()
248291 else :
249292 quantizer = LSBQuantizer ()
250293 param_groups = build_param_groups (
251- self . model , b , quantizer = quantizer if per_group_quantizer else None
294+ model , b , quantizer = quantizer if per_group_quantizer else None
252295 )
253296 base_optimizer = torch .optim .AdamW (param_groups )
254297
@@ -257,12 +300,12 @@ def test_parq_train_loop(
257300 )
258301 optimizer = QuantOptimizer (base_optimizer , quantizer , prox_map )
259302 for _ in range (3 ):
260- x = self . model .example_inputs (device = _DEVICE )
261- out = self . model (x )
303+ x = model .example_inputs (device = _DEVICE )
304+ out = model (x )
262305 out .sum ().backward ()
263306 optimizer .step ()
264307
265- for child in self . model .children ():
308+ for child in model .children ():
266309 if isinstance (child , nn .Linear ):
267310 self .assertEqual (
268311 child .weight .unique ().numel (), quantizer .get_quant_size (b )
@@ -281,7 +324,6 @@ def setUp(self):
281324 @common_utils .parametrize ("group_size" , [32 , 256 ])
282325 def test_int4_weight_only (self , group_size : int = 32 ):
283326 model = M (m = 512 , n = 512 ).to (_DEVICE , dtype = torch .bfloat16 )
284- model .reset_parameters ()
285327
286328 m_ref = copy .deepcopy (model ).eval ().to (_DEVICE )
287329 config = Int4WeightOnlyConfig (group_size = group_size )
@@ -299,7 +341,6 @@ def test_int4_weight_only(self, group_size: int = 32):
299341 @common_utils .parametrize ("group_size" , [32 , 512 ])
300342 def test_intx_weight_only (self , b : int = 2 , group_size : int = 32 ):
301343 model = M (m = 512 , n = 512 ).to (_DEVICE )
302- model .reset_parameters ()
303344
304345 m_ref = copy .deepcopy (model ).eval ().to (_DEVICE )
305346 quantize_ (
@@ -319,7 +360,6 @@ def test_intx_weight_only(self, b: int = 2, group_size: int = 32):
319360 )
320361 def test_int4_weight_only_e2e (self , group_size : int = 32 ):
321362 model = M (m = 512 , n = 512 , embedding = False ).to (torch .bfloat16 ).to (_DEVICE )
322- model .reset_parameters ()
323363
324364 m_ref = copy .deepcopy (model ).eval ().to (_DEVICE )
325365 config = Int4WeightOnlyConfig (group_size = group_size )
@@ -339,7 +379,6 @@ def test_int4_weight_only_e2e(self, group_size: int = 32):
339379 @common_utils .parametrize ("b" , [2 , 3 , 4 , 8 ])
340380 def test_intx_weight_only_e2e (self , b : int = 2 , group_size : int = 32 ):
341381 model = M (m = 512 , n = 512 , embedding = False ).to (_DEVICE )
342- model .reset_parameters ()
343382
344383 m_ref = copy .deepcopy (model ).eval ().to (_DEVICE )
345384 config = IntxWeightOnlyConfig (
@@ -366,7 +405,6 @@ def setUp(self):
366405 @common_utils .parametrize ("group_size" , [32 , 256 ])
367406 def test_intx_weight_only_parq_equivalent (self , b : int = 2 , group_size : int = 32 ):
368407 model = M (m = 512 , n = 512 ).to (_DEVICE )
369- model .reset_parameters ()
370408
371409 quantizer_ref = UnifQuantizer ()
372410 quantizer = StretchedUnifTorchaoQuantizer (b )
@@ -389,7 +427,6 @@ def test_intx_weight_only_parq_equivalent(self, b: int = 2, group_size: int = 32
389427 @common_utils .parametrize ("group_size" , [32 , 512 ])
390428 def test_intx_weight_only (self , b : int = 2 , group_size : int = 32 ):
391429 model = M (m = 512 , n = 512 ).to (_DEVICE )
392- model .reset_parameters ()
393430
394431 quantizer = StretchedUnifTorchaoQuantizer (b )
395432
@@ -411,7 +448,6 @@ def test_intx_weight_only(self, b: int = 2, group_size: int = 32):
411448 @common_utils .parametrize ("b" , [2 , 3 ])
412449 def test_intx_weight_only_e2e (self , b : int = 2 , group_size : int = 32 ):
413450 model = M (m = 512 , n = 512 , embedding = False ).to (_DEVICE )
414- model .reset_parameters ()
415451
416452 quantizer = StretchedUnifTorchaoQuantizer (b )
417453
@@ -456,14 +492,16 @@ def test_intx_weight_only_tied_embed_linear(
456492 optimizer .torchao_convert (model )
457493 check_torchao_tensor_subclass (self , model )
458494 self .assertTrue (
459- torch .equal (model .embedding .weight .qdata , model .linear2 .weight .qdata )
495+ torch .equal (model .embed_tokens .weight .qdata , model .linear2 .weight .qdata )
460496 )
461497
462498
463499class TestInt8DynamicActivationTorchaoQuantizer (common_utils .TestCase ):
464500 def setUp (self ):
465501 torch .manual_seed (123 )
466502
503+ @unittest .skipIf (_DEVICE == "cpu" , "Need GPU available" )
504+ @unittest .skipIf (not TRANSFORMERS_AVAIL , "Need transformers" )
467505 @common_utils .parametrize ("b" , [2 , 3 , 4 , 8 ])
468506 @common_utils .parametrize (
469507 "model_dtype" , [torch .float16 , torch .float32 , torch .bfloat16 ]
@@ -475,7 +513,8 @@ def test_int8_dynamic_activation_intx_e2e(
475513 model_dtype : torch .dtype = torch .float32 ,
476514 group_size : int = 32 ,
477515 ):
478- model = M (embedding = False , bias = True ).to (_DEVICE , dtype = model_dtype )
516+ config = MConfig (embedding = False , bias = True )
517+ model = PreTrainedM (config ).to (_DEVICE , dtype = model_dtype )
479518 x = model .example_inputs (device = _DEVICE ).to (model_dtype )
480519
481520 # reference model using native quantization
@@ -506,9 +545,6 @@ def test_int8_dynamic_activation_intx_e2e(
506545
507546 attach_hf_config = False
508547 if TRANSFORMERS_AVAIL :
509- from transformers import PretrainedConfig
510-
511- model .config = PretrainedConfig () # pretend this is a HF model
512548 attach_hf_config = _is_hf_model (model )
513549 self .assertTrue (attach_hf_config )
514550
@@ -530,6 +566,49 @@ def test_int8_dynamic_activation_intx_e2e(
530566 self .assertTrue (isinstance (torchao_config , config .__class__ ))
531567
532568
569+ class TestTorchAoConfigIntegration (common_utils .TestCase ):
570+ @unittest .skipIf (torch .backends .mps .is_available (), "MPS not supported" )
571+ @unittest .skipIf (not TRANSFORMERS_AVAIL , "Need transformers" )
572+ def test_tied_weights_quantization (self , b : int = 4 ):
573+ config = MConfig (m = 128 , n = 128 , tied_weights = True )
574+ model = PreTrainedM (config ).to (_DEVICE )
575+
576+ quantizer = StretchedUnifTorchaoQuantizer (b )
577+ linear_config = StretchedIntxWeightConfig (
578+ b = b ,
579+ quant_min = quantizer .quant_min ,
580+ quant_max = quantizer .quant_max ,
581+ granularity = PerAxis (0 ),
582+ )
583+ embed_config = IntxWeightOnlyConfig (
584+ weight_dtype = _BIT_WIDTH_TO_DTYPE [b ], granularity = PerGroup (32 )
585+ )
586+ module_to_config = {"_default" : linear_config }
587+ configs = [embed_config ]
588+ filter_fns = [lambda m : isinstance (m , nn .Embedding )]
589+ _attach_hf_quantization_config (model , filter_fns , configs , module_to_config )
590+
591+ quantization_config = getattr (model .config , "quantization_config" , None )
592+ self .assertTrue (isinstance (quantization_config , TorchAoConfig ))
593+ self .assertTrue (quantization_config .modules_to_not_convert == ["linear2" ])
594+
595+ # Let HF apply quantize_ given quantization_config
596+ del model .config .quantization_config
597+ with tempfile .TemporaryDirectory () as tmp_dir :
598+ model .save_pretrained (tmp_dir , safe_serialization = False )
599+ model = PreTrainedM .from_pretrained (
600+ tmp_dir , quantization_config = quantization_config
601+ )
602+
603+ check_torchao_tensor_subclass (self , model .linear1 )
604+ check_torchao_tensor_subclass (self , model .linear2 , weight_only = True )
605+ check_torchao_tensor_subclass (self , model .embed_tokens , weight_only = True )
606+
607+ self .assertTrue (
608+ model .linear2 .weight .data_ptr () == model .embed_tokens .weight .data_ptr ()
609+ )
610+
611+
533612common_utils .instantiate_parametrized_tests (TestPARQuantization )
534613common_utils .instantiate_parametrized_tests (TestUnifTorchaoQuantizer )
535614common_utils .instantiate_parametrized_tests (TestInt8DynamicActivationTorchaoQuantizer )
0 commit comments