Skip to content

Commit

Permalink
temp enable/disable fq
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
andrewor14 committed Aug 6, 2024
1 parent 886c05f commit 6288d74
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 94 deletions.
42 changes: 30 additions & 12 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@

import torch
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
from torchao.quantization.linear_activation_quantized_tensor import (
LinearActivationQuantizedTensor,
)
from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import (
AffineFakeQuantizedTensor,
)
from torchao.quantization.prototype.qat.utils import (
_choose_qparams_per_token_asymmetric,
_fake_quantize_per_channel_group,
Expand Down Expand Up @@ -252,6 +258,13 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
enable_8da4w_fake_quant,
)

def assert_fake_quant_enabled(m: torch.nn.Linear, enabled: bool):
assert isinstance(m.weight, LinearActivationQuantizedTensor)
self.assertEqual(m.weight.input_quant_func_enabled, enabled)
weight = m.weight.original_weight_tensor
self.assertTrue(isinstance(weight, AffineFakeQuantizedTensor))
self.assertEqual(weight.fake_quant_enabled, enabled)

group_size = 16
torch.manual_seed(self.SEED)
m = M()
Expand All @@ -260,9 +273,9 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
qat_model = quantizer.prepare(m)
qat_model.apply(disable_8da4w_fake_quant)
self.assertFalse(qat_model.linear1._fake_quant_enabled)
self.assertFalse(qat_model.linear2._fake_quant_enabled)
self.assertFalse(qat_model.sub.linear._fake_quant_enabled)
assert_fake_quant_enabled(qat_model.linear1, enabled=False)
assert_fake_quant_enabled(qat_model.linear2, enabled=False)
assert_fake_quant_enabled(qat_model.sub.linear, enabled=False)

# Disabled fake quant is just a normal linear
m2.linear1.weight = qat_model.linear1.weight
Expand All @@ -277,9 +290,9 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):

# Renable fake quant
qat_model.apply(enable_8da4w_fake_quant)
self.assertTrue(qat_model.linear1._fake_quant_enabled)
self.assertTrue(qat_model.linear2._fake_quant_enabled)
self.assertTrue(qat_model.sub.linear._fake_quant_enabled)
assert_fake_quant_enabled(qat_model.linear1, enabled=True)
assert_fake_quant_enabled(qat_model.linear2, enabled=True)
assert_fake_quant_enabled(qat_model.sub.linear, enabled=True)

# Fake quant should be applied as normal
quantizer2 = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
Expand All @@ -304,16 +317,21 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
disable_8da4w_fake_quant,
)

def get_qat_weight(qat_linear: torch.nn.Linear):
assert isinstance(qat_linear.weight, LinearActivationQuantizedTensor)
assert isinstance(qat_linear.weight.original_weight_tensor, AffineFakeQuantizedTensor)
return qat_linear.weight.original_weight_tensor.original_tensor

group_size = 16
torch.manual_seed(self.SEED)
m = M()
nn_model = copy.deepcopy(m)
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
qat_model = quantizer.prepare(m)
qat_model.apply(disable_8da4w_fake_quant)
nn_model.linear1.weight = qat_model.linear1.weight
nn_model.linear2.weight = qat_model.linear2.weight
nn_model.sub.linear.weight = qat_model.sub.linear.weight
nn_model.linear1.weight = torch.nn.Parameter(get_qat_weight(qat_model.linear1))
nn_model.linear2.weight = torch.nn.Parameter(get_qat_weight(qat_model.linear2))
nn_model.sub.linear.weight = torch.nn.Parameter(get_qat_weight(qat_model.sub.linear))

