55# LICENSE file in the root directory of this source tree.
66import copy
77import tempfile
8- import unittest
98
109import torch
10+ from parameterized import parameterized
1111from torch .testing ._internal .common_utils import (
1212 TestCase ,
1313 run_tests ,
1414)
1515
1616from torchao .prototype .awq import AWQConfig , AWQStep
1717from torchao .quantization import Int4WeightOnlyConfig , quantize_
18- from torchao .utils import _is_fbgemm_genai_gpu_available
18+ from torchao .utils import _is_fbgemm_genai_gpu_available , torch_version_at_least
1919
2020
2121class ToyLinearModel (torch .nn .Module ):
@@ -42,11 +42,15 @@ def forward(self, x):
4242 return x
4343
4444
45- @unittest .skipIf (not torch .cuda .is_available (), reason = "CUDA not available" )
46- @unittest .skipIf (
47- not _is_fbgemm_genai_gpu_available (),
48- reason = "need to install fbgemm_gpu_genai package" ,
49- )
45+ devices = ["cpu" ]
46+ if (
47+ torch .cuda .is_available ()
48+ and _is_fbgemm_genai_gpu_available ()
49+ and torch_version_at_least ("2.6.0" )
50+ ):
51+ devices .append ("cuda" )
52+
53+
5054class TestAWQ (TestCase ):
5155 def test_awq_config (self ):
5256 base_config = Int4WeightOnlyConfig ()
@@ -61,8 +65,8 @@ def test_awq_config(self):
6165 with self .assertRaisesRegex (ValueError , "is not one of" ):
6266 AWQConfig (base_config , step = "not_supported" )
6367
64- def test_awq_functionality ( self ):
65- device = "cuda"
68+ @ parameterized . expand ([( device ,) for device in devices ])
69+ def test_awq_functionality ( self , device ):
6670 dataset_size = 100
6771 l1 , l2 , l3 = 512 , 256 , 128
6872 original_dtype = torch .bfloat16 # tinygemm kernel only uses bfloat16 inputs
@@ -73,7 +77,15 @@ def test_awq_functionality(self):
7377 m = ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
7478
7579 # baseline quantization
76- base_config = Int4WeightOnlyConfig (group_size = group_size )
80+ if device == "cuda" :
81+ base_config = Int4WeightOnlyConfig (group_size = group_size )
82+ elif device == "cpu" :
83+ base_config = Int4WeightOnlyConfig (
84+ group_size = group_size , int4_packing_format = "opaque"
85+ )
86+ torch .manual_seed (1234 )
87+ else :
88+ assert False , "Unsupported device: {}" .format (device )
7789 m_baseline = copy .deepcopy (m )
7890 quantize_ (m_baseline , base_config )
7991
@@ -104,8 +116,8 @@ def test_awq_functionality(self):
104116 loss_base = (ref_out - baseline_out ).pow (2 ).mean ().item ()
105117 assert loss_awq < loss_base
106118
107- def test_awq_loading ( self ):
108- device = "cuda"
119+ @ parameterized . expand ([( device ,) for device in devices ])
120+ def test_awq_loading ( self , device ):
109121 dataset_size = 100
110122 l1 , l2 , l3 = 512 , 256 , 128
111123 original_dtype = torch .bfloat16 # tinygemm kernel only uses bfloat16 inputs
@@ -123,7 +135,14 @@ def test_awq_loading(self):
123135 calibration_data = dataset [:n_calibration_examples ]
124136
125137 # calibrate
126- base_config = Int4WeightOnlyConfig (group_size = group_size )
138+ if device == "cuda" :
139+ base_config = Int4WeightOnlyConfig (group_size = group_size )
140+ elif device == "cpu" :
141+ base_config = Int4WeightOnlyConfig (
142+ group_size = group_size , int4_packing_format = "opaque"
143+ )
144+ else :
145+ assert False , "Unsupported device: {}" .format (device )
127146 quant_config = AWQConfig (base_config , step = AWQStep .PREPARE )
128147 quantize_ (m , quant_config )
129148
@@ -152,14 +171,14 @@ def test_awq_loading(self):
152171 assert awq_save_load_out is not None
153172 assert torch .allclose (awq_out , awq_save_load_out , atol = 1e-2 )
154173
155- def test_awq_loading_vllm (self ):
174+ @parameterized .expand ([(device ,) for device in devices ])
175+ def test_awq_loading_vllm (self , device ):
156176 """Simulate weight loading in vllm:
157177 * prepare model weight to the same format (awq weight)
158178 * use weight.copy_(state_dict["weight"]) to copy over the quantized weights from checkpoint
159179
160180 There is also a slicing op that is ommitted here, overall e2e is tested in tests in vllm repo
161181 """
162- device = "cuda"
163182 dataset_size = 100
164183 l1 , l2 , l3 = 512 , 256 , 128
165184 original_dtype = torch .bfloat16 # tinygemm kernel only uses bfloat16 inputs
@@ -177,7 +196,14 @@ def test_awq_loading_vllm(self):
177196 calibration_data = dataset [:n_calibration_examples ]
178197
179198 # calibrate
180- base_config = Int4WeightOnlyConfig (group_size = group_size )
199+ if device == "cuda" :
200+ base_config = Int4WeightOnlyConfig (group_size = group_size )
201+ elif device == "cpu" :
202+ base_config = Int4WeightOnlyConfig (
203+ group_size = group_size , int4_packing_format = "opaque"
204+ )
205+ else :
206+ assert False , "Unsupported device: {}" .format (device )
181207 quant_config = AWQConfig (base_config , step = AWQStep .PREPARE )
182208 quantize_ (m , quant_config )
183209
0 commit comments