Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check and fix dequantize_affine is idempotent #309

Merged
merged 1 commit into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 29 additions & 15 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,24 @@
_SEED = 1234
torch.manual_seed(_SEED)

# Helper function to run a function twice
# and verify that the result is the same.
# Adds some verification to avoid side effects.
# NOTE:
# - Does not verify the args and kwargs are unchanged.
# - Assumes the output is a single Tensor
def check_idempotent(self, fn, *args, **kwargs):
output0 = fn(*args, **kwargs)
assert torch.is_tensor(output0)
output1 = fn(*args, **kwargs)
self.assertTrue(torch.equal(output0, output1), f"Expected given function {fn} to be idempotent.")
return output1


class TestQuantPrimitives(unittest.TestCase):
SEED = 123

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch version is 2.3 or lower")
def test_get_group_qparams_symmetric(self):
"""
Test that `get_group_qparams_symmetric` produces the exact same scales as
Expand Down Expand Up @@ -77,7 +91,7 @@ def test_choose_qparams_group_sym(self):
self.assertTrue(torch.equal(scale, scale_ref))
self.assertTrue(torch.equal(zero_point, zp_ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch version is 2.3 or lower")
def test_choose_qparams_token_asym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.ASYMMETRIC
Expand Down Expand Up @@ -127,7 +141,7 @@ def test_choose_qparams_tensor_sym(self):
self.assertTrue(torch.equal(scale, scale_ref))
self.assertTrue(torch.equal(zero_point, zp_ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_quantize_activation_per_token_abs_max(self):
from torchao.quantization.quant_primitives import quantize_activation_per_token_absmax
input = torch.randn(10, 10)
Expand All @@ -148,15 +162,15 @@ def test_quantize_activation_per_token_abs_max(self):
self.assertTrue(torch.equal(quantized, quantized_ref))
self.assertTrue(torch.equal(scale, scale_ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_quantize_activation_per_token_abs_max_zero_input(self):
from torchao.quantization.quant_primitives import quantize_activation_per_token_absmax
input = torch.zeros(10, 10)
# make sure it still works
quantized_ref, scale_ref = quantize_activation_per_token_absmax(input)


@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_quantize_activation_per_token_abs_max_dtype(self):
from torchao.quantization.quant_primitives import quantize_activation_per_token_absmax
input = torch.zeros(10, 10, dtype=torch.bfloat16)
Expand All @@ -172,7 +186,7 @@ def test_quantize_activation_per_token_abs_max_dtype(self):
self.assertTrue(scale_ref.dtype, torch.float32)


@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_quantize_dequantize_group_sym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.SYMMETRIC
Expand All @@ -181,7 +195,7 @@ def test_quantize_dequantize_group_sym(self):
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)

quantized = quantize_affine(input, block_size, scale, zero_point, dtype)
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32)
dequantized = check_idempotent(self, dequantize_affine, quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32)

group_size = 2
quant_min = -128
Expand All @@ -196,7 +210,7 @@ def test_quantize_dequantize_group_sym(self):
self.assertTrue(torch.equal(quantized, quantized_ref))
self.assertTrue(torch.equal(dequantized, dequantized_ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_quantize_dequantize_channel_asym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.ASYMMETRIC
Expand All @@ -205,7 +219,7 @@ def test_quantize_dequantize_channel_asym(self):
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
output_dtype = torch.float32
quantized = quantize_affine(input, block_size, scale, zero_point, dtype)
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=output_dtype)
dequantized = check_idempotent(self, dequantize_affine, quantized, block_size, scale, zero_point, dtype, output_dtype=output_dtype)

axis = 1
quant_min = -128
Expand All @@ -219,7 +233,7 @@ def test_quantize_dequantize_channel_asym(self):
self.assertTrue(torch.equal(quantized, quantized_ref))
self.assertTrue(torch.equal(dequantized, dequantized_ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_quantize_dequantize_tensor_asym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.ASYMMETRIC
Expand All @@ -228,7 +242,7 @@ def test_quantize_dequantize_tensor_asym(self):
output_dtype = torch.float32
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
quantized = quantize_affine(input, block_size, scale, zero_point, dtype)
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=output_dtype)
dequantized = check_idempotent(self, dequantize_affine, quantized, block_size, scale, zero_point, dtype, output_dtype=output_dtype)

axis = 1
quant_min = -128
Expand All @@ -242,15 +256,15 @@ def test_quantize_dequantize_tensor_asym(self):
self.assertTrue(torch.equal(quantized, quantized_ref))
self.assertTrue(torch.equal(dequantized, dequantized_ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_quantize_dequantize_channel_asym_4d(self):
input = torch.randn(3, 3, 10, 10)
mapping_type = MappingType.ASYMMETRIC
dtype = torch.int8
block_size = (3, 3, 1, 10)
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
quantized = quantize_affine(input, block_size, scale, zero_point, dtype)
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32)
dequantized = check_idempotent(self, dequantize_affine, quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32)

axis = 2
quant_min = -128
Expand All @@ -264,15 +278,15 @@ def test_quantize_dequantize_channel_asym_4d(self):
self.assertTrue(torch.equal(quantized, quantized_ref))
self.assertTrue(torch.equal(dequantized, dequantized_ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch version is 2.3 or lower")
def test_quantize_dequantize_channel_asym_4d_multi_dim_reduction(self):
input = torch.randn(3, 3, 10, 10)
mapping_type = MappingType.ASYMMETRIC
dtype = torch.int8
block_size = (3, 3, 2, 2)
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
quantized = quantize_affine(input, block_size, scale, zero_point, dtype)
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32)
dequantized = check_idempotent(self, dequantize_affine, quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32)
# we don't have corresponding ops in existing primitives, so just make sure it runs and it's close to float
torch.testing.assert_close(dequantized, input, rtol=2, atol=0.02)

Expand Down
6 changes: 5 additions & 1 deletion torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def dequantize_affine(
Output:
dequantized Tensor, with requested dtype or fp32
"""

# TODO: validations
# TODO: validate scale/zero_point dimensions are compatible with block_size
assert input.dtype == input_dtype
Expand All @@ -266,14 +267,17 @@ def dequantize_affine(
zero_point = zero_point.view(shape_after_reduction)

if zero_point_domain == ZeroPointDomain.INT:
dequant = input.to(torch.int32)
# Force a copy to avoid input modification due
# to upcoming in-place operations.
dequant = input.to(torch.int32, copy=True)
if zero_point is not None:
dequant -= zero_point.to(torch.int32)
dequant = dequant.to(output_dtype)
dequant *= scale
else:
assert zero_point_domain == ZeroPointDomain.FLOAT, f"Unexpected zero point domain: {zero_point_domain}"
mid_point = (quant_max + quant_min + 1) / 2
# This should allocate new memory and avoid input modification
dequant = input - mid_point
dequant = dequant.to(output_dtype)
dequant *= scale
Expand Down
Loading