77import tempfile
88
99import torch
10- from parameterized import parameterized
1110from torch .testing ._internal .common_utils import (
1211 TestCase ,
12+ instantiate_parametrized_tests ,
13+ parametrize ,
1314 run_tests ,
1415)
1516
@@ -55,6 +56,17 @@ def forward(self, x):
5556 devices .append ("xpu" )
5657
5758
59+ device_to_base_configs = {
60+ "cuda" : [
61+ Int4WeightOnlyConfig (group_size = 128 ),
62+ # Note: the functionality unit test doesn't work for hqq
63+ Int4WeightOnlyConfig (group_size = 128 , int4_packing_format = "tile_packed_to_4d" ),
64+ ],
65+ "cpu" : [Int4WeightOnlyConfig (group_size = 128 , int4_packing_format = "opaque" )],
66+ "xpu" : [Int4WeightOnlyConfig (group_size = 128 , int4_packing_format = "plain_int32" )],
67+ }
68+
69+
5870class TestAWQ (TestCase ):
5971 def test_awq_config (self ):
6072 base_config = Int4WeightOnlyConfig ()
@@ -69,190 +81,178 @@ def test_awq_config(self):
6981 with self .assertRaisesRegex (ValueError , "is not one of" ):
7082 AWQConfig (base_config , step = "not_supported" )
7183
72- @parameterized . expand ([( device ,) for device in devices ] )
84+ @parametrize ( " device" , devices )
7385 def test_awq_functionality (self , device ):
74- dataset_size = 100
86+ dataset_size = 10
7587 l1 , l2 , l3 = 512 , 256 , 128
7688 original_dtype = torch .bfloat16 # tinygemm kernel only uses bfloat16 inputs
77- group_size = 128
78- n_calibration_examples = 10
7989 sequence_length = 5
8090
81- m = ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
91+ assert device in device_to_base_configs , "Unsupported device: {}" .format (device )
92+ base_configs = device_to_base_configs [device ]
8293
83- # baseline quantization
84- if device == "cuda" :
85- base_config = Int4WeightOnlyConfig (group_size = group_size )
86- elif device == "xpu" :
87- base_config = Int4WeightOnlyConfig (
88- group_size = group_size , int4_packing_format = "plain_int32"
89- )
90- elif device == "cpu" :
91- base_config = Int4WeightOnlyConfig (
92- group_size = group_size , int4_packing_format = "opaque"
93- )
94- torch .manual_seed (1234 )
95- else :
96- assert False , "Unsupported device: {}" .format (device )
97- m_baseline = copy .deepcopy (m )
98- quantize_ (m_baseline , base_config )
94+ for base_config in base_configs :
95+ m = ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
96+ m_baseline = copy .deepcopy (m )
9997
100- # awq quantization
101- dataset = m .example_inputs (
102- dataset_size ,
103- sequence_length = sequence_length ,
104- dtype = original_dtype ,
105- device = device ,
106- )
107- ref_out = torch .cat ([m (d .squeeze (0 )) for d in dataset ])
98+ dataset = m .example_inputs (
99+ dataset_size ,
100+ sequence_length = sequence_length ,
101+ dtype = original_dtype ,
102+ device = device ,
103+ )
104+ # for test, we use calibration_data = dataset so that awq is
105+ # guranteed to be better than baseline
106+ # in reality, calibration_data will be a small subset or a different
107+ # dataset
108+ calibration_data = dataset
109+ # concatenatd inputs
110+ input_cat = torch .cat (calibration_data , dim = - 2 )
111+ ref_out = m (input_cat )
108112
109- calibration_data = dataset [:n_calibration_examples ]
113+ # baseline quantization
114+ quantize_ (m_baseline , base_config )
110115
111- quant_config = AWQConfig (base_config , step = AWQStep .PREPARE )
112- quantize_ (m , quant_config )
116+ # awq quantization
117+ quant_config = AWQConfig (base_config , step = AWQStep .PREPARE )
118+ quantize_ (m , quant_config )
113119
114- for example in calibration_data :
115- m (example )
120+ for example in calibration_data :
121+ m (example )
116122
117- quant_config = AWQConfig (base_config , step = AWQStep .CONVERT )
118- quantize_ (m , quant_config )
123+ quant_config = AWQConfig (base_config , step = AWQStep .CONVERT )
124+ quantize_ (m , quant_config )
119125
120- awq_out = torch .cat ([m (d .squeeze (0 )) for d in dataset ])
121- baseline_out = torch .cat ([m_baseline (d .squeeze (0 )) for d in dataset ])
126+ # evaluating on calibration data set to remove any uncertainty
127+ awq_out = m (input_cat )
128+ baseline_out = m_baseline (input_cat )
122129
123- loss_awq = (ref_out - awq_out ).pow (2 ).mean ().item ()
124- loss_base = (ref_out - baseline_out ).pow (2 ).mean ().item ()
125- assert loss_awq < loss_base
130+ loss_awq = (ref_out - awq_out ).pow (2 ).mean ().item ()
131+ loss_base = (ref_out - baseline_out ).pow (2 ).mean ().item ()
132+ assert loss_awq <= loss_base
126133
127- @parameterized . expand ([( device ,) for device in devices ] )
134+ @parametrize ( " device" , devices )
128135 def test_awq_loading (self , device ):
129- dataset_size = 100
136+ dataset_size = 10
130137 l1 , l2 , l3 = 512 , 256 , 128
131138 original_dtype = torch .bfloat16 # tinygemm kernel only uses bfloat16 inputs
132- group_size = 128
133- n_calibration_examples = 10
134139 sequence_length = 5
135140
136- m = ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
137- dataset = m .example_inputs (
138- dataset_size ,
139- sequence_length = sequence_length ,
140- dtype = original_dtype ,
141- device = device ,
142- )
143- calibration_data = dataset [:n_calibration_examples ]
144-
145- # calibrate
146- if device == "cuda" :
147- base_config = Int4WeightOnlyConfig (group_size = group_size )
148- elif device == "xpu" :
149- base_config = Int4WeightOnlyConfig (
150- group_size = group_size , int4_packing_format = "plain_int32"
151- )
152- elif device == "cpu" :
153- base_config = Int4WeightOnlyConfig (
154- group_size = group_size , int4_packing_format = "opaque"
141+ assert device in device_to_base_configs , "Unsupported device: {}" .format (device )
142+ base_configs = device_to_base_configs [device ]
143+
144+ for base_config in base_configs :
145+ m = ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
146+ dataset = m .example_inputs (
147+ dataset_size ,
148+ sequence_length = sequence_length ,
149+ dtype = original_dtype ,
150+ device = device ,
155151 )
156- else :
157- assert False , "Unsupported device: {}" . format ( device )
158- quant_config = AWQConfig ( base_config , step = AWQStep . PREPARE )
159- quantize_ ( m , quant_config )
152+ # for test purpose, we don't need to get a subset
153+ calibration_data = dataset
154+ # concatenatd inputs
155+ input_cat = torch . cat ( calibration_data , dim = - 2 )
160156
161- for example in calibration_data :
162- m (example )
157+ # calibrate
163158
164- # quantize
165- quant_config = AWQConfig (base_config , step = AWQStep .CONVERT )
166- quantize_ (m , quant_config )
159+ quant_config = AWQConfig (base_config , step = AWQStep .PREPARE )
160+ quantize_ (m , quant_config )
167161
168- with tempfile .NamedTemporaryFile () as f :
169- torch .save (m .state_dict (), f )
170- f .seek (0 )
171- state_dict = torch .load (f )
162+ for example in calibration_data :
163+ m (example )
172164
173- loaded_model = ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
174- loaded_model .load_state_dict (state_dict , assign = True )
165+ # quantize
166+ quant_config = AWQConfig (base_config , step = AWQStep .CONVERT )
167+ quantize_ (m , quant_config )
175168
176- m = torch .compile (m , fullgraph = True )
177- loaded_model = torch .compile (loaded_model , fullgraph = True )
169+ with tempfile .NamedTemporaryFile () as f :
170+ torch .save (m .state_dict (), f )
171+ f .seek (0 )
172+ state_dict = torch .load (f )
178173
179- awq_out = torch .cat ([m (d .squeeze (0 )) for d in dataset ])
180- awq_save_load_out = torch .cat ([loaded_model (d .squeeze (0 )) for d in dataset ])
174+ loaded_model = (
175+ ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
176+ )
177+ loaded_model .load_state_dict (state_dict , assign = True )
181178
182- assert awq_out is not None
183- assert awq_save_load_out is not None
184- assert torch .allclose (awq_out , awq_save_load_out , atol = 1e-2 )
179+ m = torch .compile (m , fullgraph = True )
180+ loaded_model = torch .compile (loaded_model , fullgraph = True )
185181
186- @parameterized .expand ([(device ,) for device in devices ])
182+ awq_out = m (input_cat )
183+ awq_save_load_out = loaded_model (input_cat )
184+
185+ assert awq_out is not None
186+ assert awq_save_load_out is not None
187+ assert torch .allclose (awq_out , awq_save_load_out , atol = 1e-2 )
188+
189+ @parametrize ("device" , devices )
187190 def test_awq_loading_vllm (self , device ):
188191 """Simulate weight loading in vllm:
189192 * prepare model weight to the same format (awq weight)
190193 * use weight.copy_(state_dict["weight"]) to copy over the quantized weights from checkpoint
191194
192195 There is also a slicing op that is ommitted here, overall e2e is tested in tests in vllm repo
193196 """
194- dataset_size = 100
197+ dataset_size = 10
195198 l1 , l2 , l3 = 512 , 256 , 128
196199 original_dtype = torch .bfloat16 # tinygemm kernel only uses bfloat16 inputs
197- group_size = 128
198- n_calibration_examples = 10
199200 sequence_length = 5
200201
201- m = ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
202- dataset = m .example_inputs (
203- dataset_size ,
204- sequence_length = sequence_length ,
205- dtype = original_dtype ,
206- device = device ,
207- )
208- calibration_data = dataset [:n_calibration_examples ]
209-
210- # calibrate
211- if device == "cuda" :
212- base_config = Int4WeightOnlyConfig (group_size = group_size )
213- elif device == "xpu" :
214- base_config = Int4WeightOnlyConfig (
215- group_size = group_size , int4_packing_format = "plain_int32"
216- )
217- elif device == "cpu" :
218- base_config = Int4WeightOnlyConfig (
219- group_size = group_size , int4_packing_format = "opaque"
202+ assert device in device_to_base_configs , "Unsupported device: {}" .format (device )
203+ base_configs = device_to_base_configs [device ]
204+
205+ for base_config in base_configs :
206+ m = ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
207+ dataset = m .example_inputs (
208+ dataset_size ,
209+ sequence_length = sequence_length ,
210+ dtype = original_dtype ,
211+ device = device ,
220212 )
221- else :
222- assert False , "Unsupported device: {}" . format ( device )
223- quant_config = AWQConfig ( base_config , step = AWQStep . PREPARE )
224- quantize_ ( m , quant_config )
213+ # for test purpose, we don't need to get a subset
214+ calibration_data = dataset
215+ # concatenatd inputs
216+ input_cat = torch . cat ( calibration_data , dim = - 2 )
225217
226- for example in calibration_data :
227- m (example )
218+ # calibrate
219+ quant_config = AWQConfig (base_config , step = AWQStep .PREPARE )
220+ quantize_ (m , quant_config )
228221
229- # quantize
230- quant_config = AWQConfig (base_config , step = AWQStep .CONVERT )
231- quantize_ (m , quant_config )
222+ for example in calibration_data :
223+ m (example )
232224
233- with tempfile .NamedTemporaryFile () as f :
234- torch .save (m .state_dict (), f )
235- f .seek (0 )
236- state_dict = torch .load (f )
225+ # quantize
226+ quant_config = AWQConfig (base_config , step = AWQStep .CONVERT )
227+ quantize_ (m , quant_config )
228+
229+ with tempfile .NamedTemporaryFile () as f :
230+ torch .save (m .state_dict (), f )
231+ f .seek (0 )
232+ state_dict = torch .load (f )
233+
234+ loaded_model = (
235+ ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
236+ )
237+ quant_config = AWQConfig (base_config , step = AWQStep .PREPARE_FOR_LOADING )
238+ quantize_ (loaded_model , quant_config )
237239
238- loaded_model = ToyLinearModel ( l1 , l2 , l3 ). eval (). to ( original_dtype ). to ( device )
239- quant_config = AWQConfig ( base_config , step = AWQStep . PREPARE_FOR_LOADING )
240- quantize_ ( loaded_model , quant_config )
240+ loaded_model . linear1 . weight . copy_ ( state_dict [ "linear1.weight" ] )
241+ loaded_model . linear2 . weight . copy_ ( state_dict [ "linear2.weight" ] )
242+ loaded_model . linear3 . weight . copy_ ( state_dict [ "linear3.weight" ] )
241243
242- loaded_model .linear1 .weight .copy_ (state_dict ["linear1.weight" ])
243- loaded_model .linear2 .weight .copy_ (state_dict ["linear2.weight" ])
244- loaded_model .linear3 .weight .copy_ (state_dict ["linear3.weight" ])
244+ m = torch .compile (m , fullgraph = True )
245+ loaded_model = torch .compile (loaded_model , fullgraph = True )
245246
246- m = torch . compile ( m , fullgraph = True )
247- loaded_model = torch . compile ( loaded_model , fullgraph = True )
247+ awq_out = m ( input_cat )
248+ awq_save_load_out = loaded_model ( input_cat )
248249
249- awq_out = torch .cat ([m (d .squeeze (0 )) for d in dataset ])
250- awq_save_load_out = torch .cat ([loaded_model (d .squeeze (0 )) for d in dataset ])
250+ assert awq_out is not None
251+ assert awq_save_load_out is not None
252+ assert torch .allclose (awq_out , awq_save_load_out , atol = 1e-2 )
251253
252- assert awq_out is not None
253- assert awq_save_load_out is not None
254- assert torch .allclose (awq_out , awq_save_load_out , atol = 1e-2 )
255254
255+ instantiate_parametrized_tests (TestAWQ )
256256
257257if __name__ == "__main__" :
258258 run_tests ()
0 commit comments