Skip to content

Commit add867a

Browse files
committed
Update base for Update on "New multi-step QAT API"
**Summary:** This commit adds a new multi-step QAT API with the main goal of simplifying the existing UX. The new API uses the same `QATConfig` for both the prepare and convert steps, and automatically infers the fake quantization configs based on a PTQ base config provided by the user: ```Py from torchao.quantization import ( quantize_, Int8DynamicActivationInt4WeightConfig ) from torchao.quantization.qat import QATConfig \# prepare base_config = Int8DynamicActivationInt4WeightConfig(group_size=32) qat_config = QATConfig(base_config, step="prepare") quantize_(m, qat_config) \# train (not shown) \# convert quantize_(m, QATConfig(base_config, step="convert")) ``` The main improvements include: - A single config for both prepare and convert steps - A single quantize_ for convert (instead of 2) - No chance for incompatible prepare vs convert configs - Much less boilerplate code for most common use case - Simpler config names For less common use cases such as experimentation, users can still specify arbitrary fake quantization configs for activations and/or weights as before. This is still important since there may not always be a corresponding PTQ base config. For example: ```Py from torchao.quantization import quantize_ from torchao.quantization.qat import IntxFakeQuantizeConfig, QATConfig activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) qat_config = QATConfig( activation_config=activation_config, weight_config=weight_config, step="prepare", ) quantize_(model, qat_config) \# train and convert same as above (not shown) ``` **BC-breaking notes:** This change by itself is technically not BC-breaking since we keep around the old path, but will become so when we deprecate and remove the old path in the future. Before: ```Py \# prepare activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config), quantize_(model, qat_config) \# train (not shown) \# convert quantize_(model, FromIntXQuantizationAwareTrainingConfig()) quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32)) ``` After: (see above) **Test Plan:** ``` python test/quantization/test_qat.py ``` [ghstack-poisoned]
2 parents 81096ae + 97b090d commit add867a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+2744
-175
lines changed

.github/workflows/torchao_experimental_test.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ jobs:
5353
pytest torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py
5454
python torchao/experimental/tests/test_embedding_xbit_quantizer.py
5555
python torchao/experimental/tests/test_quant_passes.py
56+
pytest -s test/prototype/test_dynamic_activation_lut.py
5657
- name: Run kernels/cpu/aarch64/tests
5758
if: runner.os == 'macOS'
5859
run: |
@@ -106,7 +107,7 @@ jobs:
106107
# conda run -n test-mps-ops-env pip install torch --index-url "https://download.pytorch.org/whl/nightly/cpu"
107108
# - name: Print torch version
108109
# run: |
109-
110+
110111
# conda run -n test-mps-ops-env python -c "import torch; print(torch.__version__)"
111112
# - name: Install requirements
112113
# run: |

