Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[quant] Add per block quantization primitives #159

Merged
merged 3 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 186 additions & 2 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,21 @@
# This test takes a long time to run
import unittest
import torch
from torchao.quantization.quant_primitives import get_group_qparams_symmetric
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3
from torchao.quantization.quant_primitives import (
get_group_qparams_symmetric,
quantize_affine,
dequantize_affine,
choose_qparams_affine,
MappingType,
)

from torchao.quantization.utils import (
TORCH_VERSION_AFTER_2_3,
TORCH_VERSION_AFTER_2_4,
)

_SEED = 1234
torch.manual_seed(_SEED)

class TestQuantPrimitives(unittest.TestCase):
SEED = 123
Expand Down Expand Up @@ -46,5 +59,176 @@ def test_get_group_qparams_symmetric(self):
(scale_ao, _) = get_group_qparams_symmetric(weight, n_bit, groupsize)
torch.testing.assert_allclose(scale_obs, scale_ao, rtol=0, atol=0)

def test_choose_qparams_group_sym(self):
"""Note: groupwise asymmetric quant is using a different way of computing zero_points, so
we don't include it here. We may just replace it with per block quant
"""
input = torch.randn(10, 10)
mapping_type = MappingType.SYMMETRIC
dtype = torch.int8
block_size = (1, 2)
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)

scale_ref, zp_ref = get_group_qparams_symmetric(input, n_bit=8, groupsize=2)

self.assertTrue(torch.equal(scale, scale_ref))
self.assertTrue(torch.equal(zero_point, zp_ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
def test_choose_qparams_token_asym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.ASYMMETRIC
dtype = torch.int8
block_size = (1, 10)
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)

scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric(input, dtype)
scale_ref = scale_ref.squeeze()
zp_ref = zp_ref.squeeze()

torch.testing.assert_allclose(scale, scale_ref, atol=10e-3, rtol=10e-3)
self.assertTrue(torch.equal(zero_point, zp_ref))

def test_choose_qparams_tensor_asym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.ASYMMETRIC
dtype = torch.int8
block_size = (10, 10)
eps = torch.finfo(torch.float32).eps
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=eps)


quant_min = -128
quant_max = 127
scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams(input, quant_min, quant_max, eps, dtype)
scale_ref = scale_ref.squeeze()
zp_ref = zp_ref.squeeze()

self.assertTrue(torch.equal(scale, scale_ref))
self.assertTrue(torch.equal(zero_point, zp_ref))

def test_choose_qparams_tensor_sym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.SYMMETRIC
dtype = torch.int8
block_size = (10, 10)
eps = torch.finfo(torch.float32).eps
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=eps)

quant_min = -128
quant_max = 127
scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams_symmetric(input, quant_min, quant_max, eps, dtype)
scale_ref = scale_ref.squeeze()
zp_ref = zp_ref.squeeze()

self.assertTrue(torch.equal(scale, scale_ref))
self.assertTrue(torch.equal(zero_point, zp_ref))


@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
def test_quantize_dequantize_group_sym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.SYMMETRIC
dtype = torch.int8
block_size = (1, 2)
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)

quantized = quantize_affine(input, block_size, scale, zero_point, dtype)
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32)

group_size = 2
quant_min = -128
quant_max = 127
quantized_ref = torch.ops.quantized_decomposed.quantize_per_channel_group(
input, scale, zero_point, quant_min, quant_max, torch.int8, group_size
)
dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_channel_group(
quantized_ref, scale, zero_point, quant_min, quant_max, torch.int8, group_size, output_dtype=torch.float32
)

self.assertTrue(torch.equal(quantized, quantized_ref))
self.assertTrue(torch.equal(dequantized, dequantized_ref))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this tell you how many elements were different and by how much? Should we use this instead?

torch.testing.assert_close(quantized, quantized_ref, atol=0, rtol=0)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this one is equal actually


@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
def test_quantize_dequantize_channel_asym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.ASYMMETRIC
dtype = torch.int8
block_size = (10, 1)
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
output_dtype = torch.float32
quantized = quantize_affine(input, block_size, scale, zero_point, dtype)
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=output_dtype)

axis = 1
quant_min = -128
quant_max = 127
quantized_ref = torch.ops.quantized_decomposed.quantize_per_channel(
input, scale, zero_point, axis, quant_min, quant_max, torch.int8
)
dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_channel(
quantized_ref, scale, zero_point, axis, quant_min, quant_max, torch.int8, out_dtype=output_dtype
)
self.assertTrue(torch.equal(quantized, quantized_ref))
self.assertTrue(torch.equal(dequantized, dequantized_ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
def test_quantize_dequantize_tensor_asym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.ASYMMETRIC
dtype = torch.int8
block_size = (10, 10)
output_dtype = torch.float32
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
quantized = quantize_affine(input, block_size, scale, zero_point, dtype)
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=output_dtype)

axis = 1
quant_min = -128
quant_max = 127
quantized_ref = torch.ops.quantized_decomposed.quantize_per_tensor(
input, scale, zero_point, quant_min, quant_max, torch.int8
)
dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_tensor(
quantized_ref, scale, zero_point, quant_min, quant_max, torch.int8, out_dtype=output_dtype
)
self.assertTrue(torch.equal(quantized, quantized_ref))
self.assertTrue(torch.equal(dequantized, dequantized_ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
def test_quantize_dequantize_channel_asym_4d(self):
input = torch.randn(3, 3, 10, 10)
mapping_type = MappingType.ASYMMETRIC
dtype = torch.int8
block_size = (3, 3, 1, 10)
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
quantized = quantize_affine(input, block_size, scale, zero_point, dtype)
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32)

axis = 2
quant_min = -128
quant_max = 127
quantized_ref = torch.ops.quantized_decomposed.quantize_per_channel(
input, scale, zero_point, axis, quant_min, quant_max, torch.int8
)
dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_channel(
quantized_ref, scale, zero_point, axis, quant_min, quant_max, torch.int8, out_dtype=torch.float32
)
self.assertTrue(torch.equal(quantized, quantized_ref))
self.assertTrue(torch.equal(dequantized, dequantized_ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
def test_quantize_dequantize_channel_asym_4d_multi_dim_reduction(self):
input = torch.randn(3, 3, 10, 10)
mapping_type = MappingType.ASYMMETRIC
dtype = torch.int8
block_size = (3, 3, 2, 2)
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
quantized = quantize_affine(input, block_size, scale, zero_point, dtype)
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32)
# we don't have corresponding ops in existing primitives, so just make sure it runs and it's close to float
torch.testing.assert_allclose(dequantized, input, rtol=2, atol=0.02)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add tests where you expect exceptions thrown



if __name__ == "__main__":
unittest.main()
1 change: 0 additions & 1 deletion torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,4 +391,3 @@ def autoquant(model, example_input, qtensor_class_list=DEFAULT_CLASS_LIST, filte
model(*example_input)
change_autoquantizable_to_quantized(model, **kwargs)
return model

Loading
Loading