# Simulate training for both models
optimizer1 = torch.optim.SGD(nn_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
Expand All @@ -335,9 +353,9 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
optimizer2.step()

# After 1 training step, weights should match exactly
torch.testing.assert_close(nn_model.linear1.weight, qat_model.linear1.weight, atol=0, rtol=0)
torch.testing.assert_close(nn_model.linear2.weight, qat_model.linear2.weight, atol=0, rtol=0)
torch.testing.assert_close(nn_model.sub.linear.weight, qat_model.sub.linear.weight, atol=0, rtol=0)
torch.testing.assert_close(nn_model.linear1.weight, get_qat_weight(qat_model.linear1), atol=0, rtol=0)
torch.testing.assert_close(nn_model.linear2.weight, get_qat_weight(qat_model.linear2), atol=0, rtol=0)
torch.testing.assert_close(nn_model.sub.linear.weight, get_qat_weight(qat_model.sub.linear), atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_generic_fake_quantize(self):
Expand Down
25 changes: 17 additions & 8 deletions torchao/quantization/linear_activation_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __new__(
cls,
original_weight_tensor: torch.Tensor,
input_quant_func: Callable,
input_quant_func_enabled: bool = True,
):
kwargs = {}
dtype = original_weight_tensor.dtype
Expand All @@ -35,22 +36,25 @@ def __init__(
self,
original_weight_tensor: torch.Tensor,
input_quant_func: Callable,
input_quant_func_enabled: bool = True,
):
self.original_weight_tensor = original_weight_tensor
self.input_quant_func = input_quant_func
self.input_quant_func_enabled = input_quant_func_enabled

def __tensor_flatten__(self):
return ["original_weight_tensor"], [self.input_quant_func]
return ["original_weight_tensor"], [self.input_quant_func, self.input_quant_func_enabled]

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
original_weight_tensor = tensor_data_dict["original_weight_tensor"]
input_quant_func, = tensor_attributes
(input_quant_func, input_quant_func_enabled) = tensor_attributes
return cls(
original_weight_tensor,
input_quant_func,
input_quant_func_enabled,
)

@classmethod
Expand All @@ -61,8 +65,15 @@ def _apply_fn_to_data(self, fn):
return self.__class__(
fn(self.original_weight_tensor),
self.input_quant_func,
self.input_quant_func_enabled,
)

def apply_input_quant_func(self, t: torch.Tensor):
if self.input_quant_func_enabled:
return self.input_quant_func(t)
else:
return t

def _get_to_kwargs(self, *args, **kwargs):
device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
device = self.device if device is None else device
Expand All @@ -82,6 +93,7 @@ def to(self, *args, **kwargs):
return self.__class__(
self.original_weight_tensor.to(**kwargs),
self.input_quant_func,
self.input_quant_func_enabled,
)

implements = classmethod(_implements)
Expand All @@ -98,9 +110,8 @@ def _(func, types, *args, **kwargs):
args[2] if len(args) > 2 else None,
)
if isinstance(weight_tensor, LinearActivationQuantizedTensor):
input_quant_func = weight_tensor.input_quant_func
original_weight_tensor = weight_tensor.original_weight_tensor
aqt = input_quant_func(input_tensor)
aqt = weight_tensor.apply_input_quant_func(input_tensor)
return torch.nn.functional.linear(aqt, original_weight_tensor, bias)

raise NotImplementedError("LinearActivationQuantizedTensor: No specialized dispatch found for linear op")
Expand All @@ -120,9 +131,8 @@ def _(func, types, *args, **kwargs):
args[2],
args[0],
)
input_quant_func = weight_tensor.input_quant_func
original_weight_tensor = weight_tensor.original_weight_tensor
aqt = input_quant_func(input_tensor)
aqt = weight_tensor.apply_input_quant_func(input_tensor)
return func(bias, aqt, original_weight_tensor)
else:
# aten.mm.default
Expand All @@ -134,9 +144,8 @@ def _(func, types, *args, **kwargs):
args[0],
args[1],
)
input_quant_func = weight_tensor.input_quant_func
original_weight_tensor = weight_tensor.original_weight_tensor
aqt = input_quant_func(input_tensor)
aqt = weight_tensor.apply_input_quant_func(input_tensor)
return func(aqt, original_weight_tensor)


