1414
1515import torch
1616import torch .nn .functional as F
17- from parameterized import parameterized
1817from torch .ao .quantization .fx ._decomposed import quantized_decomposed_lib # noqa: F401
18+ from torch .testing ._internal .common_utils import (
19+ TestCase ,
20+ instantiate_parametrized_tests ,
21+ parametrize ,
22+ )
1923
2024from torchao import quantize_
21- from torchao .float8 .config import ScalingGranularity
22- from torchao .float8 .float8_scaling_utils import hp_tensor_to_float8_dynamic
23- from torchao .float8 .float8_training_tensor import LinearMMConfig
25+ from torchao .quantization import Float8Tensor
2426from torchao .quantization .granularity import (
27+ Granularity ,
2528 PerAxis ,
2629 PerGroup ,
2730 PerRow ,
31+ PerTensor ,
2832 PerToken ,
2933)
3034from torchao .quantization .linear_quant_modules import (
4347 FakeQuantizedEmbedding ,
4448)
4549from torchao .quantization .qat .fake_quantize_config import (
50+ Float8FakeQuantizeConfig ,
4651 IntxFakeQuantizeConfig ,
4752)
4853from torchao .quantization .qat .fake_quantizer import (
54+ Float8FakeQuantizer ,
4955 IntxFakeQuantizer ,
50- _Float8RowwiseActivationFakeQuantizer ,
5156)
5257from torchao .quantization .qat .linear import (
5358 FakeQuantizedLinear ,
5863from torchao .quantization .qat .utils import (
5964 _fake_quantize_per_channel_group ,
6065 _fake_quantize_per_token ,
61- _Float8RowwiseFakeQuantize ,
6266 _get_qmin_qmax ,
6367)
6468from torchao .quantization .quant_api import (
69+ Float8DynamicActivationFloat8WeightConfig ,
70+ Float8DynamicActivationInt4WeightConfig ,
6571 Int8DynamicActivationInt4WeightConfig ,
6672)
6773from torchao .quantization .quant_primitives import (
8389 get_groupwise_affine_qparams ,
8490 groupwise_affine_quantize_tensor ,
8591)
92+ from torchao .utils import (
93+ _is_fbgemm_genai_gpu_available ,
94+ is_sm_at_least_89 ,
95+ )
8696
8797# TODO: put this in a common test utils file
8898_CUDA_IS_AVAILABLE = torch .cuda .is_available ()
@@ -193,7 +203,7 @@ def forward(self, x):
193203 return x
194204
195205
196- class TestQAT (unittest . TestCase ):
206+ class TestQAT (TestCase ):
197207 SEED = 123
198208
199209 def test_fake_quantize_per_channel_group (self ):
@@ -1420,7 +1430,7 @@ def test_qat_linear_bias(self):
14201430 example_inputs = m .example_inputs ()
14211431 m (* example_inputs )
14221432
1423- @parameterized . expand ([( torch .float32 ,), ( torch .bfloat16 ,), ( torch .float16 ,) ])
1433+ @parametrize ( "dtype" , [ torch .float32 , torch .bfloat16 , torch .float16 ])
14241434 def test_fake_quantize_per_token_vs_convert (self , dtype : torch .dtype ):
14251435 """
14261436 Test that the following produce the exact same numerics:
@@ -1437,7 +1447,7 @@ def test_fake_quantize_per_token_vs_convert(self, dtype: torch.dtype):
14371447 baseline_out = per_token_dynamic_quant (x )
14381448 torch .testing .assert_close (fake_quantizer_out , baseline_out , atol = 0 , rtol = 0 )
14391449
1440- @parameterized . expand ([( torch .float32 ,), ( torch .bfloat16 ,), ( torch .float16 ,) ])
1450+ @parametrize ( "dtype" , [ torch .float32 , torch .bfloat16 , torch .float16 ])
14411451 def test_qat_8da4w_prepare_vs_convert (self , dtype : torch .dtype ):
14421452 """
14431453 Test that the prepare and convert steps of Int8DynActInt4QATQuantizer produces
@@ -1548,7 +1558,7 @@ def test_qat_8da4w_eps(self):
15481558 actual_out = converted_model .linear1 (x )
15491559 torch .testing .assert_close (expected_out , actual_out , atol = 0 , rtol = 0 )
15501560
1551- @parameterized . expand ([( True ,), ( False ,) ])
1561+ @parametrize ( "is_symmetric" , [ True , False ])
15521562 def test_fake_quantizer_range_learning (self , is_symmetric ):
15531563 """
15541564 Test that range learning requires `IntxFakeQuantizer`s to be initialized correctly.
@@ -1589,7 +1599,7 @@ def test_fake_quantizer_range_learning(self, is_symmetric):
15891599 self .assertTrue (fake_quantizer .zero_point .requires_grad )
15901600 fake_quantizer (* example_inputs )
15911601
1592- @parameterized . expand ([( True ,), ( False ,) ])
1602+ @parametrize ( "is_symmetric" , [ True , False ])
15931603 def test_qat_range_learning (self , is_symmetric ):
15941604 """
15951605 Test end-to-end QAT flow with range learning.
@@ -1664,24 +1674,6 @@ def test_qat_range_learning(self, is_symmetric):
16641674 self .assertNotEqual (torch .count_nonzero (new_weight .grad ), 0 )
16651675 self .assertFalse (torch .equal (new_weight , prev_weight ))
16661676
1667- def test_float8_rowwise_fake_quantize (self ):
1668- """
1669- Test that `_Float8RowwiseFakeQuantize` is numerically close to `Float8TrainingTensor`.
1670- """
1671- torch .manual_seed (self .SEED )
1672- dtype = torch .float8_e4m3fn
1673- x = torch .randn (32 , 64 )
1674- axiswise_dim = 0
1675- out = _Float8RowwiseFakeQuantize .apply (x , dtype , axiswise_dim )
1676- out_expected = hp_tensor_to_float8_dynamic (
1677- x ,
1678- dtype ,
1679- LinearMMConfig (),
1680- scaling_granularity = ScalingGranularity .AXISWISE ,
1681- axiswise_dim = axiswise_dim ,
1682- ).to_original_precision ()
1683- torch .testing .assert_close (out , out_expected , atol = 0 , rtol = 0 )
1684-
16851677 def test_qat_fp8a4w_quantizer (self ):
16861678 """
16871679 Test basic model training with `Float8ActInt4WeightQATQuantizer`.
@@ -1693,7 +1685,8 @@ def test_qat_fp8a4w_quantizer(self):
16931685 for linear in [m .linear1 , m .sub .linear , m .linear2 ]:
16941686 self .assertIsInstance (linear , FakeQuantizedLinear )
16951687 self .assertIsInstance (
1696- linear .activation_fake_quantizer , _Float8RowwiseActivationFakeQuantizer
1688+ linear .activation_fake_quantizer ,
1689+ Float8FakeQuantizer ,
16971690 )
16981691 self .assertIsInstance (linear .weight_fake_quantizer , IntxFakeQuantizer )
16991692 prev_weight = copy .deepcopy (m .linear1 .weight )
@@ -1805,9 +1798,6 @@ def test_qat_api_deprecation(self):
18051798 str (w .message ),
18061799 )
18071800
1808- @unittest .skipIf (
1809- not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower"
1810- )
18111801 def test_qat_api_convert_no_quantization (self ):
18121802 """
18131803 Test that `QATConfig(step="convert")` swaps back to nn modules without quantization.
@@ -1836,6 +1826,82 @@ def test_qat_api_convert_no_quantization(self):
18361826 baseline_out = baseline_model (* x2 )
18371827 torch .testing .assert_close (out , baseline_out , atol = 0 , rtol = 0 )
18381828
1829+ def test_float8_fake_quantize_config (self ):
1830+ """
1831+ Test that the correct errors are thrown if `Float8FakeQuantizeConfig` is not instantiated properly.
1832+ """
1833+ # OK
1834+ Float8FakeQuantizeConfig (torch .float8_e4m3fn )
1835+ Float8FakeQuantizeConfig (torch .float8_e4m3fn , PerRow ())
1836+ Float8FakeQuantizeConfig (torch .float8_e4m3fn , PerTensor ())
1837+
1838+ with self .assertRaisesRegex (ValueError , "not a float8 dtype" ):
1839+ Float8FakeQuantizeConfig (torch .int8 )
1840+ with self .assertRaisesRegex (
1841+ ValueError , "Please specify the granularity object instead of the class"
1842+ ):
1843+ Float8FakeQuantizeConfig (granularity = PerRow )
1844+ with self .assertRaisesRegex (
1845+ ValueError , "Expected PerRow or PerTensor granularity"
1846+ ):
1847+ Float8FakeQuantizeConfig (granularity = PerToken ())
1848+
1849+ @parametrize ("granularity" , [PerTensor (), PerRow ()])
1850+ def test_float8_fake_quantize (self , granularity : Granularity ):
1851+ """
1852+ Test that `Float8FakeQuantizer` is numerically close to `Float8Tensor`.
1853+ """
1854+ dtype = torch .float8_e4m3fn
1855+ fq_config = Float8FakeQuantizeConfig (dtype , granularity )
1856+ fake_quantizer = Float8FakeQuantizer (fq_config )
1857+ torch .manual_seed (self .SEED )
1858+ x = torch .randn (32 , 64 )
1859+ out = fake_quantizer (x )
1860+ out_expected = Float8Tensor .to_float8 (x , dtype , granularity ).dequantize ()
1861+ sqnr = compute_error (out , out_expected )
1862+ self .assertGreater (sqnr , 16 )
1863+
1864+ @parametrize ("granularity" , [PerTensor (), PerRow ()])
1865+ @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
1866+ @unittest .skipIf (not is_sm_at_least_89 (), "Need sm89+" )
1867+ def test_quantize_api_fp8_fp8 (self , granularity : Granularity ):
1868+ """
1869+ Test the following:
1870+ quantize_(model, QATConfig(Float8DynamicActivationFloat8Weight(), step="prepare"))
1871+ quantize_(model, QATConfig(Float8DynamicActivationFloat8Weight(), step="convert"))
1872+ """
1873+ torch .manual_seed (self .SEED )
1874+ m = M ().to (torch .bfloat16 ).cuda ()
1875+ example_inputs = (m .example_inputs ()[0 ].to (torch .bfloat16 ).cuda (),)
1876+ base_config = Float8DynamicActivationFloat8WeightConfig (granularity = granularity )
1877+ quantize_ (m , QATConfig (base_config , step = "prepare" ))
1878+ m (* example_inputs )
1879+ quantize_ (m , QATConfig (base_config , step = "convert" ))
1880+ m (* example_inputs )
1881+
1882+ @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
1883+ @unittest .skipIf (not is_sm_at_least_89 (), "Need sm89+" )
1884+ @unittest .skipIf (
1885+ not _is_fbgemm_genai_gpu_available (), "Requires fbgemm-gpu-genai >= 1.2.0"
1886+ )
1887+ def test_quantize_api_fp8_int4 (self ):
1888+ """
1889+ Test the following:
1890+ quantize_(model, QATConfig(Float8DynamicActivationInt4WeightConfig(), step="prepare"))
1891+ quantize_(model, QATConfig(Float8DynamicActivationInt4WeightConfig(), step="convert"))
1892+ """
1893+ torch .manual_seed (self .SEED )
1894+ m = M ().to (torch .bfloat16 ).cuda ()
1895+ example_inputs = (m .example_inputs ()[0 ].to (torch .bfloat16 ).cuda (),)
1896+ base_config = Float8DynamicActivationInt4WeightConfig (group_size = 128 )
1897+ quantize_ (m , QATConfig (base_config , step = "prepare" ))
1898+ m (* example_inputs )
1899+ quantize_ (m , QATConfig (base_config , step = "convert" ))
1900+ m (* example_inputs )
1901+
1902+
1903+ instantiate_parametrized_tests (TestQAT )
1904+
18391905
18401906if __name__ == "__main__" :
18411907 unittest .main ()
0 commit comments