Skip to content

Commit

Permalink
Add composable QAT quantizer (#938)
Browse files Browse the repository at this point in the history
Summary: This is a utility for users who wish to apply multiple
QAT quantizers to their models. In the near future, we expect
to add an embedding QAT quantizer that composes with the
existing linear QAT quantizers.

Test Plan:
python test/quantization/test_qat.py -k test_composable_qat_quantizer
  • Loading branch information
andrewor14 authored Sep 25, 2024
1 parent fbe97a0 commit 6a4d064
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 7 deletions.
42 changes: 42 additions & 0 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from torchao.dtypes import (
TensorCoreTiledLayoutType,
)
from torchao.quantization.prototype.qat.api import (
ComposableQATQuantizer,
)
from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import (
AffineFakeQuantizedTensor,
)
Expand All @@ -34,6 +37,9 @@
MappingType,
ZeroPointDomain,
)
from torchao.quantization.unified import (
TwoStepQuantizer,
)
from torchao.quantization.utils import (
get_group_qparams_symmetric,
get_groupwise_affine_qparams,
Expand Down Expand Up @@ -626,6 +632,42 @@ def test_qat_4w_quantizer_module_swap(self):
module_swap_out = module_swap_model(*x2)
torch.testing.assert_close(subclass_out, module_swap_out, atol=0, rtol=0)

class _MyQATQuantizer(TwoStepQuantizer):
"""
Dummy quantizer that attaches a certain value to each nn.Linear's
`_temp_quantizer_values` attribute.
"""
ATTR_NAME = "_temp_quantizer_values"

def __init__(self, value: str):
self.value = value

def _attach_value(self, module: torch.nn.Module):
if isinstance(module, torch.nn.Linear):
if not hasattr(module, self.ATTR_NAME):
setattr(module, self.ATTR_NAME, [])
getattr(module, self.ATTR_NAME).append(self.value)

def prepare(self, model: torch.nn.Module):
model.apply(self._attach_value)
return model

def convert(self, model: torch.nn.Module):
model.apply(self._attach_value)
return model

def test_composable_qat_quantizer(self):
quantizer1 = self._MyQATQuantizer("quantizer1")
quantizer2 = self._MyQATQuantizer("quantizer2")
composable_quantizer = ComposableQATQuantizer([quantizer1, quantizer2])
model = M()
model = composable_quantizer.prepare(model)
self.assertTrue(hasattr(model.linear1, self._MyQATQuantizer.ATTR_NAME))
values_list = getattr(model.linear1, self._MyQATQuantizer.ATTR_NAME)
self.assertEqual(values_list, ["quantizer1", "quantizer2"])
composable_quantizer.convert(model)
values_list = getattr(model.linear1, self._MyQATQuantizer.ATTR_NAME)
self.assertEqual(values_list, ["quantizer1", "quantizer2", "quantizer1", "quantizer2"])

if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions torchao/quantization/prototype/qat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
enable_8da4w_fake_quant,
int4_weight_only_fake_quantize,
int8_dynamic_activation_int4_weight_fake_quantize,
ComposableQATQuantizer,
Int4WeightOnlyQATQuantizer,
Int8DynActInt4WeightQATQuantizer,
)
Expand All @@ -20,6 +21,7 @@
"enable_8da4w_fake_quant",
"int4_weight_only_fake_quantize",
"int8_dynamic_activation_int4_weight_fake_quantize",
"ComposableQATQuantizer",
"Int4WeightOnlyQATQuantizer",
"Int8DynActInt4WeightQATQuantizer",
"Int8DynActInt4WeightQATLinear",
Expand Down
46 changes: 43 additions & 3 deletions torchao/quantization/prototype/qat/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Optional
from typing import Any, List, Optional

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -34,6 +34,44 @@
)


class ComposableQATQuantizer(TwoStepQuantizer):
"""
Composable quantizer that users can use to apply multiple QAT quantizers easily.
Quantizers will be applied in the order they are specified in the constructor.
Note: the quantizers provided must apply to different modules in the model,
e.g. nn.Linear and nn.Embedding, otherwise the behavior will be undefined.
Example usage::
my_quantizer = ComposableQATQuantizer([
QATQuantizer1(),
QATQuantizer2(),
QATQuantizer3(),
])
model = my_quantizer.prepare(model)
train(model)
model = my_quantizer.convert(model)
"""

def __init__(self, quantizers: List[TwoStepQuantizer]):
self.quantizers = quantizers

def prepare(
self, model: torch.nn.Module, *args: Any, **kwargs: Any
) -> torch.nn.Module:
for quantizer in self.quantizers:
model = quantizer.prepare(model)
return model

def convert(
self, model: torch.nn.Module, *args: Any, **kwargs: Any
) -> torch.nn.Module:
for quantizer in self.quantizers:
model = quantizer.convert(model)
return model


# =================
# | 8da4w QAT |
# =================
Expand All @@ -44,7 +82,8 @@ def int8_dynamic_activation_int4_weight_fake_quantize(group_size=32):
int4 per group weight symmetric fake quantization to linear. Please see
:func:`~torchao.quantization.int8_dynamic_activation_int4_weight` for more details.
Example usage:
Example usage::
from torchao.quantization import quantize_
quantize_(model, int8_dynamic_activation_int4_weight_fake_quantize(group_size=32))
"""
Expand Down Expand Up @@ -151,7 +190,8 @@ def int4_weight_only_fake_quantize(group_size=128):
Applies uint4 weight-only asymmetric per-group fake quantization to linear layers.
Please see :func:`~torchao.quantization.int4_weight_only` for more details.
Example usage:
Example usage::
from torchao.quantization import quantize_
quantize_(model, int4_weight_only_fake_quantize(group_size=32))
"""
Expand Down
6 changes: 2 additions & 4 deletions torchao/quantization/unified.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from typing import Any
from typing import Any, List
from abc import ABC, abstractmethod

"""
Expand All @@ -17,7 +17,6 @@ class Quantizer(ABC):
def quantize(
self, model: torch.nn.Module, *args: Any, **kwargs: Any
) -> torch.nn.Module:

pass


Expand All @@ -27,11 +26,10 @@ class TwoStepQuantizer:
def prepare(
self, model: torch.nn.Module, *args: Any, **kwargs: Any
) -> torch.nn.Module:

pass

@abstractmethod
def convert(
self, model: torch.nn.Module, *args: Any, **kwargs: Any
) -> torch.nn.Module:

pass

0 comments on commit 6a4d064

Please sign in to comment.