Skip to content

Commit

Permalink
Refactor tensor subclass API to also use paramterization
Browse files Browse the repository at this point in the history
Summary:
Also added tests for tensor subclass api + AOTI compilation

Test Plan:
python test/integration/test_integration.py -k test_aoti

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed May 1, 2024
1 parent e3ed90f commit a50fea5
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 58 deletions.
6 changes: 4 additions & 2 deletions .github/workflows/regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ jobs:
torch-spec: 'torch==2.3.0'
gpu-arch-type: "cuda"
gpu-arch-version: "12.1"
- name: CUDA 2.4.0.dev20240421
- name: CUDA 2.4.0.dev20240428
runs-on: linux.g5.12xlarge.nvidia.gpu
torch-spec: '--pre torch==2.4.0.dev20240421+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121'
torch-spec: '--pre torch==2.4.0.dev20240428+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121'
gpu-arch-type: "cuda"
gpu-arch-version: "12.1"
- name: CPU 2.2.2
Expand All @@ -58,6 +58,8 @@ jobs:
gpu-arch-type: ${{ matrix.gpu-arch-type }}
gpu-arch-version: ${{ matrix.gpu-arch-version }}
script: |
conda create -n venv python=3.9 -y
conda activate venv
python -m pip install --upgrade pip
pip install ${{ matrix.torch-spec }}
pip install -r requirements.txt
Expand Down
83 changes: 73 additions & 10 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,31 +67,45 @@
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
import os
from parameterized import parameterized
import itertools
import logging
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4

logger = logging.getLogger("INFO")

torch.manual_seed(0)
config.cache_size_limit = 100

COMMON_DEVICE_DTYPE=[
("cpu", torch.float32),
("cpu", torch.float16),
("cpu", torch.bfloat16),
("cuda", torch.float32),
("cuda", torch.float16),
("cuda", torch.bfloat16),
# TODO: use this to reduce the number of tests
TENSOR_SUBCLASS_APIS = [
change_linear_weights_to_int8_dqtensors,
change_linear_weights_to_int8_woqtensors,
change_linear_weights_to_int4_woqtensors,
]

COMMON_DEVICES = ["cpu", "cuda"]

COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]

COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy()

def combine_parameters(a, b):
new_tuples = []
for (tuple1, tuple2) in itertools.product(a, b):
new_tuples.append(tuple1 + tuple2)
return new_tuples

def run_supported_device_dtype(test_method):
"""Assumes that the 3rd arg (args[2]) of the decorated method is device and
there is a `test_dtype` kwarg or the 4th arg (args[3]) that indicates the dtype for testing
"""
def wrapper(*args, **kwargs):
if args[2] == "cuda" and not torch.cuda.is_available():
assert len(args) >= 3, f"Not enough args. Expected more than or equal to 3, but got {len(args)}"
device = args[2]
dtype = kwargs["test_dtype"] if "test_dtype" in kwargs else args[3]
if device == "cuda" and not torch.cuda.is_available():
raise unittest.SkipTest(f"Need CUDA available.")
if args[2] == "cuda" and torch.cuda.is_available() and kwargs['test_dtype'] == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0):
if device == "cuda" and torch.cuda.is_available() and dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0):
raise unittest.SkipTest("Need CUDA and SM80+ available.")
return test_method(*args, **kwargs)
return wrapper
Expand Down Expand Up @@ -1148,6 +1162,7 @@ def _test_handle_save_load_meta_impl(
min_sqnr=35,
test_dtype=torch.bfloat16
):
logger.info(f"TestSaveLoad: {api}, {test_device}, {test_dtype}")
m, k, n = 32, 64, 32

class test_model(nn.Module):
Expand Down Expand Up @@ -1180,7 +1195,7 @@ def forward(self, x):

# load model structure
with torch.device('meta'):
model = test_model()
model = test_model().to(dtype=test_dtype)
api(model)

