Skip to content

Commit f002586

Browse files
committed
Add float8 FakeQuantizeConfig and FakeQuantizer
**Summary:** This commit adds a QAT path for float8, using the same primitives as `torchao.quantization.Float8Tensor` targeting PTQ configs like: - `Float8DynamicActivationFloat8WeightConfig` - `Float8DynamicActivationInt4WeightConfig` - `Float8WeightOnlyConfig` Usage: ``` from torchao.quantization.granularity import PerRow from torchao.quantization.qat import quantize_, QATConfig base_config = Float8DynamicActivationFloat8WeightConfig( torch.float8_e4m3fn, PerRow(), ) quantize_(model, QATConfig(base_config, step="prepare")) quantize_(model, QATConfig(base_config, step="convert")) ``` OR ``` from torchao.quantization.granularity import PerRow from torchao.quantization.qat import ( Float8FakeQuantizeConfig, QATConfig, quantize_, ) dtype = torch.float8_e4m3fn granularity = PerRow() quantize_(model, QATConfig( activation_config=Float8FakeQuantizeConfig(dtype, granularity), weight_config=Float8FakeQuantizeConfig(dtype, granularity), step="prepare", ) # convert (same as above, not shown) ``` **Test Plan:** ``` python test/quantization/test_qat.py -k test_float8_fake_quantize_config python test/quantization/test_qat.py -k test_float8_fake_quantize python test/quantization/test_qat.py -k test_quantize_api_fp8_fp8 python test/quantization/test_qat.py -k test_quantize_api_fp8_int4 ```
1 parent f01c956 commit f002586

File tree

14 files changed

+229
-109
lines changed

14 files changed

+229
-109
lines changed

docs/source/api_ref_qat.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,12 @@ Custom QAT APIs
2626

2727
FakeQuantizeConfigBase
2828
IntxFakeQuantizeConfig
29+
Float8FakeQuantizeConfig
2930
FakeQuantizedLinear
3031
FakeQuantizedEmbedding
3132
FakeQuantizerBase
3233
IntxFakeQuantizer
34+
Float8FakeQuantizer
3335
linear.enable_linear_fake_quant
3436
linear.disable_linear_fake_quant
3537

test/quantization/pt2e/test_duplicate_dq.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from typing import Any
1212

1313
import torch
14-
from torch.export import export_for_training
1514
from torch.testing._internal.common_quantization import QuantizationTestCase
1615
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests
1716

