Skip to content

Commit

Permalink
Refactor tensor subclass API to also use paramterization
Browse files Browse the repository at this point in the history
Summary:
Also added tests for tensor subclass api + AOTI compilation

Test Plan:
python test/integration/test_integration.py -k test_aoti

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Apr 30, 2024
1 parent e3ed90f commit a906c53
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 51 deletions.
82 changes: 72 additions & 10 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,31 +67,45 @@
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
import os
from parameterized import parameterized
import itertools
import logging
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4

logger = logging.getLogger("INFO")

torch.manual_seed(0)
config.cache_size_limit = 100

COMMON_DEVICE_DTYPE=[
("cpu", torch.float32),
("cpu", torch.float16),
("cpu", torch.bfloat16),
("cuda", torch.float32),
("cuda", torch.float16),
("cuda", torch.bfloat16),
TENSOR_SUBCLASS_APIS = [
change_linear_weights_to_int8_dqtensors,
change_linear_weights_to_int8_woqtensors,
change_linear_weights_to_int4_woqtensors,
]

COMMON_DEVICES = ["cpu", "cuda"]

COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]

COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy()

def combine_parameters(a, b):
new_tuples = []
for (tuple1, tuple2) in itertools.product(a, b):
new_tuples.append(tuple1 + tuple2)
return new_tuples

def run_supported_device_dtype(test_method):
"""Assumes that the 3rd arg (args[2]) of the decorated method is device and
there is a `test_dtype` kwarg or the 4th arg (args[3]) that indicates the dtype for testing
"""
def wrapper(*args, **kwargs):
if args[2] == "cuda" and not torch.cuda.is_available():
if len(args) < 3:
raise unittest.SkipTest("Not enoguh args")
device = args[2]
dtype = kwargs["test_dtype"] if "test_dtype" in kwargs else args[3]
if device == "cuda" and not torch.cuda.is_available():
raise unittest.SkipTest(f"Need CUDA available.")
if args[2] == "cuda" and torch.cuda.is_available() and kwargs['test_dtype'] == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0):
if device == "cuda" and torch.cuda.is_available() and dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0):
raise unittest.SkipTest("Need CUDA and SM80+ available.")
return test_method(*args, **kwargs)
return wrapper
Expand Down Expand Up @@ -1148,6 +1162,7 @@ def _test_handle_save_load_meta_impl(
min_sqnr=35,
test_dtype=torch.bfloat16
):
logger.info(f"TestSaveLoad: {api}, {test_device}, {test_dtype}")
m, k, n = 32, 64, 32

class test_model(nn.Module):
Expand Down Expand Up @@ -1180,7 +1195,7 @@ def forward(self, x):

# load model structure
with torch.device('meta'):
model = test_model()
model = test_model().to(dtype=test_dtype)
api(model)

# load quantized state_dict
Expand Down Expand Up @@ -1407,5 +1422,52 @@ def test_autoquant_multi_input(self, device, dtype, m1, m2, k, n):
sqnr = SQNR(out, out2)
self.assertTrue(sqnr >= 30)


class TestAOTI(unittest.TestCase):
@run_supported_device_dtype
@torch.no_grad()
@parameterized.expand(
list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)),
)
def test_aoti(self, api, test_device, test_dtype):
logger.info(f"TestAOTI: {api}, {test_device}, {test_dtype}")
if api is change_linear_weights_to_int8_dqtensors and test_device == "cuda":
self.skipTest(f"{api} in {test_device} is not support for aoti compilation yet")

if test_dtype != torch.bfloat16:
self.skipTest(f"{api} in {test_dtype} is not support for aoti compilation yet")

m, k, n = 32, 64, 32

class test_model(nn.Module):
def __init__(self):
super().__init__()
self.lin1 = nn.Linear(k, n)
self.relu = nn.ReLU()
self.lin2 = nn.Linear(n, n)

def forward(self, x):
x = self.lin1(x)
x = self.relu(x)
x = self.lin2(x)
return x

x = torch.randn(m, k, dtype=test_dtype, device=test_device)

# get float reference
model = test_model().to(dtype=test_dtype, device=test_device).eval()
ref_f = model(x)

kwargs = {"dtype": test_dtype}
api(model, **kwargs)

# running model
model(x)

# make sure it compiles
example_inputs = (x,)
torch._export.aot_compile(model, example_inputs)


if __name__ == "__main__":
unittest.main()
30 changes: 19 additions & 11 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,19 +117,27 @@ def apply_dynamic_quant(model, filter_fn=None):
change_linear_weights_to_int8_dqtensors(model, filter_fn)


def _get_subclass_inserter(cls, **kwargs):
method = kwargs.pop("method", "from_float")
import torch.nn.utils.parametrize as parametrize

