Skip to content

Commit

Permalink
Check dequantize_affine is idempotent (pytorch#309)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpuhrsch committed Jun 3, 2024
1 parent c44f077 commit e461c25
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 16 deletions.
44 changes: 29 additions & 15 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,24 @@
_SEED = 1234
torch.manual_seed(_SEED)

# Helper function to run a function twice
# and verify that the result is the same.
# Adds some verification to avoid side effects.
# NOTE:
# - Does not verify the args and kwargs are unchanged.
# - Assumes the output is a single Tensor
def check_idempotent(self, fn, *args, **kwargs):
output0 = fn(*args, **kwargs)
assert torch.is_tensor(output0)
output1 = fn(*args, **kwargs)
self.assertTrue(torch.equal(output0, output1), f"Expected given function {fn} to be idempotent.")
return output1


class TestQuantPrimitives(unittest.TestCase):
SEED = 123

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch version is 2.3 or lower")
def test_get_group_qparams_symmetric(self):
"""
Test that `get_group_qparams_symmetric` produces the exact same scales as
Expand Down Expand Up @@ -77,7 +91,7 @@ def test_choose_qparams_group_sym(self):
self.assertTrue(torch.equal(scale, scale_ref))
self.assertTrue(torch.equal(zero_point, zp_ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch version is 2.3 or lower")
def test_choose_qparams_token_asym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.ASYMMETRIC
Expand Down Expand Up @@ -127,7 +141,7 @@ def test_choose_qparams_tensor_sym(self):
self.assertTrue(torch.equal(scale, scale_ref))
self.assertTrue(torch.equal(zero_point, zp_ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_quantize_activation_per_token_abs_max(self):
from torchao.quantization.quant_primitives import quantize_activation_per_token_absmax
input = torch.randn(10, 10)
Expand All @@ -148,15 +162,15 @@ def test_quantize_activation_per_token_abs_max(self):
self.assertTrue(torch.equal(quantized, quantized_ref))
self.assertTrue(torch.equal(scale, scale_ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_quantize_activation_per_token_abs_max_zero_input(self):
from torchao.quantization.quant_primitives import quantize_activation_per_token_absmax
input = torch.zeros(10, 10)
# make sure it still works
quantized_ref, scale_ref = quantize_activation_per_token_absmax(input)


@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_quantize_activation_per_token_abs_max_dtype(self):
from torchao.quantization.quant_primitives import quantize_activation_per_token_absmax
input = torch.zeros(10, 10, dtype=torch.bfloat16)
Expand All @@ -172,7 +186,7 @@ def test_quantize_activation_per_token_abs_max_dtype(self):
self.assertTrue(scale_ref.dtype, torch.float32)


@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_quantize_dequantize_group_sym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.SYMMETRIC
Expand All @@ -181,7 +195,7 @@ def test_quantize_dequantize_group_sym(self):
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)

quantized = quantize_affine(input, block_size, scale, zero_point, dtype)
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32)
dequantized = check_idempotent(self, dequantize_affine, quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32)

group_size = 2
quant_min = -128
Expand All @@ -196,7 +210,7 @@ def test_quantize_dequantize_group_sym(self):
self.assertTrue(torch.equal(quantized, quantized_ref))
self.assertTrue(torch.equal(dequantized, dequantized_ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_quantize_dequantize_channel_asym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.ASYMMETRIC
Expand All @@ -205,7 +219,7 @@ def test_quantize_dequantize_channel_asym(self):
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
output_dtype = torch.float32
quantized = quantize_affine(input, block_size, scale, zero_point, dtype)
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=output_dtype)
dequantized = check_idempotent(self, dequantize_affine, quantized, block_size, scale, zero_point, dtype, output_dtype=output_dtype)

axis = 1
quant_min = -128
Expand All @@ -219,7 +233,7 @@ def test_quantize_dequantize_channel_asym(self):
self.assertTrue(torch.equal(quantized, quantized_ref))
self.assertTrue(torch.equal(dequantized, dequantized_ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_quantize_dequantize_tensor_asym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.ASYMMETRIC
Expand All @@ -228,7 +242,7 @@ def test_quantize_dequantize_tensor_asym(self):
output_dtype = torch.float32
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
quantized = quantize_affine(input, block_size, scale, zero_point, dtype)
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=output_dtype)
dequantized = check_idempotent(self, dequantize_affine, quantized, block_size, scale, zero_point, dtype, output_dtype=output_dtype)

axis = 1
quant_min = -128
Expand All @@ -242,15 +256,15 @@ def test_quantize_dequantize_tensor_asym(self):
self.assertTrue(torch.equal(quantized, quantized_ref))
self.assertTrue(torch.equal(dequantized, dequantized_ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_quantize_dequantize_channel_asym_4d(self):
input = torch.randn(3, 3, 10, 10)
mapping_type = MappingType.ASYMMETRIC
dtype = torch.int8
block_size = (3, 3, 1, 10)
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
quantized = quantize_affine(input, block_size, scale, zero_point, dtype)
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32)
dequantized = check_idempotent(self, dequantize_affine, quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32)

axis = 2
quant_min = -128
Expand All @@ -264,15 +278,15 @@ def test_quantize_dequantize_channel_asym_4d(self):
self.assertTrue(torch.equal(quantized, quantized_ref))
self.assertTrue(torch.equal(dequantized, dequantized_ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch version is 2.3 or lower")
def test_quantize_dequantize_channel_asym_4d_multi_dim_reduction(self):
input = torch.randn(3, 3, 10, 10)
mapping_type = MappingType.ASYMMETRIC
dtype = torch.int8
block_size = (3, 3, 2, 2)
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
quantized = quantize_affine(input, block_size, scale, zero_point, dtype)
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32)
dequantized = check_idempotent(self, dequantize_affine, quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32)
# we don't have corresponding ops in existing primitives, so just make sure it runs and it's close to float
torch.testing.assert_close(dequantized, input, rtol=2, atol=0.02)

Expand Down
6 changes: 5 additions & 1 deletion torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def dequantize_affine(
Output:
dequantized Tensor, with requested dtype or fp32
"""

# TODO: validations
# TODO: validate scale/zero_point dimensions are compatible with block_size
assert input.dtype == input_dtype
Expand All @@ -266,14 +267,17 @@ def dequantize_affine(
zero_point = zero_point.view(shape_after_reduction)

if zero_point_domain == ZeroPointDomain.INT:
dequant = input.to(torch.int32)
# Force a copy to avoid input modification due
# to upcoming in-place operations.
dequant = input.to(torch.int32, copy=True)
if zero_point is not None:
dequant -= zero_point.to(torch.int32)
dequant = dequant.to(output_dtype)
dequant *= scale
else:
assert zero_point_domain == ZeroPointDomain.FLOAT, f"Unexpected zero point domain: {zero_point_domain}"
mid_point = (quant_max + quant_min + 1) / 2
# This should allocate new memory and avoid input modification
dequant = input - mid_point
dequant = dequant.to(output_dtype)
dequant *= scale
Expand Down

0 comments on commit e461c25

Please sign in to comment.