66
66
from torch .ao .quantization .quantize_fx import convert_to_reference_fx , prepare_fx
67
67
import os
68
68
from parameterized import parameterized
69
+ import itertools
70
+ import logging
69
71
from torchao .quantization .utils import TORCH_VERSION_AFTER_2_3
70
72
73
+ logger = logging .getLogger ("INFO" )
74
+
71
75
torch .manual_seed (0 )
72
76
config .cache_size_limit = 100
73
77
74
- COMMON_DEVICE_DTYPE = [
75
- ("cpu" , torch .float32 ),
76
- ("cpu" , torch .float16 ),
77
- ("cpu" , torch .bfloat16 ),
78
- ("cuda" , torch .float32 ),
79
- ("cuda" , torch .float16 ),
80
- ("cuda" , torch .bfloat16 ),
78
+ TENSOR_SUBCLASS_APIS = [
79
+ change_linear_weights_to_int8_dqtensors ,
80
+ change_linear_weights_to_int8_woqtensors ,
81
+ change_linear_weights_to_int4_woqtensors ,
81
82
]
82
83
84
+ COMMON_DEVICES = ["cpu" , "cuda" ]
85
+
86
+ COMMON_DTYPES = [torch .float32 , torch .float16 , torch .bfloat16 ]
87
+
88
+ COMMON_DEVICE_DTYPE = list (itertools .product (COMMON_DEVICES , COMMON_DTYPES ))
89
+
83
90
def combine_parameters (a , b ):
84
91
new_tuples = []
85
92
for (tuple1 , tuple2 ) in itertools .product (a , b ):
86
93
new_tuples .append (tuple1 + tuple2 )
87
94
return new_tuples
88
95
89
96
def run_supported_device_dtype (test_method ):
97
+ """Assumes that the 3rd arg (args[2]) of the decorated method is device and
98
+ there is a `test_dtype` kwarg or the 4th arg (args[3]) that indicates the dtype for testing
99
+ """
90
100
def wrapper (* args , ** kwargs ):
91
- if args [2 ] == "cuda" and not torch .cuda .is_available ():
101
+ if len (args ) < 3 :
102
+ raise unittest .SkipTest ("Not enoguh args" )
103
+ device = args [2 ]
104
+ dtype = kwargs ["test_dtype" ] if "test_dtype" in kwargs else args [3 ]
105
+ if device == "cuda" and not torch .cuda .is_available ():
92
106
raise unittest .SkipTest (f"Need CUDA available." )
93
- if args [ 2 ] == "cuda" and torch .cuda .is_available () and kwargs [ 'test_dtype' ] == torch .bfloat16 and torch .cuda .get_device_capability () < (8 , 0 ):
107
+ if device == "cuda" and torch .cuda .is_available () and dtype == torch .bfloat16 and torch .cuda .get_device_capability () < (8 , 0 ):
94
108
raise unittest .SkipTest ("Need CUDA and SM80+ available." )
95
109
return test_method (* args , ** kwargs )
96
110
return wrapper
@@ -1145,6 +1159,7 @@ def _test_handle_save_load_meta_impl(
1145
1159
min_sqnr = 35 ,
1146
1160
test_dtype = torch .bfloat16
1147
1161
):
1162
+ logger .info (f"TestSaveLoad: { api } , { test_device } , { test_dtype } " )
1148
1163
m , k , n = 32 , 64 , 32
1149
1164
1150
1165
class test_model (nn .Module ):
@@ -1170,7 +1185,9 @@ def forward(self, x):
1170
1185
api (model )
1171
1186
torch .save (model .state_dict (), "test.pth" )
1172
1187
# get quantized reference
1173
- model_qc = torch .compile (model , mode = "max-autotune" )
1188
+ # model_qc = torch.compile(model, mode="max-autotune")
1189
+ model_qc = torch .export .export (model , (x ,)).module ()
1190
+ # model_qc = model
1174
1191
ref_q = model_qc (x ).detach ()
1175
1192
1176
1193
assert SQNR (ref_f , ref_q ) > min_sqnr
@@ -1187,7 +1204,8 @@ def forward(self, x):
1187
1204
model = model .to (device = test_device , dtype = test_dtype ).eval ()
1188
1205
1189
1206
# get quantized reference
1190
- model_qc = torch .compile (model , mode = "max-autotune" )
1207
+ # model_qc = torch.compile(model, mode="max-autotune")
1208
+ model_qc = model
1191
1209
test = model_qc (x ).detach ()
1192
1210
1193
1211
assert SQNR (ref_f , test ) > min_sqnr
@@ -1404,5 +1422,52 @@ def test_autoquant_multi_input(self, device, dtype, m1, m2, k, n):
1404
1422
sqnr = SQNR (out , out2 )
1405
1423
self .assertTrue (sqnr >= 30 )
1406
1424
1425
+
1426
+ class TestAOTI (unittest .TestCase ):
1427
+ @run_supported_device_dtype
1428
+ @torch .no_grad ()
1429
+ @parameterized .expand (
1430
+ list (itertools .product (TENSOR_SUBCLASS_APIS , COMMON_DEVICES , COMMON_DTYPES )),
1431
+ )
1432
+ def test_aoti (self , api , test_device , test_dtype ):
1433
+ logger .info (f"TestAOTI: { api } , { test_device } , { test_dtype } " )
1434
+ if api is change_linear_weights_to_int8_dqtensors and test_device == "cuda" :
1435
+ self .skipTest (f"{ api } in { test_device } is not support for aoti compilation yet" )
1436
+ m , k , n = 32 , 64 , 32
1437
+
1438
+ class test_model (nn .Module ):
1439
+ def __init__ (self ):
1440
+ super ().__init__ ()
1441
+ self .lin1 = nn .Linear (k , n )
1442
+ self .relu = nn .ReLU ()
1443
+ self .lin2 = nn .Linear (n , n )
1444
+
1445
+ def forward (self , x ):
1446
+ x = self .lin1 (x )
1447
+ x = self .relu (x )
1448
+ x = self .lin2 (x )
1449
+ return x
1450
+
1451
+ x = torch .randn (m , k , dtype = test_dtype , device = test_device )
1452
+
1453
+ # get float reference
1454
+ model = test_model ().to (dtype = test_dtype , device = test_device ).eval ()
1455
+ ref_f = model (x )
1456
+
1457
+ print ("calling quant" )
1458
+ api (model )
1459
+
1460
+ # running model
1461
+ print ("running model" )
1462
+ model (x )
1463
+ print ("model:" , model )
1464
+ print ("model weight:" , model .lin1 .weight )
1465
+
1466
+ # make sure it compiles
1467
+ example_inputs = (x ,)
1468
+ print ("compiling model" )
1469
+ torch ._export .aot_compile (model , example_inputs )
1470
+
1471
+
1407
1472
if __name__ == "__main__" :
1408
1473
unittest .main ()
0 commit comments