diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index d67b922f41..f8e07c8954 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -49,6 +49,7 @@ ) from torchao.quantization.qat.fake_quantize_config import ( Float8FakeQuantizeConfig, + Int4WeightPreshuffledFakeQuantizeConfig, IntxFakeQuantizeConfig, ) from torchao.quantization.qat.fake_quantizer import ( @@ -1929,7 +1930,7 @@ def test_quantize_api_fp8_int4(self): """ self._test_quantize_api_against_ptq( Float8DynamicActivationInt4WeightConfig(), - target_prepare_sqnr=12, + target_prepare_sqnr=22, target_convert_sqnr=float("inf"), ) @@ -1950,6 +1951,19 @@ def test_quantize_api_int4(self, version: int): target_convert_sqnr=float("inf"), ) + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + def test_quantize_api_int8_int4(self): + """ + Test the following: + quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="prepare")) + quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="convert")) + """ + self._test_quantize_api_against_ptq( + Int8DynamicActivationInt4WeightConfig(group_size=32), + target_prepare_sqnr=30, + target_convert_sqnr=float("inf"), + ) + def test_infer_fp8_int4_config(self): """ Test that fake quantize configs are correctly inferred from @@ -1964,10 +1978,9 @@ def test_infer_fp8_int4_config(self): self.assertIsInstance(act_config, Float8FakeQuantizeConfig) self.assertEqual(act_config.dtype, torch.float8_e4m3fn) self.assertIsInstance(act_config.granularity, PerRow) - self.assertIsInstance(weight_config, IntxFakeQuantizeConfig) - self.assertEqual(weight_config.dtype, torch.int4) + self.assertIsInstance(weight_config, Int4WeightPreshuffledFakeQuantizeConfig) self.assertEqual(weight_config.group_size, 128) - self.assertTrue(weight_config.is_symmetric) + self.assertEqual(weight_config.activation_dtype, torch.float8_e4m3fn) def test_infer_int4_weight_only_config(self): """ @@ -2033,6 +2046,126 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool): sqnr = compute_error(out, baseline_out).item() self.assertGreater(sqnr, 24) + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf( + not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0" + ) + def test_fbgemm_fp8_primitives(self): + """ + Compare numerics between: + (1) fbgemm_gpu.experimental.gen_ai.quantize.quantize_fp8_row + (2) Our reference QAT version in `Float8FakeQuantizer` + """ + from fbgemm_gpu.experimental.gen_ai.quantize import quantize_fp8_row + + from torchao.quantization.quant_primitives import ( + _choose_scale_float8, + _quantize_affine_float8, + ) + + x1 = torch.randn([128, 256], dtype=torch.bfloat16).cuda() + x2 = copy.deepcopy(x1) + + # (1) Just call `quantize_fp8_row` + (q1, scale1) = quantize_fp8_row(x1) + + # (2) Our reference implementation for QAT without the dequantize + scale2 = _choose_scale_float8( + x2, + (1, x2.shape[-1]), + torch.float8_e4m3fn, + hp_value_lb=1e-12, + ) + q2 = _quantize_affine_float8(x2, scale2, torch.float8_e4m3fn) + sqnr = compute_error(q1.to(torch.float32), q2.to(torch.float32)) + scale_sqnr = compute_error( + scale1.to(torch.float32).flatten(), + scale2.to(torch.float32).flatten(), + ) + self.assertGreater(sqnr, 40) + self.assertGreater(scale_sqnr, 50) + + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf( + not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0" + ) + def test_fbgemm_int4_preshuffled_primitives(self): + """ + Compare numerics between: + (1) fbgemm_gpu.experimental.gen_ai.quantize.quantize_int4_preshuffle + (2) Our reference QAT version in `Int4WeightPreshuffledFakeQuantizer` + """ + from fbgemm_gpu.experimental.gen_ai.quantize import ( + int4_row_quantize, + pack_int4, + quantize_fp8_row, + quantize_int4_preshuffle, + ) + + from torchao.quantization.quant_primitives import ( + _choose_scale_float8, + _quantize_affine_float8, + _quantize_affine_no_dtype_cast, + ) + + group_size = 128 + x1 = torch.randn([128, 256], dtype=torch.bfloat16).cuda() + x2 = copy.deepcopy(x1) + x3 = copy.deepcopy(x1) + + # (1) Just call `quantize_int4_preshuffle` + (q1, (scale1, _)) = quantize_int4_preshuffle(x1, group_size, dtype="fp8") + + # (2) Call `quantize_int4_preshuffle` but skip packing and shuffling + (q2, _) = quantize_fp8_row(x2) + (q2, scale2) = int4_row_quantize(q2, group_size) + + # (3) Reference implementation for QAT without the dequantize + fp8_scale = _choose_scale_float8( + x3, + (1, x3.shape[-1]), + torch.float8_e4m3fn, + hp_value_lb=1e-12, + ) + x3_fp8 = _quantize_affine_float8(x3, fp8_scale, torch.float8_e4m3fn) + x3_fp8 = x3_fp8.to(torch.float32) + x3_fp8_grouped = x3_fp8.view(x3_fp8.shape[0], -1, group_size) + max_abs = torch.amax(torch.abs(x3_fp8_grouped), dim=-1, keepdim=False) + scale = torch.clamp(max_abs / 8, min=1e-6) + zero_point = torch.zeros_like(scale) + q3 = _quantize_affine_no_dtype_cast( + x3_fp8, + (1, group_size), + scale, + zero_point, + quant_min=-8, + quant_max=7, + ) + scale3 = scale + + def shuffle_and_pack(t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + t = pack_int4(t.to(torch.int8)) + return torch.ops.fbgemm.preshuffle_i4(t, scale.to(torch.float8_e4m3fn))[0] + + # First, sanity check that shuffle_and_pack(q2) == q1 + torch.testing.assert_close(q1, shuffle_and_pack(q2, scale2), atol=0, rtol=0) + + # Now check q2 vs q3 with and without shuffle + sqnr_q2_q3 = compute_error(q2.to(torch.float32), q3.to(torch.float32)) + sqnr_q2_q3_preshuffle = compute_error( + shuffle_and_pack(q2, scale2).to(torch.float32), + shuffle_and_pack(q3, scale3).to(torch.float32), + ) + self.assertGreater(sqnr_q2_q3, 32) + self.assertGreater(sqnr_q2_q3_preshuffle, 32) + + # Now check shuffle_and_pack(q3) vs q1 + sqnr_q1_q3_preshuffle = compute_error( + q1.to(torch.float32), + shuffle_and_pack(q3, scale3).to(torch.float32), + ) + self.assertGreater(sqnr_q1_q3_preshuffle, 32) + instantiate_parametrized_tests(TestQAT) diff --git a/torchao/quantization/qat/fake_quantize_config.py b/torchao/quantization/qat/fake_quantize_config.py index 2999af5264..7bc1e69c85 100644 --- a/torchao/quantization/qat/fake_quantize_config.py +++ b/torchao/quantization/qat/fake_quantize_config.py @@ -77,6 +77,25 @@ def __post_init__(self): ) +@dataclass +class Int4WeightPreshuffledFakeQuantizeConfig(FakeQuantizeConfigBase): + """ + Config for pint4 weight fake quantization that targets the numerics in the following preshuffled kernel: + torch.ops.fbgemm.f8i4bf16_shuffled + + Currently this only supports float8 input activations. It is expected to be used in conjunction with + :class:`~torchao.quantization.Float8DynamicActivationInt4WeightConfig`. In the future, we may extend + this to support bfloat16 as well. + """ + + group_size: int = 128 + activation_dtype: torch.dtype = e4m3_dtype + + def __post_init__(self): + if self.activation_dtype != e4m3_dtype: + raise ValueError(f"Only {e4m3_dtype} activation is supported currently") + + @dataclass class IntxFakeQuantizeConfig(FakeQuantizeConfigBase): """ @@ -404,10 +423,9 @@ def _infer_fake_quantize_configs( dtype=torch.float8_e4m3fn, granularity=PerRow(), ) - weight_config = IntxFakeQuantizeConfig( - dtype=torch.int4, + weight_config = Int4WeightPreshuffledFakeQuantizeConfig( group_size=128, - is_symmetric=True, + activation_dtype=e4m3_dtype, ) elif isinstance(base_config, NVFP4InferenceConfig): # Note: today the PTQ config does not allow the user to specify diff --git a/torchao/quantization/qat/fake_quantizer.py b/torchao/quantization/qat/fake_quantizer.py index 7bf27f4719..8a63a0d0ad 100644 --- a/torchao/quantization/qat/fake_quantizer.py +++ b/torchao/quantization/qat/fake_quantizer.py @@ -11,6 +11,7 @@ from torchao.quantization.granularity import ( PerAxis, PerGroup, + PerRow, PerToken, ) from torchao.quantization.observer import get_block_size @@ -20,6 +21,7 @@ MappingType, _choose_scale_float8, _dequantize_affine_float8, + _fake_quantize_affine, _quantize_affine_float8, _Round, choose_qparams_affine, @@ -33,6 +35,7 @@ from .fake_quantize_config import ( FakeQuantizeConfigBase, Float8FakeQuantizeConfig, + Int4WeightPreshuffledFakeQuantizeConfig, IntxFakeQuantizeConfig, ) from .utils import ( @@ -65,6 +68,8 @@ def from_config(config: FakeQuantizeConfigBase) -> "FakeQuantizerBase": if isinstance(config, IntxFakeQuantizeConfig): return IntxFakeQuantizer(config) + elif isinstance(config, Int4WeightPreshuffledFakeQuantizeConfig): + return Int4WeightPreshuffledFakeQuantizer(config) elif isinstance(config, Float8FakeQuantizeConfig): return Float8FakeQuantizer(config) elif isinstance(config, NVFP4FakeQuantizeConfig): @@ -93,13 +98,68 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: hp_value_lb=self.config.hp_value_lb, hp_value_ub=self.config.hp_value_ub, ) - q = _quantize_affine_float8( - x, scale, self.config.dtype, cast_to_float8_dtype=False - ) + q = _quantize_affine_float8(x, scale, self.config.dtype) dq = _dequantize_affine_float8(q, scale, original_dtype) return dq +class Int4WeightPreshuffledFakeQuantizer(FakeQuantizerBase): + """ + Generic module for applying int4 fake quantization to a weight tensor, + targeting the following FBGEMM kernel: + torch.ops.fbgemm.f8i4bf16_shuffled + """ + + def __init__(self, config: Int4WeightPreshuffledFakeQuantizeConfig): + super().__init__() + self.config = config + torch._C._log_api_usage_once( + "torchao.quantization.qat.Int4WeightPreshuffledFakeQuantizer" + ) + + def forward(self, w: torch.Tensor) -> torch.Tensor: + """ + Apply int4 fake quantization to the weight tensor, using the following as a reference: + https://github.com/pytorch/FBGEMM/blob/80cc48c4b2b7fcc579e53211fc8715a8592cbd2c/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py#L112 + + Currently, we expect the activations to always be rowwise float8. + """ + assert w.dim() == 2 + assert self.config.activation_dtype == torch.float8_e4m3fn + + # First quantize weights to fp8 per row + # This simulates the numerics of fbgemm_gpu.experimental.gen_ai.quantize.quantize_fp8_row + per_row_block_size = get_block_size(w.shape, PerRow()) + fp8_scale = _choose_scale_float8( + w, + per_row_block_size, + torch.float8_e4m3fn, + hp_value_lb=1e-12, + ) + w_fp8 = _quantize_affine_float8(w, fp8_scale, torch.float8_e4m3fn) + w_fp8 = _dequantize_affine_float8(w_fp8, fp8_scale, w.dtype) + + # Now quantize to int4 per group + # This simulates the numerics of fbgemm_gpu.experimental.gen_ai.quantize.int4_row_quantize + eps = 1e-6 + fbgemm_scale_quant_max = 8 + w_fp8_grouped = w_fp8.view(w_fp8.shape[0], -1, self.config.group_size) + max_abs = torch.amax(torch.abs(w_fp8_grouped), dim=-1, keepdim=False) + scale = torch.clamp(max_abs / fbgemm_scale_quant_max, min=eps) + zero_point = torch.zeros_like(scale) + per_group_block_size = (1, self.config.group_size) + fq = _fake_quantize_affine( + w_fp8, + per_group_block_size, + scale, + zero_point, + quant_dtype=torch.int8, + quant_min=-8, + quant_max=7, + ) + return fq.to(w.dtype) + + class IntxFakeQuantizer(FakeQuantizerBase): """ Generic module for applying integer fake quantization to a tensor, as specified in the config. diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index c118e0b4ce..6298344745 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -219,6 +219,20 @@ def backward(ctx, gy: torch.Tensor) -> torch.Tensor: return gy +class _RoundToFloat8(torch.autograd.Function): + """ + Implementation of `tensor.to(float8_dtype)` with backward STE. + """ + + @staticmethod + def forward(ctx, x: torch.Tensor, float8_dtype: torch.dtype) -> torch.Tensor: + return x.to(float8_dtype) + + @staticmethod + def backward(ctx, gy: torch.Tensor) -> torch.Tensor: + return gy, None + + # TODO: decide on if we want to allow custom quant_min/quant_max here def _get_and_check_qmin_qmax(dtype, quant_min, quant_max): """Get quant_min and quant_max args based on dtype and also verify bounds. @@ -2275,7 +2289,6 @@ def _quantize_affine_float8( tensor: torch.Tensor, scale: torch.Tensor, float8_dtype: torch.dtype = torch.float8_e4m3fn, - cast_to_float8_dtype: bool = True, ) -> torch.Tensor: """ Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor. @@ -2288,9 +2301,7 @@ def _quantize_affine_float8( tensor_scaled = tensor_fp32 / scale_expanded max_value = torch.finfo(float8_dtype).max tensor_clamped = tensor_scaled.clamp(min=-max_value, max=max_value) - if cast_to_float8_dtype: - tensor_clamped = tensor_clamped.to(float8_dtype) - return tensor_clamped + return _RoundToFloat8.apply(tensor_clamped, float8_dtype) # TODO: don't register as custom op?