-
Notifications
You must be signed in to change notification settings - Fork 100
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
- Loading branch information
1 parent
796fce3
commit 3b7e221
Showing
2 changed files
with
257 additions
and
2 deletions.
There are no files selected for viewing
253 changes: 253 additions & 0 deletions
253
torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,253 @@ | ||
import torch | ||
from typing import Dict, Callable, Any, Tuple, Optional | ||
from collections import defaultdict | ||
import functools | ||
from torchao.quantization.quant_primitives import ( | ||
choose_qparams_affine, | ||
fake_quantize_affine, | ||
quantize_affine, | ||
dequantize_affine, | ||
ZeroPointDomain, | ||
MappingType, | ||
int_scaled_matmul, | ||
) | ||
from torchao.quantization.utils import ( | ||
pack_tinygemm_scales_and_zeros, | ||
) | ||
from torch.utils._python_dispatch import return_and_correct_aliasing | ||
from torchao.utils import find_multiple | ||
from torchao.dtypes.utils import ( | ||
_implements, | ||
_dispatch__torch_function__, | ||
_dispatch__torch_dispatch__, | ||
_register_layout_cls, | ||
_get_layout_tensor_constructor, | ||
LayoutType, | ||
is_device, | ||
) | ||
from typing import ClassVar | ||
from dataclasses import dataclass | ||
from torchao.utils import TORCH_VERSION_AFTER_2_5 | ||
|
||
aten = torch.ops.aten | ||
|
||
class AffineFakeQuantizedTensor(torch.Tensor): | ||
""" | ||
TODO(andrew): rewrite | ||
Affine quantized tensor subclass. Affine quantization means we quantize the floating point tensor with an affine transformation: | ||
quantized_tensor = float_tensor / scale + zero_point | ||
The shape and dtype of the tensor subclass represent how the tensor subclass looks externally, | ||
regardless of the internal representation's type or orientation. | ||
fields: | ||
layout_tensor (AQTLayout): tensor that serves as a general layout storage for the quantized data, | ||
e.g. storing plain tensors (int_data, scale, zero_point) or packed formats depending on device | ||
and operator/kernel | ||
block_size (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam | ||
e.g. when size is the same as the input tensor dimension, we are using per tensor quantization | ||
shape (torch.Size): the shape for the Tensor | ||
quant_min (Optional[int]): minimum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data` | ||
quant_max (Optional[int]): maximum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data` | ||
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float | ||
if zero_point is in integer domain, zero point is added to the quantized integer value during | ||
quantization | ||
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) | ||
value during quantization | ||
default is ZeroPointDomain.INT | ||
input_quant_func (Optional[Callable]): function for quantizing the input float Tensor to a quantized tensor subclass object, that takes float Tensor as input and outputs an AffineFakeQuantizedTensor object | ||
dtype: dtype for external representation of the tensor, e.g. torch.float32 | ||
""" | ||
|
||
@staticmethod | ||
def __new__( | ||
cls, | ||
fq_data: torch.Tensor, | ||
block_size: Tuple[int, ...], | ||
shape: torch.Size, | ||
quant_min: Optional[int] = None, | ||
quant_max: Optional[int] = None, | ||
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, | ||
dtype: torch.dtype = None, | ||
): | ||
kwargs = {} | ||
kwargs["dtype"] = dtype | ||
kwargs["requires_grad"] = True | ||
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] | ||
|
||
def __init__( | ||
self, | ||
fq_data: torch.Tensor, | ||
block_size: Tuple[int, ...], | ||
shape: torch.Size, | ||
quant_min: Optional[int] = None, | ||
quant_max: Optional[int] = None, | ||
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, | ||
dtype: torch.dtype = None, | ||
): | ||
self.fq_data = fq_data | ||
self.block_size = block_size | ||
self.quant_min = quant_min | ||
self.quant_max = quant_max | ||
self.zero_point_domain = zero_point_domain | ||
|
||
def __tensor_flatten__(self): | ||
return ["fq_data"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype] | ||
|
||
@classmethod | ||
def __tensor_unflatten__( | ||
cls, tensor_data_dict, tensor_attributes, outer_size, | ||
): | ||
fq_data = tensor_data_dict["fq_data"] | ||
block_size, shape, quant_min, quant_max, zero_point_domain, dtype = tensor_attributes | ||
return cls( | ||
fq_data, | ||
block_size, | ||
shape if outer_size is None else outer_size, | ||
quant_min, | ||
quant_max, | ||
zero_point_domain, | ||
dtype=dtype, | ||
) | ||
|
||
@classmethod | ||
def from_float( | ||
cls, | ||
input_float: torch.Tensor, | ||
mapping_type: MappingType, | ||
block_size: Tuple[int, ...], | ||
target_dtype: torch.dtype, | ||
quant_min: Optional[int] = None, | ||
quant_max: Optional[int] = None, | ||
eps: Optional[float] = None, | ||
scale_dtype: Optional[torch.dtype] = None, | ||
zero_point_dtype: Optional[torch.dtype] = None, | ||
preserve_zero: bool = True, | ||
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, | ||
): | ||
original_shape = input_float.shape | ||
|
||
scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain) | ||
fq_data = fake_quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain) | ||
|
||
return cls( | ||
fq_data, | ||
block_size, | ||
original_shape, | ||
quant_min, | ||
quant_max, | ||
zero_point_domain, | ||
dtype=input_float.dtype | ||
) | ||
|
||
def _get_to_kwargs(self, *args, **kwargs): | ||
device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) | ||
device = self.device if device is None else device | ||
dtype = self.dtype if dtype is None else dtype | ||
memory_format = ( | ||
memory_format if memory_format is not None else torch.preserve_format | ||
) | ||
kwargs = { | ||
"device": device, | ||
"dtype": dtype, | ||
"memory_format": memory_format, | ||
} | ||
return kwargs | ||
|
||
def to(self, *args, **kwargs): | ||
kwargs = self._get_to_kwargs(*args, **kwargs) | ||
device = kwargs.pop("device") | ||
# not supported yet | ||
kwargs.pop("memory_format") | ||
return self.__class__( | ||
self.fq_data.to(device), | ||
self.block_size, | ||
self.shape, | ||
self.quant_min, | ||
self.quant_max, | ||
self.zero_point_domain, | ||
**kwargs, | ||
) | ||
|
||
def _apply_fn_to_data(self, fn): | ||
return self.__class__( | ||
fn(self.fq_data), | ||
self.block_size, | ||
self.shape, | ||
self.quant_min, | ||
self.quant_max, | ||
self.zero_point_domain, | ||
dtype=self.dtype, | ||
) | ||
|
||
implements = classmethod(_implements) | ||
__torch_function__ = classmethod(_dispatch__torch_function__) | ||
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__) | ||
|
||
implements = AffineFakeQuantizedTensor.implements | ||
|
||
|
||
@implements(torch.nn.functional.linear) | ||
def _(func, types, *args, **kwargs): | ||
input_tensor, weight_tensor, bias = ( | ||
args[0], | ||
args[1], | ||
args[2] if len(args) > 2 else None, | ||
) | ||
with torch._C.DisableTorchFunctionSubclass(): | ||
return torch.nn.functional.linear(input_tensor, weight_tensor, bias) | ||
|
||
|
||
@implements([aten.mm.default, aten.addmm.default]) | ||
def _(func, types, *args, **kwargs): | ||
if func == aten.addmm.default: | ||
input_tensor, weight_tensor, bias = ( | ||
args[1], | ||
args[2], | ||
args[0], | ||
) | ||
print("here aten.addmm") | ||
with torch._C.DisableTorchFunctionSubclass(): | ||
return func(bias, input_tensor, weight_tensor) | ||
else: | ||
input_tensor, weight_tensor, bias = ( | ||
args[0], | ||
args[1], | ||
None | ||
) | ||
print("here aten.mm") | ||
with torch._C.DisableTorchFunctionSubclass(): | ||
return func(input_tensor, weight_tensor) | ||
|
||
|
||
@implements([aten.detach.default]) | ||
def _(func, types, *args, **kwargs): | ||
return return_and_correct_aliasing( | ||
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) | ||
) | ||
|
||
|
||
@implements([aten.clone.default]) | ||
def _(func, types, *args, **kwargs): | ||
return return_and_correct_aliasing( | ||
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) | ||
) | ||
|
||
|
||
@implements([aten._to_copy.default]) | ||
def _(func, types, *args, **kwargs): | ||
return return_and_correct_aliasing( | ||
func, | ||
args, | ||
kwargs, | ||
args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), | ||
) | ||
|
||
@implements([aten.t.default]) | ||
def _(func, types, *args, **kwargs): | ||
return return_and_correct_aliasing( | ||
func, args, kwargs, args[0]._apply_fn_to_data(torch.t) | ||
) | ||
|
||
to_affine_fake_quantized = AffineFakeQuantizedTensor.from_float |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters