Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor tensor subclass API to also use paramterization #146

Merged
merged 2 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions .github/workflows/regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ jobs:
torch-spec: 'torch==2.3.0'
jerryzh168 marked this conversation as resolved.
Show resolved Hide resolved
gpu-arch-type: "cuda"
gpu-arch-version: "12.1"
- name: CUDA 2.4.0.dev20240421
- name: CUDA 2.4.0.dev20240428
runs-on: linux.g5.12xlarge.nvidia.gpu
torch-spec: '--pre torch==2.4.0.dev20240421+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121'
torch-spec: '--pre torch==2.4.0.dev20240428+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121'
gpu-arch-type: "cuda"
gpu-arch-version: "12.1"
- name: CPU 2.2.2
Expand All @@ -58,6 +58,11 @@ jobs:
gpu-arch-type: ${{ matrix.gpu-arch-type }}
gpu-arch-version: ${{ matrix.gpu-arch-version }}
script: |
conda create -n venv python=3.9 -y
conda activate venv
echo "::group::Install newer objcopy that supports --set-section-alignment"
yum install -y devtoolset-10-binutils
export PATH=/opt/rh/devtoolset-10/root/usr/bin/:$PATH
python -m pip install --upgrade pip
pip install ${{ matrix.torch-spec }}
pip install -r requirements.txt
Expand Down
92 changes: 82 additions & 10 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,31 +67,46 @@
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
import os
from parameterized import parameterized
import itertools
import logging
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4

logger = logging.getLogger("INFO")

torch.manual_seed(0)
config.cache_size_limit = 100

COMMON_DEVICE_DTYPE=[
("cpu", torch.float32),
("cpu", torch.float16),
("cpu", torch.bfloat16),
("cuda", torch.float32),
("cuda", torch.float16),
("cuda", torch.bfloat16),
# TODO: use this to reduce the number of tests
TENSOR_SUBCLASS_APIS = [
jerryzh168 marked this conversation as resolved.
Show resolved Hide resolved
change_linear_weights_to_int8_dqtensors,
change_linear_weights_to_int8_woqtensors,
change_linear_weights_to_int4_woqtensors,
]

COMMON_DEVICES = ["cpu", "cuda"]

COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]

COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy()

def combine_parameters(a, b):
new_tuples = []
for (tuple1, tuple2) in itertools.product(a, b):
new_tuples.append(tuple1 + tuple2)
return new_tuples

def run_supported_device_dtype(test_method):
"""Assumes that the 3rd arg (args[2]) of the decorated method is device and
there is a `test_dtype` kwarg or the 4th arg (args[3]) that indicates the dtype for testing
"""
def wrapper(*args, **kwargs):
if args[2] == "cuda" and not torch.cuda.is_available():
if len(args) < 3:
raise unittest.SkipTest(f"Not enough args. Expected more than or equal to 3, but got {len(args)}")
device = args[2]
dtype = kwargs["test_dtype"] if "test_dtype" in kwargs else args[3]
if device == "cuda" and not torch.cuda.is_available():
raise unittest.SkipTest(f"Need CUDA available.")
if args[2] == "cuda" and torch.cuda.is_available() and kwargs['test_dtype'] == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0):
if device == "cuda" and torch.cuda.is_available() and dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0):
raise unittest.SkipTest("Need CUDA and SM80+ available.")
return test_method(*args, **kwargs)
return wrapper
Expand Down Expand Up @@ -1148,6 +1163,7 @@ def _test_handle_save_load_meta_impl(
min_sqnr=35,
test_dtype=torch.bfloat16
):
logger.info(f"TestSaveLoad: {api}, {test_device}, {test_dtype}")
m, k, n = 32, 64, 32

class test_model(nn.Module):
Expand Down Expand Up @@ -1180,7 +1196,7 @@ def forward(self, x):

# load model structure
with torch.device('meta'):
model = test_model()
model = test_model().to(dtype=test_dtype)
api(model)

# load quantized state_dict
Expand Down Expand Up @@ -1407,5 +1423,61 @@ def test_autoquant_multi_input(self, device, dtype, m1, m2, k, n):
sqnr = SQNR(out, out2)
self.assertTrue(sqnr >= 30)


class TestAOTI(unittest.TestCase):
@parameterized.expand(
list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)),
)
def test_aoti(self, api, test_device, test_dtype):
if not TORCH_VERSION_AFTER_2_4:
self.skipTest("aoti compatibility requires 2.4+.")

if test_device == "cuda":
self.skipTest("AOTI has some issues in cuda test right now, skipping")

