44# This source code is licensed under the BSD 3-Clause license found in the
55# LICENSE file in the root directory of this source tree.
66import copy
7+ import tempfile
78import unittest
89from typing import Optional
910
5152 torch_version_at_least ,
5253)
5354
54- _DEVICE = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
55+ if torch .xpu .is_available ():
56+ _DEVICE = "xpu"
57+ elif torch .backends .mps .is_available ():
58+ _DEVICE = "mps"
59+ elif torch .cuda .is_available ():
60+ _DEVICE = "cuda"
61+ else :
62+ _DEVICE = "cpu"
5563
5664if TRANSFORMERS_AVAIL :
57- from transformers import PretrainedConfig , TorchAoConfig
58- from transformers .quantizers .quantizer_torchao import TorchAoHfQuantizer
65+ from transformers import PretrainedConfig , PreTrainedModel , TorchAoConfig
66+
67+
68+ class MConfig (PretrainedConfig ):
69+ def __init__ (
70+ self ,
71+ m = 256 ,
72+ n = 128 ,
73+ k = 16 ,
74+ bias = False ,
75+ embedding = True ,
76+ tied_weights = False ,
77+ ** kwargs ,
78+ ):
79+ super ().__init__ (** kwargs )
80+ self .m = m
81+ self .n = n
82+ self .k = k
83+ self .bias = bias
84+ self .embedding = embedding
85+ self .tied_weights = tied_weights
86+
87+
88+ class M (nn .Module ):
89+ _tied_weights_keys : list [str ] = []
90+
91+ def __init__ (
92+ self , m = 256 , n = 128 , k = 16 , bias = False , embedding = True , tied_weights = False
93+ ):
94+ nn .Module .__init__ (self )
95+ self .embed_tokens = nn .Embedding (k , m ) if embedding else nn .Identity ()
96+ self .linear1 = nn .Linear (m , n , bias = bias )
97+ self .linear2 = nn .Linear (n , k , bias = bias )
98+ self .relu = nn .ReLU ()
99+ self .sigmoid = nn .Sigmoid ()
100+
101+ if embedding and tied_weights :
102+ assert self .embed_tokens .weight .shape == self .linear2 .weight .shape
103+ self .tie_weights ()
104+ self ._tied_weights_keys .append ("linear2.weight" )
105+
106+ def tie_weights (self ):
107+ self .linear2 .weight = self .embed_tokens .weight
108+
109+ def example_inputs (self , device = None ):
110+ if isinstance (self .embed_tokens , nn .Identity ):
111+ inputs = torch .randn (1 , self .linear1 .in_features , device = device )
112+ else :
113+ k = self .embed_tokens .num_embeddings
114+ inputs = torch .randint (1 , k , (1 , self .linear1 .in_features ), device = device )
115+ return inputs
116+
117+ def forward (self , x ):
118+ x = self .embed_tokens (x )
119+ x = self .relu (self .linear1 (x ))
120+ x = self .sigmoid (self .linear2 (x ))
121+ return x
122+
123+
124+ class PreTrainedM (M , PreTrainedModel ):
125+ base_model_prefix = "base"
126+ config_class = MConfig
127+
128+ def __init__ (self , config : MConfig ):
129+ PreTrainedModel .__init__ (self , config )
130+ M .__init__ (
131+ self ,
132+ m = config .m ,
133+ n = config .n ,
134+ k = config .k ,
135+ bias = config .bias ,
136+ embedding = config .embedding ,
137+ tied_weights = config .tied_weights ,
138+ )
139+
140+ def get_input_embeddings (self ) -> nn .Module :
141+ return self .embed_tokens
59142
60143
61144def split_param_groups (model ) -> tuple [list , list , list ]:
@@ -199,55 +282,9 @@ def apply_activation_quantization(
199282 pass
200283
201284
202- class M (nn .Module ):
203- _tied_weights_keys : list [str ] = []
204-
205- def __init__ (
206- self , m = 256 , n = 128 , k = 16 , bias = False , embedding = True , tied_weights = False
207- ):
208- super ().__init__ ()
209- self .embedding = nn .Embedding (k , m ) if embedding else nn .Identity ()
210- self .linear1 = nn .Linear (m , n , bias = bias )
211- self .linear2 = nn .Linear (n , k , bias = bias )
212- self .relu = nn .ReLU ()
213- self .sigmoid = nn .Sigmoid ()
214-
215- if embedding and tied_weights :
216- assert self .embedding .weight .shape == self .linear2 .weight .shape
217- self .tie_weights ()
218- self ._tied_weights_keys .append ("linear2.weight" )
219-
220- def tie_weights (self ):
221- self .linear2 .weight = self .embedding .weight
222-
223- def reset_parameters (self ):
224- for module in (self .linear1 , self .linear2 ):
225- nn .init .xavier_uniform_ (module .weight )
226- if module .bias is not None :
227- nn .init .zeros_ (module .bias )
228-
229- def example_inputs (self , device = None ):
230- if isinstance (self .embedding , nn .Identity ):
231- inputs = torch .randn (1 , self .linear1 .in_features , device = device )
232- else :
233- k = self .embedding .num_embeddings
234- inputs = torch .randint (1 , k , (1 , self .linear1 .in_features ), device = device )
235- return inputs
236-
237- def get_input_embeddings (self ) -> nn .Module :
238- return self .embedding
239-
240- def forward (self , x ):
241- x = self .embedding (x )
242- x = self .relu (self .linear1 (x ))
243- x = self .sigmoid (self .linear2 (x ))
244- return x
245-
246-
247285class TestPARQuantization (common_utils .TestCase ):
248286 def setUp (self ):
249287 torch .manual_seed (123 )
250- self .model = M (bias = True ).to (_DEVICE )
251288
252289 @common_utils .parametrize ("b" , [0 , 1 , 2 , 4 ])
253290 @common_utils .parametrize ("unif_quant" , [True , False ])
@@ -256,13 +293,13 @@ def setUp(self):
256293 def test_parq_train_loop (
257294 self , b : int = 2 , unif_quant = True , hard_prox = True , per_group_quantizer = False
258295 ):
259- self . model . reset_parameters ( )
296+ model = M ( bias = True ). to ( _DEVICE )
260297 if unif_quant :
261298 quantizer = TernaryUnifQuantizer () if b == 0 else UnifQuantizer ()
262299 else :
263300 quantizer = LSBQuantizer ()
264301 param_groups = build_param_groups (
265- self . model , b , quantizer = quantizer if per_group_quantizer else None
302+ model , b , quantizer = quantizer if per_group_quantizer else None
266303 )
267304 base_optimizer = torch .optim .AdamW (param_groups )
268305
@@ -271,12 +308,12 @@ def test_parq_train_loop(
271308 )
272309 optimizer = QuantOptimizer (base_optimizer , quantizer , prox_map )
273310 for _ in range (3 ):
274- x = self . model .example_inputs (device = _DEVICE )
275- out = self . model (x )
311+ x = model .example_inputs (device = _DEVICE )
312+ out = model (x )
276313 out .sum ().backward ()
277314 optimizer .step ()
278315
279- for child in self . model .children ():
316+ for child in model .children ():
280317 if isinstance (child , nn .Linear ):
281318 self .assertEqual (
282319 child .weight .unique ().numel (), quantizer .get_quant_size (b )
@@ -295,7 +332,6 @@ def setUp(self):
295332 @common_utils .parametrize ("group_size" , [32 , 256 ])
296333 def test_int4_weight_only (self , group_size : int = 32 ):
297334 model = M (m = 512 , n = 512 ).to (_DEVICE , dtype = torch .bfloat16 )
298- model .reset_parameters ()
299335
300336 m_ref = copy .deepcopy (model ).eval ().to (_DEVICE )
301337 config = Int4WeightOnlyConfig (group_size = group_size )
@@ -313,7 +349,6 @@ def test_int4_weight_only(self, group_size: int = 32):
313349 @common_utils .parametrize ("group_size" , [32 , 512 ])
314350 def test_intx_weight_only (self , b : int = 2 , group_size : int = 32 ):
315351 model = M (m = 512 , n = 512 ).to (_DEVICE )
316- model .reset_parameters ()
317352
318353 m_ref = copy .deepcopy (model ).eval ().to (_DEVICE )
319354 quantize_ (
@@ -333,7 +368,6 @@ def test_intx_weight_only(self, b: int = 2, group_size: int = 32):
333368 )
334369 def test_int4_weight_only_e2e (self , group_size : int = 32 ):
335370 model = M (m = 512 , n = 512 , embedding = False ).to (torch .bfloat16 ).to (_DEVICE )
336- model .reset_parameters ()
337371
338372 m_ref = copy .deepcopy (model ).eval ().to (_DEVICE )
339373 config = Int4WeightOnlyConfig (group_size = group_size )
@@ -353,7 +387,6 @@ def test_int4_weight_only_e2e(self, group_size: int = 32):
353387 @common_utils .parametrize ("b" , [2 , 3 , 4 , 8 ])
354388 def test_intx_weight_only_e2e (self , b : int = 2 , group_size : int = 32 ):
355389 model = M (m = 512 , n = 512 , embedding = False ).to (_DEVICE )
356- model .reset_parameters ()
357390
358391 m_ref = copy .deepcopy (model ).eval ().to (_DEVICE )
359392 config = IntxWeightOnlyConfig (
@@ -380,7 +413,6 @@ def setUp(self):
380413 @common_utils .parametrize ("group_size" , [32 , 256 ])
381414 def test_intx_weight_only_parq_equivalent (self , b : int = 2 , group_size : int = 32 ):
382415 model = M (m = 512 , n = 512 ).to (_DEVICE )
383- model .reset_parameters ()
384416
385417 quantizer_ref = UnifQuantizer ()
386418 quantizer = StretchedUnifTorchaoQuantizer (b )
@@ -403,7 +435,6 @@ def test_intx_weight_only_parq_equivalent(self, b: int = 2, group_size: int = 32
403435 @common_utils .parametrize ("group_size" , [32 , 512 ])
404436 def test_intx_weight_only (self , b : int = 2 , group_size : int = 32 ):
405437 model = M (m = 512 , n = 512 ).to (_DEVICE )
406- model .reset_parameters ()
407438
408439 quantizer = StretchedUnifTorchaoQuantizer (b )
409440
@@ -425,7 +456,6 @@ def test_intx_weight_only(self, b: int = 2, group_size: int = 32):
425456 @common_utils .parametrize ("b" , [2 , 3 ])
426457 def test_intx_weight_only_e2e (self , b : int = 2 , group_size : int = 32 ):
427458 model = M (m = 512 , n = 512 , embedding = False ).to (_DEVICE )
428- model .reset_parameters ()
429459
430460 quantizer = StretchedUnifTorchaoQuantizer (b )
431461
@@ -470,14 +500,16 @@ def test_intx_weight_only_tied_embed_linear(
470500 optimizer .torchao_convert (model )
471501 check_torchao_tensor_subclass (self , model )
472502 self .assertTrue (
473- torch .equal (model .embedding .weight .qdata , model .linear2 .weight .qdata )
503+ torch .equal (model .embed_tokens .weight .qdata , model .linear2 .weight .qdata )
474504 )
475505
476506
477507class TestInt8DynamicActivationTorchaoQuantizer (common_utils .TestCase ):
478508 def setUp (self ):
479509 torch .manual_seed (123 )
480510
511+ @unittest .skipIf (_DEVICE in ("mps" , "cpu" ), "Need GPU available" )
512+ @unittest .skipIf (not TRANSFORMERS_AVAIL , "Need transformers" )
481513 @common_utils .parametrize ("b" , [2 , 3 , 4 , 8 ])
482514 @common_utils .parametrize (
483515 "model_dtype" , [torch .float16 , torch .float32 , torch .bfloat16 ]
@@ -489,7 +521,8 @@ def test_int8_dynamic_activation_intx_e2e(
489521 model_dtype : torch .dtype = torch .float32 ,
490522 group_size : int = 32 ,
491523 ):
492- model = M (embedding = False , bias = True ).to (_DEVICE , dtype = model_dtype )
524+ config = MConfig (embedding = False , bias = True )
525+ model = PreTrainedM (config ).to (_DEVICE , dtype = model_dtype )
493526 x = model .example_inputs (device = _DEVICE ).to (model_dtype )
494527
495528 # reference model using native quantization
@@ -520,7 +553,6 @@ def test_int8_dynamic_activation_intx_e2e(
520553
521554 attach_hf_config = False
522555 if TRANSFORMERS_AVAIL :
523- model .config = PretrainedConfig () # pretend this is a HF model
524556 attach_hf_config = _is_hf_model (model )
525557 self .assertTrue (attach_hf_config )
526558
@@ -543,10 +575,11 @@ def test_int8_dynamic_activation_intx_e2e(
543575
544576
545577class TestTorchAoConfigIntegration (common_utils .TestCase ):
578+ @unittest .skipIf (_DEVICE in ("mps" , "cpu" ), "Need GPU available" )
546579 @unittest .skipIf (not TRANSFORMERS_AVAIL , "Need transformers" )
547580 def test_tied_weights_quantization (self , b : int = 4 ):
548- model = M (m = 128 , n = 128 , tied_weights = True ). to ( _DEVICE )
549- model . config = PretrainedConfig () # pretend this is a HF model
581+ config = MConfig (m = 128 , n = 128 , tied_weights = True )
582+ model = PreTrainedM ( config ). to ( _DEVICE )
550583
551584 quantizer = StretchedUnifTorchaoQuantizer (b )
552585 linear_config = StretchedIntxWeightConfig (
@@ -567,27 +600,20 @@ def test_tied_weights_quantization(self, b: int = 4):
567600 self .assertTrue (isinstance (quantization_config , TorchAoConfig ))
568601 self .assertTrue (quantization_config .modules_to_not_convert == ["linear2" ])
569602
570- # Simulate transformers.PreTrainedModel.from_pretrained
571- hf_quantizer = TorchAoHfQuantizer (
572- quantization_config ,
573- pre_quantized = False ,
574- modules_to_not_convert = quantization_config .modules_to_not_convert ,
575- )
576- state_dict = model .state_dict ()
577- unexpected_keys = []
578- for n , p in state_dict .items ():
579- if hf_quantizer .check_quantized_param (model , p , n , state_dict ):
580- hf_quantizer .create_quantized_param (
581- model , p , n , _DEVICE , state_dict , unexpected_keys
582- )
583- model .tie_weights ()
603+ # Let HF apply quantize_ given quantization_config
604+ del model .config .quantization_config
605+ with tempfile .TemporaryDirectory () as tmp_dir :
606+ model .save_pretrained (tmp_dir , safe_serialization = False )
607+ model = PreTrainedM .from_pretrained (
608+ tmp_dir , quantization_config = quantization_config
609+ )
584610
585611 check_torchao_tensor_subclass (self , model .linear1 )
586612 check_torchao_tensor_subclass (self , model .linear2 , weight_only = True )
587- check_torchao_tensor_subclass (self , model .embedding , weight_only = True )
613+ check_torchao_tensor_subclass (self , model .embed_tokens , weight_only = True )
588614
589615 self .assertTrue (
590- model .linear2 .weight .data_ptr () == model .embedding .weight .data_ptr ()
616+ model .linear2 .weight .data_ptr () == model .embed_tokens .weight .data_ptr ()
591617 )
592618
593619
0 commit comments