Skip to content

Commit a9e5563

Browse files
committed
Refactor tensor subclass API to also use paramterization
Summary: Also added tests for tensor subclass api + AOTI compilation Test Plan: python test/integration/test_integration.py -k test_aoti Reviewers: Subscribers: Tasks: Tags:
1 parent c403580 commit a9e5563

File tree

3 files changed

+148
-21
lines changed

3 files changed

+148
-21
lines changed

test/integration/test_integration.py

Lines changed: 76 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,31 +66,45 @@
6666
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
6767
import os
6868
from parameterized import parameterized
69+
import itertools
70+
import logging
6971
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3
7072

73+
logger = logging.getLogger("INFO")
74+
7175
torch.manual_seed(0)
7276
config.cache_size_limit = 100
7377

74-
COMMON_DEVICE_DTYPE=[
75-
("cpu", torch.float32),
76-
("cpu", torch.float16),
77-
("cpu", torch.bfloat16),
78-
("cuda", torch.float32),
79-
("cuda", torch.float16),
80-
("cuda", torch.bfloat16),
78+
TENSOR_SUBCLASS_APIS = [
79+
change_linear_weights_to_int8_dqtensors,
80+
change_linear_weights_to_int8_woqtensors,
81+
change_linear_weights_to_int4_woqtensors,
8182
]
8283

84+
COMMON_DEVICES = ["cpu", "cuda"]
85+
86+
COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]
87+
88+
COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES))
89+
8390
def combine_parameters(a, b):
8491
new_tuples = []
8592
for (tuple1, tuple2) in itertools.product(a, b):
8693
new_tuples.append(tuple1 + tuple2)
8794
return new_tuples
8895

8996
def run_supported_device_dtype(test_method):
97+
"""Assumes that the 3rd arg (args[2]) of the decorated method is device and
98+
there is a `test_dtype` kwarg or the 4th arg (args[3]) that indicates the dtype for testing
99+
"""
90100
def wrapper(*args, **kwargs):
91-
if args[2] == "cuda" and not torch.cuda.is_available():
101+
if len(args) < 3:
102+
raise unittest.SkipTest("Not enoguh args")
103+
device = args[2]
104+
dtype = kwargs["test_dtype"] if "test_dtype" in kwargs else args[3]
105+
if device == "cuda" and not torch.cuda.is_available():
92106
raise unittest.SkipTest(f"Need CUDA available.")
93-
if args[2] == "cuda" and torch.cuda.is_available() and kwargs['test_dtype'] == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0):
107+
if device == "cuda" and torch.cuda.is_available() and dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0):
94108
raise unittest.SkipTest("Need CUDA and SM80+ available.")
95109
return test_method(*args, **kwargs)
96110
return wrapper
@@ -1145,6 +1159,7 @@ def _test_handle_save_load_meta_impl(
11451159
min_sqnr=35,
11461160
test_dtype=torch.bfloat16
11471161
):
1162+
logger.info(f"TestSaveLoad: {api}, {test_device}, {test_dtype}")
11481163
m, k, n = 32, 64, 32
11491164

11501165
class test_model(nn.Module):
@@ -1170,7 +1185,9 @@ def forward(self, x):
11701185
api(model)
11711186
torch.save(model.state_dict(), "test.pth")
11721187
# get quantized reference
1173-
model_qc = torch.compile(model, mode="max-autotune")
1188+
# model_qc = torch.compile(model, mode="max-autotune")
1189+
model_qc = torch.export.export(model, (x,)).module()
1190+
# model_qc = model
11741191
ref_q = model_qc(x).detach()
11751192

11761193
assert SQNR(ref_f, ref_q) > min_sqnr
@@ -1187,7 +1204,8 @@ def forward(self, x):
11871204
model = model.to(device=test_device, dtype=test_dtype).eval()
11881205

