Skip to content

Commit

Permalink
Add decorator for custom op and inductor decomp registration
Browse files Browse the repository at this point in the history
Summary:
This PR adds a decorator to register custom op and also an inductor dcomposition.

The goal is for torch.export path to be able to see high level ops like quantize_affine instead of breaking down the op,
this is because some backends like xnnpack wants to work with these higher level ops.

Test Plan:
regression tests:
`python test/quantization/test_quant_api.py`
`python test/integration/test_integration.py`

also need to check performance with `python tutorials/quantize_vit/run_vit_b_quant.py`

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Jun 21, 2024
1 parent bc8599f commit ac2e283
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 111 deletions.
1 change: 0 additions & 1 deletion test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
choose_qparams_affine,
quantize_affine,
dequantize_affine,
MappingType,
)
from torchao.quantization.utils import (
dequantize_per_channel,
Expand Down
4 changes: 0 additions & 4 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@
from torchao.dtypes import (
AffineQuantizedTensor,
)
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
)
from torchao.quantization.subclass import (
LinearActQuantizedTensor,
Int8WeightOnlyQuantizedLinearWeight,
Expand Down
32 changes: 15 additions & 17 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
quantize_affine,
dequantize_affine,
choose_qparams_affine,
MappingType,
ZeroPointDomain,
)
# TODO: remove test for utils?
from torchao.quantization.utils import (
Expand Down Expand Up @@ -167,7 +165,7 @@ def test_choose_qparams_group_sym(self):
we don't include it here. We may just replace it with per block quant
"""
input = torch.randn(10, 10)
mapping_type = MappingType.SYMMETRIC
mapping_type = "symmetric"
dtype = torch.int8
block_size = (1, 2)
eps = torch.finfo(torch.float32).eps
Expand All @@ -183,7 +181,7 @@ def test_choose_qparams_group_sym(self):
@unittest.skipIf(is_fbcode(), "broken in fbcode")
def test_choose_qparams_token_asym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.ASYMMETRIC
mapping_type = "asymmetric"
dtype = torch.int8
block_size = (1, 10)
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
Expand All @@ -198,7 +196,7 @@ def test_choose_qparams_token_asym(self):
@unittest.skipIf(is_fbcode(), "broken in fbcode")
def test_choose_qparams_tensor_asym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.ASYMMETRIC
mapping_type = "asymmetric"
dtype = torch.int8
block_size = (10, 10)
eps = torch.finfo(torch.float32).eps
Expand All @@ -217,7 +215,7 @@ def test_choose_qparams_tensor_asym(self):
@unittest.skipIf(is_fbcode(), "broken in fbcode")
def test_choose_qparams_tensor_sym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.SYMMETRIC
mapping_type = "symmetric"
dtype = torch.int8
block_size = (10, 10)
eps = torch.finfo(torch.float32).eps
Expand All @@ -237,7 +235,7 @@ def test_quantize_activation_per_token_abs_max(self):
input = torch.randn(10, 10)
quantized_ref, scale_ref = quantize_activation_per_token_absmax(input)

mapping_type = MappingType.SYMMETRIC
mapping_type = "symmetric"
block_size = list(input.shape)
for i in range(len(block_size) - 1):
block_size[i] = 1
Expand Down Expand Up @@ -278,7 +276,7 @@ def test_quantize_activation_per_token_abs_max_dtype(self):
@unittest.skipIf(is_fbcode(), "broken in fbcode")
def test_quantize_dequantize_group_sym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.SYMMETRIC
mapping_type = "symmetric"
dtype = torch.int8
block_size = (1, 2)
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
Expand All @@ -303,7 +301,7 @@ def test_quantize_dequantize_group_sym(self):
@unittest.skipIf(is_fbcode(), "broken in fbcode")
def test_quantize_dequantize_channel_asym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.ASYMMETRIC
mapping_type = "asymmetric"
dtype = torch.int8
block_size = (10, 1)
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
Expand All @@ -327,7 +325,7 @@ def test_quantize_dequantize_channel_asym(self):
@unittest.skipIf(is_fbcode(), "broken in fbcode")
def test_quantize_dequantize_tensor_asym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.ASYMMETRIC
mapping_type = "asymmetric"
dtype = torch.int8
block_size = (10, 10)
output_dtype = torch.float32
Expand All @@ -351,7 +349,7 @@ def test_quantize_dequantize_tensor_asym(self):
@unittest.skipIf(is_fbcode(), "broken in fbcode")
def test_quantize_dequantize_channel_asym_4d(self):
input = torch.randn(3, 3, 10, 10)
mapping_type = MappingType.ASYMMETRIC
mapping_type = "asymmetric"
dtype = torch.int8
block_size = (3, 3, 1, 10)
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
Expand All @@ -373,7 +371,7 @@ def test_quantize_dequantize_channel_asym_4d(self):
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch version is 2.3 or lower")
def test_quantize_dequantize_channel_asym_4d_multi_dim_reduction(self):
input = torch.randn(3, 3, 10, 10)
mapping_type = MappingType.ASYMMETRIC
mapping_type = "asymmetric"
dtype = torch.int8
block_size = (3, 3, 2, 2)
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
Expand All @@ -384,7 +382,7 @@ def test_quantize_dequantize_channel_asym_4d_multi_dim_reduction(self):

def test_choose_qparams_tensor_asym_eps(self):
input = torch.zeros(10, 10)
mapping_type = MappingType.ASYMMETRIC
mapping_type = "asymmetric"
dtype = torch.int8
block_size = (10, 10)
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype)
Expand All @@ -406,7 +404,7 @@ def test_raises(self):
"""Make sure some errors are raised when user requested an unsupported type of quantization
"""
input = torch.randn(10, 10)
mapping_type = MappingType.ASYMMETRIC
mapping_type = "asymmetric"
dtype = torch.int8
block_size = (10, 10)
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype)
Expand All @@ -425,7 +423,7 @@ def test_not_preserve_zero_not_supported(self):
"""Making sure preserve_zero == False is not supported for symmetric quant"""
input = torch.randn(10, 256)
n_bit = 4
mapping_type = MappingType.SYMMETRIC
mapping_type = "symmetric"
dtype = torch.int8
block_size = (1, 128)
quant_min = 0
Expand Down Expand Up @@ -453,7 +451,7 @@ def test_get_groupwise_affine_qparams(self):
n_bit = 4
scale_ref, zero_point_ref = _get_groupwise_affine_qparams(input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16)

mapping_type = MappingType.ASYMMETRIC
mapping_type = "asymmetric"
dtype = torch.int8
block_size = (1, 128)
quant_min = 0
Expand All @@ -473,7 +471,7 @@ def test_get_groupwise_affine_qparams(self):
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
preserve_zero=False,
zero_point_domain=ZeroPointDomain.FLOAT,
zero_point_domain="float",
)

self.assertTrue(torch.equal(scale, scale_ref))
Expand Down
20 changes: 9 additions & 11 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
choose_qparams_affine,
quantize_affine,
dequantize_affine,
ZeroPointDomain,
MappingType,
int_scaled_matmul,
)
from torchao.quantization.utils import (
Expand Down Expand Up @@ -98,12 +96,12 @@ class AffineQuantizedTensor(torch.Tensor):
shape (torch.Size): the shape for the Tensor
quant_min (Optional[int]): minimum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data`
quant_max (Optional[int]): maximum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data`
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float
zero_point_domain (str): the domain that zero_point is in, should be eitehr "int" or "float"
if zero_point is in integer domain, zero point is added to the quantized integer value during
quantization
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
value during quantization
default is ZeroPointDomain.INT
default is "int"
input_quant_func (Optional[Callable]): function for quantizing the input float Tensor to a quantized tensor subclass object, that takes float Tensor as input and outputs an AffineQuantizedTensor object
dtype: dtype for external representation of the tensor, e.g. torch.float32
"""
Expand All @@ -116,7 +114,7 @@ def __new__(
shape: torch.Size,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
zero_point_domain: str = "int",
dtype=None,
strides=None,
):
Expand All @@ -138,7 +136,7 @@ def __init__(
shape: torch.Size,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
zero_point_domain: str = "int",
dtype=None,
strides=None,
):
Expand Down Expand Up @@ -184,7 +182,7 @@ def __tensor_unflatten__(
def from_float(
cls,
input_float: torch.Tensor,
mapping_type: MappingType,
mapping_type: str,
block_size: Tuple[int, ...],
target_dtype: torch.dtype,
quant_min: Optional[int] = None,
Expand All @@ -193,7 +191,7 @@ def from_float(
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
preserve_zero: bool = True,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
zero_point_domain: str = "int",
extended_layout: str = "plain",
# TODO: this is only for "tensor_core_tiled", need to figure out
# the proper API for this arg
Expand Down Expand Up @@ -520,7 +518,7 @@ def get_plain(self):
target_dtype = torch.int32
quant_min = 0
quant_max = 15
zero_point_domain = ZeroPointDomain.FLOAT
zero_point_domain = "int"
assert len(block_size) == 2 and block_size[0] == 1
groupsize = block_size[-1]
dequantized = torch.ops.aten._weight_int4pack_mm(torch.eye(eye_shape, device=device, dtype=original_dtype), self.packed_weight, groupsize, self.scale_and_zero)
Expand Down Expand Up @@ -597,7 +595,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
weight_is_uint4 and
weight_qtensor.dtype == torch.bfloat16 and
len(weight_qtensor.shape) == 2 and
weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT and
weight_qtensor.zero_point_domain == "float" and
weight_qtensor.extended_layout == "tensor_core_tiled"
):
assert weight_qtensor.block_size[0] == 1, f"Requires groupwise quantization, got block_size: {block_size}"
Expand Down Expand Up @@ -640,7 +638,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
len(weight_qtensor.block_size) == 2 and
weight_qtensor.block_size[0] == 1 and
weight_qtensor.block_size[1] == weight_qtensor.shape[1] and
weight_qtensor.zero_point_domain == ZeroPointDomain.INT and
weight_qtensor.zero_point_domain == "int" and
weight_qtensor.extended_layout == "plain"
):
# TODO: enable cpu and mps efficient path
Expand Down
23 changes: 9 additions & 14 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@
to_linear_act_quantized,
)

from .quant_primitives import (
MappingType,
ZeroPointDomain,
)
from .weight_only import WeightOnlyInt8QuantLinear
from .unified import Quantizer, TwoStepQuantizer
from .GPTQ import (
Expand Down Expand Up @@ -270,15 +266,15 @@ def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tens
# weight settings
groupsize = 32
mapping_type = MappingType.ASYMMETRIC
mapping_type = "asymmetric"
block_size = (1, groupsize)
target_dtype = torch.int32
quant_min = 0
quant_max = 15
eps = 1e-6
preserve_zero = False
zero_point_dtype = torch.bfloat16
zero_point_domain = ZeroPointDomain.FLOAT
zero_point_domain = "float"
apply_weight_quant = lambda x: to_affine_quantized(
x, mapping_type, block_size, target_dtype, quant_min, quant_max, eps,
Expand Down Expand Up @@ -319,7 +315,7 @@ def apply_int8_dynamic_activation_int4_weight_quant(weight):
from torchao.dtypes import to_affine_quantized

# weight settings
mapping_type = MappingType.SYMMETRIC
mapping_type = "symmetric"
block_size = (1, group_size)
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
Expand All @@ -336,7 +332,7 @@ def get_per_token_block_size(x):
return block_size

# input settings
input_mapping_type = MappingType.ASYMMETRIC
input_mapping_type = "asymmetric"
input_target_dtype = torch.int8
input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype)

Expand All @@ -360,16 +356,15 @@ def int4_weight_only(group_size=128, inner_k_tiles=8):
def apply_int4_weight_only_quant(weight):
# avoid circular dep
from torchao.dtypes import to_affine_quantized

mapping_type = MappingType.ASYMMETRIC
mapping_type = "asymmetric"
block_size = (1, group_size)
target_dtype = torch.int32
quant_min = 0
quant_max = 15
eps = 1e-6
preserve_zero = False
zero_point_dtype = torch.bfloat16
zero_point_domain = ZeroPointDomain.FLOAT
zero_point_domain = "float"
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, extended_layout="tensor_core_tiled", inner_k_tiles=inner_k_tiles)

return apply_int4_weight_only_quant
Expand All @@ -383,7 +378,7 @@ def apply_int8wo_quant(weight):
# avoid circular dep
from torchao.dtypes import to_affine_quantized

mapping_type = MappingType.SYMMETRIC
mapping_type = "symmetric"
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64
Expand All @@ -406,7 +401,7 @@ def apply_int8_dynamic_activation_int8_weight_quant(weight):
# avoid circular dep
from torchao.dtypes import to_affine_quantized
# weight settings
mapping_type = MappingType.SYMMETRIC
mapping_type = "symmetric"
def get_weight_block_size(x):
return (1, x.shape[1])
target_dtype = torch.int8
Expand All @@ -420,7 +415,7 @@ def get_per_token_block_size(x):
block_size[i] = 1
return block_size

input_mapping_type = MappingType.SYMMETRIC
input_mapping_type = "symmetric"
input_target_dtype = torch.int8
input_eps = 1e-5
input_quant_min = -127
Expand Down
Loading

0 comments on commit ac2e283

Please sign in to comment.