44# This source code is licensed under the BSD 3-Clause license found in the
55# LICENSE file in the root directory of this source tree.
66import copy
7+ import tempfile
78import unittest
89from typing import Optional
910
5354
5455_DEVICE = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
5556
57+
58+ class M (nn .Module ):
59+ _tied_weights_keys : list [str ] = []
60+
61+ def __init__ (
62+ self , m = 256 , n = 128 , k = 16 , bias = False , embedding = True , tied_weights = False
63+ ):
64+ nn .Module .__init__ (self )
65+ self .embed_tokens = nn .Embedding (k , m ) if embedding else nn .Identity ()
66+ self .linear1 = nn .Linear (m , n , bias = bias )
67+ self .linear2 = nn .Linear (n , k , bias = bias )
68+ self .relu = nn .ReLU ()
69+ self .sigmoid = nn .Sigmoid ()
70+
71+ if embedding and tied_weights :
72+ assert self .embed_tokens .weight .shape == self .linear2 .weight .shape
73+ self .tie_weights ()
74+ self ._tied_weights_keys .append ("linear2.weight" )
75+
76+ def tie_weights (self ):
77+ self .linear2 .weight = self .embed_tokens .weight
78+
79+ def example_inputs (self , device = None ):
80+ if isinstance (self .embed_tokens , nn .Identity ):
81+ inputs = torch .randn (1 , self .linear1 .in_features , device = device )
82+ else :
83+ k = self .embed_tokens .num_embeddings
84+ inputs = torch .randint (1 , k , (1 , self .linear1 .in_features ), device = device )
85+ return inputs
86+
87+ def forward (self , x ):
88+ x = self .embed_tokens (x )
89+ x = self .relu (self .linear1 (x ))
90+ x = self .sigmoid (self .linear2 (x ))
91+ return x
92+
93+
5694if TRANSFORMERS_AVAIL :
57- from transformers import PretrainedConfig , TorchAoConfig
58- from transformers .quantizers .quantizer_torchao import TorchAoHfQuantizer
95+ from transformers import PretrainedConfig , PreTrainedModel , TorchAoConfig
96+
97+ class MConfig (PretrainedConfig ):
98+ def __init__ (
99+ self ,
100+ m = 256 ,
101+ n = 128 ,
102+ k = 16 ,
103+ bias = False ,
104+ embedding = True ,
105+ tied_weights = False ,
106+ ** kwargs ,
107+ ):
108+ super ().__init__ (** kwargs )
109+ self .m = m
110+ self .n = n
111+ self .k = k
112+ self .bias = bias
113+ self .embedding = embedding
114+ self .tied_weights = tied_weights
115+
116+ class PreTrainedM (M , PreTrainedModel ):
117+ base_model_prefix = "base"
118+ config_class = MConfig
119+
120+ def __init__ (self , config : MConfig ):
121+ PreTrainedModel .__init__ (self , config )
122+ M .__init__ (
123+ self ,
124+ m = config .m ,
125+ n = config .n ,
126+ k = config .k ,
127+ bias = config .bias ,
128+ embedding = config .embedding ,
129+ tied_weights = config .tied_weights ,
130+ )
131+
132+ def get_input_embeddings (self ) -> nn .Module :
133+ return self .embed_tokens
59134
60135
61136def split_param_groups (model ) -> tuple [list , list , list ]:
@@ -199,55 +274,9 @@ def apply_activation_quantization(
199274 pass
200275
201276
202- class M (nn .Module ):
203- _tied_weights_keys : list [str ] = []
204-
205- def __init__ (
206- self , m = 256 , n = 128 , k = 16 , bias = False , embedding = True , tied_weights = False
207- ):
208- super ().__init__ ()
209- self .embedding = nn .Embedding (k , m ) if embedding else nn .Identity ()
210- self .linear1 = nn .Linear (m , n , bias = bias )
211- self .linear2 = nn .Linear (n , k , bias = bias )
212- self .relu = nn .ReLU ()
213- self .sigmoid = nn .Sigmoid ()
214-
215- if embedding and tied_weights :
216- assert self .embedding .weight .shape == self .linear2 .weight .shape
217- self .tie_weights ()
218- self ._tied_weights_keys .append ("linear2.weight" )
219-
220- def tie_weights (self ):
221- self .linear2 .weight = self .embedding .weight
222-
223- def reset_parameters (self ):
224- for module in (self .linear1 , self .linear2 ):
225- nn .init .xavier_uniform_ (module .weight )
226- if module .bias is not None :
227- nn .init .zeros_ (module .bias )
228-
229- def example_inputs (self , device = None ):
230- if isinstance (self .embedding , nn .Identity ):
231- inputs = torch .randn (1 , self .linear1 .in_features , device = device )
232- else :
233- k = self .embedding .num_embeddings
234- inputs = torch .randint (1 , k , (1 , self .linear1 .in_features ), device = device )
235- return inputs
236-
237- def get_input_embeddings (self ) -> nn .Module :
238- return self .embedding
239-
240- def forward (self , x ):
241- x = self .embedding (x )
242- x = self .relu (self .linear1 (x ))
243- x = self .sigmoid (self .linear2 (x ))
244- return x
245-
246-
247277class TestPARQuantization (common_utils .TestCase ):
248278 def setUp (self ):
249279 torch .manual_seed (123 )
250- self .model = M (bias = True ).to (_DEVICE )
251280
252281 @common_utils .parametrize ("b" , [0 , 1 , 2 , 4 ])
253282 @common_utils .parametrize ("unif_quant" , [True , False ])
@@ -256,13 +285,13 @@ def setUp(self):
256285 def test_parq_train_loop (
257286 self , b : int = 2 , unif_quant = True , hard_prox = True , per_group_quantizer = False
258287 ):
259- self . model . reset_parameters ( )
288+ model = M ( bias = True ). to ( _DEVICE )
260289 if unif_quant :
261290 quantizer = TernaryUnifQuantizer () if b == 0 else UnifQuantizer ()
262291 else :
263292 quantizer = LSBQuantizer ()
264293 param_groups = build_param_groups (
265- self . model , b , quantizer = quantizer if per_group_quantizer else None
294+ model , b , quantizer = quantizer if per_group_quantizer else None
266295 )
267296 base_optimizer = torch .optim .AdamW (param_groups )
268297
@@ -271,12 +300,12 @@ def test_parq_train_loop(
271300 )
272301 optimizer = QuantOptimizer (base_optimizer , quantizer , prox_map )
273302 for _ in range (3 ):
274- x = self . model .example_inputs (device = _DEVICE )
275- out = self . model (x )
303+ x = model .example_inputs (device = _DEVICE )
304+ out = model (x )
276305 out .sum ().backward ()
277306 optimizer .step ()
278307
279- for child in self . model .children ():
308+ for child in model .children ():
280309 if isinstance (child , nn .Linear ):
281310 self .assertEqual (
282311 child .weight .unique ().numel (), quantizer .get_quant_size (b )
@@ -295,7 +324,6 @@ def setUp(self):
295324 @common_utils .parametrize ("group_size" , [32 , 256 ])
296325 def test_int4_weight_only (self , group_size : int = 32 ):
297326 model = M (m = 512 , n = 512 ).to (_DEVICE , dtype = torch .bfloat16 )
298- model .reset_parameters ()
299327
300328 m_ref = copy .deepcopy (model ).eval ().to (_DEVICE )
301329 config = Int4WeightOnlyConfig (group_size = group_size )
@@ -313,7 +341,6 @@ def test_int4_weight_only(self, group_size: int = 32):
313341 @common_utils .parametrize ("group_size" , [32 , 512 ])
314342 def test_intx_weight_only (self , b : int = 2 , group_size : int = 32 ):
315343 model = M (m = 512 , n = 512 ).to (_DEVICE )
316- model .reset_parameters ()
317344
318345 m_ref = copy .deepcopy (model ).eval ().to (_DEVICE )
319346 quantize_ (
@@ -333,7 +360,6 @@ def test_intx_weight_only(self, b: int = 2, group_size: int = 32):
333360 )
334361 def test_int4_weight_only_e2e (self , group_size : int = 32 ):
335362 model = M (m = 512 , n = 512 , embedding = False ).to (torch .bfloat16 ).to (_DEVICE )
336- model .reset_parameters ()
337363
338364 m_ref = copy .deepcopy (model ).eval ().to (_DEVICE )
339365 config = Int4WeightOnlyConfig (group_size = group_size )
@@ -353,7 +379,6 @@ def test_int4_weight_only_e2e(self, group_size: int = 32):
353379 @common_utils .parametrize ("b" , [2 , 3 , 4 , 8 ])
354380 def test_intx_weight_only_e2e (self , b : int = 2 , group_size : int = 32 ):
355381 model = M (m = 512 , n = 512 , embedding = False ).to (_DEVICE )
356- model .reset_parameters ()
357382
358383 m_ref = copy .deepcopy (model ).eval ().to (_DEVICE )
359384 config = IntxWeightOnlyConfig (
@@ -380,7 +405,6 @@ def setUp(self):
380405 @common_utils .parametrize ("group_size" , [32 , 256 ])
381406 def test_intx_weight_only_parq_equivalent (self , b : int = 2 , group_size : int = 32 ):
382407 model = M (m = 512 , n = 512 ).to (_DEVICE )
383- model .reset_parameters ()
384408
385409 quantizer_ref = UnifQuantizer ()
386410 quantizer = StretchedUnifTorchaoQuantizer (b )
@@ -403,7 +427,6 @@ def test_intx_weight_only_parq_equivalent(self, b: int = 2, group_size: int = 32
403427 @common_utils .parametrize ("group_size" , [32 , 512 ])
404428 def test_intx_weight_only (self , b : int = 2 , group_size : int = 32 ):
405429 model = M (m = 512 , n = 512 ).to (_DEVICE )
406- model .reset_parameters ()
407430
408431 quantizer = StretchedUnifTorchaoQuantizer (b )
409432
@@ -425,7 +448,6 @@ def test_intx_weight_only(self, b: int = 2, group_size: int = 32):
425448 @common_utils .parametrize ("b" , [2 , 3 ])
426449 def test_intx_weight_only_e2e (self , b : int = 2 , group_size : int = 32 ):
427450 model = M (m = 512 , n = 512 , embedding = False ).to (_DEVICE )
428- model .reset_parameters ()
429451
430452 quantizer = StretchedUnifTorchaoQuantizer (b )
431453
@@ -470,14 +492,16 @@ def test_intx_weight_only_tied_embed_linear(
470492 optimizer .torchao_convert (model )
471493 check_torchao_tensor_subclass (self , model )
472494 self .assertTrue (
473- torch .equal (model .embedding .weight .qdata , model .linear2 .weight .qdata )
495+ torch .equal (model .embed_tokens .weight .qdata , model .linear2 .weight .qdata )
474496 )
475497
476498
477499class TestInt8DynamicActivationTorchaoQuantizer (common_utils .TestCase ):
478500 def setUp (self ):
479501 torch .manual_seed (123 )
480502
503+ @unittest .skipIf (_DEVICE == "cpu" , "Need GPU available" )
504+ @unittest .skipIf (not TRANSFORMERS_AVAIL , "Need transformers" )
481505 @common_utils .parametrize ("b" , [2 , 3 , 4 , 8 ])
482506 @common_utils .parametrize (
483507 "model_dtype" , [torch .float16 , torch .float32 , torch .bfloat16 ]
@@ -489,7 +513,8 @@ def test_int8_dynamic_activation_intx_e2e(
489513 model_dtype : torch .dtype = torch .float32 ,
490514 group_size : int = 32 ,
491515 ):
492- model = M (embedding = False , bias = True ).to (_DEVICE , dtype = model_dtype )
516+ config = MConfig (embedding = False , bias = True )
517+ model = PreTrainedM (config ).to (_DEVICE , dtype = model_dtype )
493518 x = model .example_inputs (device = _DEVICE ).to (model_dtype )
494519
495520 # reference model using native quantization
@@ -520,7 +545,6 @@ def test_int8_dynamic_activation_intx_e2e(
520545
521546 attach_hf_config = False
522547 if TRANSFORMERS_AVAIL :
523- model .config = PretrainedConfig () # pretend this is a HF model
524548 attach_hf_config = _is_hf_model (model )
525549 self .assertTrue (attach_hf_config )
526550
@@ -543,10 +567,11 @@ def test_int8_dynamic_activation_intx_e2e(
543567
544568
545569class TestTorchAoConfigIntegration (common_utils .TestCase ):
570+ @unittest .skipIf (torch .backends .mps .is_available (), "MPS not supported" )
546571 @unittest .skipIf (not TRANSFORMERS_AVAIL , "Need transformers" )
547572 def test_tied_weights_quantization (self , b : int = 4 ):
548- model = M (m = 128 , n = 128 , tied_weights = True ). to ( _DEVICE )
549- model . config = PretrainedConfig () # pretend this is a HF model
573+ config = MConfig (m = 128 , n = 128 , tied_weights = True )
574+ model = PreTrainedM ( config ). to ( _DEVICE )
550575
551576 quantizer = StretchedUnifTorchaoQuantizer (b )
552577 linear_config = StretchedIntxWeightConfig (
@@ -567,27 +592,20 @@ def test_tied_weights_quantization(self, b: int = 4):
567592 self .assertTrue (isinstance (quantization_config , TorchAoConfig ))
568593 self .assertTrue (quantization_config .modules_to_not_convert == ["linear2" ])
569594
570- # Simulate transformers.PreTrainedModel.from_pretrained
571- hf_quantizer = TorchAoHfQuantizer (
572- quantization_config ,
573- pre_quantized = False ,
574- modules_to_not_convert = quantization_config .modules_to_not_convert ,
575- )
576- state_dict = model .state_dict ()
577- unexpected_keys = []
578- for n , p in state_dict .items ():
579- if hf_quantizer .check_quantized_param (model , p , n , state_dict ):
580- hf_quantizer .create_quantized_param (
581- model , p , n , _DEVICE , state_dict , unexpected_keys
582- )
583- model .tie_weights ()
595+ # Let HF apply quantize_ given quantization_config
596+ del model .config .quantization_config
597+ with tempfile .TemporaryDirectory () as tmp_dir :
598+ model .save_pretrained (tmp_dir , safe_serialization = False )
599+ model = PreTrainedM .from_pretrained (
600+ tmp_dir , quantization_config = quantization_config
601+ )
584602
585603 check_torchao_tensor_subclass (self , model .linear1 )
586604 check_torchao_tensor_subclass (self , model .linear2 , weight_only = True )
587- check_torchao_tensor_subclass (self , model .embedding , weight_only = True )
605+ check_torchao_tensor_subclass (self , model .embed_tokens , weight_only = True )
588606
589607 self .assertTrue (
590- model .linear2 .weight .data_ptr () == model .embedding .weight .data_ptr ()
608+ model .linear2 .weight .data_ptr () == model .embed_tokens .weight .data_ptr ()
591609 )
592610
593611
0 commit comments