11891206
# get quantized reference
1190-
model_qc = torch.compile(model, mode="max-autotune")
1207+
# model_qc = torch.compile(model, mode="max-autotune")
1208+
model_qc = model
11911209
test = model_qc(x).detach()
11921210

11931211
assert SQNR(ref_f, test) > min_sqnr
@@ -1404,5 +1422,52 @@ def test_autoquant_multi_input(self, device, dtype, m1, m2, k, n):
14041422
sqnr = SQNR(out, out2)
14051423
self.assertTrue(sqnr >= 30)
14061424

1425+
1426+
class TestAOTI(unittest.TestCase):
1427+
@run_supported_device_dtype
1428+
@torch.no_grad()
1429+
@parameterized.expand(
1430+
list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)),
1431+
)
1432+
def test_aoti(self, api, test_device, test_dtype):
1433+
logger.info(f"TestAOTI: {api}, {test_device}, {test_dtype}")
1434+
if api is change_linear_weights_to_int8_dqtensors and test_device == "cuda":
1435+
self.skipTest(f"{api} in {test_device} is not support for aoti compilation yet")
1436+
m, k, n = 32, 64, 32
1437+
1438+
class test_model(nn.Module):
1439+
def __init__(self):
1440+
super().__init__()
1441+
self.lin1 = nn.Linear(k, n)
1442+
self.relu = nn.ReLU()
1443+
self.lin2 = nn.Linear(n, n)
1444+
1445+
def forward(self, x):
1446+
x = self.lin1(x)
1447+
x = self.relu(x)
1448+
x = self.lin2(x)
1449+
return x
1450+
1451+
x = torch.randn(m, k, dtype=test_dtype, device=test_device)
1452+
1453+
# get float reference
1454+
model = test_model().to(dtype=test_dtype, device=test_device).eval()
1455+
ref_f = model(x)
1456+
1457+
print("calling quant")
1458+
api(model)
1459+
1460+
# running model
1461+
print("running model")
1462+
model(x)
1463+
print("model:", model)
1464+
print("model weight:", model.lin1.weight)
1465+
1466+
# make sure it compiles
1467+
example_inputs = (x,)
1468+
print("compiling model")
1469+
torch._export.aot_compile(model, example_inputs)
1470+
1471+
14071472
if __name__ == "__main__":
14081473
unittest.main()

torchao/quantization/quant_api.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,50 @@ def apply_dynamic_quant(model, filter_fn=None):
117117
change_linear_weights_to_int8_dqtensors(model, filter_fn)
118118

119119

120-
def _get_subclass_inserter(cls, **kwargs):
120+
import torch.nn.utils.parametrize as parametrize
121+
122+
123+
class ConstructTensorSubclass(torch.nn.Module):
124+
def __init__(self, tensor_subclass_ctr, transposed, shape, groupsize, inner_k_tiles, dtype):
125+
super().__init__()
126+
self.tensor_subclass_ctr = tensor_subclass_ctr
127+
self.transposed = transposed
128+
self.shape = shape
129+
self.groupsize = groupsize
130+
self.inner_k_tiles = inner_k_tiles
131+
self.dtype = dtype
132+
133+
def forward(self, int_data, scales_and_zeros):
134+
return Int4WeightOnlyQuantizedLinearWeight.from_qtensor_components(int_data, scales_and_zeros, self.transposed, self.shape, self.groupsize, self.inner_k_tiles, dtype=self.dtype)
135+
136+
def right_inverse(self, tensor_subclass_instance):
137+
# new_kwargs = {"groupsize": self.groupsize, "inner_k_tiles": self.inner_k_tiles}
138+
# tensor_subclass_instance = self.tensor_subclass_ctr.from_float(input_float, **new_kwargs)
139+
return tensor_subclass_instance.int_data, tensor_subclass_instance.scales_and_zeros
140+
141+
142+
def _get_subclass_inserter(cls, use_param=False, **kwargs):
121143
method = kwargs.pop("method", "from_float")
122144
def insert_subclass(lin):
123-
lin.weight = torch.nn.Parameter(
124-
# cls.from_float(...)
125-
getattr(cls, method)(lin.weight, **kwargs), requires_grad=False
126-
)
145+
if use_param:
146+
new_kwargs = {}
147+
if "groupsize" in kwargs:
148+
new_kwargs["groupsize"] = kwargs["groupsize"]
149+
if "inner_k_tiles" in kwargs:
150+
new_kwargs["inner_k_tiles"] = kwargs["inner_k_tiles"]
151+
int_data, scales_and_zeros, transposed, groupsize, inner_k_tiles = cls.to_qtensor_components(lin.weight, **new_kwargs)
152+
kwargs["transposed"] = transposed
153+
kwargs["shape"] = lin.weight.shape
154+
kwargs["dtype"] = lin.weight.dtype
155+
kwargs["groupsize"] = groupsize
156+
kwargs["inner_k_tiles"] = inner_k_tiles
157+
lin.weight = torch.nn.Parameter(cls(int_data, scales_and_zeros, **kwargs), requires_grad=False)
158+
parametrize.register_parametrization(lin, "weight", ConstructTensorSubclass(cls, **kwargs))
159+
else:
160+
lin.weight = torch.nn.Parameter(
161+
# cls.from_float(...)
162+
getattr(cls, method)(lin.weight, **kwargs), requires_grad=False
163+
)
127164
return lin
128165

