Skip to content

Commit

Permalink
Enable dispatch to tinygemm int4 and int8 kernels for unified quantiz…
Browse files Browse the repository at this point in the history
…ed tensor

Summary:
This adds some dispatch to the tinygemm kernels for cuda, although need to resolve implementation
mismatch problem for tinygemm first

Test Plan:
TODO

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed May 8, 2024
1 parent b34d1ac commit 230ebf6
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 4 deletions.
40 changes: 40 additions & 0 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,46 @@ def get_per_token_block_size(x):
ref = m_copy(*example_inputs)
self.assertTrue(torch.equal(res, ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
def test_quantized_tensor_subclass_int4(self):
from torchao.quantization.subclass import AffineQuantizedTensor
from torchao.quantization.quant_primitives import MappingType
import copy

# weight settings
groupsize = 32
mapping_type = MappingType.ASYMMETRIC
block_size = (1, groupsize)
target_dtype = torch.int8
eps = torch.finfo(torch.bfloat16).eps
quant_min = -8
quant_max = 7
preserve_zero = False

# weight only quantization
input_quant_func = None

m = ToyLinearModel().eval().to(torch.bfloat16)
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs()))

def to_quantized(weight):
return AffineQuantizedTensor.from_float(weight, mapping_type, block_size, target_dtype, preserve_zero, quant_min, quant_max, eps, input_quant_func=input_quant_func)

m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False)
m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False)
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)

# reference
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors
change_linear_weights_to_int4_woqtensors(m_copy, groupsize=groupsize)

res = m(*example_inputs)
ref = m_copy(*example_inputs)

self.assertTrue(torch.equal(res, ref))




Expand Down
59 changes: 55 additions & 4 deletions torchao/quantization/subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ class AffineQuantizedTensor(torch.Tensor):
shape (torch.Size): the shape for the Tensor
quant_min (Optional[int]): minimum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data`
quant_max (Optional[int]): maximum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data`
input_quant_func (Optional[Callable]): function for quantizing the input float Tensor to a quantized tensor subclass object, that takes input Tensor as input and outputs an AffineQuantizedTensor object
input_quant_func (Optional[Callable]): function for quantizing the input float Tensor to a quantized tensor subclass object, that takes float Tensor as input and outputs an AffineQuantizedTensor object
dtype: dtype for external representation of the tensor, e.g. torch.float32
"""

Expand All @@ -642,6 +642,7 @@ def __new__(
quant_max: Optional[int] = None,
input_quant_func: Optional[Callable] = None,
dtype=None,
# TODO: remove args and kwargs
*args,
**kwargs
):
Expand Down Expand Up @@ -684,7 +685,9 @@ def __repr__(self):
f"device={self.device}, dtype={self.dtype}, input_quant_func={self.input_quant_func}, requires_grad={self.requires_grad})"
)

def dequantize(self, output_dtype=torch.float32):
def dequantize(self, output_dtype=None):
if output_dtype is None:
output_dtype = self.dtype
return dequantize_affine(self.int_data, self.block_size, self.scale, self.zero_point, self.int_data.dtype, self.quant_min, self.quant_max, output_dtype=output_dtype)

def __tensor_flatten__(self):
Expand Down Expand Up @@ -716,13 +719,15 @@ def from_float(
mapping_type,
block_size,
target_dtype,
preserve_zero = True,
quant_min = None,
quant_max = None,
eps = None,
scale_dtype = None,
zero_point_dtype = None,
input_quant_func = None,
):
# TODO: add preserve_zero arg to choose_qparams_affine
scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype)
int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max)
return cls(
Expand Down Expand Up @@ -810,7 +815,6 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
if (
func in [aten.mm.default, aten.addmm.default]
and args[0].is_floating_point()
and args[0].is_cuda
):
if func == aten.addmm.default:
assert args[1].shape[-1] == args[2].shape[0], (
Expand All @@ -832,7 +836,54 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
args[1],
None if len(args) == 2 else args[2],
)
if weight_qtensor.input_quant_func is not None:
if weight_qtensor.input_quant_func is None:
is_cuda = args[0].is_cuda
# weight only quantization
is_int8 = (
weight_qtensor.int_data.dtype == torch.int8 and
self.quant_min is None or self.quant_min == -128 and
self.quant_max is None or self.quant_max == 127
)
is_int4 = (
weight_qtensor.int_data.dtype == torch.int8 and
self.quant_min is None or self.quant_min == -8 and
self.quant_max is None or self.quant_max == 7
)

if (
is_cuda and
is_int4 and
len(weight_qtensor.shape) == 2 and
weight_qtensor.block_size[0] == 1
):
# groupwise int4 quantization
# TODO: currently doing packing on the fly, we'll need to figure out
# the API to do packing before hand
# TODO: zero_point transform
# TODO: expose the arg
innerKTiles = 8
packed_weight = torch.ops.aten._convert_weight_to_int4pack(weight_qtensor.int_data, innerKTiles)
groupsize = weight_qtensor.block_size[-1]
# adjust zero_point to be compatible with tinygemm
def int_zero_point_to_float(zero_point, scale, qaunt_min, mid_point):
return (quant_min - zero_point + mid_point) * scale

mid_point = 8
zero_point_float = int_zero_point_to_float(weight_qtensor.zero_point, weight_qtensor.scale, weight_qtensor.quant_min, mid_point)
scales_and_zeros = pack_tinygemm_scales_and_zeros(weight_qtensor.scale, zero_point_float)
return _weight_int4pack_mm(input_tensor.contiguous(), packed_weight, groupsize, scales_and_zeros)
elif (
is_cuda and
is_int8 and
len(weight_qtensor.shape) == 2 and
len(weight_qtensor.block_size) == 2 and
weight_qtensor.block_size[0] == 1 and
weight_qtensor.block_size[1] == weight_qtensor.shape[1]
):
# per channel int8 quantization
return _weight_int8pack_mm(input_tensor.contiguous(), weight_qtensor.int_data, weight_qtensor.scale)
else:
# dynamic quantization
input_tensor = weight_qtensor.input_quant_func(input_tensor)
input_tensor = input_tensor.dequantize()
weight_tensor = weight_qtensor.dequantize()
Expand Down

0 comments on commit 230ebf6

Please sign in to comment.