Skip to content

Commit e1823c6

Browse files
committed
Add float8 FakeQuantizeConfig and FakeQuantizer
**Summary:** This commit adds a QAT path for float8, using the same primitives as `torchao.quantization.Float8Tensor` targeting PTQ configs like: - `Float8DynamicActivationFloat8WeightConfig` - `Float8DynamicActivationInt4WeightConfig` - `Float8WeightOnlyConfig` Usage: ``` from torchao.quantization.granularity import PerRow from torchao.quantization.qat import quantize_, QATConfig base_config = Float8DynamicActivationFloat8WeightConfig( torch.float8_e4m3fn, PerRow(), ) quantize_(model, QATConfig(base_config, step="prepare")) quantize_(model, QATConfig(base_config, step="convert")) ``` OR ``` from torchao.quantization.granularity import PerRow from torchao.quantization.qat import ( Float8FakeQuantizeConfig, QATConfig, quantize_, ) dtype = torch.float8_e4m3fn granularity = PerRow() quantize_(model, QATConfig( activation_config=Float8FakeQuantizeConfig(dtype, granularity), weight_config=Float8FakeQuantizeConfig(dtype, granularity), step="prepare", ) # convert (same as above, not shown) ``` **Test Plan:** ``` python test/quantization/test_qat.py -k test_float8_fake_quantize_config python test/quantization/test_qat.py -k test_float8_fake_quantize python test/quantization/test_qat.py -k test_quantize_api_fp8_fp8 python test/quantization/test_qat.py -k test_quantize_api_fp8_int4 ```
1 parent 46ba24c commit e1823c6

File tree

8 files changed

+229
-100
lines changed

8 files changed

+229
-100
lines changed

docs/source/api_ref_qat.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,12 @@ Custom QAT APIs
2626

2727
FakeQuantizeConfigBase
2828
IntxFakeQuantizeConfig
29+
Float8FakeQuantizeConfig
2930
FakeQuantizedLinear
3031
FakeQuantizedEmbedding
3132
FakeQuantizerBase
3233
IntxFakeQuantizer
34+
Float8FakeQuantizer
3335
linear.enable_linear_fake_quant
3436
linear.disable_linear_fake_quant
3537

test/quantization/test_qat.py

Lines changed: 99 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,21 @@
1414

1515
import torch
1616
import torch.nn.functional as F
17-
from parameterized import parameterized
1817
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
18+
from torch.testing._internal.common_utils import (
19+
TestCase,
20+
instantiate_parametrized_tests,
21+
parametrize,
22+
)
1923