logger.info(f"TestAOTI: {api}, {test_device}, {test_dtype}")
if api is change_linear_weights_to_int8_dqtensors and test_device == "cuda":
self.skipTest(f"{api} in {test_device} is not support for aoti compilation yet")

if test_dtype != torch.bfloat16:
self.skipTest(f"{api} in {test_dtype} is not support for aoti compilation yet")

if test_device == "cuda" and not torch.cuda.is_available():
self.skipTest(f"Need CUDA available.")
if test_device == "cuda" and torch.cuda.is_available() and test_dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0):
self.skipTest("Need CUDA and SM80+ available.")

m, k, n = 32, 64, 32

class test_model(nn.Module):
def __init__(self):
super().__init__()
self.lin1 = nn.Linear(k, n)
self.relu = nn.ReLU()
self.lin2 = nn.Linear(n, n)

def forward(self, x):
x = self.lin1(x)
x = self.relu(x)
x = self.lin2(x)
return x

x = torch.randn(m, k, dtype=test_dtype, device=test_device)

# get float reference
model = test_model().to(dtype=test_dtype, device=test_device).eval()
ref_f = model(x)

kwargs = {"dtype": test_dtype}
api(model, **kwargs)

# running model
model(x)

# make sure it compiles
example_inputs = (x,)
torch._export.aot_compile(model, example_inputs)


if __name__ == "__main__":
unittest.main()
32 changes: 20 additions & 12 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch.nn.functional as F

from .dynamic_quant import DynamicallyPerAxisQuantizedLinear
from .utils import TORCH_VERSION_AFTER_2_3
from .utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4

