Skip to content

Commit

Permalink
Enable dispatch to tinygemm int4 and int8 kernels for quantized tensor (
Browse files Browse the repository at this point in the history
pytorch#230)

Summary:
This adds some dispatch to the tinygemm kernels for cuda, although need to resolve implementation
mismatch problem for tinygemm first

Test Plan:
python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_int4
python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_int8

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 authored and lancerts committed May 17, 2024
1 parent 9cec68c commit 40b1fde
Show file tree
Hide file tree
Showing 5 changed files with 243 additions and 43 deletions.
105 changes: 96 additions & 9 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import unittest
import torch
import os
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import (
prepare_pt2e,
convert_pt2e,
Expand All @@ -36,7 +35,7 @@


def dynamic_quant(model, example_inputs):
m = capture_pre_autograd_graph(model, example_inputs)
m = torch.export.export(model, example_inputs).module()
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_dynamic=True))
m = prepare_pt2e(m, quantizer)
m = convert_pt2e(m)
Expand All @@ -50,14 +49,14 @@ def _apply_dynamic_quant(model):
"""
_replace_with_custom_fn_if_matches_filter(
model,
lambda linear_mod: dynamic_quant(linear_mod, (torch.randn(1, linear_mod.in_features))),
lambda linear_mod: dynamic_quant(linear_mod, (torch.randn(1, linear_mod.in_features),)),
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
)
return model


def capture_and_prepare(model, example_inputs):
m = capture_pre_autograd_graph(model, example_inputs)
m = torch.export.export(model, example_inputs)
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_dynamic=True))
m = prepare_pt2e(m, quantizer)
# TODO: we can run the weight observer in convert_pt2e so that user don't need to run this
Expand Down Expand Up @@ -88,13 +87,13 @@ def quantize(self, model: torch.nn.Module) -> torch.nn.Module:
return model

class ToyLinearModel(torch.nn.Module):
def __init__(self):
def __init__(self, m=64, n=32, k=64):
super().__init__()
self.linear1 = torch.nn.Linear(64, 32, bias=False).to(torch.float)
self.linear2 = torch.nn.Linear(32, 64, bias=False).to(torch.float)
self.linear1 = torch.nn.Linear(m, n, bias=False).to(torch.float)
self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float)

def example_inputs(self):
return (torch.randn(1, 64).to(torch.float),)
return (torch.randn(1, self.linear1.in_features).to(torch.float),)

def forward(self, x):
x = self.linear1(x)
Expand All @@ -104,8 +103,9 @@ def forward(self, x):
class TestQuantFlow(unittest.TestCase):
def test_dynamic_quant_gpu_singleline(self):
m = ToyLinearModel().eval()
example_inputs = m.example_inputs()
m = _apply_dynamic_quant(m)
quantized = m(*m.example_inputs())
quantized = m(*example_inputs)
# AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64
# While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {})
# m = torch.compile(m, mode="max-autotune")
Expand Down Expand Up @@ -442,7 +442,94 @@ def get_per_token_block_size(x):
ref = m_copy(*example_inputs)
self.assertTrue(torch.equal(res, ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_int4(self):
from torchao.quantization.subclass import AffineQuantizedTensor
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization.quant_primitives import ZeroPointDomain
import copy

# weight settings
groupsize = 32
mapping_type = MappingType.ASYMMETRIC
block_size = (1, groupsize)
target_dtype = torch.int32
quant_min = 0
quant_max = 15
eps = 1e-6
preserve_zero = False
zero_point_dtype = torch.bfloat16

# weight only quantization
input_quant_func = None

# use 1024 so that we don't need padding
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs()))

def to_quantized(weight):
return AffineQuantizedTensor.from_float(
weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps,
zero_point_dtype=zero_point_dtype,
preserve_zero=preserve_zero,
zero_point_domain=ZeroPointDomain.FLOAT,
input_quant_func=input_quant_func,
)

m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False)
m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False)
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)

# reference
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors
change_linear_weights_to_int4_woqtensors(m_copy, groupsize=groupsize)

res = m(*example_inputs)
ref = m_copy(*example_inputs)

self.assertTrue(torch.equal(res, ref))


@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_int8(self):
from torchao.quantization.subclass import AffineQuantizedTensor
from torchao.quantization.quant_primitives import MappingType
import copy

# weight settings
mapping_type = MappingType.SYMMETRIC
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64

# weight only quantization
input_quant_func = None

m = ToyLinearModel().eval().to(torch.bfloat16)
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs()))

def to_quantized(weight):
block_size = (1, weight.shape[1])
return AffineQuantizedTensor.from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, input_quant_func=input_quant_func)

m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False)
m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False)
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)

# reference
from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors
change_linear_weights_to_int8_woqtensors(m_copy)

res = m(*example_inputs)
ref = m_copy(*example_inputs)

torch.testing.assert_close(res, ref, rtol=0.00001, atol=1e-2)


if __name__ == "__main__":
Expand Down
11 changes: 4 additions & 7 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,8 @@ def test_not_preserve_zero_not_supported(self):


def test_tinygemm_get_groupwise_affine_qparams(self):
from torchao.quantization.quant_primitives import ZeroPointDomain

input = torch.randn(10, 256)
n_bit = 4
scale_ref, zero_point_ref = get_groupwise_affine_qparams(input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16)
Expand All @@ -351,16 +353,11 @@ def test_tinygemm_get_groupwise_affine_qparams(self):
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
preserve_zero=False,
zero_point_domain=ZeroPointDomain.FLOAT,
)

def int_zero_point_to_float(zero_point, scale, qaunt_min, mid_point):
return (quant_min - zero_point + mid_point) * scale

mid_point = 2 ** (n_bit - 1)
zero_point_float = int_zero_point_to_float(zero_point, scale, quant_min, mid_point)

self.assertTrue(torch.equal(scale, scale_ref))
torch.testing.assert_close(zero_point_float, zero_point_ref, rtol=0.00001, atol=torch.max(scale)*0.03)
self.assertTrue(torch.equal(zero_point, zero_point_ref))


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
quantize_activation_per_token_absmax,
safe_int_mm,
)
from .utils import TORCH_VERSION_AFTER_2_4
import torch.nn.functional as F
try:
from torch._inductor.utils import do_bench
Expand Down
85 changes: 66 additions & 19 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ def guard_dtype_size(tensor_arg, arg_name, dtype=None, size=None):
torch.uint7: (0, 2**7-1),
})

class MappingType(Enum):
SYMMETRIC = 0
ASYMMETRIC = 1

class ZeroPointDomain(Enum):
INT = 0
FLOAT = 1

# TODO: decide on if we want to allow custom quant_min/quant_max here
def _get_and_check_qmin_qmax(dtype, quant_min, quant_max):
"""Get quant_min and quant_max args based on dtype and also
Expand Down Expand Up @@ -141,7 +149,8 @@ def quantize_affine(
zero_point: Optional[torch.Tensor],
output_dtype: torch.dtype,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None
quant_max: Optional[int] = None,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
):
"""
Args:
Expand All @@ -153,6 +162,12 @@ def quantize_affine(
output_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
quant_min (Optional[int]): minimum quantized value for output Tensor, if not specified, it will be derived from dtype
quant_max (Optional[int]): maximum quantized value for output Tensor, if not specified, it will be derived from dtype
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float
if zero_point is in integer domain, zero point is added to the quantized integer value during
quantization
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
value during quantization
default is ZeroPointDomain.INT
Note:
How can block_size represent different granularities?
Expand Down Expand Up @@ -184,9 +199,19 @@ def quantize_affine(
if zero_point is not None:
zero_point = zero_point.view(shape_after_reduction)

quant = torch.clamp(
torch.round(input / scale) + zero_point, quant_min, quant_max
).to(output_dtype)
if zero_point_domain == ZeroPointDomain.INT:
quant = torch.clamp(
torch.round(input / scale) + zero_point, quant_min, quant_max
).to(output_dtype)
else:
assert zero_point_domain == ZeroPointDomain.FLOAT
mid_point = (quant_max + quant_min + 1) / 2
min_val = zero_point - scale * mid_point
quant = (
torch.clamp(
torch.round((input - min_val) / scale),
quant_min, quant_max)
).to(output_dtype)
quant = quant.view(original_shape)

return quant
Expand All @@ -199,6 +224,7 @@ def dequantize_affine(
input_dtype: torch.dtype,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
*,
output_dtype: torch.dtype = torch.float32,
):
Expand All @@ -213,6 +239,12 @@ def dequantize_affine(
quant_min (Optional[int]): minimum quantized value for input Tensor
quant_max (Optional[int]): maximum quantized value for input Tensor
output_dtype (torch.dtype): dtype for output Tensor, default is fp32
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float
if zero_point is in integer domain, zero point is added to the quantized integer value during
quantization
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
value during quantization
default is ZeroPointDomain.INT
Output:
dequantized Tensor, with requested dtype or fp32
Expand All @@ -233,18 +265,22 @@ def dequantize_affine(
if zero_point is not None:
zero_point = zero_point.view(shape_after_reduction)

dequant = input.to(torch.int32)
if zero_point is not None:
dequant -= zero_point.to(torch.int32)
dequant = dequant.to(output_dtype)
dequant *= scale
dequant = dequant.view(original_shape)
return dequant.to(output_dtype)
if zero_point_domain == ZeroPointDomain.INT:
dequant = input.to(torch.int32)
if zero_point is not None:
dequant -= zero_point.to(torch.int32)
dequant = dequant.to(output_dtype)
dequant *= scale
else:
assert zero_point_domain == ZeroPointDomain.FLOAT, f"Unexpected zero point domain: {zero_point_domain}"
mid_point = (quant_max + quant_min + 1) / 2
dequant = input - mid_point
dequant = dequant.to(output_dtype)
dequant *= scale
if zero_point is not None:
dequant += zero_point


class MappingType(Enum):
SYMMETRIC = 0
ASYMMETRIC = 1
return dequant.view(original_shape).to(output_dtype)

def choose_qparams_affine(
input: torch.Tensor,
Expand All @@ -256,7 +292,8 @@ def choose_qparams_affine(
eps: Optional[float] = None,
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
preserve_zero = True,
preserve_zero: bool = True,
zero_point_domain = ZeroPointDomain.INT,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
Expand All @@ -280,6 +317,13 @@ def choose_qparams_affine(
If we don't need zero to be exactly representable, we won't do rounding and clamping for zero_point
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float
if zero_point is in integer domain, zero point is added to the quantized integer value during
quantization
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
value during quantization
default is ZeroPointDomain.INT
Output:
Tuple of scales and zero_points Tensor with requested dtype
"""
Expand Down Expand Up @@ -310,15 +354,18 @@ def choose_qparams_affine(
scale = max_val_pos / (float(quant_max - quant_min) / 2)
if not preserve_zero:
raise ValueError("preserve_zero == False is not supported for symmetric quantization")
zero_point = torch.full_like(scale, int((quant_min + quant_max + 1) / 2))
if zero_point_domain != ZeroPointDomain.INT:
raise ValueError("zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization")
zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2))
else:
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
if preserve_zero:
zero_point = quant_min - torch.round(min_val_neg / scale)
zero_point = torch.clamp(zero_point, quant_min, quant_max)
else:
zero_point = quant_min - min_val_neg / scale

assert zero_point_domain == ZeroPointDomain.FLOAT, "if not preserve_zero, zero_point must be in FLOAT domain"
mid_point = (quant_max + quant_min + 1) / 2
zero_point = min_val_neg + scale * mid_point

if eps is None:
eps = torch.finfo(input.dtype).eps
Expand Down
Loading

0 comments on commit 40b1fde

Please sign in to comment.