4949)
5050from  torchao .quantization .qat .fake_quantize_config  import  (
5151    Float8FakeQuantizeConfig ,
52+     Int4WeightFBGEMMFakeQuantizeConfig ,
5253    IntxFakeQuantizeConfig ,
5354)
5455from  torchao .quantization .qat .fake_quantizer  import  (
@@ -1929,7 +1930,7 @@ def test_quantize_api_fp8_int4(self):
19291930        """ 
19301931        self ._test_quantize_api_against_ptq (
19311932            Float8DynamicActivationInt4WeightConfig (),
1932-             target_prepare_sqnr = 12 ,
1933+             target_prepare_sqnr = 22 ,
19331934            target_convert_sqnr = float ("inf" ),
19341935        )
19351936
@@ -1950,6 +1951,19 @@ def test_quantize_api_int4(self, version: int):
19501951            target_convert_sqnr = float ("inf" ),
19511952        )
19521953
1954+     @unittest .skipIf (not  _CUDA_IS_AVAILABLE , "skipping when cuda is not available" ) 
1955+     def  test_quantize_api_int8_int4 (self ):
1956+         """ 
1957+         Test the following: 
1958+             quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="prepare")) 
1959+             quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="convert")) 
1960+         """ 
1961+         self ._test_quantize_api_against_ptq (
1962+             Int8DynamicActivationInt4WeightConfig (group_size = 32 ),
1963+             target_prepare_sqnr = 30 ,
1964+             target_convert_sqnr = float ("inf" ),
1965+         )
1966+ 
19531967    def  test_infer_fp8_int4_config (self ):
19541968        """ 
19551969        Test that fake quantize configs are correctly inferred from 
@@ -1964,10 +1978,9 @@ def test_infer_fp8_int4_config(self):
19641978        self .assertIsInstance (act_config , Float8FakeQuantizeConfig )
19651979        self .assertEqual (act_config .dtype , torch .float8_e4m3fn )
19661980        self .assertIsInstance (act_config .granularity , PerRow )
1967-         self .assertIsInstance (weight_config , IntxFakeQuantizeConfig )
1968-         self .assertEqual (weight_config .dtype , torch .int4 )
1981+         self .assertIsInstance (weight_config , Int4WeightFBGEMMFakeQuantizeConfig )
19691982        self .assertEqual (weight_config .group_size , 128 )
1970-         self .assertTrue (weight_config .is_symmetric )
1983+         self .assertEqual (weight_config .activation_dtype ,  torch . float8_e4m3fn )
19711984
19721985    def  test_infer_int4_weight_only_config (self ):
19731986        """ 
@@ -2033,6 +2046,128 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool):
20332046        sqnr  =  compute_error (out , baseline_out ).item ()
20342047        self .assertGreater (sqnr , 24 )
20352048
2049+     @unittest .skipIf (not  _CUDA_IS_AVAILABLE , "skipping when cuda is not available" ) 
2050+     @unittest .skipIf ( 
2051+         not  _is_fbgemm_genai_gpu_available (), "Requires fbgemm-gpu-genai >= 1.2.0"  
2052+     ) 
2053+     def  test_fbgemm_fp8_primitives (self ):
2054+         """ 
2055+         Compare numerics between: 
2056+             (1) fbgemm_gpu.experimental.gen_ai.quantize.quantize_fp8_row 
2057+             (2) Our reference QAT version in `Float8FakeQuantizer` 
2058+         """ 
2059+         from  fbgemm_gpu .experimental .gen_ai .quantize  import  quantize_fp8_row 
2060+ 
2061+         from  torchao .quantization .quant_primitives  import  (
2062+             _choose_scale_float8 ,
2063+             _quantize_affine_float8 ,
2064+         )
2065+ 
2066+         x1  =  torch .randn ([128 , 256 ], dtype = torch .bfloat16 ).cuda ()
2067+         x2  =  copy .deepcopy (x1 )
2068+ 
2069+         # (1) Just call `quantize_fp8_row` 
2070+         (q1 , scale1 ) =  quantize_fp8_row (x1 )
2071+ 
2072+         # (2) Our reference implementation for QAT without the dequantize 
2073+         scale2  =  _choose_scale_float8 (
2074+             x2 ,
2075+             (1 , x2 .shape [- 1 ]),
2076+             torch .float8_e4m3fn ,
2077+             hp_value_lb = 1e-12 ,
2078+         )
2079+         q2  =  _quantize_affine_float8 (
2080+             x2 , scale2 , torch .float8_e4m3fn , cast_to_float8_dtype = False 
2081+         )
2082+         sqnr  =  compute_error (q1 .to (torch .float32 ), q2 .to (torch .float32 ))
2083+         scale_sqnr  =  compute_error (
2084+             scale1 .to (torch .float32 ).flatten (),
2085+             scale2 .to (torch .float32 ).flatten (),
2086+         )
2087+         self .assertGreater (sqnr , 30 )
2088+         self .assertGreater (scale_sqnr , 50 )
2089+ 
2090+     @unittest .skipIf (not  _CUDA_IS_AVAILABLE , "skipping when cuda is not available" ) 
2091+     @unittest .skipIf ( 
2092+         not  _is_fbgemm_genai_gpu_available (), "Requires fbgemm-gpu-genai >= 1.2.0"  
2093+     ) 
2094+     def  test_fbgemm_int4_primitives (self ):
2095+         """ 
2096+         Compare numerics between: 
2097+             (1) fbgemm_gpu.experimental.gen_ai.quantize.quantize_int4_preshuffle 
2098+             (2) Our reference QAT version in `Int4WeightFBGEMMFakeQuantizer` 
2099+         """ 
2100+         from  fbgemm_gpu .experimental .gen_ai .quantize  import  (
2101+             int4_row_quantize ,
2102+             pack_int4 ,
2103+             quantize_fp8_row ,
2104+             quantize_int4_preshuffle ,
2105+         )
2106+ 
2107+         from  torchao .quantization .quant_primitives  import  (
2108+             _choose_scale_float8 ,
2109+             _quantize_affine_float8 ,
2110+             _quantize_affine_no_dtype_cast ,
2111+         )
2112+ 
2113+         group_size  =  128 
2114+         x1  =  torch .randn ([128 , 256 ], dtype = torch .bfloat16 ).cuda ()
2115+         x2  =  copy .deepcopy (x1 )
2116+         x3  =  copy .deepcopy (x1 )
2117+ 
2118+         # (1) Just call `quantize_int4_preshuffle` 
2119+         (q1 , (scale1 , _ )) =  quantize_int4_preshuffle (x1 , group_size , dtype = "fp8" )
2120+ 
2121+         # (2) Call `quantize_int4_preshuffle` but skip packing and shuffling 
2122+         (q2 , _ ) =  quantize_fp8_row (x2 )
2123+         (q2 , scale2 ) =  int4_row_quantize (q2 , group_size )
2124+ 
2125+         # (3) Reference implementation for QAT without the dequantize 
2126+         fp8_scale  =  _choose_scale_float8 (
2127+             x3 ,
2128+             (1 , x3 .shape [- 1 ]),
2129+             torch .float8_e4m3fn ,
2130+             hp_value_lb = 1e-12 ,
2131+         )
2132+         x3_fp8  =  _quantize_affine_float8 (x3 , fp8_scale , torch .float8_e4m3fn )
2133+         x3_fp8  =  x3_fp8 .to (torch .float32 )
2134+         x3_fp8_grouped  =  x3_fp8 .view (x3_fp8 .shape [0 ], - 1 , group_size )
2135+         max_abs  =  torch .amax (torch .abs (x3_fp8_grouped ), dim = - 1 , keepdim = False )
2136+         scale  =  torch .clamp (max_abs  /  8 , min = 1e-6 )
2137+         zero_point  =  torch .zeros_like (scale )
2138+         q3  =  _quantize_affine_no_dtype_cast (
2139+             x3_fp8 ,
2140+             (1 , group_size ),
2141+             scale ,
2142+             zero_point ,
2143+             quant_min = - 8 ,
2144+             quant_max = 7 ,
2145+         )
2146+         scale3  =  scale 
2147+ 
2148+         def  shuffle_and_pack (t : torch .Tensor , scale : torch .Tensor ) ->  torch .Tensor :
2149+             t  =  pack_int4 (t .to (torch .int8 ))
2150+             return  torch .ops .fbgemm .preshuffle_i4 (t , scale .to (torch .float8_e4m3fn ))[0 ]
2151+ 
2152+         # First, sanity check that shuffle_and_pack(q2) == q1 
2153+         torch .testing .assert_close (q1 , shuffle_and_pack (q2 , scale2 ), atol = 0 , rtol = 0 )
2154+ 
2155+         # Now check q2 vs q3 with and without shuffle 
2156+         sqnr_q2_q3  =  compute_error (q2 .to (torch .float32 ), q3 .to (torch .float32 ))
2157+         sqnr_q2_q3_preshuffle  =  compute_error (
2158+             shuffle_and_pack (q2 , scale2 ).to (torch .float32 ),
2159+             shuffle_and_pack (q3 , scale3 ).to (torch .float32 ),
2160+         )
2161+         self .assertGreater (sqnr_q2_q3 , 32 )
2162+         self .assertGreater (sqnr_q2_q3_preshuffle , 32 )
2163+ 
2164+         # Now check shuffle_and_pack(q3) vs q1 
2165+         sqnr_q1_q3_preshuffle  =  compute_error (
2166+             q1 .to (torch .float32 ),
2167+             shuffle_and_pack (q3 , scale3 ).to (torch .float32 ),
2168+         )
2169+         self .assertGreater (sqnr_q1_q3_preshuffle , 32 )
2170+ 
20362171
20372172instantiate_parametrized_tests (TestQAT )
20382173
0 commit comments