from .subclass import (
Int4WeightOnlyQuantizedLinearWeight,
Expand Down Expand Up @@ -117,19 +117,27 @@ def apply_dynamic_quant(model, filter_fn=None):
change_linear_weights_to_int8_dqtensors(model, filter_fn)


def _get_subclass_inserter(cls, **kwargs):
method = kwargs.pop("method", "from_float")
import torch.nn.utils.parametrize as parametrize

def _get_subclass_inserter(cls, enable_parametrization=False, **kwargs):
constructor = kwargs.pop("constructor", "subclass_constructor")
from_float = kwargs.pop("method", "from_float")
def insert_subclass(lin):
lin.weight = torch.nn.Parameter(
# cls.from_float(...)
getattr(cls, method)(lin.weight, **kwargs), requires_grad=False
)
if enable_parametrization:
lin.weight = torch.nn.Parameter(cls.from_float(lin.weight, **kwargs), requires_grad=False)
_, args = lin.weight.__tensor_flatten__()
parametrize.register_parametrization(lin, "weight", getattr(cls, constructor)(*args))
else:
lin.weight = torch.nn.Parameter(
# cls.from_float(...)
getattr(cls, from_float)(lin.weight, **kwargs), requires_grad=False
)
return lin

return insert_subclass


def change_linear_weights_to_int8_dqtensors(model, filter_fn=None):
def change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a big fan of undocumented, generic kwargs. I think it's hard to tell users the intent, because you can't write documentation for it.

Copy link
Contributor Author

@jerryzh168 jerryzh168 Apr 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is args for from_float function for the tensor subclass, this is just to be consistent with the existing int4 tensor subclass apis

also I don't think we want to use this function as a top level API, can we refactor this a bit later?

"""
Converts all linear weight tensors to the `Int8DynamicallyQuantizedLinearWeight`
Tensor subclass, effectively applying the same form of quantization
Expand All @@ -141,11 +149,11 @@ def change_linear_weights_to_int8_dqtensors(model, filter_fn=None):
)

_replace_with_custom_fn_if_matches_filter(
model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight), filter_fn
model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=TORCH_VERSION_AFTER_2_4, **kwargs), filter_fn
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would expect we use parametrization only for AOTI? As some kind of "pre-processing" there.
Especially given that my understanding of the long term plan is that AOTI will do this pre-processing themselves and we wll be able to remove it from there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we also need this for torch.export (used by executorch), I'll add a test in next PR, also we want to have a consistent code path for all backends/runtimes I think. is there any problems with enabling this for all use cases?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD do you think that long term we want export to do the pre-processing?

I think if that's the case, then we might just want to figure out that story now (it might be less work than getting dynamo to handle parametrizations).

The main contentious bit is probably just where this pre-processing should live. One possible answer is that it should happen transparently as part of torch.export.export(): automatically search the created state dict for subclasses and flatten them (although this might be a problem if the user expects the state dict of the ExportedProgram to alias the original model's state dict)

)


def change_linear_weights_to_int8_woqtensors(model, filter_fn=None):
def change_linear_weights_to_int8_woqtensors(model, filter_fn=None, **kwargs):
"""
Converts all linear weight tensors to the
`Int8WeightOnlyQuantizedLinearWeight` tensor subclass,
Expand All @@ -154,7 +162,7 @@ def change_linear_weights_to_int8_woqtensors(model, filter_fn=None):
"""
_replace_with_custom_fn_if_matches_filter(
model,
_get_subclass_inserter(Int8WeightOnlyQuantizedLinearWeight),
_get_subclass_inserter(Int8WeightOnlyQuantizedLinearWeight, enable_parametrization=TORCH_VERSION_AFTER_2_4, **kwargs),
_is_linear if filter_fn is None else filter_fn,
)

Expand All @@ -170,7 +178,7 @@ def change_linear_weights_to_int4_woqtensors(model, **kwargs):

_replace_with_custom_fn_if_matches_filter(
model,
_get_subclass_inserter(Int4WeightOnlyQuantizedLinearWeight, **kwargs),
_get_subclass_inserter(Int4WeightOnlyQuantizedLinearWeight, enable_parametrization=TORCH_VERSION_AFTER_2_4, **kwargs),
filter_fn,
)

Expand Down
24 changes: 15 additions & 9 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@
] + (_AFTER_TORCH_2_3_ONLY if TORCH_VERSION_AFTER_2_3 else [])


def guard_dtype_size(tensor_arg, arg_name, dtype=None, size=None):
if dtype is not None and tensor_arg.dtype != dtype:
raise ValueError("Expected Tensor argument {arg_name} to have dtype {dtype}, but got {tensor_arg.dtype} instead.")
if size is not None and tensor_arg.size() != size:
raise ValueError("Expected Tensor argument {arg_name} to have size {size}, but got {tensor_arg.size()} instead.")


_DTYPE_TO_QVALUE_BOUNDS = {
torch.uint8: (0, 255),
torch.int8: (-128, 127),
Expand Down Expand Up @@ -493,7 +500,7 @@ def quant_int8_dynamic_per_token_linear(
x_vals_int8, x_scales, w_vals_int8_t, w_scales, out_dtype
)
if bias is not None:
mm_out += bias
mm_out = mm_out + bias
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is some issue with this in AOT Inductor I think. cc @desertfire

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @cpuhrsch 's question is why rewriting "+=". I am not aware any AOTI restriction that needs this rewrite.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@desertfire the error I'm getting with "+=" is this: https://gist.github.com/jerryzh168/d4ea2fb8138376cff903c38aaef8f5ef, is this expected?

return mm_out


Expand Down Expand Up @@ -554,7 +561,7 @@ def quant_int8_per_token_matmul(
return y


def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128):
def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16):
"""This is tinygemm specific, we'll keep this for now"""
if groupsize > w.shape[-1]:
groupsize = w.shape[-1]
Expand All @@ -570,15 +577,14 @@ def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128):
max_int = 2**n_bit - 1
scales = (max_val - min_val).clamp(min=1e-6) / max_int
zeros = min_val + scales * (2 ** (n_bit - 1))
return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
torch.bfloat16
return scales.to(dtype=dtype).reshape(w.shape[0], -1), zeros.to(
dtype=dtype
).reshape(w.shape[0], -1)


def pack_tinygemm_scales_and_zeros(scales, zeros):
assert scales.shape == zeros.shape
assert scales.dtype == torch.bfloat16
assert zeros.dtype == torch.bfloat16
guard_dtype_size(scales, "scales", dtype=torch.bfloat16, size=zeros.size())
guard_dtype_size(zeros, "zeros", dtype=torch.bfloat16)
return (
torch.cat(
[
Expand Down Expand Up @@ -661,8 +667,8 @@ def groupwise_affine_dequantize_tensor_from_qparams(
return w_dq


def groupwise_affine_quantize_tensor(w, n_bit=4, groupsize=128):
scales, zeros = get_groupwise_affine_qparams(w, n_bit, groupsize)
def groupwise_affine_quantize_tensor(w, n_bit=4, groupsize=128, dtype=torch.bfloat16):
scales, zeros = get_groupwise_affine_qparams(w, n_bit, groupsize, dtype)
w_int4x8 = groupwise_affine_quantize_tensor_from_qparams(
w, scales, zeros, n_bit, groupsize
)
Expand Down
Loading
Loading