Expand Down
123 changes: 66 additions & 57 deletions torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from typing import Tuple, Optional
from typing import Callable, Optional, Tuple
from torchao.quantization.quant_primitives import (
_get_and_check_qmin_qmax,
choose_qparams_affine,
Expand Down Expand Up @@ -30,50 +30,49 @@ class AffineFakeQuantizedTensor(torch.Tensor):
regardless of the internal representation's type or orientation.
fields:
float_data (torch.Tensor): tensor holding the original float values, needed for actual quantization later
fq_data (torch.Tensor): tensor holding the fake quantized values
block_size (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam
e.g. when size is the same as the input tensor dimension, we are using per tensor quantization
shape (torch.Size): the shape for the Tensor
quant_min (Optional[int]): minimum quantized value for the Tensor
quant_max (Optional[int]): maximum quantized value for the Tensor
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float
if zero_point is in integer domain, zero point is added to the quantized integer value during quantization
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
value during quantization
default is ZeroPointDomain.INT
original_tensor (torch.Tensor): tensor holding the original float values, needed for actual quantization later
apply_fake_quant_fn (Callable): function that transforms `original_tensor` to fake quantized values
"""

@staticmethod
def __new__(
cls,
float_data: torch.Tensor,
fq_data: torch.Tensor,
original_tensor: torch.Tensor,
apply_fake_quant_fn: Callable,
fake_quant_enabled: bool = True,
):
kwargs = {}
kwargs["device"] = float_data.device
kwargs["dtype"] = float_data.dtype
kwargs["device"] = original_tensor.device
kwargs["dtype"] = original_tensor.dtype
kwargs["requires_grad"] = True
return torch.Tensor._make_wrapper_subclass(cls, float_data.shape, **kwargs) # type: ignore[attr-defined]
return torch.Tensor._make_wrapper_subclass(cls, original_tensor.shape, **kwargs) # type: ignore[attr-defined]

def __init__(
self,
float_data: torch.Tensor,
fq_data: torch.Tensor,
original_tensor: torch.Tensor,
apply_fake_quant_fn: Callable,
fake_quant_enabled: bool = True,
):
self.float_data = float_data
self.fq_data = fq_data
# TODO: original_tensor is not getting updated!
original_tensor.requires_grad_(self.requires_grad)
self.original_tensor = original_tensor
self.apply_fake_quant_fn = apply_fake_quant_fn
self.fake_quant_enabled = fake_quant_enabled

def __tensor_flatten__(self):
return ["float_data", "fq_data"], []
return ["original_tensor"], [self.apply_fake_quant_fn, self.fake_quant_enabled]

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride,
):
float_data = tensor_data_dict["float_data"]
fq_data = tensor_data_dict["fq_data"]
return cls(float_data, fq_data)
original_tensor = tensor_data_dict["original_tensor"]
(apply_fake_quant_fn, fake_quant_enabled) = tensor_attributes
return cls(
original_tensor,
apply_fake_quant_fn,
fake_quant_enabled,
)

@classmethod
def from_float(
Expand All @@ -90,30 +89,35 @@ def from_float(
preserve_zero: bool = True,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
):
quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
scale, zero_point = choose_qparams_affine(
input_float,
mapping_type,
block_size,
target_dtype,
quant_min,
quant_max,
eps,
scale_dtype,
zero_point_dtype,
preserve_zero,
zero_point_domain,
)
fq_data = _GenericFakeQuantize.apply(
input_float,
block_size,
scale,
zero_point,
quant_min,
quant_max,
zero_point_domain,
)
return cls(input_float, fq_data)
def apply_fake_quant_fn(t: torch.Tensor):
qmin, qmax = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
scale, zero_point = choose_qparams_affine(
t,
mapping_type,
block_size,
target_dtype,
qmin,
qmax,
eps,
scale_dtype,
zero_point_dtype,
preserve_zero,
zero_point_domain,
)
fq = _GenericFakeQuantize.apply(
t,
block_size,
scale,
zero_point,
qmin,
qmax,
zero_point_domain,
)
return fq
return cls(input_float, apply_fake_quant_fn)

def to_fake_quantized(self) -> torch.Tensor:
return self.apply_fake_quant_fn(self.original_tensor)

def _get_to_kwargs(self, *args, **kwargs):
device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
Expand All @@ -135,13 +139,18 @@ def to(self, *args, **kwargs):
# not supported yet
kwargs.pop("memory_format")
return self.__class__(
self.float_data.to(device),
self.fq_data.to(device),
self.original_tensor.to(device),
self.apply_fake_quant_fn,
self.fake_quant_enabled,
**kwargs,
)

def _apply_fn_to_data(self, fn):
return self.__class__(self.float_data, fn(self.fq_data))
return self.__class__(
fn(self.original_tensor),
self.apply_fake_quant_fn,
self.fake_quant_enabled,
)

implements = classmethod(_implements)
__torch_function__ = classmethod(_dispatch__torch_function__)
Expand All @@ -158,9 +167,9 @@ def _(func, types, *args, **kwargs):
args[2] if len(args) > 2 else None,
)
if isinstance(input_tensor, AffineFakeQuantizedTensor):
input_tensor = input_tensor.fq_data
input_tensor = input_tensor.to_fake_quantized()
if isinstance(weight_tensor, AffineFakeQuantizedTensor):
weight_tensor = weight_tensor.fq_data
weight_tensor = weight_tensor.to_fake_quantized()
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)

@implements([aten.mm.default, aten.addmm.default])
Expand All @@ -174,9 +183,9 @@ def _(func, types, *args, **kwargs):
input_tensor = args[input_index]
weight_tensor = args[input_index + 1]
if isinstance(input_tensor, AffineFakeQuantizedTensor):
input_tensor = input_tensor.fq_data
input_tensor = input_tensor.to_fake_quantized()
if isinstance(weight_tensor, AffineFakeQuantizedTensor):
weight_tensor = weight_tensor.fq_data
weight_tensor = weight_tensor.to_fake_quantized()
if bias is not None:
return func(bias, input_tensor, weight_tensor)
else:
Expand Down
Loading

0 comments on commit 6288d74

Please sign in to comment.