2024
from torchao import quantize_
21-
from torchao.float8.config import ScalingGranularity
22-
from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic
23-
from torchao.float8.float8_training_tensor import LinearMMConfig
25+
from torchao.quantization import Float8Tensor
2426
from torchao.quantization.granularity import (
27+
Granularity,
2528
PerAxis,
2629
PerGroup,
2730
PerRow,
31+
PerTensor,
2832
PerToken,
2933
)
3034
from torchao.quantization.linear_quant_modules import (
@@ -43,11 +47,12 @@
4347
FakeQuantizedEmbedding,
4448
)
4549
from torchao.quantization.qat.fake_quantize_config import (
50+
Float8FakeQuantizeConfig,
4651
IntxFakeQuantizeConfig,
4752
)
4853
from torchao.quantization.qat.fake_quantizer import (
54+
Float8FakeQuantizer,
4955
IntxFakeQuantizer,
50-
_Float8RowwiseActivationFakeQuantizer,
5156
)
5257
from torchao.quantization.qat.linear import (
5358
FakeQuantizedLinear,
@@ -58,10 +63,11 @@
5863
from torchao.quantization.qat.utils import (
5964
_fake_quantize_per_channel_group,
6065
_fake_quantize_per_token,
61-
_Float8RowwiseFakeQuantize,
6266
_get_qmin_qmax,
6367
)
6468
from torchao.quantization.quant_api import (
69+
Float8DynamicActivationFloat8WeightConfig,
70+
Float8DynamicActivationInt4WeightConfig,
6571
Int8DynamicActivationInt4WeightConfig,
6672
)
6773
from torchao.quantization.quant_primitives import (
@@ -83,6 +89,10 @@
8389
get_groupwise_affine_qparams,
8490
groupwise_affine_quantize_tensor,
8591
)
92+
from torchao.utils import (
93+
_is_fbgemm_genai_gpu_available,
94+
is_sm_at_least_89,
95+
)
8696

8797
# TODO: put this in a common test utils file
8898
_CUDA_IS_AVAILABLE = torch.cuda.is_available()
@@ -193,7 +203,7 @@ def forward(self, x):
193203
return x
194204

195205

196-
class TestQAT(unittest.TestCase):
206+
class TestQAT(TestCase):
197207
SEED = 123
198208

199209
def test_fake_quantize_per_channel_group(self):
@@ -1420,7 +1430,7 @@ def test_qat_linear_bias(self):
14201430
example_inputs = m.example_inputs()
14211431
m(*example_inputs)
14221432

1423-
@parameterized.expand([(torch.float32,), (torch.bfloat16,), (torch.float16,)])
1433+
@parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
14241434
def test_fake_quantize_per_token_vs_convert(self, dtype: torch.dtype):
14251435
"""
14261436
Test that the following produce the exact same numerics:
@@ -1437,7 +1447,7 @@ def test_fake_quantize_per_token_vs_convert(self, dtype: torch.dtype):
14371447
baseline_out = per_token_dynamic_quant(x)
14381448
torch.testing.assert_close(fake_quantizer_out, baseline_out, atol=0, rtol=0)
14391449

1440-
@parameterized.expand([(torch.float32,), (torch.bfloat16,), (torch.float16,)])
1450+
@parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
14411451
def test_qat_8da4w_prepare_vs_convert(self, dtype: torch.dtype):
14421452
"""
14431453
Test that the prepare and convert steps of Int8DynActInt4QATQuantizer produces
@@ -1548,7 +1558,7 @@ def test_qat_8da4w_eps(self):
15481558
actual_out = converted_model.linear1(x)
15491559
torch.testing.assert_close(expected_out, actual_out, atol=0, rtol=0)
15501560

1551-
@parameterized.expand([(True,), (False,)])
1561+
@parametrize("is_symmetric", [True, False])
15521562
def test_fake_quantizer_range_learning(self, is_symmetric):
15531563
"""
15541564
Test that range learning requires `IntxFakeQuantizer`s to be initialized correctly.
@@ -1589,7 +1599,7 @@ def test_fake_quantizer_range_learning(self, is_symmetric):
15891599
self.assertTrue(fake_quantizer.zero_point.requires_grad)
15901600
fake_quantizer(*example_inputs)
15911601

1592-
@parameterized.expand([(True,), (False,)])
1602+
@parametrize("is_symmetric", [True, False])
15931603
def test_qat_range_learning(self, is_symmetric):
15941604
"""
15951605
Test end-to-end QAT flow with range learning.
@@ -1664,24 +1674,6 @@ def test_qat_range_learning(self, is_symmetric):
16641674
self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0)
16651675
self.assertFalse(torch.equal(new_weight, prev_weight))
16661676

