Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable dispatch to tinygemm int4 and int8 kernels for quantized tensor #230

Merged
merged 1 commit into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 96 additions & 9 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import unittest
import torch
import os
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import (
prepare_pt2e,
convert_pt2e,
Expand All @@ -36,7 +35,7 @@


def dynamic_quant(model, example_inputs):
m = capture_pre_autograd_graph(model, example_inputs)
m = torch.export.export(model, example_inputs).module()
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_dynamic=True))
m = prepare_pt2e(m, quantizer)
m = convert_pt2e(m)
Expand All @@ -50,14 +49,14 @@ def _apply_dynamic_quant(model):
"""
_replace_with_custom_fn_if_matches_filter(
model,
lambda linear_mod: dynamic_quant(linear_mod, (torch.randn(1, linear_mod.in_features))),
lambda linear_mod: dynamic_quant(linear_mod, (torch.randn(1, linear_mod.in_features),)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. Why is that extra comma needed now?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By all means don't be blocked on this comment haha

lambda mod, fqn: isinstance(mod, torch.nn.Linear),
)
return model


def capture_and_prepare(model, example_inputs):
m = capture_pre_autograd_graph(model, example_inputs)
m = torch.export.export(model, example_inputs)
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_dynamic=True))
m = prepare_pt2e(m, quantizer)
# TODO: we can run the weight observer in convert_pt2e so that user don't need to run this
Expand Down Expand Up @@ -88,13 +87,13 @@ def quantize(self, model: torch.nn.Module) -> torch.nn.Module:
return model

class ToyLinearModel(torch.nn.Module):
def __init__(self):
def __init__(self, m=64, n=32, k=64):
super().__init__()
self.linear1 = torch.nn.Linear(64, 32, bias=False).to(torch.float)
self.linear2 = torch.nn.Linear(32, 64, bias=False).to(torch.float)
self.linear1 = torch.nn.Linear(m, n, bias=False).to(torch.float)
self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float)

def example_inputs(self):
return (torch.randn(1, 64).to(torch.float),)
return (torch.randn(1, self.linear1.in_features).to(torch.float),)

def forward(self, x):
x = self.linear1(x)
Expand All @@ -104,8 +103,9 @@ def forward(self, x):
class TestQuantFlow(unittest.TestCase):
def test_dynamic_quant_gpu_singleline(self):
m = ToyLinearModel().eval()
example_inputs = m.example_inputs()
m = _apply_dynamic_quant(m)
quantized = m(*m.example_inputs())
quantized = m(*example_inputs)
# AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64
# While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {})
# m = torch.compile(m, mode="max-autotune")
Expand Down Expand Up @@ -442,7 +442,94 @@ def get_per_token_block_size(x):
ref = m_copy(*example_inputs)
self.assertTrue(torch.equal(res, ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_int4(self):
from torchao.quantization.subclass import AffineQuantizedTensor
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization.quant_primitives import ZeroPointDomain
import copy

# weight settings
groupsize = 32
mapping_type = MappingType.ASYMMETRIC
block_size = (1, groupsize)
target_dtype = torch.int32
quant_min = 0
quant_max = 15
eps = 1e-6
preserve_zero = False
zero_point_dtype = torch.bfloat16

# weight only quantization
input_quant_func = None

# use 1024 so that we don't need padding
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs()))

def to_quantized(weight):
return AffineQuantizedTensor.from_float(
weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps,
zero_point_dtype=zero_point_dtype,
preserve_zero=preserve_zero,
zero_point_domain=ZeroPointDomain.FLOAT,
input_quant_func=input_quant_func,
)

m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False)
m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False)
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)

# reference
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors
change_linear_weights_to_int4_woqtensors(m_copy, groupsize=groupsize)

res = m(*example_inputs)
ref = m_copy(*example_inputs)

self.assertTrue(torch.equal(res, ref))


@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_int8(self):
from torchao.quantization.subclass import AffineQuantizedTensor
from torchao.quantization.quant_primitives import MappingType
import copy

# weight settings
mapping_type = MappingType.SYMMETRIC
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64

# weight only quantization
input_quant_func = None

m = ToyLinearModel().eval().to(torch.bfloat16)
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs()))

def to_quantized(weight):
block_size = (1, weight.shape[1])
return AffineQuantizedTensor.from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, input_quant_func=input_quant_func)

m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False)
m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False)
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)

# reference
from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors
change_linear_weights_to_int8_woqtensors(m_copy)

res = m(*example_inputs)
ref = m_copy(*example_inputs)

torch.testing.assert_close(res, ref, rtol=0.00001, atol=1e-2)


if __name__ == "__main__":
Expand Down
11 changes: 4 additions & 7 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,8 @@ def test_not_preserve_zero_not_supported(self):


def test_tinygemm_get_groupwise_affine_qparams(self):
from torchao.quantization.quant_primitives import ZeroPointDomain

input = torch.randn(10, 256)
n_bit = 4
scale_ref, zero_point_ref = get_groupwise_affine_qparams(input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16)
Expand All @@ -351,16 +353,11 @@ def test_tinygemm_get_groupwise_affine_qparams(self):
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
preserve_zero=False,
zero_point_domain=ZeroPointDomain.FLOAT,
)

def int_zero_point_to_float(zero_point, scale, qaunt_min, mid_point):
return (quant_min - zero_point + mid_point) * scale

mid_point = 2 ** (n_bit - 1)
zero_point_float = int_zero_point_to_float(zero_point, scale, quant_min, mid_point)

self.assertTrue(torch.equal(scale, scale_ref))
torch.testing.assert_close(zero_point_float, zero_point_ref, rtol=0.00001, atol=torch.max(scale)*0.03)
self.assertTrue(torch.equal(zero_point, zero_point_ref))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat



if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
quantize_activation_per_token_absmax,
safe_int_mm,
)
from .utils import TORCH_VERSION_AFTER_2_4
import torch.nn.functional as F
try:
from torch._inductor.utils import do_bench
Expand Down
85 changes: 66 additions & 19 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ def guard_dtype_size(tensor_arg, arg_name, dtype=None, size=None):
torch.uint7: (0, 2**7-1),
})

class MappingType(Enum):
SYMMETRIC = 0
ASYMMETRIC = 1

class ZeroPointDomain(Enum):
INT = 0
FLOAT = 1

# TODO: decide on if we want to allow custom quant_min/quant_max here
def _get_and_check_qmin_qmax(dtype, quant_min, quant_max):
"""Get quant_min and quant_max args based on dtype and also
Expand Down Expand Up @@ -141,7 +149,8 @@ def quantize_affine(
zero_point: Optional[torch.Tensor],
output_dtype: torch.dtype,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None
quant_max: Optional[int] = None,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
):
"""
Args:
Expand All @@ -153,6 +162,12 @@ def quantize_affine(
output_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
quant_min (Optional[int]): minimum quantized value for output Tensor, if not specified, it will be derived from dtype
quant_max (Optional[int]): maximum quantized value for output Tensor, if not specified, it will be derived from dtype
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float
if zero_point is in integer domain, zero point is added to the quantized integer value during
quantization
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
value during quantization
default is ZeroPointDomain.INT

Note:
How can block_size represent different granularities?
Expand Down Expand Up @@ -184,9 +199,19 @@ def quantize_affine(
if zero_point is not None:
zero_point = zero_point.view(shape_after_reduction)

quant = torch.clamp(
torch.round(input / scale) + zero_point, quant_min, quant_max
).to(output_dtype)
if zero_point_domain == ZeroPointDomain.INT:
quant = torch.clamp(
torch.round(input / scale) + zero_point, quant_min, quant_max
).to(output_dtype)
else:
assert zero_point_domain == ZeroPointDomain.FLOAT
mid_point = (quant_max + quant_min + 1) / 2
min_val = zero_point - scale * mid_point
quant = (
torch.clamp(
torch.round((input - min_val) / scale),
quant_min, quant_max)
).to(output_dtype)
quant = quant.view(original_shape)

return quant
Expand All @@ -199,6 +224,7 @@ def dequantize_affine(
input_dtype: torch.dtype,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
*,
output_dtype: torch.dtype = torch.float32,
):
Expand All @@ -213,6 +239,12 @@ def dequantize_affine(
quant_min (Optional[int]): minimum quantized value for input Tensor
quant_max (Optional[int]): maximum quantized value for input Tensor
output_dtype (torch.dtype): dtype for output Tensor, default is fp32
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float
if zero_point is in integer domain, zero point is added to the quantized integer value during
quantization
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
value during quantization
default is ZeroPointDomain.INT

Output:
dequantized Tensor, with requested dtype or fp32
Expand All @@ -233,18 +265,22 @@ def dequantize_affine(
if zero_point is not None:
zero_point = zero_point.view(shape_after_reduction)

dequant = input.to(torch.int32)
if zero_point is not None:
dequant -= zero_point.to(torch.int32)
dequant = dequant.to(output_dtype)
dequant *= scale
dequant = dequant.view(original_shape)
return dequant.to(output_dtype)
if zero_point_domain == ZeroPointDomain.INT:
dequant = input.to(torch.int32)
if zero_point is not None:
dequant -= zero_point.to(torch.int32)
dequant = dequant.to(output_dtype)
dequant *= scale
else:
assert zero_point_domain == ZeroPointDomain.FLOAT, f"Unexpected zero point domain: {zero_point_domain}"
mid_point = (quant_max + quant_min + 1) / 2
dequant = input - mid_point
dequant = dequant.to(output_dtype)
dequant *= scale
if zero_point is not None:
dequant += zero_point


class MappingType(Enum):
SYMMETRIC = 0
ASYMMETRIC = 1
return dequant.view(original_shape).to(output_dtype)

def choose_qparams_affine(
input: torch.Tensor,
Expand All @@ -256,7 +292,8 @@ def choose_qparams_affine(
eps: Optional[float] = None,
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
preserve_zero = True,
preserve_zero: bool = True,
zero_point_domain = ZeroPointDomain.INT,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
Expand All @@ -280,6 +317,13 @@ def choose_qparams_affine(

If we don't need zero to be exactly representable, we won't do rounding and clamping for zero_point

zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float
if zero_point is in integer domain, zero point is added to the quantized integer value during
quantization
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
value during quantization
default is ZeroPointDomain.INT

Output:
Tuple of scales and zero_points Tensor with requested dtype
"""
Expand Down Expand Up @@ -310,15 +354,18 @@ def choose_qparams_affine(
scale = max_val_pos / (float(quant_max - quant_min) / 2)
if not preserve_zero:
raise ValueError("preserve_zero == False is not supported for symmetric quantization")
zero_point = torch.full_like(scale, int((quant_min + quant_max + 1) / 2))
if zero_point_domain != ZeroPointDomain.INT:
raise ValueError("zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization")
zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2))
else:
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
if preserve_zero:
zero_point = quant_min - torch.round(min_val_neg / scale)
zero_point = torch.clamp(zero_point, quant_min, quant_max)
else:
zero_point = quant_min - min_val_neg / scale

assert zero_point_domain == ZeroPointDomain.FLOAT, "if not preserve_zero, zero_point must be in FLOAT domain"
mid_point = (quant_max + quant_min + 1) / 2
zero_point = min_val_neg + scale * mid_point

if eps is None:
eps = torch.finfo(input.dtype).eps
Expand Down
Loading
Loading