Skip to content

Commit

Permalink
Refactor tensor subclass API to also use paramterization
Browse files Browse the repository at this point in the history
Summary:
Also added tests for tensor subclass api + AOTI compilation

Test Plan:
python test/integration/test_integration.py -k test_aoti

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Apr 19, 2024
1 parent c403580 commit a9e5563
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 21 deletions.
87 changes: 76 additions & 11 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,31 +66,45 @@
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
import os
from parameterized import parameterized
import itertools
import logging
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3

logger = logging.getLogger("INFO")

torch.manual_seed(0)
config.cache_size_limit = 100

COMMON_DEVICE_DTYPE=[
("cpu", torch.float32),
("cpu", torch.float16),
("cpu", torch.bfloat16),
("cuda", torch.float32),
("cuda", torch.float16),
("cuda", torch.bfloat16),
TENSOR_SUBCLASS_APIS = [
change_linear_weights_to_int8_dqtensors,
change_linear_weights_to_int8_woqtensors,
change_linear_weights_to_int4_woqtensors,
]

COMMON_DEVICES = ["cpu", "cuda"]

COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]

COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES))

def combine_parameters(a, b):
new_tuples = []
for (tuple1, tuple2) in itertools.product(a, b):
new_tuples.append(tuple1 + tuple2)
return new_tuples

def run_supported_device_dtype(test_method):
"""Assumes that the 3rd arg (args[2]) of the decorated method is device and
there is a `test_dtype` kwarg or the 4th arg (args[3]) that indicates the dtype for testing
"""
def wrapper(*args, **kwargs):
if args[2] == "cuda" and not torch.cuda.is_available():
if len(args) < 3:
raise unittest.SkipTest("Not enoguh args")
device = args[2]
dtype = kwargs["test_dtype"] if "test_dtype" in kwargs else args[3]
if device == "cuda" and not torch.cuda.is_available():
raise unittest.SkipTest(f"Need CUDA available.")
if args[2] == "cuda" and torch.cuda.is_available() and kwargs['test_dtype'] == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0):
if device == "cuda" and torch.cuda.is_available() and dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0):
raise unittest.SkipTest("Need CUDA and SM80+ available.")
return test_method(*args, **kwargs)
return wrapper
Expand Down Expand Up @@ -1145,6 +1159,7 @@ def _test_handle_save_load_meta_impl(
min_sqnr=35,
test_dtype=torch.bfloat16
):
logger.info(f"TestSaveLoad: {api}, {test_device}, {test_dtype}")
m, k, n = 32, 64, 32

class test_model(nn.Module):
Expand All @@ -1170,7 +1185,9 @@ def forward(self, x):
api(model)
torch.save(model.state_dict(), "test.pth")
# get quantized reference
model_qc = torch.compile(model, mode="max-autotune")
# model_qc = torch.compile(model, mode="max-autotune")
model_qc = torch.export.export(model, (x,)).module()
# model_qc = model
ref_q = model_qc(x).detach()

assert SQNR(ref_f, ref_q) > min_sqnr
Expand All @@ -1187,7 +1204,8 @@ def forward(self, x):
model = model.to(device=test_device, dtype=test_dtype).eval()

# get quantized reference
model_qc = torch.compile(model, mode="max-autotune")
# model_qc = torch.compile(model, mode="max-autotune")
model_qc = model
test = model_qc(x).detach()

assert SQNR(ref_f, test) > min_sqnr
Expand Down Expand Up @@ -1404,5 +1422,52 @@ def test_autoquant_multi_input(self, device, dtype, m1, m2, k, n):
sqnr = SQNR(out, out2)
self.assertTrue(sqnr >= 30)


class TestAOTI(unittest.TestCase):
@run_supported_device_dtype
@torch.no_grad()
@parameterized.expand(
list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)),
)
def test_aoti(self, api, test_device, test_dtype):
logger.info(f"TestAOTI: {api}, {test_device}, {test_dtype}")
if api is change_linear_weights_to_int8_dqtensors and test_device == "cuda":
self.skipTest(f"{api} in {test_device} is not support for aoti compilation yet")
m, k, n = 32, 64, 32

class test_model(nn.Module):
def __init__(self):
super().__init__()
self.lin1 = nn.Linear(k, n)
self.relu = nn.ReLU()
self.lin2 = nn.Linear(n, n)

def forward(self, x):
x = self.lin1(x)
x = self.relu(x)
x = self.lin2(x)
return x

x = torch.randn(m, k, dtype=test_dtype, device=test_device)

# get float reference
model = test_model().to(dtype=test_dtype, device=test_device).eval()
ref_f = model(x)

print("calling quant")
api(model)

# running model
print("running model")
model(x)
print("model:", model)
print("model weight:", model.lin1.weight)

# make sure it compiles
example_inputs = (x,)
print("compiling model")
torch._export.aot_compile(model, example_inputs)


