Skip to content

Commit

Permalink
Add decorator for custom op and inductor decomp registration
Browse files Browse the repository at this point in the history
Summary:
This PR adds a decorator to register custom op and also an inductor dcomposition.

The goal is for torch.export path to be able to see high level ops like quantize_affine instead of breaking down the op,
this is because some backends like xnnpack wants to work with these higher level ops.

Test Plan:
regression tests:
`python test/quantization/test_quant_api.py`
`python test/integration/test_integration.py`

also need to check performance with `python tutorials/quantize_vit/run_vit_b_quant.py`

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Jun 21, 2024
1 parent e6460c2 commit cf3234c
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 94 deletions.
3 changes: 1 addition & 2 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
choose_qparams_affine,
quantize_affine,
dequantize_affine,
MappingType,
)
from torchao.quantization.utils import (
dequantize_per_channel,
Expand Down Expand Up @@ -1436,7 +1435,7 @@ def test_get_model_size_aqt(self, api, test_device, test_dtype):
api(model)
size2 = torchao.utils.get_model_size_in_bytes(model)
self.assertTrue(size2 < size)




Expand Down
4 changes: 0 additions & 4 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@
from torchao.dtypes import (
AffineQuantizedTensor,
)
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
)
from torchao.quantization.subclass import (
LinearActQuantizedTensor,
Int8WeightOnlyQuantizedLinearWeight,
Expand Down
20 changes: 9 additions & 11 deletions torchao/dtypes/aqt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
choose_qparams_affine,
quantize_affine,
dequantize_affine,
ZeroPointDomain,
MappingType,
int_scaled_matmul,
)
from torchao.quantization.utils import (
Expand Down Expand Up @@ -98,12 +96,12 @@ class AffineQuantizedTensor(torch.Tensor):
shape (torch.Size): the shape for the Tensor
quant_min (Optional[int]): minimum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data`
quant_max (Optional[int]): maximum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data`
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float
zero_point_domain (str): the domain that zero_point is in, should be eitehr "int" or "float"
if zero_point is in integer domain, zero point is added to the quantized integer value during
quantization
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
value during quantization
default is ZeroPointDomain.INT
default is "int"
input_quant_func (Optional[Callable]): function for quantizing the input float Tensor to a quantized tensor subclass object, that takes float Tensor as input and outputs an AffineQuantizedTensor object
dtype: dtype for external representation of the tensor, e.g. torch.float32
"""
Expand All @@ -116,7 +114,7 @@ def __new__(
shape: torch.Size,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
zero_point_domain: str = "int",
dtype=None,
strides=None,
):
Expand All @@ -138,7 +136,7 @@ def __init__(
shape: torch.Size,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
zero_point_domain: str = "int",
dtype=None,
strides=None,
):
Expand Down Expand Up @@ -184,7 +182,7 @@ def __tensor_unflatten__(
def from_float(
cls,
input_float: torch.Tensor,
mapping_type: MappingType,
mapping_type: str,
block_size: Tuple[int, ...],
target_dtype: torch.dtype,
quant_min: Optional[int] = None,
Expand All @@ -193,7 +191,7 @@ def from_float(
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
preserve_zero: bool = True,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
zero_point_domain: str = "int",
extended_layout: str = "plain",
# TODO: this is only for "tensor_core_tiled", need to figure out
# the proper API for this arg
Expand Down Expand Up @@ -520,7 +518,7 @@ def get_plain(self):
target_dtype = torch.int32
quant_min = 0
quant_max = 15
zero_point_domain = ZeroPointDomain.FLOAT
zero_point_domain = "int"
assert len(block_size) == 2 and block_size[0] == 1
groupsize = block_size[-1]
dequantized = torch.ops.aten._weight_int4pack_mm(torch.eye(eye_shape, device=device, dtype=original_dtype), self.packed_weight, groupsize, self.scale_and_zero)
Expand Down Expand Up @@ -597,7 +595,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
weight_is_uint4 and
weight_qtensor.dtype == torch.bfloat16 and
len(weight_qtensor.shape) == 2 and
weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT and
weight_qtensor.zero_point_domain == "float" and
weight_qtensor.extended_layout == "tensor_core_tiled"
):
assert weight_qtensor.block_size[0] == 1, f"Requires groupwise quantization, got block_size: {block_size}"
Expand Down Expand Up @@ -640,7 +638,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
len(weight_qtensor.block_size) == 2 and
weight_qtensor.block_size[0] == 1 and
weight_qtensor.block_size[1] == weight_qtensor.shape[1] and
weight_qtensor.zero_point_domain == ZeroPointDomain.INT and
weight_qtensor.zero_point_domain == "int" and
weight_qtensor.extended_layout == "plain"
):
# TODO: enable cpu and mps efficient path
Expand Down
22 changes: 9 additions & 13 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@
to_linear_act_quantized,
)

from .quant_primitives import (
MappingType,
ZeroPointDomain,
)
from .weight_only import WeightOnlyInt8QuantLinear
from .unified import Quantizer, TwoStepQuantizer
from .GPTQ import (
Expand Down Expand Up @@ -272,15 +268,15 @@ def quantize(model: torch.nn.Module, apply_tensor_subclass: Union[str, Callable[
# weight settings
groupsize = 32
mapping_type = MappingType.ASYMMETRIC
mapping_type = "asymmetric"
block_size = (1, groupsize)
target_dtype = torch.int32
quant_min = 0
quant_max = 15
eps = 1e-6
preserve_zero = False
zero_point_dtype = torch.bfloat16
zero_point_domain = ZeroPointDomain.FLOAT
zero_point_domain = "float"
apply_weight_quant = lambda x: to_affine_quantized(
x, mapping_type, block_size, target_dtype, quant_min, quant_max, eps,
Expand Down Expand Up @@ -321,7 +317,7 @@ def apply_8da4w_quant(weight):
from torchao.dtypes import to_affine_quantized

# weight settings
mapping_type = MappingType.SYMMETRIC
mapping_type = "symmetric"
block_size = (1, groupsize)
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
Expand All @@ -338,7 +334,7 @@ def get_per_token_block_size(x):
return block_size

# input settings
input_mapping_type = MappingType.ASYMMETRIC
input_mapping_type = "asymmetric"
input_target_dtype = torch.int8
input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype)

Expand All @@ -363,15 +359,15 @@ def apply_int4wo_quant(weight):
# avoid circular dep
from torchao.dtypes import to_affine_quantized

mapping_type = MappingType.ASYMMETRIC
mapping_type = "asymmetric"
block_size = (1, groupsize)
target_dtype = torch.int32
quant_min = 0
quant_max = 15
eps = 1e-6
preserve_zero = False
zero_point_dtype = torch.bfloat16
zero_point_domain = ZeroPointDomain.FLOAT
zero_point_domain = "float"
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, extended_layout="tensor_core_tiled", inner_k_tiles=inner_k_tiles)

return apply_int4wo_quant
Expand All @@ -385,7 +381,7 @@ def apply_int8wo_quant(weight):
# avoid circular dep
from torchao.dtypes import to_affine_quantized

mapping_type = MappingType.SYMMETRIC
mapping_type = "symmetric"
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64
Expand All @@ -407,7 +403,7 @@ def apply_int8dyn_quant(weight):
# avoid circular dep
from torchao.dtypes import to_affine_quantized
# weight settings
mapping_type = MappingType.SYMMETRIC
mapping_type = "symmetric"
def get_weight_block_size(x):
return (1, x.shape[1])
target_dtype = torch.int8
Expand All @@ -421,7 +417,7 @@ def get_per_token_block_size(x):
block_size[i] = 1
return block_size

input_mapping_type = MappingType.SYMMETRIC
input_mapping_type = "symmetric"
input_target_dtype = torch.int8
input_eps = 1e-5
input_quant_min = -127
Expand Down
Loading

0 comments on commit cf3234c

Please sign in to comment.