Skip to content

Commit

Permalink
Some follow up fixes for quant primitives (pytorch#220)
Browse files Browse the repository at this point in the history
Summary:
att

Test Plan:
python test/quantization/test_quant_primitives.py -k test_raises

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 authored May 8, 2024
1 parent e8ee479 commit 2e252d0
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 6 deletions.
19 changes: 19 additions & 0 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,5 +279,24 @@ def test_get_group_qparams_symmetric_memory(self):
after_choose_qparams_mem_use = torch.cuda.memory_allocated()
self.assertTrue(after_choose_qparams_mem_use < 1.2 * original_mem_use)

def test_raises(self):
"""Make sure some errors are raised when user requested an unsupported type of quantization
"""
input = torch.randn(10, 10)
mapping_type = MappingType.ASYMMETRIC
dtype = torch.int8
block_size = (10, 10)
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype)


# make sure we can't quantize int32 tensors:
with self.assertRaisesRegex(AssertionError, "Unsupported input dtype:"):
_ = quantize_affine(input.to(torch.int32), block_size, scale, zero_point, dtype)

# block_size and scale/zero_point shape mismatch
block_size = (1, 1)
with self.assertRaisesRegex(RuntimeError, "is invalid for input of size 1"):
_ = quantize_affine(input, block_size, scale, zero_point, dtype)

if __name__ == "__main__":
unittest.main()
17 changes: 11 additions & 6 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,9 @@ def quantize_affine(
):
"""
Args:
input (torch.Tensor): original float32 or bfloat16 Tensor
input (torch.Tensor): original float32, float16 or bfloat16 Tensor
block_size: (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam
e.g. when size is the same as the input tensor dimension, we are using per tensor quantization
e.g. when size is the same as the input tensor dimension, we are using per tensor quantization
scale (float): quantization parameter for affine quantization
zero_point (int): quantization parameter for affine quantization
output_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
Expand All @@ -171,6 +171,8 @@ def quantize_affine(
quantized tensor with requested dtype
"""
# TODO: validations
# TODO: validate scale/zero_point dimensions are compatible with block_size
assert input.dtype in [torch.float32, torch.float16, torch.bfloat16], f"Unsupported input dtype: {input.dtype}"
quant_min, quant_max = _get_and_check_qmin_qmax(output_dtype, quant_min, quant_max)
shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size())
original_shape = input.shape
Expand Down Expand Up @@ -198,7 +200,7 @@ def dequantize_affine(
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
*,
output_dtype: Optional[torch.dtype] = None,
output_dtype: torch.dtype = torch.float32,
):
"""
Args:
Expand All @@ -210,13 +212,15 @@ def dequantize_affine(
dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
quant_min (Optional[int]): minimum quantized value for input Tensor
quant_max (Optional[int]): maximum quantized value for input Tensor
output_dtype (torch.dtype?): optional dtype for output Tensor, default is fp32
output_dtype (torch.dtype): dtype for output Tensor, default is fp32
Output:
dequantized Tensor, with requested dtype or fp32
"""
# TODO: validations
# TODO: validate scale/zero_point dimensions are compatible with block_size
assert input.dtype == input_dtype
assert output_dtype in [torch.float32, torch.float16, torch.bfloat16], f"Unsupported output dtype: {output_dtype}"
quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max)

shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size())
Expand All @@ -229,9 +233,10 @@ def dequantize_affine(
if zero_point is not None:
zero_point = zero_point.view(shape_after_reduction)

dequant = input.to(output_dtype)
dequant = input.to(torch.int32)
if zero_point is not None:
dequant -= zero_point
dequant -= zero_point.to(torch.int32)
dequant = dequant.to(output_dtype)
dequant *= scale
dequant = dequant.view(original_shape)
return dequant.to(output_dtype)
Expand Down

0 comments on commit 2e252d0

Please sign in to comment.