CITATION.cff

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@ message: "If you use this software, please cite it as below."
44
type: software
55
authors:
66
- given-names: "torchao maintainers and contributors"
7-
url: "https//github.com/pytorch/torchao"
7+
url: "https//github.com/pytorch/ao"
88
license: "BSD-3-Clause"
99
date-released: "2024-10-25"

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ If you find the torchao library useful, please cite it in your work as below.
278278
@software{torchao,
279279
title={TorchAO: PyTorch-Native Training-to-Serving Model Optimization},
280280
author={torchao},
281-
url={https://github.com/pytorch/torchao},
281+
url={https://github.com/pytorch/ao},
282282
license={BSD-3-Clause},
283283
month={oct},
284284
year={2024}

benchmarks/benchmark_blockwise_scaled_linear_triton.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from triton.testing import do_bench
1414

1515
from torchao.float8.float8_utils import compute_error
16-
from torchao.prototype.blockwise_fp8.blockwise_quantization import (
16+
from torchao.prototype.blockwise_fp8_inference.blockwise_quantization import (
1717
blockwise_fp8_gemm,
1818
fp8_blockwise_act_quant,
1919
fp8_blockwise_weight_quant,

benchmarks/float8/bench_grouped_mm.py

Lines changed: 1 addition & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6-
import random
76
from typing import Optional
87

98
import fire
109
import pandas as pd
1110
import torch
1211
from utils import do_benchmarks, get_name_to_moe_shapes_iter
1312

13+
from torchao.prototype.moe_training.utils import generate_jagged_offs
1414
from torchao.testing.training.roofline_utils import get_specs
1515

1616

@@ -146,39 +146,6 @@ def do_scaled_grouped_mm(A, B):
146146
data_df.to_csv(out_filename)
147147

148148

149-
def generate_jagged_offs(E, M, dtype=torch.int32, device="cuda"):
150-
"""
151-
Generates a tensor of length E, containing random values divisible by 16,
152-
from 0 to M, in sorted order, and where the final value in the tensor is always M.
153-
Args:
154-
E (int): The length of the tensor.
155-
M (int): The maximum value in the tensor.
156-
Returns:
157-
torch.Tensor: A tensor of length E with the specified properties.
158-
"""
159-
# Ensure M is divisible by 16
160-
if M % 16 != 0:
161-
raise ValueError("M must be divisible by 16")
162-
163-
# Generate a list of possible values
164-
possible_values = [i for i in range(0, M + 1, 16)]
165-
166-
# If E is larger than the number of possible values, raise an error
167-
if E > len(possible_values):
168-
raise ValueError("E cannot be larger than the number of possible values")
169-
170-
# Randomly select E - 1 values from the possible values (excluding M)
171-
selected_values = torch.tensor(random.sample(possible_values[:-1], E - 1))
172-
173-
# Append M to the selected values
174-
selected_values = torch.cat((selected_values, torch.tensor([M])))
175-
176-
# Sort the selected values
177-
selected_values, _ = torch.sort(selected_values)
178-
179-
return selected_values.to(dtype).to(device)
180-
181-
182149
def main() -> None:
183150
fire.Fire(run)
184151

docs/source/tutorials_source/pt2e_quant_openvino_inductor.rst

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ OpenVINO and NNCF could be easily installed via `pip distribution <https://docs.
7474
.. code-block:: bash
7575
7676
pip install -U pip
77-
pip install openvino, nncf
77+
pip install openvino nncf
7878
7979
8080
1. Capture FX Graph
@@ -84,7 +84,6 @@ We will start by performing the necessary imports, capturing the FX Graph from t
8484

8585
.. code-block:: python
8686
87-
import copy
8887
import openvino.torch
8988
import torch
9089
import torchvision.models as models
@@ -106,7 +105,7 @@ We will start by performing the necessary imports, capturing the FX Graph from t
106105
example_inputs = (x,)
107106
108107
# Capture the FX Graph to be quantized
109-
with torch.no_grad(), nncf.torch.disable_patching():
108+
with torch.no_grad():
110109
exported_model = torch.export.export(model, example_inputs).module()
111110
112111
@@ -204,7 +203,7 @@ After that the FX Graph can utilize OpenVINO optimizations using `torch.compile(
204203

205204
.. code-block:: python
206205
207-
with torch.no_grad(), nncf.torch.disable_patching():
206+
with torch.no_grad():
208207
optimized_model = torch.compile(quantized_model, backend="openvino")
209208
210209
# Running some benchmark
@@ -235,6 +234,10 @@ These advanced NNCF algorithms can be accessed via the NNCF `quantize_pt2e` API:
235234
236235
237236
calibration_dataset = nncf.Dataset(calibration_loader, transform_fn)
237+
238+
with torch.no_grad():
239+
exported_model = torch.export.export(model, example_inputs).module()
240+
238241
quantized_model = quantize_pt2e(
239242
exported_model, quantizer, calibration_dataset, smooth_quant=True, fast_bias_correction=False
240243
)

test/quantization/test_config_serialization.py renamed to test/core/test_config.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,5 +187,19 @@ def test_version_mismatch():
187187
config_from_dict(reconstructable)
188188

189189

190+
def test_default_version():
191+
"""Making sure the default version for a new config inheriting from AOBaseConfig is always 1
192+
because it's the default VERSION that all children has when they haven't explicitly
193+
defined a VERSION class variable
194+
"""
195+
196+
@dataclass
197+
class DummyConfig(AOBaseConfig):
198+
pass
199+
200+
config = DummyConfig()
201+
assert config.VERSION == 1, "Default version must be 1"
202+
203+
190204
if __name__ == "__main__":
191205
pytest.main([__file__])

test/float8/test_base.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import random
99
import re
1010
import unittest
11-
import warnings
1211

1312
import pytest
1413
import torch
@@ -381,6 +380,9 @@ def test_linear_from_config_params(
381380
"linear_dtype", [torch.bfloat16, torch.float16, torch.float32]
382381
)
383382
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
383+
@unittest.skipIf(
384+
torch.cuda.is_available() and not is_sm_at_least_90(), "CUDA capability < 9.0"
385+
)
384386
@skip_if_rocm("ROCm enablement in progress")
385387
def test_linear_from_recipe(
386388
self,
@@ -389,12 +391,6 @@ def test_linear_from_recipe(
389391
linear_dtype: torch.dtype,
390392
linear_bias: bool,
391393
):
392-
if torch.cuda.get_device_capability() < (9, 0):
393-
warnings.warn(
394-
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
395-
)
396-
pytest.skip()
397-
398394
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
399395
m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype)
400396
config = Float8LinearConfig.from_recipe_name(recipe_name)

0 commit comments

Comments
 (0)