4949)
5050from torchao .quantization .qat .fake_quantize_config import (
5151 Float8FakeQuantizeConfig ,
52+ Int4WeightPreshuffledFakeQuantizeConfig ,
5253 IntxFakeQuantizeConfig ,
5354)
5455from torchao .quantization .qat .fake_quantizer import (
@@ -1929,7 +1930,7 @@ def test_quantize_api_fp8_int4(self):
19291930 """
19301931 self ._test_quantize_api_against_ptq (
19311932 Float8DynamicActivationInt4WeightConfig (),
1932- target_prepare_sqnr = 12 ,
1933+ target_prepare_sqnr = 22 ,
19331934 target_convert_sqnr = float ("inf" ),
19341935 )
19351936
@@ -1950,6 +1951,19 @@ def test_quantize_api_int4(self, version: int):
19501951 target_convert_sqnr = float ("inf" ),
19511952 )
19521953
1954+ @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
1955+ def test_quantize_api_int8_int4 (self ):
1956+ """
1957+ Test the following:
1958+ quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="prepare"))
1959+ quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="convert"))
1960+ """
1961+ self ._test_quantize_api_against_ptq (
1962+ Int8DynamicActivationInt4WeightConfig (group_size = 32 ),
1963+ target_prepare_sqnr = 30 ,
1964+ target_convert_sqnr = float ("inf" ),
1965+ )
1966+
19531967 def test_infer_fp8_int4_config (self ):
19541968 """
19551969 Test that fake quantize configs are correctly inferred from
@@ -1964,10 +1978,9 @@ def test_infer_fp8_int4_config(self):
19641978 self .assertIsInstance (act_config , Float8FakeQuantizeConfig )
19651979 self .assertEqual (act_config .dtype , torch .float8_e4m3fn )
19661980 self .assertIsInstance (act_config .granularity , PerRow )
1967- self .assertIsInstance (weight_config , IntxFakeQuantizeConfig )
1968- self .assertEqual (weight_config .dtype , torch .int4 )
1981+ self .assertIsInstance (weight_config , Int4WeightPreshuffledFakeQuantizeConfig )
19691982 self .assertEqual (weight_config .group_size , 128 )
1970- self .assertTrue (weight_config .is_symmetric )
1983+ self .assertEqual (weight_config .activation_dtype , torch . float8_e4m3fn )
19711984
19721985 def test_infer_int4_weight_only_config (self ):
19731986 """
@@ -2033,6 +2046,126 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool):
20332046 sqnr = compute_error (out , baseline_out ).item ()
20342047 self .assertGreater (sqnr , 24 )
20352048
2049+ @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
2050+ @unittest .skipIf (
2051+ not _is_fbgemm_genai_gpu_available (), "Requires fbgemm-gpu-genai >= 1.2.0"
2052+ )
2053+ def test_fbgemm_fp8_primitives (self ):
2054+ """
2055+ Compare numerics between:
2056+ (1) fbgemm_gpu.experimental.gen_ai.quantize.quantize_fp8_row
2057+ (2) Our reference QAT version in `Float8FakeQuantizer`
2058+ """
2059+ from fbgemm_gpu .experimental .gen_ai .quantize import quantize_fp8_row
2060+
2061+ from torchao .quantization .quant_primitives import (
2062+ _choose_scale_float8 ,
2063+ _quantize_affine_float8 ,
2064+ )
2065+
2066+ x1 = torch .randn ([128 , 256 ], dtype = torch .bfloat16 ).cuda ()
2067+ x2 = copy .deepcopy (x1 )
2068+
2069+ # (1) Just call `quantize_fp8_row`
2070+ (q1 , scale1 ) = quantize_fp8_row (x1 )
2071+
2072+ # (2) Our reference implementation for QAT without the dequantize
2073+ scale2 = _choose_scale_float8 (
2074+ x2 ,
2075+ (1 , x2 .shape [- 1 ]),
2076+ torch .float8_e4m3fn ,
2077+ hp_value_lb = 1e-12 ,
2078+ )
2079+ q2 = _quantize_affine_float8 (x2 , scale2 , torch .float8_e4m3fn )
2080+ sqnr = compute_error (q1 .to (torch .float32 ), q2 .to (torch .float32 ))
2081+ scale_sqnr = compute_error (
2082+ scale1 .to (torch .float32 ).flatten (),
2083+ scale2 .to (torch .float32 ).flatten (),
2084+ )
2085+ self .assertGreater (sqnr , 40 )
2086+ self .assertGreater (scale_sqnr , 50 )
2087+
2088+ @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
2089+ @unittest .skipIf (
2090+ not _is_fbgemm_genai_gpu_available (), "Requires fbgemm-gpu-genai >= 1.2.0"
2091+ )
2092+ def test_fbgemm_int4_preshuffled_primitives (self ):
2093+ """
2094+ Compare numerics between:
2095+ (1) fbgemm_gpu.experimental.gen_ai.quantize.quantize_int4_preshuffle
2096+ (2) Our reference QAT version in `Int4WeightPreshuffledFakeQuantizer`
2097+ """
2098+ from fbgemm_gpu .experimental .gen_ai .quantize import (
2099+ int4_row_quantize ,
2100+ pack_int4 ,
2101+ quantize_fp8_row ,
2102+ quantize_int4_preshuffle ,
2103+ )
2104+
2105+ from torchao .quantization .quant_primitives import (
2106+ _choose_scale_float8 ,
2107+ _quantize_affine_float8 ,
2108+ _quantize_affine_no_dtype_cast ,
2109+ )
2110+
2111+ group_size = 128
2112+ x1 = torch .randn ([128 , 256 ], dtype = torch .bfloat16 ).cuda ()
2113+ x2 = copy .deepcopy (x1 )
2114+ x3 = copy .deepcopy (x1 )
2115+
2116+ # (1) Just call `quantize_int4_preshuffle`
2117+ (q1 , (scale1 , _ )) = quantize_int4_preshuffle (x1 , group_size , dtype = "fp8" )
2118+
2119+ # (2) Call `quantize_int4_preshuffle` but skip packing and shuffling
2120+ (q2 , _ ) = quantize_fp8_row (x2 )
2121+ (q2 , scale2 ) = int4_row_quantize (q2 , group_size )
2122+
2123+ # (3) Reference implementation for QAT without the dequantize
2124+ fp8_scale = _choose_scale_float8 (
2125+ x3 ,
2126+ (1 , x3 .shape [- 1 ]),
2127+ torch .float8_e4m3fn ,
2128+ hp_value_lb = 1e-12 ,
2129+ )
2130+ x3_fp8 = _quantize_affine_float8 (x3 , fp8_scale , torch .float8_e4m3fn )
2131+ x3_fp8 = x3_fp8 .to (torch .float32 )
2132+ x3_fp8_grouped = x3_fp8 .view (x3_fp8 .shape [0 ], - 1 , group_size )
2133+ max_abs = torch .amax (torch .abs (x3_fp8_grouped ), dim = - 1 , keepdim = False )
2134+ scale = torch .clamp (max_abs / 8 , min = 1e-6 )
2135+ zero_point = torch .zeros_like (scale )
2136+ q3 = _quantize_affine_no_dtype_cast (
2137+ x3_fp8 ,
2138+ (1 , group_size ),
2139+ scale ,
2140+ zero_point ,
2141+ quant_min = - 8 ,
2142+ quant_max = 7 ,
2143+ )
2144+ scale3 = scale
2145+
2146+ def shuffle_and_pack (t : torch .Tensor , scale : torch .Tensor ) -> torch .Tensor :
2147+ t = pack_int4 (t .to (torch .int8 ))
2148+ return torch .ops .fbgemm .preshuffle_i4 (t , scale .to (torch .float8_e4m3fn ))[0 ]
2149+
2150+ # First, sanity check that shuffle_and_pack(q2) == q1
2151+ torch .testing .assert_close (q1 , shuffle_and_pack (q2 , scale2 ), atol = 0 , rtol = 0 )
2152+
2153+ # Now check q2 vs q3 with and without shuffle
2154+ sqnr_q2_q3 = compute_error (q2 .to (torch .float32 ), q3 .to (torch .float32 ))
2155+ sqnr_q2_q3_preshuffle = compute_error (
2156+ shuffle_and_pack (q2 , scale2 ).to (torch .float32 ),
2157+ shuffle_and_pack (q3 , scale3 ).to (torch .float32 ),
2158+ )
2159+ self .assertGreater (sqnr_q2_q3 , 32 )
2160+ self .assertGreater (sqnr_q2_q3_preshuffle , 32 )
2161+
2162+ # Now check shuffle_and_pack(q3) vs q1
2163+ sqnr_q1_q3_preshuffle = compute_error (
2164+ q1 .to (torch .float32 ),
2165+ shuffle_and_pack (q3 , scale3 ).to (torch .float32 ),
2166+ )
2167+ self .assertGreater (sqnr_q1_q3_preshuffle , 32 )
2168+
20362169
20372170instantiate_parametrized_tests (TestQAT )
20382171
0 commit comments