test/quantization/pt2e/test_quantize_pt2e.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
per_channel_weight_observer_range_neg_127_to_127,
2020
weight_observer_range_neg_127_to_127,
2121
)
22-
from torch.export import export_for_training
2322
from torch.fx import Node
2423
from torch.testing._internal.common_quantization import (
2524
NodeSpec as ns,

test/quantization/pt2e/test_quantize_pt2e_qat.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
default_symmetric_qnnpack_qat_qconfig,
1919
)
2020
from torch.ao.quantization.quantize_fx import prepare_qat_fx
21-
from torch.export import export_for_training
2221
from torch.testing._internal.common_cuda import TEST_CUDA
2322
from torch.testing._internal.common_quantization import (
2423
NodeSpec as ns,

test/quantization/pt2e/test_representation.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
import torch
1313
from torch._higher_order_ops.out_dtype import out_dtype # noqa: F401
14-
from torch.export import export_for_training
1514
from torch.testing._internal.common_quantization import (
1615
NodeSpec as ns,
1716
)

test/quantization/pt2e/test_x86inductor_quantizer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
import torch
1414
import torch.nn as nn
15-
from torch.export import export_for_training
1615
from torch.testing._internal.common_quantization import (
1716
NodeSpec as ns,
1817
)

test/quantization/test_qat.py

Lines changed: 99 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,21 @@
1414

1515
import torch
1616
import torch.nn.functional as F
17-
from parameterized import parameterized
1817
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
18+
from torch.testing._internal.common_utils import (
19+
TestCase,
20+
instantiate_parametrized_tests,
21+
parametrize,
22+
)
1923

2024
from torchao import quantize_
21-
from torchao.float8.config import ScalingGranularity
22-
from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic
23-
from torchao.float8.float8_training_tensor import LinearMMConfig
25+
from torchao.quantization import Float8Tensor
2426
from torchao.quantization.granularity import (
27+
Granularity,
2528
PerAxis,
2629
PerGroup,
2730
PerRow,
31+
PerTensor,
2832
PerToken,
2933
)
3034
from torchao.quantization.linear_quant_modules import (
@@ -43,11 +47,12 @@
4347
FakeQuantizedEmbedding,
4448
)
4549
from torchao.quantization.qat.fake_quantize_config import (
50+
Float8FakeQuantizeConfig,
4651
IntxFakeQuantizeConfig,
4752
)
4853
from torchao.quantization.qat.fake_quantizer import (
54+
Float8FakeQuantizer,
4955
IntxFakeQuantizer,
50-
_Float8RowwiseActivationFakeQuantizer,
5156
)
5257
from torchao.quantization.qat.linear import (
5358
FakeQuantizedLinear,
@@ -58,10 +63,11 @@
5863
from torchao.quantization.qat.utils import (
5964
_fake_quantize_per_channel_group,
6065
_fake_quantize_per_token,
61-
_Float8RowwiseFakeQuantize,
6266
_get_qmin_qmax,
6367
)
6468
from torchao.quantization.quant_api import (
69+
Float8DynamicActivationFloat8WeightConfig,
70+
Float8DynamicActivationInt4WeightConfig,
6571
Int8DynamicActivationInt4WeightConfig,
6672
)
6773
from torchao.quantization.quant_primitives import (
@@ -83,6 +89,10 @@
8389
get_groupwise_affine_qparams,
8490
groupwise_affine_quantize_tensor,
8591
)
92+
from torchao.utils import (
93+
_is_fbgemm_genai_gpu_available,
94+
is_sm_at_least_89,
95+
)
8696

8797
# TODO: put this in a common test utils file
8898
_CUDA_IS_AVAILABLE = torch.cuda.is_available()
@@ -193,7 +203,7 @@ def forward(self, x):
193203
return x
194204

195205

196-
class TestQAT(unittest.TestCase):
206+
class TestQAT(TestCase):
197207
SEED = 123
198208

199209
def test_fake_quantize_per_channel_group(self):
@@ -1420,7 +1430,7 @@ def test_qat_linear_bias(self):
14201430
example_inputs = m.example_inputs()
14211431
m(*example_inputs)
14221432

1423-
@parameterized.expand([(torch.float32,), (torch.bfloat16,), (torch.float16,)])
1433+
@parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
14241434
def test_fake_quantize_per_token_vs_convert(self, dtype: torch.dtype):
14251435
"""
14261436
Test that the following produce the exact same numerics:
@@ -1437,7 +1447,7 @@ def test_fake_quantize_per_token_vs_convert(self, dtype: torch.dtype):
14371447
baseline_out = per_token_dynamic_quant(x)
14381448
torch.testing.assert_close(fake_quantizer_out, baseline_out, atol=0, rtol=0)
14391449

1440-
@parameterized.expand([(torch.float32,), (torch.bfloat16,), (torch.float16,)])
1450+
@parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
14411451
def test_qat_8da4w_prepare_vs_convert(self, dtype: torch.dtype):
14421452
"""
14431453
Test that the prepare and convert steps of Int8DynActInt4QATQuantizer produces
@@ -1548,7 +1558,7 @@ def test_qat_8da4w_eps(self):
15481558
actual_out = converted_model.linear1(x)
15491559
torch.testing.assert_close(expected_out, actual_out, atol=0, rtol=0)
15501560

1551-
@parameterized.expand([(True,), (False,)])
1561+
@parametrize("is_symmetric", [True, False])
15521562
def test_fake_quantizer_range_learning(self, is_symmetric):
15531563
"""
15541564
Test that range learning requires `IntxFakeQuantizer`s to be initialized correctly.
@@ -1589,7 +1599,7 @@ def test_fake_quantizer_range_learning(self, is_symmetric):
15891599
self.assertTrue(fake_quantizer.zero_point.requires_grad)
15901600
fake_quantizer(*example_inputs)
15911601

1592-
@parameterized.expand([(True,), (False,)])
1602+
@parametrize("is_symmetric", [True, False])
15931603
def test_qat_range_learning(self, is_symmetric):
15941604
"""
15951605
Test end-to-end QAT flow with range learning.
@@ -1664,24 +1674,6 @@ def test_qat_range_learning(self, is_symmetric):
16641674
self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0)
16651675
self.assertFalse(torch.equal(new_weight, prev_weight))
16661676

1667-
def test_float8_rowwise_fake_quantize(self):
1668-
"""
1669-
Test that `_Float8RowwiseFakeQuantize` is numerically close to `Float8TrainingTensor`.
1670-
"""
1671-
torch.manual_seed(self.SEED)
1672-
dtype = torch.float8_e4m3fn
1673-
x = torch.randn(32, 64)
1674-
axiswise_dim = 0
1675-
out = _Float8RowwiseFakeQuantize.apply(x, dtype, axiswise_dim)
1676-
out_expected = hp_tensor_to_float8_dynamic(
1677-
x,
1678-
dtype,
1679-
LinearMMConfig(),
1680-
scaling_granularity=ScalingGranularity.AXISWISE,
1681-
axiswise_dim=axiswise_dim,
1682-
).to_original_precision()
1683-
torch.testing.assert_close(out, out_expected, atol=0, rtol=0)
1684-
16851677
def test_qat_fp8a4w_quantizer(self):
16861678
"""
16871679
Test basic model training with `Float8ActInt4WeightQATQuantizer`.
@@ -1693,7 +1685,8 @@ def test_qat_fp8a4w_quantizer(self):
16931685
for linear in [m.linear1, m.sub.linear, m.linear2]:
16941686
self.assertIsInstance(linear, FakeQuantizedLinear)
16951687
self.assertIsInstance(
1696-
linear.activation_fake_quantizer, _Float8RowwiseActivationFakeQuantizer
1688+
linear.activation_fake_quantizer,
1689+
Float8FakeQuantizer,
16971690
)
16981691
self.assertIsInstance(linear.weight_fake_quantizer, IntxFakeQuantizer)
16991692
prev_weight = copy.deepcopy(m.linear1.weight)
@@ -1805,9 +1798,6 @@ def test_qat_api_deprecation(self):
18051798
str(w.message),
18061799
)
18071800

1808-
@unittest.skipIf(
1809-
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1810-
)
18111801
def test_qat_api_convert_no_quantization(self):
18121802
"""
18131803
Test that `QATConfig(step="convert")` swaps back to nn modules without quantization.
@@ -1836,6 +1826,82 @@ def test_qat_api_convert_no_quantization(self):
18361826
baseline_out = baseline_model(*x2)
18371827
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)
18381828