# load quantized state_dict
Expand Down Expand Up @@ -1407,5 +1422,53 @@ def test_autoquant_multi_input(self, device, dtype, m1, m2, k, n):
sqnr = SQNR(out, out2)
self.assertTrue(sqnr >= 30)


class TestAOTI(unittest.TestCase):
@parameterized.expand(
list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)),
)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "aoti compatibility requires 2.4+.")
@torch.no_grad()
# @run_supported_device_dtype
def test_aoti(self, api, test_device, test_dtype):
logger.info(f"TestAOTI: {api}, {test_device}, {test_dtype}")
if api is change_linear_weights_to_int8_dqtensors and test_device == "cuda":
self.skipTest(f"{api} in {test_device} is not support for aoti compilation yet")

if test_dtype != torch.bfloat16:
self.skipTest(f"{api} in {test_dtype} is not support for aoti compilation yet")

m, k, n = 32, 64, 32

class test_model(nn.Module):
def __init__(self):
super().__init__()
self.lin1 = nn.Linear(k, n)
self.relu = nn.ReLU()
self.lin2 = nn.Linear(n, n)

def forward(self, x):
x = self.lin1(x)
x = self.relu(x)
x = self.lin2(x)
return x

x = torch.randn(m, k, dtype=test_dtype, device=test_device)

# get float reference
model = test_model().to(dtype=test_dtype, device=test_device).eval()
ref_f = model(x)

kwargs = {"dtype": test_dtype}
api(model, **kwargs)

# running model
model(x)

# make sure it compiles
example_inputs = (x,)
torch._export.aot_compile(model, example_inputs)


if __name__ == "__main__":
unittest.main()
32 changes: 20 additions & 12 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch.nn.functional as F

from .dynamic_quant import DynamicallyPerAxisQuantizedLinear
from .utils import TORCH_VERSION_AFTER_2_3
from .utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4