1667-
def test_float8_rowwise_fake_quantize(self):
1668-
"""
1669-
Test that `_Float8RowwiseFakeQuantize` is numerically close to `Float8TrainingTensor`.
1670-
"""
1671-
torch.manual_seed(self.SEED)
1672-
dtype = torch.float8_e4m3fn
1673-
x = torch.randn(32, 64)
1674-
axiswise_dim = 0
1675-
out = _Float8RowwiseFakeQuantize.apply(x, dtype, axiswise_dim)
1676-
out_expected = hp_tensor_to_float8_dynamic(
1677-
x,
1678-
dtype,
1679-
LinearMMConfig(),
1680-
scaling_granularity=ScalingGranularity.AXISWISE,
1681-
axiswise_dim=axiswise_dim,
1682-
).to_original_precision()
1683-
torch.testing.assert_close(out, out_expected, atol=0, rtol=0)
1684-
16851677
def test_qat_fp8a4w_quantizer(self):
16861678
"""
16871679
Test basic model training with `Float8ActInt4WeightQATQuantizer`.
@@ -1693,7 +1685,8 @@ def test_qat_fp8a4w_quantizer(self):
16931685
for linear in [m.linear1, m.sub.linear, m.linear2]:
16941686
self.assertIsInstance(linear, FakeQuantizedLinear)
16951687
self.assertIsInstance(
1696-
linear.activation_fake_quantizer, _Float8RowwiseActivationFakeQuantizer
1688+
linear.activation_fake_quantizer,
1689+
Float8FakeQuantizer,
16971690
)
16981691
self.assertIsInstance(linear.weight_fake_quantizer, IntxFakeQuantizer)
16991692
prev_weight = copy.deepcopy(m.linear1.weight)
@@ -1833,6 +1826,82 @@ def test_qat_api_convert_no_quantization(self):
18331826
baseline_out = baseline_model(*x2)
18341827
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)
18351828

1829+
def test_float8_fake_quantize_config(self):
1830+
"""
1831+
Test that the correct errors are thrown if `Float8FakeQuantizeConfig` is not instantiated properly.
1832+
"""
1833+
# OK
1834+
Float8FakeQuantizeConfig(torch.float8_e4m3fn)
1835+
Float8FakeQuantizeConfig(torch.float8_e4m3fn, PerRow())
1836+
Float8FakeQuantizeConfig(torch.float8_e4m3fn, PerTensor())
1837+
1838+
with self.assertRaisesRegex(ValueError, "not a float8 dtype"):
1839+
Float8FakeQuantizeConfig(torch.int8)
1840+
with self.assertRaisesRegex(
1841+
ValueError, "Please specify the granularity object instead of the class"
1842+
):
1843+
Float8FakeQuantizeConfig(granularity=PerRow)
1844+
with self.assertRaisesRegex(
1845+
ValueError, "Expected PerRow or PerTensor granularity"
1846+
):
1847+
Float8FakeQuantizeConfig(granularity=PerToken())
1848+
1849+
@parametrize("granularity", [PerTensor(), PerRow()])
1850+
def test_float8_fake_quantize(self, granularity: Granularity):
1851+
"""
1852+
Test that `Float8FakeQuantizer` is numerically close to `Float8Tensor`.
1853+
"""
1854+
dtype = torch.float8_e4m3fn
1855+
fq_config = Float8FakeQuantizeConfig(dtype, granularity)
1856+
fake_quantizer = Float8FakeQuantizer(fq_config)
1857+
torch.manual_seed(self.SEED)
1858+
x = torch.randn(32, 64)
1859+
out = fake_quantizer(x)
1860+
out_expected = Float8Tensor.to_float8(x, dtype, granularity).dequantize()
1861+
sqnr = compute_error(out, out_expected)
1862+
self.assertGreater(sqnr, 16)
1863+
1864+
@parametrize("granularity", [PerTensor(), PerRow()])
1865+
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
1866+
@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
1867+
def test_quantize_api_fp8_fp8(self, granularity: Granularity):
1868+
"""
1869+
Test the following:
1870+
quantize_(model, QATConfig(Float8DynamicActivationFloat8Weight(), step="prepare"))
1871+
quantize_(model, QATConfig(Float8DynamicActivationFloat8Weight(), step="convert"))
1872+
"""
1873+
torch.manual_seed(self.SEED)
1874+
m = M().to(torch.bfloat16).cuda()
1875+
example_inputs = (m.example_inputs()[0].to(torch.bfloat16).cuda(),)
1876+
base_config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
1877+
quantize_(m, QATConfig(base_config, step="prepare"))
1878+
m(*example_inputs)
1879+
quantize_(m, QATConfig(base_config, step="convert"))
1880+
m(*example_inputs)
1881+
1882+
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
1883+
@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
1884+
@unittest.skipIf(
1885+
not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0"
1886+
)
1887+
def test_quantize_api_fp8_int4(self):
1888+
"""
1889+
Test the following:
1890+
quantize_(model, QATConfig(Float8DynamicActivationInt4WeightConfig(), step="prepare"))
1891+
quantize_(model, QATConfig(Float8DynamicActivationInt4WeightConfig(), step="convert"))
1892+
"""
1893+
torch.manual_seed(self.SEED)
1894+
m = M().to(torch.bfloat16).cuda()
1895+
example_inputs = (m.example_inputs()[0].to(torch.bfloat16).cuda(),)
1896+
base_config = Float8DynamicActivationInt4WeightConfig(group_size=128)
1897+
quantize_(m, QATConfig(base_config, step="prepare"))
1898+
m(*example_inputs)
1899+
quantize_(m, QATConfig(base_config, step="convert"))
1900+
m(*example_inputs)
1901+
1902+
1903+
instantiate_parametrized_tests(TestQAT)
1904+
18361905

18371906
if __name__ == "__main__":
18381907
unittest.main()

torchao/quantization/qat/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515
from .fake_quantize_config import (
1616
FakeQuantizeConfig,
1717
FakeQuantizeConfigBase,
18+
Float8FakeQuantizeConfig,
1819
IntxFakeQuantizeConfig,
1920
)
2021
from .fake_quantizer import (
2122
FakeQuantizer,
2223
FakeQuantizerBase,
24+
Float8FakeQuantizer,
2325
IntxFakeQuantizer,
2426
)
2527
from .linear import (
@@ -34,6 +36,8 @@
3436
"QATStep",
3537
"FakeQuantizeConfigBase",
3638
"FakeQuantizerBase",
39+
"Float8FakeQuantizeConfig",
40+
"Float8FakeQuantizer",
3741
"IntxFakeQuantizeConfig",
3842
"IntxFakeQuantizer",
3943
"FakeQuantizedLinear",

torchao/quantization/qat/fake_quantize_config.py

Lines changed: 77 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,17 @@
1111
import torch
1212

1313
from torchao.core.config import AOBaseConfig
14+
from torchao.float8.config import e4m3_dtype
15+
from torchao.float8.inference import (
16+
FP8Granularity,
17+
_normalize_granularity,
18+
)
1419
from torchao.quantization.granularity import (
1520
Granularity,
1621
PerAxis,
1722
PerGroup,
23+
PerRow,
24+
PerTensor,
1825
PerToken,
1926
)
2027
from torchao.quantization.quant_primitives import (
@@ -24,6 +31,7 @@
2431
TorchAODType,
2532
ZeroPointDomain,
2633
)
34+
from torchao.utils import _is_float8_type
2735

2836
from .utils import _log_deprecation_warning
2937

@@ -36,6 +44,39 @@ class FakeQuantizeConfigBase(abc.ABC):
3644
pass
3745

3846

47+
@dataclass
48+
class Float8FakeQuantizeConfig(FakeQuantizeConfigBase):
49+
"""
50+
Config for float8 fake quantization, targeting :class:`~torchao.quantization.Float8Tensor`.
51+
52+
Args:
53+
dtype (torch.dtype): the dtype for float8 Tensor
54+
granularity (FP8Granularity): the granularity for the Tensor, currently either PerRow() or PerTensor()
55+
hp_value_lb (Optional[float]): the lower bound for high precision floating point value for calculating scale
56+
hp_value_ub (Optional[float]): the upper bound for high precision floating point value for calculating scale
57+
"""
58+
59+
dtype: torch.dtype = e4m3_dtype
60+
granularity: FP8Granularity = PerRow()
61+
hp_value_lb: Optional[float] = None
62+
hp_value_ub: Optional[float] = None
63+
64+
def __post_init__(self):
65+
"""
66+
Verify dtype and granularity are the ones we support.
67+
"""
68+
if not _is_float8_type(self.dtype):
69+
raise ValueError(f"{self.dtype} is not a float8 dtype")
70+
if isinstance(self.granularity, type):
71+
raise ValueError(
72+
"Please specify the granularity object instead of the class, e.g. PerRow() instead of PerRow"
73+
)
74+
if type(self.granularity) not in [PerRow, PerTensor]:
75+
raise ValueError(
76+
f"Expected PerRow or PerTensor granularity, got {self.granularity}"
77+
)
78+
79+
3980
@dataclass
4081
class IntxFakeQuantizeConfig(FakeQuantizeConfigBase):
4182
"""
@@ -279,6 +320,7 @@ def __post_init__(self):
279320
_log_deprecation_warning(self)
280321