1829+
def test_float8_fake_quantize_config(self):
1830+
"""
1831+
Test that the correct errors are thrown if `Float8FakeQuantizeConfig` is not instantiated properly.
1832+
"""
1833+
# OK
1834+
Float8FakeQuantizeConfig(torch.float8_e4m3fn)
1835+
Float8FakeQuantizeConfig(torch.float8_e4m3fn, PerRow())
1836+
Float8FakeQuantizeConfig(torch.float8_e4m3fn, PerTensor())
1837+
1838+
with self.assertRaisesRegex(ValueError, "not a float8 dtype"):
1839+
Float8FakeQuantizeConfig(torch.int8)
1840+
with self.assertRaisesRegex(
1841+
ValueError, "Please specify the granularity object instead of the class"
1842+
):
1843+
Float8FakeQuantizeConfig(granularity=PerRow)
1844+
with self.assertRaisesRegex(
1845+
ValueError, "Expected PerRow or PerTensor granularity"
1846+
):
1847+
Float8FakeQuantizeConfig(granularity=PerToken())
1848+
1849+
@parametrize("granularity", [PerTensor(), PerRow()])
1850+
def test_float8_fake_quantize(self, granularity: Granularity):
1851+
"""
1852+
Test that `Float8FakeQuantizer` is numerically close to `Float8Tensor`.
1853+
"""
1854+
dtype = torch.float8_e4m3fn
1855+
fq_config = Float8FakeQuantizeConfig(dtype, granularity)
1856+
fake_quantizer = Float8FakeQuantizer(fq_config)
1857+
torch.manual_seed(self.SEED)
1858+
x = torch.randn(32, 64)
1859+
out = fake_quantizer(x)
1860+
out_expected = Float8Tensor.to_float8(x, dtype, granularity).dequantize()
1861+
sqnr = compute_error(out, out_expected)
1862+
self.assertGreater(sqnr, 16)
1863+
1864+
@parametrize("granularity", [PerTensor(), PerRow()])
1865+
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
1866+
@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
1867+
def test_quantize_api_fp8_fp8(self, granularity: Granularity):
1868+
"""
1869+
Test the following:
1870+
quantize_(model, QATConfig(Float8DynamicActivationFloat8Weight(), step="prepare"))
1871+
quantize_(model, QATConfig(Float8DynamicActivationFloat8Weight(), step="convert"))
1872+
"""
1873+
torch.manual_seed(self.SEED)
1874+
m = M().to(torch.bfloat16).cuda()
1875+
example_inputs = (m.example_inputs()[0].to(torch.bfloat16).cuda(),)
1876+
base_config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
1877+
quantize_(m, QATConfig(base_config, step="prepare"))
1878+
m(*example_inputs)
1879+
quantize_(m, QATConfig(base_config, step="convert"))
1880+
m(*example_inputs)
1881+
1882+
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
1883+
@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
1884+
@unittest.skipIf(
1885+
not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0"
1886+
)
1887+
def test_quantize_api_fp8_int4(self):
1888+
"""
1889+
Test the following:
1890+
quantize_(model, QATConfig(Float8DynamicActivationInt4WeightConfig(), step="prepare"))
1891+
quantize_(model, QATConfig(Float8DynamicActivationInt4WeightConfig(), step="convert"))
1892+
"""
1893+
torch.manual_seed(self.SEED)
1894+
m = M().to(torch.bfloat16).cuda()
1895+
example_inputs = (m.example_inputs()[0].to(torch.bfloat16).cuda(),)
1896+
base_config = Float8DynamicActivationInt4WeightConfig(group_size=128)
1897+
quantize_(m, QATConfig(base_config, step="prepare"))
1898+
m(*example_inputs)
1899+
quantize_(m, QATConfig(base_config, step="convert"))
1900+
m(*example_inputs)
1901+
1902+
1903+
instantiate_parametrized_tests(TestQAT)
1904+
18391905

18401906
if __name__ == "__main__":
18411907
unittest.main()

torchao/quantization/qat/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515
from .fake_quantize_config import (
1616
FakeQuantizeConfig,
1717
FakeQuantizeConfigBase,
18+
Float8FakeQuantizeConfig,
1819
IntxFakeQuantizeConfig,
1920
)
2021
from .fake_quantizer import (
2122
FakeQuantizer,
2223
FakeQuantizerBase,
24+
Float8FakeQuantizer,
2325
IntxFakeQuantizer,
2426
)
2527
from .linear import (
@@ -34,6 +36,8 @@
3436
"QATStep",
3537
"FakeQuantizeConfigBase",
3638
"FakeQuantizerBase",
39+
"Float8FakeQuantizeConfig",
40+
"Float8FakeQuantizer",
3741
"IntxFakeQuantizeConfig",
3842
"IntxFakeQuantizer",
3943
"FakeQuantizedLinear",

0 commit comments

Comments
 (0)