Skip to content

Commit

Permalink
Add support for asymmetric act quant for int8 dynamic quant (#1131)
Browse files Browse the repository at this point in the history
* Add support for asymmetric activation quantization for int8 dynamic quant

Summary:
This is needed for executorch: https://github.com/pytorch/executorch/blob/01d878310a1e22791bc6be65566382cd5632ff10/examples/models/llama/source_transformation/quantize.py#L416

Test Plan:
python test/dtypes/test_affine_quantized.py

Reviewers:

Subscribers:

Tasks:

Tags:

* remove qmin/qmax

* remove eps

* fix if branch
  • Loading branch information
jerryzh168 authored Oct 22, 2024
1 parent 85ec209 commit 12e4acf
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 6 deletions.
2 changes: 2 additions & 0 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
int8_dynamic_activation_int8_semi_sparse_weight,
float8_weight_only,
)
from torchao.quantization.quant_primitives import MappingType
from torchao.dtypes import SemiSparseLayout
from torch.testing._internal import common_utils
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
Expand All @@ -26,6 +27,7 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool):
int8_weight_only(),
int8_dynamic_activation_int4_weight(),
int8_dynamic_activation_int8_weight(),
int8_dynamic_activation_int8_weight(act_mapping_type=MappingType.ASYMMETRIC),
]
if do_int4:
base_functions.append(int4_weight_only(group_size=32))
Expand Down
8 changes: 5 additions & 3 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
TORCH_VERSION_AT_LEAST_2_3,
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
)
from pathlib import Path
from torchao._models.llama.tokenizer import get_tokenizer
Expand Down Expand Up @@ -576,6 +577,7 @@ def test_quantized_tensor_subclass_int8_wo(self):

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.5 and below")
def test_quantized_tensor_subclass_int8_dyn_quant(self):
# use multiples of 1024 so that we don't need padding
m = ToyLinearModel(1024, 1024, 2048).eval().to(torch.bfloat16).to("cuda")
Expand Down Expand Up @@ -732,8 +734,8 @@ def test_multitensor_pad_unpad(self):
self.assertEqual(mt.count, 3)
mt.unpad()
self.assertEqual(mt.count, 1)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_multitensor_inplace_operation(self):
from torchao.quantization.GPTQ_MT import MultiTensor
Expand All @@ -742,7 +744,7 @@ def test_multitensor_inplace_operation(self):
mt += 1 # In-place addition
self.assertTrue(torch.equal(mt.values[0], torch.full((3, 3), 2)))




common_utils.instantiate_parametrized_tests(TestQuantFlow)
Expand Down
13 changes: 10 additions & 3 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
unwrap_tensor_subclass,
)
from .subclass import (
Expand Down Expand Up @@ -480,7 +481,10 @@ def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor:
"""
mapping_type = MappingType.ASYMMETRIC
target_dtype = torch.int8
return to_affine_quantized_intx(x, mapping_type, _get_per_token_block_size(x), target_dtype)
if TORCH_VERSION_AT_LEAST_2_6:
return to_affine_quantized_intx(x, mapping_type, _get_per_token_block_size(x), target_dtype, scale_dtype=torch.float64, zero_point_dtype=torch.int64)
else:
return to_affine_quantized_intx(x, mapping_type, _get_per_token_block_size(x), target_dtype)

def apply_int8_dynamic_activation_int4_weight_quant(weight, group_size=32, mapping_type=MappingType.SYMMETRIC):
"""This is defined here instead of local function to support serialization
Expand Down Expand Up @@ -589,7 +593,7 @@ def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor:
return to_affine_quantized_intx(x, mapping_type, _get_per_token_block_size(x), target_dtype, eps=eps, quant_min=quant_min, quant_max=quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)


def int8_dynamic_activation_int8_weight(layout=PlainLayout()):
def int8_dynamic_activation_int8_weight(layout=PlainLayout(), act_mapping_type=MappingType.SYMMETRIC):
"""
Applies int8 dynamic symmetric per-token activation and int8 per-channel weight
quantization to linear layers
Expand All @@ -612,7 +616,10 @@ def get_weight_block_size(x):
zero_point_dtype = torch.int64

# input settings
input_quant_func = _int8_symm_per_token_reduced_range_quant
if act_mapping_type == MappingType.SYMMETRIC:
input_quant_func = _int8_symm_per_token_reduced_range_quant
else:
input_quant_func = _int8_asymm_per_token_quant

block_size = get_weight_block_size(weight)
weight = to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, _layout=layout)
Expand Down

0 comments on commit 12e4acf

Please sign in to comment.