if __name__ == "__main__":
unittest.main()
50 changes: 44 additions & 6 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,50 @@ def apply_dynamic_quant(model, filter_fn=None):
change_linear_weights_to_int8_dqtensors(model, filter_fn)


def _get_subclass_inserter(cls, **kwargs):
import torch.nn.utils.parametrize as parametrize


class ConstructTensorSubclass(torch.nn.Module):
def __init__(self, tensor_subclass_ctr, transposed, shape, groupsize, inner_k_tiles, dtype):
super().__init__()
self.tensor_subclass_ctr = tensor_subclass_ctr
self.transposed = transposed
self.shape = shape
self.groupsize = groupsize
self.inner_k_tiles = inner_k_tiles
self.dtype = dtype

def forward(self, int_data, scales_and_zeros):
return Int4WeightOnlyQuantizedLinearWeight.from_qtensor_components(int_data, scales_and_zeros, self.transposed, self.shape, self.groupsize, self.inner_k_tiles, dtype=self.dtype)

def right_inverse(self, tensor_subclass_instance):
# new_kwargs = {"groupsize": self.groupsize, "inner_k_tiles": self.inner_k_tiles}
# tensor_subclass_instance = self.tensor_subclass_ctr.from_float(input_float, **new_kwargs)
return tensor_subclass_instance.int_data, tensor_subclass_instance.scales_and_zeros


def _get_subclass_inserter(cls, use_param=False, **kwargs):
method = kwargs.pop("method", "from_float")
def insert_subclass(lin):
lin.weight = torch.nn.Parameter(
# cls.from_float(...)
getattr(cls, method)(lin.weight, **kwargs), requires_grad=False
)
if use_param:
new_kwargs = {}
if "groupsize" in kwargs:
new_kwargs["groupsize"] = kwargs["groupsize"]
if "inner_k_tiles" in kwargs:
new_kwargs["inner_k_tiles"] = kwargs["inner_k_tiles"]
int_data, scales_and_zeros, transposed, groupsize, inner_k_tiles = cls.to_qtensor_components(lin.weight, **new_kwargs)
kwargs["transposed"] = transposed
kwargs["shape"] = lin.weight.shape
kwargs["dtype"] = lin.weight.dtype
kwargs["groupsize"] = groupsize
kwargs["inner_k_tiles"] = inner_k_tiles
lin.weight = torch.nn.Parameter(cls(int_data, scales_and_zeros, **kwargs), requires_grad=False)
parametrize.register_parametrization(lin, "weight", ConstructTensorSubclass(cls, **kwargs))
else:
lin.weight = torch.nn.Parameter(
# cls.from_float(...)
getattr(cls, method)(lin.weight, **kwargs), requires_grad=False
)
return lin

return insert_subclass
Expand Down Expand Up @@ -168,9 +205,10 @@ def change_linear_weights_to_int4_woqtensors(model, **kwargs):
"""
filter_fn = kwargs.pop("filter_fn", _is_linear)

print("kwargs in change linear to int4:", kwargs)
_replace_with_custom_fn_if_matches_filter(
model,
_get_subclass_inserter(Int4WeightOnlyQuantizedLinearWeight, **kwargs),
_get_subclass_inserter(Int4WeightOnlyQuantizedLinearWeight, use_param=True, **kwargs),
filter_fn,
)

Expand Down
32 changes: 28 additions & 4 deletions torchao/quantization/subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,9 +502,21 @@ def from_float(cls, input_float, groupsize=128, inner_k_tiles=8):
Int4WeightOnlyQuantizedLinearWeight.from_float(model.lin_mod.weight)
)
"""
int_data, scales_and_zeros, transposed, groupsize, inner_k_tils = cls.to_qtensor_components(input_float, groupsize, inner_k_tiles)
return cls(
int_data,
scales_and_zeros,
transposed,
input_float.shape,
groupsize,
inner_k_tiles,
dtype=input_float.dtype,
)

@classmethod
def to_qtensor_components(cls, input_float, groupsize=128, inner_k_tiles=8):
assert groupsize in [256, 128, 64, 32]
assert inner_k_tiles in [8, 4, 2]
orig_shape = input_float.shape
orig_out_features, orig_in_features = input_float.shape

# padding
Expand All @@ -520,13 +532,25 @@ def from_float(cls, input_float, groupsize=128, inner_k_tiles=8):
input_float, 4, groupsize
)
int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles)
return int_data, scales_and_zeros, False, groupsize, inner_k_tiles

@classmethod
def from_qtensor_components(
cls,
int_data,
scales_and_zeros,
transposed,
shape,
groupsize,
inner_k_tiles,
**kwargs
):
return cls(
int_data,
scales_and_zeros,
False,
orig_shape,
transposed,
shape,
groupsize,
inner_k_tiles,
dtype=input_float.dtype,
**kwargs,
)

0 comments on commit a9e5563

Please sign in to comment.