from .subclass import (
Int4WeightOnlyQuantizedLinearWeight,
Expand Down Expand Up @@ -117,19 +117,27 @@ def apply_dynamic_quant(model, filter_fn=None):
change_linear_weights_to_int8_dqtensors(model, filter_fn)


def _get_subclass_inserter(cls, **kwargs):
method = kwargs.pop("method", "from_float")
import torch.nn.utils.parametrize as parametrize

def _get_subclass_inserter(cls, enable_parametrization=False, **kwargs):
constructor = kwargs.pop("constructor", "subclass_constructor")
from_float = kwargs.pop("method", "from_float")
def insert_subclass(lin):
lin.weight = torch.nn.Parameter(
# cls.from_float(...)
getattr(cls, method)(lin.weight, **kwargs), requires_grad=False
)
if enable_parametrization:
lin.weight = torch.nn.Parameter(cls.from_float(lin.weight, **kwargs), requires_grad=False)
_, args = lin.weight.__tensor_flatten__()
parametrize.register_parametrization(lin, "weight", getattr(cls, constructor)(cls, *args))
else:
lin.weight = torch.nn.Parameter(
# cls.from_float(...)
getattr(cls, from_float)(lin.weight, **kwargs), requires_grad=False
)
return lin

return insert_subclass


def change_linear_weights_to_int8_dqtensors(model, filter_fn=None):
def change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs):
"""
Converts all linear weight tensors to the `Int8DynamicallyQuantizedLinearWeight`
Tensor subclass, effectively applying the same form of quantization
Expand All @@ -141,11 +149,11 @@ def change_linear_weights_to_int8_dqtensors(model, filter_fn=None):
)

_replace_with_custom_fn_if_matches_filter(
model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight), filter_fn
model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=TORCH_VERSION_AFTER_2_4, **kwargs), filter_fn
)


def change_linear_weights_to_int8_woqtensors(model, filter_fn=None):
def change_linear_weights_to_int8_woqtensors(model, filter_fn=None, **kwargs):
"""
Converts all linear weight tensors to the
`Int8WeightOnlyQuantizedLinearWeight` tensor subclass,
Expand All @@ -154,7 +162,7 @@ def change_linear_weights_to_int8_woqtensors(model, filter_fn=None):
"""
_replace_with_custom_fn_if_matches_filter(
model,
_get_subclass_inserter(Int8WeightOnlyQuantizedLinearWeight),
_get_subclass_inserter(Int8WeightOnlyQuantizedLinearWeight, enable_parametrization=TORCH_VERSION_AFTER_2_4, **kwargs),
_is_linear if filter_fn is None else filter_fn,
)

Expand All @@ -170,7 +178,7 @@ def change_linear_weights_to_int4_woqtensors(model, **kwargs):

_replace_with_custom_fn_if_matches_filter(
model,
_get_subclass_inserter(Int4WeightOnlyQuantizedLinearWeight, **kwargs),
_get_subclass_inserter(Int4WeightOnlyQuantizedLinearWeight, enable_parametrization=TORCH_VERSION_AFTER_2_4, **kwargs),
filter_fn,
)

Expand Down
24 changes: 15 additions & 9 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@
] + (_AFTER_TORCH_2_3_ONLY if TORCH_VERSION_AFTER_2_3 else [])


def guard_dtype_size(tensor_arg, arg_name, dtype=None, size=None):
if dtype is not None and tensor_arg.dtype != dtype:
raise ValueError("Expected Tensor argument {arg_name} to have dtype {dtype}, but got {tensor_arg.dtype} instead.")
if size is not None and tensor_arg.size() != size:
raise ValueError("Expected Tensor argument {arg_name} to have size {size}, but got {tensor_arg.size()} instead.")


_DTYPE_TO_QVALUE_BOUNDS = {
torch.uint8: (0, 255),
torch.int8: (-128, 127),
Expand Down Expand Up @@ -493,7 +500,7 @@ def quant_int8_dynamic_per_token_linear(
x_vals_int8, x_scales, w_vals_int8_t, w_scales, out_dtype
)
if bias is not None:
mm_out += bias
mm_out = mm_out + bias
return mm_out


Expand Down Expand Up @@ -554,7 +561,7 @@ def quant_int8_per_token_matmul(
return y


def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128):
def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16):
"""This is tinygemm specific, we'll keep this for now"""
if groupsize > w.shape[-1]:
groupsize = w.shape[-1]
Expand All @@ -570,15 +577,14 @@ def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128):
max_int = 2**n_bit - 1
scales = (max_val - min_val).clamp(min=1e-6) / max_int
zeros = min_val + scales * (2 ** (n_bit - 1))
return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
torch.bfloat16
return scales.to(dtype=dtype).reshape(w.shape[0], -1), zeros.to(
dtype=dtype
).reshape(w.shape[0], -1)


def pack_tinygemm_scales_and_zeros(scales, zeros):
assert scales.shape == zeros.shape
assert scales.dtype == torch.bfloat16
assert zeros.dtype == torch.bfloat16
guard_dtype_size(scales, "scales", dtype=torch.bfloat16, size=zeros.size())
guard_dtype_size(zeros, "zeros", dtype=torch.bfloat16)
return (
torch.cat(
[
Expand Down Expand Up @@ -661,8 +667,8 @@ def groupwise_affine_dequantize_tensor_from_qparams(
return w_dq


def groupwise_affine_quantize_tensor(w, n_bit=4, groupsize=128):
scales, zeros = get_groupwise_affine_qparams(w, n_bit, groupsize)
def groupwise_affine_quantize_tensor(w, n_bit=4, groupsize=128, dtype=torch.bfloat16):
scales, zeros = get_groupwise_affine_qparams(w, n_bit, groupsize, dtype)
w_int4x8 = groupwise_affine_quantize_tensor_from_qparams(
w, scales, zeros, n_bit, groupsize
)
Expand Down
Loading

0 comments on commit a50fea5

Please sign in to comment.