Skip to content

Commit

Permalink
temp
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
andrewor14 committed Aug 1, 2024
1 parent 796fce3 commit 3b7e221
Show file tree
Hide file tree
Showing 2 changed files with 257 additions and 2 deletions.
253 changes: 253 additions & 0 deletions torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
import torch
from typing import Dict, Callable, Any, Tuple, Optional
from collections import defaultdict
import functools
from torchao.quantization.quant_primitives import (
choose_qparams_affine,
fake_quantize_affine,
quantize_affine,
dequantize_affine,
ZeroPointDomain,
MappingType,
int_scaled_matmul,
)
from torchao.quantization.utils import (
pack_tinygemm_scales_and_zeros,
)
from torch.utils._python_dispatch import return_and_correct_aliasing
from torchao.utils import find_multiple
from torchao.dtypes.utils import (
_implements,
_dispatch__torch_function__,
_dispatch__torch_dispatch__,
_register_layout_cls,
_get_layout_tensor_constructor,
LayoutType,
is_device,
)
from typing import ClassVar
from dataclasses import dataclass
from torchao.utils import TORCH_VERSION_AFTER_2_5

aten = torch.ops.aten

class AffineFakeQuantizedTensor(torch.Tensor):
"""
TODO(andrew): rewrite
Affine quantized tensor subclass. Affine quantization means we quantize the floating point tensor with an affine transformation:
quantized_tensor = float_tensor / scale + zero_point
The shape and dtype of the tensor subclass represent how the tensor subclass looks externally,
regardless of the internal representation's type or orientation.
fields:
layout_tensor (AQTLayout): tensor that serves as a general layout storage for the quantized data,
e.g. storing plain tensors (int_data, scale, zero_point) or packed formats depending on device
and operator/kernel
block_size (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam
e.g. when size is the same as the input tensor dimension, we are using per tensor quantization
shape (torch.Size): the shape for the Tensor
quant_min (Optional[int]): minimum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data`
quant_max (Optional[int]): maximum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data`
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float
if zero_point is in integer domain, zero point is added to the quantized integer value during
quantization
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
value during quantization
default is ZeroPointDomain.INT
input_quant_func (Optional[Callable]): function for quantizing the input float Tensor to a quantized tensor subclass object, that takes float Tensor as input and outputs an AffineFakeQuantizedTensor object
dtype: dtype for external representation of the tensor, e.g. torch.float32
"""

@staticmethod
def __new__(
cls,
fq_data: torch.Tensor,
block_size: Tuple[int, ...],
shape: torch.Size,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
dtype: torch.dtype = None,
):
kwargs = {}
kwargs["dtype"] = dtype
kwargs["requires_grad"] = True
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(
self,
fq_data: torch.Tensor,
block_size: Tuple[int, ...],
shape: torch.Size,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
dtype: torch.dtype = None,
):
self.fq_data = fq_data
self.block_size = block_size
self.quant_min = quant_min
self.quant_max = quant_max
self.zero_point_domain = zero_point_domain

def __tensor_flatten__(self):
return ["fq_data"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype]

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size,
):
fq_data = tensor_data_dict["fq_data"]
block_size, shape, quant_min, quant_max, zero_point_domain, dtype = tensor_attributes
return cls(
fq_data,
block_size,
shape if outer_size is None else outer_size,
quant_min,
quant_max,
zero_point_domain,
dtype=dtype,
)

@classmethod
def from_float(
cls,
input_float: torch.Tensor,
mapping_type: MappingType,
block_size: Tuple[int, ...],
target_dtype: torch.dtype,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
eps: Optional[float] = None,
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
preserve_zero: bool = True,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
):
original_shape = input_float.shape

scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
fq_data = fake_quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)

return cls(
fq_data,
block_size,
original_shape,
quant_min,
quant_max,
zero_point_domain,
dtype=input_float.dtype
)

def _get_to_kwargs(self, *args, **kwargs):
device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
device = self.device if device is None else device
dtype = self.dtype if dtype is None else dtype
memory_format = (
memory_format if memory_format is not None else torch.preserve_format
)
kwargs = {
"device": device,
"dtype": dtype,
"memory_format": memory_format,
}
return kwargs

def to(self, *args, **kwargs):
kwargs = self._get_to_kwargs(*args, **kwargs)
device = kwargs.pop("device")
# not supported yet
kwargs.pop("memory_format")
return self.__class__(
self.fq_data.to(device),
self.block_size,
self.shape,
self.quant_min,
self.quant_max,
self.zero_point_domain,
**kwargs,
)

def _apply_fn_to_data(self, fn):
return self.__class__(
fn(self.fq_data),
self.block_size,
self.shape,
self.quant_min,
self.quant_max,
self.zero_point_domain,
dtype=self.dtype,
)

implements = classmethod(_implements)
__torch_function__ = classmethod(_dispatch__torch_function__)
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)

implements = AffineFakeQuantizedTensor.implements


@implements(torch.nn.functional.linear)
def _(func, types, *args, **kwargs):
input_tensor, weight_tensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
with torch._C.DisableTorchFunctionSubclass():
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)


@implements([aten.mm.default, aten.addmm.default])
def _(func, types, *args, **kwargs):
if func == aten.addmm.default:
input_tensor, weight_tensor, bias = (
args[1],
args[2],
args[0],
)
print("here aten.addmm")
with torch._C.DisableTorchFunctionSubclass():
return func(bias, input_tensor, weight_tensor)
else:
input_tensor, weight_tensor, bias = (
args[0],
args[1],
None
)
print("here aten.mm")
with torch._C.DisableTorchFunctionSubclass():
return func(input_tensor, weight_tensor)


@implements([aten.detach.default])
def _(func, types, *args, **kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)


@implements([aten.clone.default])
def _(func, types, *args, **kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)


@implements([aten._to_copy.default])
def _(func, types, *args, **kwargs):
return return_and_correct_aliasing(
func,
args,
kwargs,
args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone),
)

@implements([aten.t.default])
def _(func, types, *args, **kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.t)
)

to_affine_fake_quantized = AffineFakeQuantizedTensor.from_float
6 changes: 4 additions & 2 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def _replace_with_custom_fn_if_matches_filter(
def _is_linear(mod, *args):
# avoid circular dep
from torchao.dtypes import AffineQuantizedTensor
from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import AffineFakeQuantizedTensor

# adding weight tensor subclass isinstance check to make sure the weight is only quantized once
# when it is shared by multiple linear modules
Expand All @@ -193,6 +194,7 @@ def _is_linear(mod, *args):
and not isinstance(mod.weight, AutoQuantizableLinearWeight)
and not isinstance(mod.weight, AffineQuantizedTensor)
and not isinstance(mod.weight, LinearActivationQuantizedTensor)
and not isinstance(mod.weight, AffineFakeQuantizedTensor)
)

import torch.nn.utils.parametrize as parametrize
Expand Down Expand Up @@ -257,9 +259,9 @@ def replace_conv2d_1x1(conv):
)


def _get_linear_subclass_inserter(constructor):
def _get_linear_subclass_inserter(constructor, requires_grad=False):
def insert_subclass(lin):
lin.weight = torch.nn.Parameter(constructor(lin.weight), requires_grad=False)
lin.weight = torch.nn.Parameter(constructor(lin.weight), requires_grad=requires_grad)
return lin

return insert_subclass
Expand Down

0 comments on commit 3b7e221

Please sign in to comment.