Skip to content

Commit

Permalink
Enable dispatch to tinygemm int4 and int8 kernels for unified quantiz…
Browse files Browse the repository at this point in the history
…ed tensor

Summary:
This adds some dispatch to the tinygemm kernels for cuda, although need to resolve implementation
mismatch problem for tinygemm first

Test Plan:
TODO

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed May 10, 2024
1 parent b91b6be commit 13d643a
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 13 deletions.
93 changes: 84 additions & 9 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import unittest
import torch
import os
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import (
prepare_pt2e,
convert_pt2e,
Expand All @@ -36,7 +35,7 @@


def dynamic_quant(model, example_inputs):
m = capture_pre_autograd_graph(model, example_inputs)
m = torch.export.export(model, example_inputs).module()
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_dynamic=True))
m = prepare_pt2e(m, quantizer)
m = convert_pt2e(m)
Expand All @@ -50,14 +49,14 @@ def _apply_dynamic_quant(model):
"""
_replace_with_custom_fn_if_matches_filter(
model,
lambda linear_mod: dynamic_quant(linear_mod, (torch.randn(1, linear_mod.in_features))),
lambda linear_mod: dynamic_quant(linear_mod, (torch.randn(1, linear_mod.in_features),)),
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
)
return model


def capture_and_prepare(model, example_inputs):
m = capture_pre_autograd_graph(model, example_inputs)
m = torch.export.export(model, example_inputs)
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_dynamic=True))
m = prepare_pt2e(m, quantizer)
# TODO: we can run the weight observer in convert_pt2e so that user don't need to run this
Expand Down Expand Up @@ -88,13 +87,13 @@ def quantize(self, model: torch.nn.Module) -> torch.nn.Module:
return model

class ToyLinearModel(torch.nn.Module):
def __init__(self):
def __init__(self, m=64, n=32, k=64):
super().__init__()
self.linear1 = torch.nn.Linear(64, 32, bias=False).to(torch.float)
self.linear2 = torch.nn.Linear(32, 64, bias=False).to(torch.float)
self.linear1 = torch.nn.Linear(m, n, bias=False).to(torch.float)
self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float)

def example_inputs(self):
return (torch.randn(1, 64).to(torch.float),)
return (torch.randn(1, self.linear1.in_features).to(torch.float),)

def forward(self, x):
x = self.linear1(x)
Expand All @@ -104,8 +103,10 @@ def forward(self, x):
class TestQuantFlow(unittest.TestCase):
def test_dynamic_quant_gpu_singleline(self):
m = ToyLinearModel().eval()
example_inputs = m.example_inputs()
m = _apply_dynamic_quant(m)
quantized = m(*m.example_inputs())
example_inputs = (torch.randn(1, 64),)
quantized = m(*example_inputs)
# AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64
# While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {})
# m = torch.compile(m, mode="max-autotune")
Expand Down Expand Up @@ -442,7 +443,81 @@ def get_per_token_block_size(x):
ref = m_copy(*example_inputs)
self.assertTrue(torch.equal(res, ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
def test_quantized_tensor_subclass_int4(self):
from torchao.quantization.subclass import TinygemmAffineQuantizedTensor
from torchao.quantization.quant_primitives import MappingType
import copy

# weight settings
groupsize = 32
mapping_type = MappingType.ASYMMETRIC
block_size = (1, groupsize)
eps = 1e-6
preserve_zero = False

# weight only quantization
input_quant_func = None

# use 1024 so that we don't need padding
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs()))

def to_quantized(weight):
return TinygemmAffineQuantizedTensor.from_float(weight, mapping_type, block_size, eps, input_quant_func=input_quant_func)

m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False)
m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False)
assert isinstance(m.linear1.weight, TinygemmAffineQuantizedTensor)
assert isinstance(m.linear2.weight, TinygemmAffineQuantizedTensor)

# reference
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors
change_linear_weights_to_int4_woqtensors(m_copy, groupsize=groupsize)

res = m(*example_inputs)
ref = m_copy(*example_inputs)

torch.testing.assert_close(res, ref, rtol=0.00001, atol=0.02)


@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
def test_quantized_tensor_subclass_int8(self):
from torchao.quantization.subclass import AffineQuantizedTensor
from torchao.quantization.quant_primitives import MappingType
import copy

# weight settings
mapping_type = MappingType.SYMMETRIC
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64

# weight only quantization
input_quant_func = None

m = ToyLinearModel().eval().to(torch.bfloat16)
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs()))

def to_quantized(weight):
block_size = (1, weight.shape[1])
return AffineQuantizedTensor.from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, input_quant_func=input_quant_func)

m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False)
m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False)
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)

# reference
from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors
change_linear_weights_to_int8_woqtensors(m_copy)

res = m(*example_inputs)
ref = m_copy(*example_inputs)

torch.testing.assert_close(res, ref, rtol=0.00001, atol=1e-2)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def test_tinygemm_get_groupwise_affine_qparams(self):
preserve_zero=False,
)

def int_zero_point_to_float(zero_point, scale, qaunt_min, mid_point):
def int_zero_point_to_float(zero_point, scale, quant_min, mid_point):
return (quant_min - zero_point + mid_point) * scale

mid_point = 2 ** (n_bit - 1)
Expand Down
98 changes: 95 additions & 3 deletions torchao/quantization/subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
dynamically_quantize_per_channel,
groupwise_affine_quantize_tensor,
quant_int8_dynamic_per_token_linear,
pack_tinygemm_scales_and_zeros,
unpack_tinygemm_scales_and_zeros,
groupwise_affine_quantize_tensor_from_qparams,
choose_qparams_affine,
quantize_affine,
dequantize_affine,
Expand Down Expand Up @@ -619,7 +621,7 @@ class AffineQuantizedTensor(torch.Tensor):
shape (torch.Size): the shape for the Tensor
quant_min (Optional[int]): minimum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data`
quant_max (Optional[int]): maximum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data`
input_quant_func (Optional[Callable]): function for quantizing the input float Tensor to a quantized tensor subclass object, that takes input Tensor as input and outputs an AffineQuantizedTensor object
input_quant_func (Optional[Callable]): function for quantizing the input float Tensor to a quantized tensor subclass object, that takes float Tensor as input and outputs an AffineQuantizedTensor object
dtype: dtype for external representation of the tensor, e.g. torch.float32
"""