281322

323+
# TODO: rewrite using registration API?
282324
def _infer_fake_quantize_configs(
283325
base_config: AOBaseConfig,
284326
) -> Tuple[Optional[FakeQuantizeConfigBase], Optional[FakeQuantizeConfigBase]]:
@@ -291,6 +333,8 @@ def _infer_fake_quantize_configs(
291333
"""
292334
# avoid circular imports
293335
from torchao.quantization import (
336+
Float8DynamicActivationFloat8WeightConfig,
337+
Float8DynamicActivationInt4WeightConfig,
294338
Int4WeightOnlyConfig,
295339
Int8DynamicActivationInt4WeightConfig,
296340
)
@@ -302,18 +346,45 @@ def _infer_fake_quantize_configs(
302346
is_symmetric=base_config.act_mapping_type == MappingType.SYMMETRIC,
303347
)
304348
weight_config = IntxFakeQuantizeConfig(
305-
dtype=TorchAODType.INT4,
349+
dtype=torch.int4,
306350
group_size=base_config.group_size,
307351
is_symmetric=base_config.mapping_type == MappingType.SYMMETRIC,
308352
)
309-
return (act_config, weight_config)
310353
elif isinstance(base_config, Int4WeightOnlyConfig):
354+
if base_config.version != 2:
355+
raise ValueError(f"Only version 2 of {type(base_config)} is supported")
356+
act_config = None
357+
weight_config = IntxFakeQuantizeConfig(
358+
dtype=torch.int4,
359+
group_size=base_config.group_size,
360+
is_symmetric=True,
361+
)
362+
elif isinstance(base_config, Float8DynamicActivationFloat8WeightConfig):
363+
if base_config.version != 2:
364+
raise ValueError(f"Only version 2 of {type(base_config)} is supported")
365+
(act_granularity, weight_granularity) = _normalize_granularity(
366+
base_config.granularity
367+
)
368+
act_config = Float8FakeQuantizeConfig(
369+
dtype=base_config.activation_dtype,
370+
granularity=act_granularity,
371+
hp_value_lb=base_config.activation_value_lb,
372+
hp_value_ub=base_config.activation_value_ub,
373+
)
374+
weight_config = Float8FakeQuantizeConfig(
375+
dtype=base_config.weight_dtype,
376+
granularity=weight_granularity,
377+
)
378+
elif isinstance(base_config, Float8DynamicActivationInt4WeightConfig):
379+
act_config = Float8FakeQuantizeConfig(
380+
dtype=torch.float8_e4m3fn,
381+
granularity=PerRow(),
382+
)
311383
weight_config = IntxFakeQuantizeConfig(
312-
dtype=torch.uint4,
384+
dtype=torch.int4,
313385
group_size=base_config.group_size,
314-
is_symmetric=False,
315-
zero_point_domain=base_config.zero_point_domain,
386+
is_symmetric=True,
316387
)
317-
return (None, weight_config)
318388
else:
319389
raise ValueError("Unexpected base config: %s" % base_config)
390+
return (act_config, weight_config)

0 commit comments

Comments
 (0)