Skip to content

Commit

Permalink
Revert "Enable dispatch to tinygemm int4 and int8 kernels for quantiz…
Browse files Browse the repository at this point in the history
…ed tenso…"

This reverts commit 10da375.
  • Loading branch information
cpuhrsch authored May 15, 2024
1 parent 10da375 commit dba8e10
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 243 deletions.
105 changes: 9 additions & 96 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import unittest
import torch
import os
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import (
prepare_pt2e,
convert_pt2e,
Expand All @@ -35,7 +36,7 @@


def dynamic_quant(model, example_inputs):
m = torch.export.export(model, example_inputs).module()
m = capture_pre_autograd_graph(model, example_inputs)
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_dynamic=True))
m = prepare_pt2e(m, quantizer)
m = convert_pt2e(m)
Expand All @@ -49,14 +50,14 @@ def _apply_dynamic_quant(model):
"""
_replace_with_custom_fn_if_matches_filter(
model,
lambda linear_mod: dynamic_quant(linear_mod, (torch.randn(1, linear_mod.in_features),)),
lambda linear_mod: dynamic_quant(linear_mod, (torch.randn(1, linear_mod.in_features))),
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
)
return model


def capture_and_prepare(model, example_inputs):
m = torch.export.export(model, example_inputs)
m = capture_pre_autograd_graph(model, example_inputs)
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_dynamic=True))
m = prepare_pt2e(m, quantizer)
# TODO: we can run the weight observer in convert_pt2e so that user don't need to run this
Expand Down Expand Up @@ -87,13 +88,13 @@ def quantize(self, model: torch.nn.Module) -> torch.nn.Module:
return model

class ToyLinearModel(torch.nn.Module):
def __init__(self, m=64, n=32, k=64):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(m, n, bias=False).to(torch.float)
self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float)
self.linear1 = torch.nn.Linear(64, 32, bias=False).to(torch.float)
self.linear2 = torch.nn.Linear(32, 64, bias=False).to(torch.float)

def example_inputs(self):
return (torch.randn(1, self.linear1.in_features).to(torch.float),)
return (torch.randn(1, 64).to(torch.float),)

def forward(self, x):
x = self.linear1(x)
Expand All @@ -103,9 +104,8 @@ def forward(self, x):
class TestQuantFlow(unittest.TestCase):
def test_dynamic_quant_gpu_singleline(self):
m = ToyLinearModel().eval()
example_inputs = m.example_inputs()
m = _apply_dynamic_quant(m)
quantized = m(*example_inputs)
quantized = m(*m.example_inputs())
# AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64
# While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {})
# m = torch.compile(m, mode="max-autotune")
Expand Down Expand Up @@ -442,94 +442,7 @@ def get_per_token_block_size(x):
ref = m_copy(*example_inputs)
self.assertTrue(torch.equal(res, ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_int4(self):
from torchao.quantization.subclass import AffineQuantizedTensor
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization.quant_primitives import ZeroPointDomain
import copy

# weight settings
groupsize = 32
mapping_type = MappingType.ASYMMETRIC
block_size = (1, groupsize)
target_dtype = torch.int32
quant_min = 0
quant_max = 15
eps = 1e-6
preserve_zero = False
zero_point_dtype = torch.bfloat16

# weight only quantization
input_quant_func = None

# use 1024 so that we don't need padding
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs()))

def to_quantized(weight):
return AffineQuantizedTensor.from_float(
weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps,
zero_point_dtype=zero_point_dtype,
preserve_zero=preserve_zero,
zero_point_domain=ZeroPointDomain.FLOAT,
input_quant_func=input_quant_func,
)

m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False)
m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False)
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)

# reference
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors
change_linear_weights_to_int4_woqtensors(m_copy, groupsize=groupsize)

res = m(*example_inputs)
ref = m_copy(*example_inputs)

self.assertTrue(torch.equal(res, ref))


@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_int8(self):
from torchao.quantization.subclass import AffineQuantizedTensor
from torchao.quantization.quant_primitives import MappingType
import copy

# weight settings
mapping_type = MappingType.SYMMETRIC
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64

# weight only quantization
input_quant_func = None

m = ToyLinearModel().eval().to(torch.bfloat16)
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs()))

def to_quantized(weight):
block_size = (1, weight.shape[1])
return AffineQuantizedTensor.from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, input_quant_func=input_quant_func)

m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False)
m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False)
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)

# reference
from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors
change_linear_weights_to_int8_woqtensors(m_copy)

res = m(*example_inputs)
ref = m_copy(*example_inputs)

torch.testing.assert_close(res, ref, rtol=0.00001, atol=1e-2)


if __name__ == "__main__":
Expand Down
11 changes: 7 additions & 4 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,6 @@ def test_not_preserve_zero_not_supported(self):


def test_tinygemm_get_groupwise_affine_qparams(self):
from torchao.quantization.quant_primitives import ZeroPointDomain

input = torch.randn(10, 256)
n_bit = 4
scale_ref, zero_point_ref = get_groupwise_affine_qparams(input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16)
Expand All @@ -353,11 +351,16 @@ def test_tinygemm_get_groupwise_affine_qparams(self):
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
preserve_zero=False,
zero_point_domain=ZeroPointDomain.FLOAT,
)

def int_zero_point_to_float(zero_point, scale, qaunt_min, mid_point):
return (quant_min - zero_point + mid_point) * scale

mid_point = 2 ** (n_bit - 1)
zero_point_float = int_zero_point_to_float(zero_point, scale, quant_min, mid_point)

self.assertTrue(torch.equal(scale, scale_ref))
self.assertTrue(torch.equal(zero_point, zero_point_ref))
torch.testing.assert_close(zero_point_float, zero_point_ref, rtol=0.00001, atol=torch.max(scale)*0.03)


if __name__ == "__main__":
Expand Down
1 change: 0 additions & 1 deletion torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
quantize_activation_per_token_absmax,
safe_int_mm,
)
from .utils import TORCH_VERSION_AFTER_2_4
import torch.nn.functional as F
try:
from torch._inductor.utils import do_bench
Expand Down
85 changes: 19 additions & 66 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,6 @@ def guard_dtype_size(tensor_arg, arg_name, dtype=None, size=None):
torch.uint7: (0, 2**7-1),
})

class MappingType(Enum):
SYMMETRIC = 0
ASYMMETRIC = 1

class ZeroPointDomain(Enum):
INT = 0
FLOAT = 1

# TODO: decide on if we want to allow custom quant_min/quant_max here
def _get_and_check_qmin_qmax(dtype, quant_min, quant_max):
"""Get quant_min and quant_max args based on dtype and also
Expand Down Expand Up @@ -149,8 +141,7 @@ def quantize_affine(
zero_point: Optional[torch.Tensor],
output_dtype: torch.dtype,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
quant_max: Optional[int] = None
):
"""
Args:
Expand All @@ -162,12 +153,6 @@ def quantize_affine(
output_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
quant_min (Optional[int]): minimum quantized value for output Tensor, if not specified, it will be derived from dtype
quant_max (Optional[int]): maximum quantized value for output Tensor, if not specified, it will be derived from dtype
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float
if zero_point is in integer domain, zero point is added to the quantized integer value during
quantization
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
value during quantization
default is ZeroPointDomain.INT
Note:
How can block_size represent different granularities?
Expand Down Expand Up @@ -199,19 +184,9 @@ def quantize_affine(
if zero_point is not None:
zero_point = zero_point.view(shape_after_reduction)

if zero_point_domain == ZeroPointDomain.INT:
quant = torch.clamp(
torch.round(input / scale) + zero_point, quant_min, quant_max
).to(output_dtype)
else:
assert zero_point_domain == ZeroPointDomain.FLOAT
mid_point = (quant_max + quant_min + 1) / 2
min_val = zero_point - scale * mid_point
quant = (
torch.clamp(
torch.round((input - min_val) / scale),
quant_min, quant_max)
).to(output_dtype)
quant = torch.clamp(
torch.round(input / scale) + zero_point, quant_min, quant_max
).to(output_dtype)
quant = quant.view(original_shape)

return quant
Expand All @@ -224,7 +199,6 @@ def dequantize_affine(
input_dtype: torch.dtype,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
*,
output_dtype: torch.dtype = torch.float32,
):
Expand All @@ -239,12 +213,6 @@ def dequantize_affine(
quant_min (Optional[int]): minimum quantized value for input Tensor
quant_max (Optional[int]): maximum quantized value for input Tensor
output_dtype (torch.dtype): dtype for output Tensor, default is fp32
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float
if zero_point is in integer domain, zero point is added to the quantized integer value during
quantization
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
value during quantization
default is ZeroPointDomain.INT
Output:
dequantized Tensor, with requested dtype or fp32
Expand All @@ -265,22 +233,18 @@ def dequantize_affine(
if zero_point is not None:
zero_point = zero_point.view(shape_after_reduction)

if zero_point_domain == ZeroPointDomain.INT:
dequant = input.to(torch.int32)
if zero_point is not None:
dequant -= zero_point.to(torch.int32)
dequant = dequant.to(output_dtype)
dequant *= scale
else:
assert zero_point_domain == ZeroPointDomain.FLOAT, f"Unexpected zero point domain: {zero_point_domain}"
mid_point = (quant_max + quant_min + 1) / 2
dequant = input - mid_point
dequant = dequant.to(output_dtype)
dequant *= scale
if zero_point is not None:
dequant += zero_point
dequant = input.to(torch.int32)
if zero_point is not None:
dequant -= zero_point.to(torch.int32)
dequant = dequant.to(output_dtype)
dequant *= scale
dequant = dequant.view(original_shape)
return dequant.to(output_dtype)

return dequant.view(original_shape).to(output_dtype)

class MappingType(Enum):
SYMMETRIC = 0
ASYMMETRIC = 1

def choose_qparams_affine(
input: torch.Tensor,
Expand All @@ -292,8 +256,7 @@ def choose_qparams_affine(
eps: Optional[float] = None,
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
preserve_zero: bool = True,
zero_point_domain = ZeroPointDomain.INT,
preserve_zero = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
Expand All @@ -317,13 +280,6 @@ def choose_qparams_affine(
If we don't need zero to be exactly representable, we won't do rounding and clamping for zero_point
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float
if zero_point is in integer domain, zero point is added to the quantized integer value during
quantization
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
value during quantization
default is ZeroPointDomain.INT
Output:
Tuple of scales and zero_points Tensor with requested dtype
"""
Expand Down Expand Up @@ -354,18 +310,15 @@ def choose_qparams_affine(
scale = max_val_pos / (float(quant_max - quant_min) / 2)
if not preserve_zero:
raise ValueError("preserve_zero == False is not supported for symmetric quantization")
if zero_point_domain != ZeroPointDomain.INT:
raise ValueError("zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization")
zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2))
zero_point = torch.full_like(scale, int((quant_min + quant_max + 1) / 2))
else:
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
if preserve_zero:
zero_point = quant_min - torch.round(min_val_neg / scale)
zero_point = torch.clamp(zero_point, quant_min, quant_max)
else:
assert zero_point_domain == ZeroPointDomain.FLOAT, "if not preserve_zero, zero_point must be in FLOAT domain"
mid_point = (quant_max + quant_min + 1) / 2
zero_point = min_val_neg + scale * mid_point
zero_point = quant_min - min_val_neg / scale


if eps is None:
eps = torch.finfo(input.dtype).eps
Expand Down
Loading

0 comments on commit dba8e10

Please sign in to comment.