Expand All @@ -635,6 +637,7 @@ def __new__(
quant_max: Optional[int] = None,
input_quant_func: Optional[Callable] = None,
dtype=None,
# TODO: remove args and kwargs
*args,
**kwargs
):
Expand Down Expand Up @@ -677,7 +680,9 @@ def __repr__(self):
f"device={self.device}, dtype={self.dtype}, input_quant_func={self.input_quant_func}, requires_grad={self.requires_grad})"
)

def dequantize(self, output_dtype=torch.float32):
def dequantize(self, output_dtype=None):
if output_dtype is None:
output_dtype = self.dtype
return dequantize_affine(self.int_data, self.block_size, self.scale, self.zero_point, self.int_data.dtype, self.quant_min, self.quant_max, output_dtype=output_dtype)

def __tensor_flatten__(self):
Expand Down Expand Up @@ -740,7 +745,54 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
args[1],
args[2] if len(args) > 2 else None,
)
if weight_qtensor.input_quant_func is not None:
if weight_qtensor.input_quant_func is None:
is_cuda = args[0].is_cuda
is_cpu = args[0].device == torch.device("cpu")
# weight only quantization
is_int8 = (
weight_qtensor.int_data.dtype == torch.int8 and
weight_qtensor.quant_min is None or weight_qtensor.quant_min == -128 and
weight_qtensor.quant_max is None or weight_qtensor.quant_max == 127
)
is_uint4 = (
weight_qtensor.int_data.dtype == torch.int32 and
weight_qtensor.quant_min == 0 and
weight_qtensor.quant_max == 15
)

# TODO: enable cpu and mps path as well
# TODO: make sure weight dimension matches the expectation of the int4mm kernel
# TODO: move this to TinygemmAffineQuantizedTensor
if (
is_cuda and
is_uint4 and
weight_qtensor.dtype == torch.bfloat16 and
len(weight_qtensor.shape) == 2 and
weight_qtensor.block_size[0] == 1
):
# groupwise int4 quantization
# TODO: currently doing packing on the fly, we'll need to figure out
# the API to do packing before hand
# TODO: expose the arg
innerKTiles = 8
packed_weight = torch.ops.aten._convert_weight_to_int4pack(weight_qtensor.int_data.to(torch.int32), innerKTiles)
scales_and_zeros = pack_tinygemm_scales_and_zeros(weight_qtensor.scale, weight_qtensor.zero_point)
groupsize = weight_qtensor.block_size[-1]
return torch.ops.aten._weight_int4pack_mm(input_tensor.contiguous(), packed_weight, groupsize, scales_and_zeros)
elif (
is_cpu and
is_int8 and
len(weight_qtensor.shape) == 2 and
len(weight_qtensor.block_size) == 2 and
weight_qtensor.block_size[0] == 1 and
weight_qtensor.block_size[1] == weight_qtensor.shape[1]
):
# TODO: enable mps path as well
# per channel int8 weight only quantizated mm
return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), weight_qtensor.int_data, weight_qtensor.scale)
else:
# dynamic quantization
# TODO: enable int8 dynamic quant dispatch
input_tensor = weight_qtensor.input_quant_func(input_tensor)
input_tensor = input_tensor.dequantize()
weight_tensor = weight_qtensor.dequantize()
Expand Down Expand Up @@ -865,3 +917,43 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
kwargs,
args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone),
)


# TODO: add padding support
class TinygemmAffineQuantizedTensor(AffineQuantizedTensor):
@classmethod
def from_float(
cls,
input_float,
mapping_type,
block_size,
eps = None,
scale_dtype = None,
zero_point_dtype = None,
input_quant_func = None,
):
# TODO: replace this with uint4 dtype
target_dtype = torch.int32
quant_min = 0
quant_max = 15
preserve_zero = False
scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero)
def int_zero_point_to_float(zero_point, scale, quant_min, mid_point):
return (quant_min - zero_point + mid_point) * scale

mid_point = (quant_min + quant_max + 1) / 2
zero_point_float = int_zero_point_to_float(zero_point, scale, quant_min, mid_point)
n_bit = 4
groupsize = block_size[1]
int_data = groupwise_affine_quantize_tensor_from_qparams(input_float, scale, zero_point_float, n_bit, groupsize)
return cls(
int_data,
scale,
zero_point_float,
block_size,
input_float.shape,
quant_min,
quant_max,
input_quant_func=input_quant_func,
dtype=input_float.dtype
)

0 comments on commit 13d643a

Please sign in to comment.