4949)
5050from torchao .quantization .qat .fake_quantize_config import (
5151 Float8FakeQuantizeConfig ,
52+ Int4WeightFBGEMMFakeQuantizeConfig ,
5253 IntxFakeQuantizeConfig ,
5354)
5455from torchao .quantization .qat .fake_quantizer import (
@@ -1929,7 +1930,7 @@ def test_quantize_api_fp8_int4(self):
19291930 """
19301931 self ._test_quantize_api_against_ptq (
19311932 Float8DynamicActivationInt4WeightConfig (),
1932- target_prepare_sqnr = 12 ,
1933+ target_prepare_sqnr = 22 ,
19331934 target_convert_sqnr = float ("inf" ),
19341935 )
19351936
@@ -1950,6 +1951,19 @@ def test_quantize_api_int4(self, version: int):
19501951 target_convert_sqnr = float ("inf" ),
19511952 )
19521953
1954+ @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
1955+ def test_quantize_api_int8_int4 (self ):
1956+ """
1957+ Test the following:
1958+ quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="prepare"))
1959+ quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="convert"))
1960+ """
1961+ self ._test_quantize_api_against_ptq (
1962+ Int8DynamicActivationInt4WeightConfig (group_size = 32 ),
1963+ target_prepare_sqnr = 30 ,
1964+ target_convert_sqnr = float ("inf" ),
1965+ )
1966+
19531967 def test_infer_fp8_int4_config (self ):
19541968 """
19551969 Test that fake quantize configs are correctly inferred from
@@ -1964,10 +1978,9 @@ def test_infer_fp8_int4_config(self):
19641978 self .assertIsInstance (act_config , Float8FakeQuantizeConfig )
19651979 self .assertEqual (act_config .dtype , torch .float8_e4m3fn )
19661980 self .assertIsInstance (act_config .granularity , PerRow )
1967- self .assertIsInstance (weight_config , IntxFakeQuantizeConfig )
1968- self .assertEqual (weight_config .dtype , torch .int4 )
1981+ self .assertIsInstance (weight_config , Int4WeightFBGEMMFakeQuantizeConfig )
19691982 self .assertEqual (weight_config .group_size , 128 )
1970- self .assertTrue (weight_config .is_symmetric )
1983+ self .assertEqual (weight_config .activation_dtype , torch . float8_e4m3fn )
19711984
19721985 def test_infer_int4_weight_only_config (self ):
19731986 """
@@ -2033,6 +2046,128 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool):
20332046 sqnr = compute_error (out , baseline_out ).item ()
20342047 self .assertGreater (sqnr , 24 )
20352048
2049+ @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
2050+ @unittest .skipIf (
2051+ not _is_fbgemm_genai_gpu_available (), "Requires fbgemm-gpu-genai >= 1.2.0"
2052+ )
2053+ def test_fbgemm_fp8_primitives (self ):
2054+ """
2055+ Compare numerics between:
2056+ (1) fbgemm_gpu.experimental.gen_ai.quantize.quantize_fp8_row
2057+ (2) Our reference QAT version in `Float8FakeQuantizer`
2058+ """
2059+ from fbgemm_gpu .experimental .gen_ai .quantize import quantize_fp8_row
2060+
2061+ from torchao .quantization .quant_primitives import (
2062+ _choose_scale_float8 ,
2063+ _quantize_affine_float8 ,
2064+ )
2065+
2066+ x1 = torch .randn ([128 , 256 ], dtype = torch .bfloat16 ).cuda ()
2067+ x2 = copy .deepcopy (x1 )
2068+
2069+ # (1) Just call `quantize_fp8_row`
2070+ (q1 , scale1 ) = quantize_fp8_row (x1 )
2071+
2072+ # (2) Our reference implementation for QAT without the dequantize
2073+ scale2 = _choose_scale_float8 (
2074+ x2 ,
2075+ (1 , x2 .shape [- 1 ]),
2076+ torch .float8_e4m3fn ,
2077+ hp_value_lb = 1e-12 ,
2078+ )
2079+ q2 = _quantize_affine_float8 (
2080+ x2 , scale2 , torch .float8_e4m3fn , cast_to_float8_dtype = False
2081+ )
2082+ sqnr = compute_error (q1 .to (torch .float32 ), q2 .to (torch .float32 ))
2083+ scale_sqnr = compute_error (
2084+ scale1 .to (torch .float32 ).flatten (),
2085+ scale2 .to (torch .float32 ).flatten (),
2086+ )
2087+ self .assertGreater (sqnr , 30 )
2088+ self .assertGreater (scale_sqnr , 50 )
2089+
2090+ @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
2091+ @unittest .skipIf (
2092+ not _is_fbgemm_genai_gpu_available (), "Requires fbgemm-gpu-genai >= 1.2.0"
2093+ )
2094+ def test_fbgemm_int4_primitives (self ):
2095+ """
2096+ Compare numerics between:
2097+ (1) fbgemm_gpu.experimental.gen_ai.quantize.quantize_int4_preshuffle
2098+ (2) Our reference QAT version in `Int4WeightFBGEMMFakeQuantizer`
2099+ """
2100+ from fbgemm_gpu .experimental .gen_ai .quantize import (
2101+ int4_row_quantize ,
2102+ pack_int4 ,
2103+ quantize_fp8_row ,
2104+ quantize_int4_preshuffle ,
2105+ )
2106+
2107+ from torchao .quantization .quant_primitives import (
2108+ _choose_scale_float8 ,
2109+ _quantize_affine_float8 ,
2110+ _quantize_affine_no_dtype_cast ,
2111+ )
2112+
2113+ group_size = 128
2114+ x1 = torch .randn ([128 , 256 ], dtype = torch .bfloat16 ).cuda ()
2115+ x2 = copy .deepcopy (x1 )
2116+ x3 = copy .deepcopy (x1 )
2117+
2118+ # (1) Just call `quantize_int4_preshuffle`
2119+ (q1 , (scale1 , _ )) = quantize_int4_preshuffle (x1 , group_size , dtype = "fp8" )
2120+
2121+ # (2) Call `quantize_int4_preshuffle` but skip packing and shuffling
2122+ (q2 , _ ) = quantize_fp8_row (x2 )
2123+ (q2 , scale2 ) = int4_row_quantize (q2 , group_size )
2124+
2125+ # (3) Reference implementation for QAT without the dequantize
2126+ fp8_scale = _choose_scale_float8 (
2127+ x3 ,
2128+ (1 , x3 .shape [- 1 ]),
2129+ torch .float8_e4m3fn ,
2130+ hp_value_lb = 1e-12 ,
2131+ )
2132+ x3_fp8 = _quantize_affine_float8 (x3 , fp8_scale , torch .float8_e4m3fn )
2133+ x3_fp8 = x3_fp8 .to (torch .float32 )
2134+ x3_fp8_grouped = x3_fp8 .view (x3_fp8 .shape [0 ], - 1 , group_size )
2135+ max_abs = torch .amax (torch .abs (x3_fp8_grouped ), dim = - 1 , keepdim = False )
2136+ scale = torch .clamp (max_abs / 8 , min = 1e-6 )
2137+ zero_point = torch .zeros_like (scale )
2138+ q3 = _quantize_affine_no_dtype_cast (
2139+ x3_fp8 ,
2140+ (1 , group_size ),
2141+ scale ,
2142+ zero_point ,
2143+ quant_min = - 8 ,
2144+ quant_max = 7 ,
2145+ )
2146+ scale3 = scale
2147+
2148+ def shuffle_and_pack (t : torch .Tensor , scale : torch .Tensor ) -> torch .Tensor :
2149+ t = pack_int4 (t .to (torch .int8 ))
2150+ return torch .ops .fbgemm .preshuffle_i4 (t , scale .to (torch .float8_e4m3fn ))[0 ]
2151+
2152+ # First, sanity check that shuffle_and_pack(q2) == q1
2153+ torch .testing .assert_close (q1 , shuffle_and_pack (q2 , scale2 ), atol = 0 , rtol = 0 )
2154+
2155+ # Now check q2 vs q3 with and without shuffle
2156+ sqnr_q2_q3 = compute_error (q2 .to (torch .float32 ), q3 .to (torch .float32 ))
2157+ sqnr_q2_q3_preshuffle = compute_error (
2158+ shuffle_and_pack (q2 , scale2 ).to (torch .float32 ),
2159+ shuffle_and_pack (q3 , scale3 ).to (torch .float32 ),
2160+ )
2161+ self .assertGreater (sqnr_q2_q3 , 32 )
2162+ self .assertGreater (sqnr_q2_q3_preshuffle , 32 )
2163+
2164+ # Now check shuffle_and_pack(q3) vs q1
2165+ sqnr_q1_q3_preshuffle = compute_error (
2166+ q1 .to (torch .float32 ),
2167+ shuffle_and_pack (q3 , scale3 ).to (torch .float32 ),
2168+ )
2169+ self .assertGreater (sqnr_q1_q3_preshuffle , 32 )
2170+
20362171
20372172instantiate_parametrized_tests (TestQAT )
20382173
0 commit comments