diff --git a/README.md b/README.md index 93b844c1c6..189217626d 100644 --- a/README.md +++ b/README.md @@ -179,12 +179,17 @@ With this quantization flow, we achieve **67% VRAM reduction and 12-20% speedup* Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization-Aware Training (QAT) to overcome this limitation, especially for lower bit-width dtypes such as int4. In collaboration with [TorchTune](https://github.com/pytorch/torchtune/blob/main/recipes/quantization.md#quantization-aware-training-qat), we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). For more details, please refer to the [QAT README](torchao/quantization/qat/README.md) and the [original blog](https://pytorch.org/blog/quantization-aware-training/): ```python -from torchao.quantization import quantize_ -from torchao.quantization.qat import IntxFakeQuantizeConfig, IntXQuantizationAwareTrainingConfig -activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) -weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) -qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config), -quantize_(my_model, qat_config) +from torchao.quantization import quantize_, Int8DynamicActivationInt4WeightConfig +from torchao.quantization.qat import QATConfig + +# prepare +base_config = Int8DynamicActivationInt4WeightConfig(group_size=32) +quantize_(my_model, QATConfig(base_config, step="prepare")) + +# train model (not shown) + +# convert +quantize_(my_model, QATConfig(base_config, step="convert")) ``` Users can also combine LoRA + QAT to speed up training by [1.89x](https://dev-discuss.pytorch.org/t/speeding-up-qat-by-1-89x-with-lora/2700) compared to vanilla QAT using this [fine-tuning recipe](https://github.com/pytorch/torchtune/blob/main/recipes/qat_lora_finetune_distributed.py). diff --git a/docs/source/api_ref_qat.rst b/docs/source/api_ref_qat.rst index b912e6ffef..bfac8f398d 100644 --- a/docs/source/api_ref_qat.rst +++ b/docs/source/api_ref_qat.rst @@ -6,7 +6,7 @@ torchao.quantization.qat .. currentmodule:: torchao.quantization.qat -QAT Configs for quantize_ +Main Config for quantize_ --------------------------------------- For a full example of how to use QAT with our main `quantize_` API, please refer to the `QAT README `__. @@ -15,8 +15,8 @@ please refer to the `QAT README `FakeQuantizedLinear` + base_config = Int8DynamicActivationInt4WeightConfig(group_size=32) + quantize_(model, QATConfig(base_config, step="prepare")) # fine-tune train_loop(model) @@ -232,18 +225,12 @@ The next step is to actually quantize the model: .. code:: py - from torchao.quantization import ( - Int8DynamicActivationInt4WeightConfig, - ) - from torchao.quantization.qat import ( - FromIntXQuantizationAwareTrainingConfig, - ) + from torchao.quantization import Int8DynamicActivationInt4WeightConfig - # convert: transform fake quantization ops into actual quantized ops - # swap `FakeQuantizedLinear` back to `torch.nn.Linear` and inserts - # quantized activation and weight tensor subclasses - quantize_(model, FromIntXQuantizationAwareTrainingConfig()) - quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32)) + # convert: swap `FakeQuantizedLinear` -> `torch.nn.Linear`, then quantize using `base_config` + quantize_(model, QATConfig(base_config, step="convert")) + + # inference or generate Now our model is ready for serving, and will typically have higher quantized accuracy than if we did not apply the prepare step (fake quantization) during diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index c83f64022b..bd6ede0af5 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -34,6 +34,8 @@ ComposableQATQuantizer, FromIntXQuantizationAwareTrainingConfig, IntXQuantizationAwareTrainingConfig, + QATConfig, + QATStep, initialize_fake_quantizers, ) from torchao.quantization.qat.embedding import ( @@ -59,7 +61,7 @@ _get_qmin_qmax, ) from torchao.quantization.quant_api import ( - int8_dynamic_activation_int4_weight, + Int8DynamicActivationInt4WeightConfig, ) from torchao.quantization.quant_primitives import ( MappingType, @@ -1261,11 +1263,67 @@ def test_qat_prototype_bc(self): @unittest.skipIf( not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" ) - def test_quantize_api_standalone(self): + def test_qat_config_init(self): + """ + Test that the correct errors are thrown if `QATConfig` is not instantiated properly. + """ + base_config = Int8DynamicActivationInt4WeightConfig(group_size=32) + fq_config = IntxFakeQuantizeConfig(torch.int8, "per_channel") + + # OK + QATConfig(base_config, step="prepare") + QATConfig(base_config, step="convert") + QATConfig(base_config, step=QATStep.PREPARE) + QATConfig(base_config, step=QATStep.CONVERT) + QATConfig(activation_config=fq_config, weight_config=fq_config, step="prepare") + QATConfig(weight_config=fq_config, step="prepare") + + # OK: good step values + self.assertEqual(QATConfig(base_config).step, "prepare") + self.assertEqual(QATConfig(base_config, step="Prepare").step, "prepare") + self.assertEqual(QATConfig(base_config, step="CONVERT").step, "convert") + + # Bad step + with self.assertRaisesRegex(ValueError, "`step` must be one of"): + QATConfig(base_config, step="blah") + + # Step was not a keyword arg + with self.assertRaisesRegex( + TypeError, "4 positional arguments but 5 were given" + ): + QATConfig(base_config, None, None, "prepare") + + # No configs are provided + with self.assertRaisesRegex( + ValueError, "One of `base_config` or `weight_config` must be specified" + ): + QATConfig(step="prepare") + + # Clashing configs are provided + with self.assertRaisesRegex(ValueError, "Cannot specify both"): + QATConfig(base_config, weight_config=fq_config, step="prepare") + with self.assertRaisesRegex(ValueError, "Cannot specify both"): + QATConfig(base_config, activation_config=fq_config, step="prepare") + with self.assertRaisesRegex( + ValueError, "must be specified in the convert step" + ): + QATConfig(weight_config=fq_config, step="convert") + + # FakeQuantizeConfigBase was specified as base_config + with self.assertRaisesRegex( + ValueError, + "was passed as `base_config`. Did you mean to do the following instead?", + ): + QATConfig(fq_config, step="prepare") + + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) + def test_quantize_api_prepare(self): """ Test that the following: - quantize_(model, IntXQuantizationAwareTrainingConfig(...)) + quantize_(model, QATConfig(...)) can produce the same results as `ComposableQATQuantizer`. """ @@ -1290,20 +1348,15 @@ def test_quantize_api_standalone(self): baseline_model = baseline_quantizer.prepare(baseline_model) # quantize_ API - activation_config = IntxFakeQuantizeConfig( - torch.int8, - "per_token", - is_symmetric=False, - ) + act_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=group_size) - quantize_( - m, - IntXQuantizationAwareTrainingConfig(activation_config, weight_config), + qat_config1 = QATConfig( + activation_config=act_config, weight_config=weight_config ) + qat_config2 = QATConfig(weight_config=weight_config) + quantize_(m, qat_config1) quantize_( - m, - IntXQuantizationAwareTrainingConfig(weight_config=weight_config), - filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding), + m, qat_config2, filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding) ) # Compare model values @@ -1322,37 +1375,29 @@ def test_quantize_api_errors(self): Test that we throw exceptions with helpful error messages if `quantize_` runs into unexpected configurations. """ - my_config = IntxFakeQuantizeConfig(torch.int8, group_size=32) + fq_config = IntxFakeQuantizeConfig(torch.int8, group_size=32) + qat_config = QATConfig(activation_config=fq_config, weight_config=fq_config) m = M3() # Embedding currently only supports weight-only quantization with self.assertRaisesRegex( ValueError, "Activation fake quantization is not supported for embedding" ): - quantize_( - m, - IntXQuantizationAwareTrainingConfig(my_config, my_config), - lambda m, _: isinstance(m, torch.nn.Embedding), - ) + quantize_(m, qat_config, lambda m, _: isinstance(m, torch.nn.Embedding)) # Only linear and embedding are supported currently with self.assertRaisesRegex(ValueError, "does not have QAT support"): - quantize_( - m, - IntXQuantizationAwareTrainingConfig(my_config, my_config), - lambda m, _: isinstance(m, torch.nn.ReLU), - ) + quantize_(m, qat_config, lambda m, _: isinstance(m, torch.nn.ReLU)) @unittest.skipIf( not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" ) - def test_quantize_api_convert_path(self): + def test_quantize_api_e2e(self): """ Test that the following: - quantize_(model, IntXQuantizationAwareTrainingConfig(...)) - quantize_(model, FromIntXQuantizationAwareTrainingConfig(...)) - quantize_(model, int8_dynamic_activation_int4_weight()) + quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="prepare")) + quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="convert")) can produce the same results as `Int8DynActInt4WeightQATQuantizer` prepare + convert. """ @@ -1370,16 +1415,8 @@ def test_quantize_api_convert_path(self): baseline_model = baseline_quantizer.prepare(baseline_model) # quantize_ prepare - activation_config = IntxFakeQuantizeConfig( - torch.int8, - "per_token", - is_symmetric=False, - ) - weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=group_size) - quantize_( - m, - IntXQuantizationAwareTrainingConfig(activation_config, weight_config), - ) + base_config = Int8DynamicActivationInt4WeightConfig(group_size=group_size) + quantize_(m, QATConfig(base_config, step="prepare")) # Compare prepared values torch.manual_seed(self.SEED) @@ -1393,8 +1430,7 @@ def test_quantize_api_convert_path(self): baseline_model = baseline_quantizer.convert(baseline_model) # quantize_ convert - quantize_(m, FromIntXQuantizationAwareTrainingConfig()) - quantize_(m, int8_dynamic_activation_int4_weight(group_size=group_size)) + quantize_(m, QATConfig(base_config, step="convert")) # Compare converted values torch.manual_seed(self.SEED) @@ -1447,14 +1483,12 @@ def test_qat_linear_bias(self): Test that QAT supports linear bias. """ m = ModelWithLinearBias() - activation_config = IntxFakeQuantizeConfig( - torch.int8, "per_token", is_symmetric=False - ) + act_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=32) - quantize_( - m, - IntXQuantizationAwareTrainingConfig(activation_config, weight_config), + qat_config = QATConfig( + activation_config=act_config, weight_config=weight_config ) + quantize_(m, qat_config) example_inputs = m.example_inputs() m(*example_inputs) @@ -1653,7 +1687,7 @@ def test_qat_range_learning(self): ) m = M() example_inputs = m.example_inputs() - quantize_(m, IntXQuantizationAwareTrainingConfig(weight_config=config)) + quantize_(m, QATConfig(weight_config=config)) # Not initialized, should fail for t in m._get_all_weight_qparams(): @@ -1756,6 +1790,60 @@ def test_qat_fp8a4w_quantizer(self): self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0) self.assertFalse(torch.equal(new_weight, prev_weight)) + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) + def test_legacy_quantize_api_e2e(self): + """ + Test that the following two APIs are numerically equivalent: + + New API: + quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="prepare")) + quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="convert")) + + Old API: + quantize_(model, IntXQuantizationAwareTrainingConfig(...)) + quantize_(model, FromIntXQuantizationAwareTrainingConfig()) + quantize_(model, Int8DynamicActivationInt4WeightConfig()) + """ + group_size = 16 + torch.manual_seed(self.SEED) + m = M() + baseline_model = copy.deepcopy(m) + + # Baseline prepare + act_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) + weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=group_size) + old_qat_config = IntXQuantizationAwareTrainingConfig(act_config, weight_config) + quantize_(baseline_model, old_qat_config) + + # QATConfig prepare + base_config = Int8DynamicActivationInt4WeightConfig(group_size=group_size) + quantize_(m, QATConfig(base_config, step="prepare")) + + # Compare prepared values + torch.manual_seed(self.SEED) + x = m.example_inputs() + x2 = copy.deepcopy(x) + out = m(*x) + baseline_out = baseline_model(*x2) + torch.testing.assert_close(out, baseline_out, atol=0, rtol=0) + + # Baseline convert + quantize_(baseline_model, FromIntXQuantizationAwareTrainingConfig()) + quantize_(baseline_model, base_config) + + # quantize_ convert + quantize_(m, QATConfig(base_config, step="convert")) + + # Compare converted values + torch.manual_seed(self.SEED) + x = m.example_inputs() + x2 = copy.deepcopy(x) + out = m(*x) + baseline_out = baseline_model(*x2) + torch.testing.assert_close(out, baseline_out, atol=0, rtol=0) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/qat/README.md b/torchao/quantization/qat/README.md index 777181b67e..9a11aa7b51 100644 --- a/torchao/quantization/qat/README.md +++ b/torchao/quantization/qat/README.md @@ -67,76 +67,85 @@ def train_loop(m: torch.nn.Module): optimizer.zero_grad() ``` + ### quantize_ API (recommended) -The recommended way to run QAT in torchao is through the `quantize_` API: -1. **Prepare:** specify how weights and/or activations are to be quantized through -[`IntxFakeQuantizeConfig`](https://docs.pytorch.org/ao/main/generated/torchao.quantization.qat.IntxFakeQuantizeConfig.html#torchao.quantization.qat.IntxFakeQuantizeConfig) and passing these to [`IntXQuantizationAwareTrainingConfig`](https://docs.pytorch.org/ao/main/generated/torchao.quantization.qat.IntXQuantizationAwareTrainingConfig.html#torchao.quantization.qat.IntXQuantizationAwareTrainingConfig) -2. **Convert:** quantize the model using the standard post-training quantization (PTQ) -functions such as [`Int8DynamicActivationInt4WeightConfig`](https://docs.pytorch.org/ao/main/generated/torchao.quantization.Int8DynamicActivationInt4WeightConfig.html#torchao.quantization.Int8DynamicActivationInt4WeightConfig) +The recommended way to run QAT in torchao is through the `quantize_` API. -For example: +1. **Prepare:** The main [`QATConfig`](https://docs.pytorch.org/ao/main/generated/torchao.quantization.qat.QATConfig.html) +accepts a post-training quantization (PTQ) config and automatically infers +the corresponding fake quantization configs to use. +2. **Convert:** quantize the model using the base config provided +Currently only the following PTQ base configs are supported: +- [`Int8DynamicActivationInt4WeightConfig`](https://docs.pytorch.org/ao/main/generated/torchao.quantization.Int8DynamicActivationInt4WeightConfig.html) +- [`Int4WeightOnlyConfig`](https://docs.pytorch.org/ao/main/generated/torchao.quantization.Int4WeightOnlyConfig.html) + +For example (most use cases): ```python -from torchao.quantization import ( - quantize_, - Int8DynamicActivationInt4WeightConfig, -) -from torchao.quantization.qat import ( - IntxFakeQuantizeConfig, - FromIntXQuantizationAwareTrainingConfig, - IntXQuantizationAwareTrainingConfig, -) +from torchao.quantization import quantize_, Int8DynamicActivationInt4WeightConfig +from torchao.quantization.qat import QATConfig + model = get_model() -# prepare: insert fake quantization ops -# swaps `torch.nn.Linear` with `FakeQuantizedLinear` -activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) -weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) -quantize_( - model, - IntXQuantizationAwareTrainingConfig(activation_config, weight_config), -) +# prepare: swap `torch.nn.Linear` -> `FakeQuantizedLinear` +base_config = Int8DynamicActivationInt4WeightConfig(group_size=32) +quantize_(model, QATConfig(base_config, step="prepare")) # train train_loop(model) -# convert: transform fake quantization ops into actual quantized ops -# swap `FakeQuantizedLinear` back to `torch.nn.Linear` and inserts -# quantized activation and weight tensor subclasses -quantize_(model, FromIntXQuantizationAwareTrainingConfig()) -quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32)) +# convert: swap `FakeQuantizedLinear` -> `torch.nn.Linear`, then quantize using `base_config` +quantize_(model, QATConfig(base_config, step="convert")) # inference or generate ``` -To fake quantize embedding in addition to linear, you can additionally call -the following with a filter function during the prepare step: +The `quantize_` API also allows more general quantization settings that +may not have a corresponding PTQ base config, e.g. for experimentation +purposes. Users can specify custom fake quantization configs for activations +and/or weights. For example, the following usage is numerically equivalent +to the above: -``` -# first apply linear transformation to the model as above +```python +from torchao.quantization import quantize_, Int8DynamicActivationInt4WeightConfig +from torchao.quantization.qat import IntxFakeQuantizeConfig, QATConfig + +model = get_model() + +# prepare: swap `torch.nn.Linear` -> `FakeQuantizedLinear` activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) -quantize_( - model, - IntXQuantizationAwareTrainingConfig(activation_config, weight_config), +qat_config = QATConfig( + activation_config=activation_config, + weight_config=weight_config, + step="prepare", ) +quantize_(model, qat_config) -# then apply weight-only transformation to embedding layers -# activation fake quantization is not supported for embedding layers -quantize_( - m, - IntXQuantizationAwareTrainingConfig(weight_config=weight_config), - filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding) -) +# train +train_loop(model) + +# convert: (not shown, same as before) +``` + +To fake quantize embedding in addition to linear, you can additionally call +the following with a filter function during the prepare step: + +``` +# First apply linear transformation to the model as above +# Then apply weight-only transformation to embedding layers +# (activation fake quantization is not supported for embedding layers) +qat_config = QATConfig(weight_config=weight_config, step="prepare") +quantize_(m, qat_config, filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding)) ``` ### Quantizer API (legacy) Alternatively, torchao provides a few hardcoded quantization settings through -the following Quantizers: +the following Quantizers, but these may be removed soon: - [Int8DynActInt4QATQuantizer](https://docs.pytorch.org/ao/main/generated/torchao.quantization.qat.Int8DynActInt4WeightQATQuantizer.html#torchao.quantization.qat.Int8DynActInt4WeightQATQuantizer) (linear), targeting int8 per-token dynamic asymmetric activation + int4 per-group symmetric weight - [Int4WeightOnlyQATQuantizer](https://docs.pytorch.org/ao/main/generated/torchao.quantization.qat.Int4WeightOnlyQATQuantizer.html#torchao.quantization.qat.Int4WeightOnlyQATQuantizer) (linear), targeting int4 per-group asymmetric weight using the efficient [int4 tinygemm kernel](https://github.com/pytorch/pytorch/blob/a672f6c84e318bbf455f13dfdd3fd7c68a388bf5/aten/src/ATen/native/cuda/int4mm.cu#L1097) after training) - [Int4WeightOnlyEmbeddingQATQuantizer](https://docs.pytorch.org/ao/main/generated/torchao.quantization.qat.Int4WeightOnlyEmbeddingQATQuantizer.html#torchao.quantization.qat.Int4WeightOnlyEmbeddingQATQuantizer) (embedding), targeting int4 per-group symmetric weight diff --git a/torchao/quantization/qat/__init__.py b/torchao/quantization/qat/__init__.py index 1035cd8a38..5d3d0996d0 100644 --- a/torchao/quantization/qat/__init__.py +++ b/torchao/quantization/qat/__init__.py @@ -2,6 +2,8 @@ ComposableQATQuantizer, FromIntXQuantizationAwareTrainingConfig, IntXQuantizationAwareTrainingConfig, + QATConfig, + QATStep, from_intx_quantization_aware_training, initialize_fake_quantizers, intx_quantization_aware_training, @@ -24,21 +26,25 @@ ) __all__ = [ - "ComposableQATQuantizer", + "QATConfig", + "QATStep", "FakeQuantizeConfigBase", + "IntxFakeQuantizeConfig", + "FakeQuantizer", "FakeQuantizedLinear", "FakeQuantizedEmbedding", - "FakeQuantizer", + # Prototype + "initialize_fake_quantizers", + # Legacy quantizers + "ComposableQATQuantizer", "Float8ActInt4WeightQATQuantizer", - "FromIntXQuantizationAwareTrainingConfig", "Int4WeightOnlyEmbeddingQATQuantizer", "Int4WeightOnlyQATQuantizer", "Int8DynActInt4WeightQATQuantizer", - "IntxFakeQuantizeConfig", - "IntXQuantizationAwareTrainingConfig", - "initialize_fake_quantizers", # for BC "FakeQuantizeConfig", "from_intx_quantization_aware_training", + "FromIntXQuantizationAwareTrainingConfig", "intx_quantization_aware_training", + "IntXQuantizationAwareTrainingConfig", ] diff --git a/torchao/quantization/qat/api.py b/torchao/quantization/qat/api.py index 22607269c8..0b7c1228b0 100644 --- a/torchao/quantization/qat/api.py +++ b/torchao/quantization/qat/api.py @@ -5,25 +5,230 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass +from enum import Enum from typing import Any, List, Optional, Tuple import torch from torchao.core.config import AOBaseConfig from torchao.quantization.transform_module import ( + _QUANTIZE_CONFIG_HANDLER, register_quantize_module_handler, ) from torchao.quantization.unified import TwoStepQuantizer +from .embedding import FakeQuantizedEmbedding from .fake_quantize_config import ( FakeQuantizeConfig, # noqa: F401, for BC FakeQuantizeConfigBase, + _infer_fake_quantize_configs, ) +from .linear import FakeQuantizedLinear + + +class QATStep(str, Enum): + """ + Enum value for the `step` field in :class:`~torchao.quantization.qat.QATConfig`. + """ + + PREPARE = "prepare" + CONVERT = "convert" +@dataclass +class QATConfig(AOBaseConfig): + """ + Config for applying quantization-aware training (QAT) to a `torch.nn.Module`, + to be used with :func:`~torchao.quantization.quant_api.quantize_`. + + This config has two steps, "prepare" and "convert". The prepare step applies + "fake" quantization to the model and should be applied before training, while + the convert step converts the model into an actual quantized model. Fake + quantization here refers to simulating the quantization numerics (e.g. int4) + using high precision arithmetic (e.g. bf16), with the goal of reducing + eventual degradation from quantization. + + There are two ways to use this config. The first involves passing a base + post-training quantization (PTQ) config, which we will use to automatically + infer the corresponding fake quantization schemes to use in the prepare phase. + In the convert phase, we will then apply the base PTQ config to the model. + This will be the most common use case. + + Example usage:: + + from torchao.quantization import ( + quantize_, + Int8DynamicActivationInt4WeightConfig, + ) + from torchao.quantization.qat import QATConfig + + base_config = Int8DynamicActivationInt4WeightConfig(group_size=32) + quantize_(model, QATConfig(base_config, step="prepare")) + train_loop(model) + quantize_(model, QATConfig(base_config, step="convert")) + + Currently only the following are supported as base configs: + + - :class:`~torchao.quantization.Int8DynamicActivationInt4WeightConfig` + - :class:`~torchao.quantization.Int4WeightOnlyConfig` + + The second way to use this config involves specifying the fake quantization + schemes directly. Users will pass in :class:`~torchao.quantization.qat.FakeQuantizeConfigBase` + for weights and/or activations instead of the base PTQ config. This use case + is mostly for experimentation, e.g. when the corresponding PTQ config does + not exist yet. + + Example usage:: + + from torchao.quantization import quantize_ + from torchao.quantization.qat import IntxFakeQuantizeConfig + + activation_config = IntxFakeQuantizeConfig( + torch.int8, "per_token", is_symmetric=False, + ) + weight_config = IntxFakeQuantizeConfig( + torch.int4, group_size=32, is_symmetric=True, + ) + qat_config = QATConfig( + # must specify one of `base_config` or `weight_config` + activation_config=act_config, + weight_config=weight_config, + step="prepare", + ) + quantize_(model, qat_config) + + Args: + base_config (Optional[AOBaseConfig]): Base PTQ config to infer the fake + quantization configs during the prepare phase, and to apply directly + during the convert phase. + activation_config (Optional[FakeQuantizeConfigBase]): Custom fake + quantization config for input activations, always optional. + Must be None if `base_config` is used. + weight_config (Optional[FakeQuantizeConfigBase]): Custom fake quantization + config for weights. Must be None if `base_config` is used. + + Keyword args: + step (str): One of "prepare" or "convert", determines the QAT phase + + Raises: + ValueError: If `base_config` and `activation_config` are both specified + ValueError: If `base_config` and `weight_config` are both specified + ValueError: If neither `base_config` nor `weight_config` is specified + ValueError: If `step` is not one of "prepare" or "convert" + ValueError: If `base_config` is None but `step` is "convert" + ValueError: If the config is applied on a module that is not a + `torch.nn.Linear` or `torch.nn.Embedding`, or it is applied on + `torch.nn.Embedding` with an activation config + """ + + base_config: Optional[AOBaseConfig] + activation_config: Optional[FakeQuantizeConfigBase] + weight_config: Optional[FakeQuantizeConfigBase] + step: QATStep + + # Express `step` as a keyword argument + # TODO: Use `kw_only=True` instead, added in python 3.10 + def __init__( + self, + base_config: Optional[AOBaseConfig] = None, + activation_config: Optional[FakeQuantizeConfigBase] = None, + weight_config: Optional[FakeQuantizeConfigBase] = None, + *, + step: QATStep = "prepare", + ): + self.base_config = base_config + self.activation_config = activation_config + self.weight_config = weight_config + self.step = step + self.__post_init__() + + def __post_init__(self): + self.step = self.step.lower() + all_step_values = [s.value for s in QATStep] + if self.step not in all_step_values: + raise ValueError(f"`step` must be one of {all_step_values}") + if self.base_config is None and self.weight_config is None: + raise ValueError( + "One of `base_config` or `weight_config` must be specified" + ) + if self.base_config is not None and self.activation_config is not None: + raise ValueError( + "Cannot specify both `base_config` and `activation_config`" + ) + if self.base_config is not None and self.weight_config is not None: + raise ValueError("Cannot specify both `base_config` and `weight_config`") + if self.base_config is None and self.step == "convert": + raise ValueError("`base_config` must be specified in the convert step") + if isinstance(self.base_config, FakeQuantizeConfigBase): + config_type = self.base_config.__class__.__name__ + raise ValueError( + f"{config_type} was passed as `base_config`. Did you mean to do the following instead?\n" + " qat_config = QATConfig(\n" + f" activation_config={config_type}(...),\n" + f" weight_config={config_type}(...),\n" + ' step="prepare",\n' + " )" + ) + + +@register_quantize_module_handler(QATConfig) +def _qat_config_transform( + module: torch.nn.Module, + config: QATConfig, +) -> torch.nn.Module: + """ + During the prepare step, perform module swap to apply fake quantization. + If the base PTQ config is specified, derive the fake quantization configs from it. + + During the convert step, first perform module swap to revert all fake quantized + modules to the corresponding built-in `torch.nn.Module`s, then apply the + base config directly to quantize the module. + """ + # Prepare step + # Swap nn.Linear -> FakeQuantizedLinear + # Swap nn.Embedding -> FakeQuantizedEmbedding + base_config = config.base_config + step = config.step + if step == QATStep.PREPARE: + if base_config is not None: + (act_config, weight_config) = _infer_fake_quantize_configs(base_config) + else: + act_config = config.activation_config + weight_config = config.weight_config + if isinstance(module, torch.nn.Linear): + return FakeQuantizedLinear.from_linear(module, act_config, weight_config) + elif isinstance(module, torch.nn.Embedding): + if act_config is not None: + raise ValueError( + "Activation fake quantization is not supported for embedding" + ) + return FakeQuantizedEmbedding.from_embedding(module, weight_config) + else: + raise ValueError( + "Module of type '%s' does not have QAT support" % type(module) + ) + else: + # Convert step + # Swap FakeQuantizedLinear -> nn.Linear + # Swap FakeQuantizedEmbedding -> nn.Embedding + # Then apply the base config's transform function to quantize the model + assert step == QATStep.CONVERT, "unexpected step '%s' in QATConfig" % step + assert base_config is not None, "expected `base_config` in convert step" + if isinstance(module, FakeQuantizedLinear): + module = module.to_linear() + elif isinstance(module, FakeQuantizedEmbedding): + module = module.to_embedding() + else: + # Unrelated module, ignore + return module + return _QUANTIZE_CONFIG_HANDLER[type(base_config)](module, base_config) + + +# TODO: deprecate @dataclass class IntXQuantizationAwareTrainingConfig(AOBaseConfig): """ + (Will be deprecated soon) Config for applying fake quantization to a `torch.nn.Module`. to be used with :func:`~torchao.quantization.quant_api.quantize_`. @@ -61,9 +266,6 @@ def _intx_quantization_aware_training_transform( module: torch.nn.Module, config: IntXQuantizationAwareTrainingConfig, ) -> torch.nn.Module: - from .embedding import FakeQuantizedEmbedding - from .linear import FakeQuantizedLinear - mod = module activation_config = config.activation_config weight_config = config.weight_config @@ -84,8 +286,10 @@ def _intx_quantization_aware_training_transform( raise ValueError("Module of type '%s' does not have QAT support" % type(mod)) +# TODO: deprecate class FromIntXQuantizationAwareTrainingConfig(AOBaseConfig): """ + (Will be deprecated soon) Config for converting a model with fake quantized modules, such as :func:`~torchao.quantization.qat.linear.FakeQuantizedLinear` and :func:`~torchao.quantization.qat.linear.FakeQuantizedEmbedding`, @@ -118,9 +322,6 @@ def _from_intx_quantization_aware_training_transform( If the given module is a fake quantized module, return the original corresponding version of the module without fake quantization. """ - from .embedding import FakeQuantizedEmbedding - from .linear import FakeQuantizedLinear - if isinstance(mod, FakeQuantizedLinear): return mod.to_linear() elif isinstance(mod, FakeQuantizedEmbedding): @@ -173,7 +374,7 @@ def initialize_fake_quantizers( ) -> None: """ (Prototype) Initialize the scales and zero points on all - :class:`~`torchao.quantization.qat.fake_quantizer.FakeQuantizer` + :class:`~torchao.quantization.qat.fake_quantizer.FakeQuantizer` in the model based on the provided example inputs. """ # avoid circular dependencies diff --git a/torchao/quantization/qat/fake_quantize_config.py b/torchao/quantization/qat/fake_quantize_config.py index 7369c02148..77b40267ad 100644 --- a/torchao/quantization/qat/fake_quantize_config.py +++ b/torchao/quantization/qat/fake_quantize_config.py @@ -6,10 +6,11 @@ import abc from dataclasses import dataclass -from typing import Any, Optional, Union +from typing import Any, Optional, Tuple, Union import torch +from torchao.core.config import AOBaseConfig from torchao.quantization.granularity import ( Granularity, PerAxis, @@ -36,7 +37,8 @@ class FakeQuantizeConfigBase(abc.ABC): @dataclass class IntxFakeQuantizeConfig(FakeQuantizeConfigBase): """ - Config for how to fake quantize weights or activations. + Config for how to fake quantize weights or activations, + targeting integer dtypes up to torch.int8. Args: dtype: dtype to simulate during fake quantization, e.g. torch.int8. @@ -259,3 +261,43 @@ def __setattr__(self, name: str, value: Any): # For BC FakeQuantizeConfig = IntxFakeQuantizeConfig + + +def _infer_fake_quantize_configs( + base_config: AOBaseConfig, +) -> Tuple[Optional[FakeQuantizeConfigBase], Optional[FakeQuantizeConfigBase]]: + """ + Given a base post-training quantization (PTQ) config, infer the corresponding + `FakeQuantizeConfigBase`s for both the activations and the weights. + This is called during the prepare phase of QAT. + + Return a 2-tuple of (activation_config, weight_config) for fake quantization. + """ + # avoid circular imports + from torchao.quantization import ( + Int4WeightOnlyConfig, + Int8DynamicActivationInt4WeightConfig, + ) + + if isinstance(base_config, Int8DynamicActivationInt4WeightConfig): + act_config = IntxFakeQuantizeConfig( + dtype=torch.int8, + granularity="per_token", + is_symmetric=base_config.act_mapping_type == MappingType.SYMMETRIC, + ) + weight_config = IntxFakeQuantizeConfig( + dtype=TorchAODType.INT4, + group_size=base_config.group_size, + is_symmetric=base_config.mapping_type == MappingType.SYMMETRIC, + ) + return (act_config, weight_config) + elif isinstance(base_config, Int4WeightOnlyConfig): + weight_config = IntxFakeQuantizeConfig( + dtype=torch.uint4, + group_size=base_config.group_size, + is_symmetric=False, + zero_point_domain=base_config.zero_point_domain, + ) + return (None, weight_config) + else: + raise ValueError("Unexpected base config: %s" % base_config) diff --git a/torchao/quantization/transform_module.py b/torchao/quantization/transform_module.py index 339d46be35..52bc721f1f 100644 --- a/torchao/quantization/transform_module.py +++ b/torchao/quantization/transform_module.py @@ -4,14 +4,14 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. import functools -from typing import Callable, Dict +from typing import Callable, Dict, Type import torch from torchao.core.config import AOBaseConfig _QUANTIZE_CONFIG_HANDLER: Dict[ - AOBaseConfig, + Type[AOBaseConfig], Callable[[torch.nn.Module, AOBaseConfig], torch.nn.Module], ] = {}