def _get_subclass_inserter(cls, enable_parametrization=False, **kwargs):
constructor = kwargs.pop("constructor", "subclass_constructor")
from_float = kwargs.pop("method", "from_float")
def insert_subclass(lin):
lin.weight = torch.nn.Parameter(
# cls.from_float(...)
getattr(cls, method)(lin.weight, **kwargs), requires_grad=False
)
if enable_parametrization:
lin.weight = torch.nn.Parameter(cls.from_float(lin.weight), requires_grad=False)
_, args = lin.weight.__tensor_flatten__()
parametrize.register_parametrization(lin, "weight", getattr(cls, constructor)(cls, *args))
else:
lin.weight = torch.nn.Parameter(
# cls.from_float(...)
getattr(cls, from_float)(lin.weight, **kwargs), requires_grad=False
)
return lin

return insert_subclass


def change_linear_weights_to_int8_dqtensors(model, filter_fn=None):
def change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs):
"""
Converts all linear weight tensors to the `Int8DynamicallyQuantizedLinearWeight`
Tensor subclass, effectively applying the same form of quantization
Expand All @@ -141,11 +149,11 @@ def change_linear_weights_to_int8_dqtensors(model, filter_fn=None):
)

_replace_with_custom_fn_if_matches_filter(
model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight), filter_fn
model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=True, **kwargs), filter_fn
)


def change_linear_weights_to_int8_woqtensors(model, filter_fn=None):
def change_linear_weights_to_int8_woqtensors(model, filter_fn=None, **kwargs):
"""
Converts all linear weight tensors to the
`Int8WeightOnlyQuantizedLinearWeight` tensor subclass,
Expand All @@ -154,7 +162,7 @@ def change_linear_weights_to_int8_woqtensors(model, filter_fn=None):
"""
_replace_with_custom_fn_if_matches_filter(
model,
_get_subclass_inserter(Int8WeightOnlyQuantizedLinearWeight),
_get_subclass_inserter(Int8WeightOnlyQuantizedLinearWeight, enable_parametrization=True, **kwargs),
_is_linear if filter_fn is None else filter_fn,
)

Expand All @@ -170,7 +178,7 @@ def change_linear_weights_to_int4_woqtensors(model, **kwargs):

_replace_with_custom_fn_if_matches_filter(
model,
_get_subclass_inserter(Int4WeightOnlyQuantizedLinearWeight, **kwargs),
_get_subclass_inserter(Int4WeightOnlyQuantizedLinearWeight, enable_parametrization=True, **kwargs),
filter_fn,
)

Expand Down
16 changes: 8 additions & 8 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def quant_int8_dynamic_per_token_linear(
x_vals_int8, x_scales, w_vals_int8_t, w_scales, out_dtype
)
if bias is not None:
mm_out += bias
mm_out = mm_out + bias
return mm_out


Expand Down Expand Up @@ -554,7 +554,7 @@ def quant_int8_per_token_matmul(
return y


def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128):
def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16):
"""This is tinygemm specific, we'll keep this for now"""
if groupsize > w.shape[-1]:
groupsize = w.shape[-1]
Expand All @@ -570,15 +570,15 @@ def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128):
max_int = 2**n_bit - 1
scales = (max_val - min_val).clamp(min=1e-6) / max_int
zeros = min_val + scales * (2 ** (n_bit - 1))
return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
torch.bfloat16
return scales.to(dtype=dtype).reshape(w.shape[0], -1), zeros.to(
dtype=dtype
).reshape(w.shape[0], -1)


def pack_tinygemm_scales_and_zeros(scales, zeros):
assert scales.shape == zeros.shape
assert scales.dtype == torch.bfloat16
assert zeros.dtype == torch.bfloat16
assert scales.dtype == torch.bfloat16, f" got dtype: {scales.dtype}"
assert zeros.dtype == torch.bfloat16, f" got dtype: {zeros.dtype}"
return (
torch.cat(
[
Expand Down Expand Up @@ -661,8 +661,8 @@ def groupwise_affine_dequantize_tensor_from_qparams(
return w_dq


def groupwise_affine_quantize_tensor(w, n_bit=4, groupsize=128):
scales, zeros = get_groupwise_affine_qparams(w, n_bit, groupsize)
def groupwise_affine_quantize_tensor(w, n_bit=4, groupsize=128, dtype=torch.bfloat16):
scales, zeros = get_groupwise_affine_qparams(w, n_bit, groupsize, dtype)
w_int4x8 = groupwise_affine_quantize_tensor_from_qparams(
w, scales, zeros, n_bit, groupsize
)
Expand Down
Loading

0 comments on commit a906c53

Please sign in to comment.