129166
return insert_subclass
@@ -168,9 +205,10 @@ def change_linear_weights_to_int4_woqtensors(model, **kwargs):
168205
"""
169206
filter_fn = kwargs.pop("filter_fn", _is_linear)
170207

208+
print("kwargs in change linear to int4:", kwargs)
171209
_replace_with_custom_fn_if_matches_filter(
172210
model,
173-
_get_subclass_inserter(Int4WeightOnlyQuantizedLinearWeight, **kwargs),
211+
_get_subclass_inserter(Int4WeightOnlyQuantizedLinearWeight, use_param=True, **kwargs),
174212
filter_fn,
175213
)
176214

torchao/quantization/subclass.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -502,9 +502,21 @@ def from_float(cls, input_float, groupsize=128, inner_k_tiles=8):
502502
Int4WeightOnlyQuantizedLinearWeight.from_float(model.lin_mod.weight)
503503
)
504504
"""
505+
int_data, scales_and_zeros, transposed, groupsize, inner_k_tils = cls.to_qtensor_components(input_float, groupsize, inner_k_tiles)
506+
return cls(
507+
int_data,
508+
scales_and_zeros,
509+
transposed,
510+
input_float.shape,
511+
groupsize,
512+
inner_k_tiles,
513+
dtype=input_float.dtype,
514+
)
515+
516+
@classmethod
517+
def to_qtensor_components(cls, input_float, groupsize=128, inner_k_tiles=8):
505518
assert groupsize in [256, 128, 64, 32]
506519
assert inner_k_tiles in [8, 4, 2]
507-
orig_shape = input_float.shape
508520
orig_out_features, orig_in_features = input_float.shape
509521

510522
# padding
@@ -520,13 +532,25 @@ def from_float(cls, input_float, groupsize=128, inner_k_tiles=8):
520532
input_float, 4, groupsize
521533
)
522534
int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles)
535+
return int_data, scales_and_zeros, False, groupsize, inner_k_tiles
523536

537+
@classmethod
538+
def from_qtensor_components(
539+
cls,
540+
int_data,
541+
scales_and_zeros,
542+
transposed,
543+
shape,
544+
groupsize,
545+
inner_k_tiles,
546+
**kwargs
547+
):
524548
return cls(
525549
int_data,
526550
scales_and_zeros,
527-
False,
528-
orig_shape,
551+
transposed,
552+
shape,
529553
groupsize,
530554
inner_k_tiles,
531-
dtype=input_float.dtype,
555+
**kwargs,
532556
)

0 commit comments

Comments
 (0)