Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor QAT into its own module #555

Merged
merged 1 commit into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@

import torch
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
from torchao.quantization.prototype.qat import (
from torchao.quantization.prototype.qat.utils import (
_choose_qparams_per_token_asymmetric,
_fake_quantize_per_channel_group,
_fake_quantize_per_token,
_GenericFakeQuantize,
fake_quantize_per_channel_group,
fake_quantize_per_token,
)
from torchao.quantization.quant_primitives import (
fake_quantize_affine,
Expand Down Expand Up @@ -85,7 +85,7 @@ def test_fake_quantize_per_channel_group(self):
x2 = copy.deepcopy(x)

# fake quant op
out = fake_quantize_per_channel_group(
out = _fake_quantize_per_channel_group(
x, s, zp, qmin, qmax, group_size,
)
out.sum().backward()
Expand All @@ -110,7 +110,7 @@ def test_fake_quantize_per_token(self):
(s, zp) = _choose_qparams_per_token_asymmetric(x, torch.float32, torch.int32)

# fake quant op
out = fake_quantize_per_token(x, s, zp, qmin, qmax)
out = _fake_quantize_per_token(x, s, zp, qmin, qmax)
out.sum().backward()

# compare against PTQ ops
Expand All @@ -135,7 +135,7 @@ def _set_ptq_weight(
Int8DynActInt4WeightLinear,
WeightOnlyInt4Linear,
)
from torchao.quantization.prototype.qat import (
from torchao.quantization.prototype.qat.api import (
Int8DynActInt4WeightQATLinear,
Int4WeightOnlyQATLinear,
)
Expand Down Expand Up @@ -167,7 +167,7 @@ def _set_ptq_weight(

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_linear(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATLinear
from torchao.quantization.prototype.qat.api import Int8DynActInt4WeightQATLinear
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear

group_size = 128
Expand Down Expand Up @@ -425,7 +425,7 @@ def test_qat_4w_primitives(self):
# TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now")
def test_qat_4w_linear(self):
from torchao.quantization.prototype.qat import Int4WeightOnlyQATLinear
from torchao.quantization.prototype.qat.api import Int4WeightOnlyQATLinear
from torchao.quantization.GPTQ import WeightOnlyInt4Linear

group_size = 128
Expand Down
17 changes: 17 additions & 0 deletions torchao/quantization/prototype/qat/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from .api import (
disable_4w_fake_quant,
disable_8da4w_fake_quant,
enable_4w_fake_quant,
enable_8da4w_fake_quant,
Int4WeightOnlyQATQuantizer,
Int8DynActInt4WeightQATQuantizer,
)

__all__ = [
"disable_4w_fake_quant",
"disable_8da4w_fake_quant",
"enable_4w_fake_quant",
"enable_8da4w_fake_quant",
"Int4WeightOnlyQATQuantizer",
"Int8DynActInt4WeightQATQuantizer",
]
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, List, Optional, Tuple
from typing import Any, Optional

import torch
import torch.nn.functional as F
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
from torch.library import impl

from torchao.quantization.GPTQ import (
_check_linear_int4_k,
Expand All @@ -20,14 +18,13 @@
Int8DynActInt4WeightLinear,
WeightOnlyInt4Linear,
)
from torchao.quantization.quant_primitives import (
fake_quantize_affine_cachemask,
ZeroPointDomain,
)
from torchao.quantization.quant_primitives import ZeroPointDomain
from torchao.quantization.unified import TwoStepQuantizer
from torchao.quantization.utils import (
_get_per_token_block_size,
get_group_qparams_symmetric,
from torchao.quantization.utils import get_group_qparams_symmetric
from .utils import (
_choose_qparams_per_token_asymmetric,
_fake_quantize_per_channel_group,
_fake_quantize_per_token,
)


Expand Down Expand Up @@ -163,7 +160,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x, self.scales_precision, self.zero_points_precision,
)
(act_qmin, act_qmax) = self._get_qmin_qmax(8)
x_fq = fake_quantize_per_token(
x_fq = _fake_quantize_per_token(
x, act_scales, act_zp, act_qmin, act_qmax,
)
else:
Expand All @@ -177,7 +174,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# TODO: pass zp dtype to `get_group_qparams_symmetric` instead
weight_zp = weight_zp.to(self.zero_points_precision)
(weight_qmin, weight_qmax) = self._get_qmin_qmax(4)
w_fq = fake_quantize_per_channel_group(
w_fq = _fake_quantize_per_channel_group(
self.weight,
weight_scales,
weight_zp,
Expand Down Expand Up @@ -349,7 +346,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
scales, zero_points = get_groupwise_affine_qparams(
self.weight, n_bit, self.groupsize, self.scales_precision,
)
w_fq = fake_quantize_per_channel_group(
w_fq = _fake_quantize_per_channel_group(
self.weight,
scales,
zero_points,
Expand All @@ -373,135 +370,3 @@ def disable_4w_fake_quant(mod: torch.nn.Module):
"""
if isinstance(mod, Int4WeightOnlyQATLinear):
mod.disable_fake_quant()


# ========================
# | QUANT PRIMITIVES |
# ========================

class _GenericFakeQuantize(torch.autograd.Function):
"""
Implementation of generic fake quantize with backward STE.

With the appropriate input tensor shape, this can be used to express
grouped per channel fake quantize or per token fake quantize.
"""

@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
input: torch.Tensor,
scales: torch.Tensor,
zero_points: torch.Tensor,
quant_min: int,
quant_max: int,
block_size: List[int],
zero_point_domain: ZeroPointDomain=ZeroPointDomain.INT,
) -> torch.Tensor:
# Note: for bf16 inputs, casting them to fp32 has the unexpected
# side effect of reducing memory footprint significantly, presumably
# because bf16 * fp32 kernels are not as memory efficient
assert input.dtype == torch.float32
assert scales.dtype == torch.float32
assert zero_points.dtype == torch.int32

(fq, mask) = fake_quantize_affine_cachemask(
input,
block_size,
scales,
zero_points,
torch.int32,
quant_min,
quant_max,
zero_point_domain,
)

ctx.save_for_backward(mask)
return fq

@staticmethod
def backward(ctx, gy):
(mask,) = ctx.saved_tensors
return gy * mask, None, None, None, None, None, None

def fake_quantize_per_channel_group(
input: torch.Tensor,
scales: torch.Tensor,
zero_points: torch.Tensor,
quant_min: int,
quant_max: int,
group_size: int,
zero_point_domain: ZeroPointDomain=ZeroPointDomain.INT,
) -> torch.Tensor:
assert group_size > 1
assert input.shape[-1] % group_size == 0
assert input.dim() == 2
block_size = (1, group_size)
return _GenericFakeQuantize.apply(
input, scales, zero_points, quant_min, quant_max, block_size, zero_point_domain,
)

def fake_quantize_per_token(
input: torch.Tensor,
scales: torch.Tensor,
zero_points: torch.Tensor,
quant_min: int,
quant_max: int,
) -> torch.Tensor:
from torch.ao.quantization.fx._decomposed import _per_token_quant_qparam_dim_check

_per_token_quant_qparam_dim_check(input, scales, zero_points)
block_size = _get_per_token_block_size(input)
fq_input = input.to(torch.float32)
fq = _GenericFakeQuantize.apply(
fq_input, scales, zero_points, quant_min, quant_max, block_size,
)
return fq.reshape_as(input).to(input.dtype)

# TODO: This is copied from torch/ao/quantization/fx/_decomposed.py.
# The version in pytorch does not have backward support yet so we add
# it here for now until https://github.com/pytorch/pytorch/pull/123452
# is landed.
def _choose_qparams_per_token_asymmetric(
input: torch.Tensor,
scales_precision: torch.dtype = torch.float32,
zero_points_precision: torch.dtype = torch.float32,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Choose quantization parameters for per token quantization. This means for a N dimension Tensor
(M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize
every N elements with the same quantization parameter. The dimension for scales/zero_points
will be (M1 * M2 ... * Mn)

Args:
input (torch.Tensor): original float32/float16 Tensor
scales_precision (torch.dtype): precision of returned scales
zero_points_precision (torch.dtype): precision of returned zero points

Returns:
scales and zero_points, both float32 Tensors
"""
# Based on https://github.com/google/XNNPACK/blob/df156f0cf3db5a4576cc711123eeb54915f82ffc/src/xnnpack/quantization.h#L18
qmin, qmax = -128, 127
min_val = torch.amin(input, dim=-1, keepdim=True)
max_val = torch.amax(input, dim=-1, keepdim=True)
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
eps = torch.finfo(torch.float32).eps # use xnnpack eps?

# scale
scale = (max_val_pos - min_val_neg) / float(qmax - qmin)
scale = scale.clamp(min=eps)

# zero point
descaled_min = min_val_neg / scale
descaled_max = max_val_pos / scale
zero_point_from_min_error = qmin + descaled_min
zero_point_from_max_error = qmax + descaled_max
zero_point = torch.where(
zero_point_from_min_error + zero_point_from_max_error > 0,
qmin - descaled_min,
qmax - descaled_max,
)
zero_point = torch.clamp(zero_point, qmin, qmax).round()

return scale.to(scales_precision), zero_point.to(zero_points_precision)
Loading
Loading