-
Notifications
You must be signed in to change notification settings - Fork 357
Improve QAT fp8-int4 numerics #2937
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -49,6 +49,7 @@ | |
| ) | ||
| from torchao.quantization.qat.fake_quantize_config import ( | ||
| Float8FakeQuantizeConfig, | ||
| Int4WeightPreshuffledFakeQuantizeConfig, | ||
| IntxFakeQuantizeConfig, | ||
| ) | ||
| from torchao.quantization.qat.fake_quantizer import ( | ||
|
|
@@ -1929,7 +1930,7 @@ def test_quantize_api_fp8_int4(self): | |
| """ | ||
| self._test_quantize_api_against_ptq( | ||
| Float8DynamicActivationInt4WeightConfig(), | ||
| target_prepare_sqnr=12, | ||
| target_prepare_sqnr=22, | ||
| target_convert_sqnr=float("inf"), | ||
| ) | ||
|
|
||
|
|
@@ -1950,6 +1951,19 @@ def test_quantize_api_int4(self, version: int): | |
| target_convert_sqnr=float("inf"), | ||
| ) | ||
|
|
||
| @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") | ||
| def test_quantize_api_int8_int4(self): | ||
| """ | ||
| Test the following: | ||
| quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="prepare")) | ||
| quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="convert")) | ||
| """ | ||
| self._test_quantize_api_against_ptq( | ||
| Int8DynamicActivationInt4WeightConfig(group_size=32), | ||
| target_prepare_sqnr=30, | ||
| target_convert_sqnr=float("inf"), | ||
| ) | ||
|
|
||
| def test_infer_fp8_int4_config(self): | ||
| """ | ||
| Test that fake quantize configs are correctly inferred from | ||
|
|
@@ -1964,10 +1978,9 @@ def test_infer_fp8_int4_config(self): | |
| self.assertIsInstance(act_config, Float8FakeQuantizeConfig) | ||
| self.assertEqual(act_config.dtype, torch.float8_e4m3fn) | ||
| self.assertIsInstance(act_config.granularity, PerRow) | ||
| self.assertIsInstance(weight_config, IntxFakeQuantizeConfig) | ||
| self.assertEqual(weight_config.dtype, torch.int4) | ||
| self.assertIsInstance(weight_config, Int4WeightPreshuffledFakeQuantizeConfig) | ||
| self.assertEqual(weight_config.group_size, 128) | ||
| self.assertTrue(weight_config.is_symmetric) | ||
| self.assertEqual(weight_config.activation_dtype, torch.float8_e4m3fn) | ||
|
|
||
| def test_infer_int4_weight_only_config(self): | ||
| """ | ||
|
|
@@ -2033,6 +2046,126 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool): | |
| sqnr = compute_error(out, baseline_out).item() | ||
| self.assertGreater(sqnr, 24) | ||
|
|
||
| @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") | ||
| @unittest.skipIf( | ||
| not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0" | ||
| ) | ||
| def test_fbgemm_fp8_primitives(self): | ||
| """ | ||
| Compare numerics between: | ||
| (1) fbgemm_gpu.experimental.gen_ai.quantize.quantize_fp8_row | ||
| (2) Our reference QAT version in `Float8FakeQuantizer` | ||
| """ | ||
| from fbgemm_gpu.experimental.gen_ai.quantize import quantize_fp8_row | ||
|
|
||
| from torchao.quantization.quant_primitives import ( | ||
| _choose_scale_float8, | ||
| _quantize_affine_float8, | ||
| ) | ||
|
|
||
| x1 = torch.randn([128, 256], dtype=torch.bfloat16).cuda() | ||
| x2 = copy.deepcopy(x1) | ||
|
|
||
| # (1) Just call `quantize_fp8_row` | ||
| (q1, scale1) = quantize_fp8_row(x1) | ||
|
|
||
| # (2) Our reference implementation for QAT without the dequantize | ||
| scale2 = _choose_scale_float8( | ||
| x2, | ||
| (1, x2.shape[-1]), | ||
| torch.float8_e4m3fn, | ||
| hp_value_lb=1e-12, | ||
| ) | ||
| q2 = _quantize_affine_float8(x2, scale2, torch.float8_e4m3fn) | ||
| sqnr = compute_error(q1.to(torch.float32), q2.to(torch.float32)) | ||
| scale_sqnr = compute_error( | ||
| scale1.to(torch.float32).flatten(), | ||
| scale2.to(torch.float32).flatten(), | ||
| ) | ||
| self.assertGreater(sqnr, 40) | ||
| self.assertGreater(scale_sqnr, 50) | ||
|
|
||
| @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") | ||
| @unittest.skipIf( | ||
| not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0" | ||
| ) | ||
| def test_fbgemm_int4_preshuffled_primitives(self): | ||
| """ | ||
| Compare numerics between: | ||
| (1) fbgemm_gpu.experimental.gen_ai.quantize.quantize_int4_preshuffle | ||
| (2) Our reference QAT version in `Int4WeightPreshuffledFakeQuantizer` | ||
| """ | ||
| from fbgemm_gpu.experimental.gen_ai.quantize import ( | ||
| int4_row_quantize, | ||
| pack_int4, | ||
| quantize_fp8_row, | ||
| quantize_int4_preshuffle, | ||
| ) | ||
|
|
||
| from torchao.quantization.quant_primitives import ( | ||
| _choose_scale_float8, | ||
| _quantize_affine_float8, | ||
| _quantize_affine_no_dtype_cast, | ||
| ) | ||
|
|
||
| group_size = 128 | ||
| x1 = torch.randn([128, 256], dtype=torch.bfloat16).cuda() | ||
| x2 = copy.deepcopy(x1) | ||
| x3 = copy.deepcopy(x1) | ||
|
|
||
| # (1) Just call `quantize_int4_preshuffle` | ||
| (q1, (scale1, _)) = quantize_int4_preshuffle(x1, group_size, dtype="fp8") | ||
|
|
||
| # (2) Call `quantize_int4_preshuffle` but skip packing and shuffling | ||
| (q2, _) = quantize_fp8_row(x2) | ||
| (q2, scale2) = int4_row_quantize(q2, group_size) | ||
|
Comment on lines
+2120
to
+2121
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this what happens in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah it happens here, not sure the reason: https://github.com/pytorch/FBGEMM/blob/3ca2859adc0ae24b1214ccacedff24ea5fce9be5/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py#L138 |
||
|
|
||
| # (3) Reference implementation for QAT without the dequantize | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for this, should we just initialize There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried doing that but it's tricky because we don't want the dequantize steps here (but they're needed in the fake quantizer) |
||
| fp8_scale = _choose_scale_float8( | ||
| x3, | ||
| (1, x3.shape[-1]), | ||
| torch.float8_e4m3fn, | ||
| hp_value_lb=1e-12, | ||
| ) | ||
| x3_fp8 = _quantize_affine_float8(x3, fp8_scale, torch.float8_e4m3fn) | ||
| x3_fp8 = x3_fp8.to(torch.float32) | ||
| x3_fp8_grouped = x3_fp8.view(x3_fp8.shape[0], -1, group_size) | ||
| max_abs = torch.amax(torch.abs(x3_fp8_grouped), dim=-1, keepdim=False) | ||
| scale = torch.clamp(max_abs / 8, min=1e-6) | ||
| zero_point = torch.zeros_like(scale) | ||
| q3 = _quantize_affine_no_dtype_cast( | ||
| x3_fp8, | ||
| (1, group_size), | ||
| scale, | ||
| zero_point, | ||
| quant_min=-8, | ||
| quant_max=7, | ||
| ) | ||
| scale3 = scale | ||
|
|
||
| def shuffle_and_pack(t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: | ||
| t = pack_int4(t.to(torch.int8)) | ||
| return torch.ops.fbgemm.preshuffle_i4(t, scale.to(torch.float8_e4m3fn))[0] | ||
|
|
||
| # First, sanity check that shuffle_and_pack(q2) == q1 | ||
| torch.testing.assert_close(q1, shuffle_and_pack(q2, scale2), atol=0, rtol=0) | ||
|
|
||
| # Now check q2 vs q3 with and without shuffle | ||
| sqnr_q2_q3 = compute_error(q2.to(torch.float32), q3.to(torch.float32)) | ||
| sqnr_q2_q3_preshuffle = compute_error( | ||
| shuffle_and_pack(q2, scale2).to(torch.float32), | ||
| shuffle_and_pack(q3, scale3).to(torch.float32), | ||
| ) | ||
| self.assertGreater(sqnr_q2_q3, 32) | ||
| self.assertGreater(sqnr_q2_q3_preshuffle, 32) | ||
|
|
||
| # Now check shuffle_and_pack(q3) vs q1 | ||
| sqnr_q1_q3_preshuffle = compute_error( | ||
| q1.to(torch.float32), | ||
| shuffle_and_pack(q3, scale3).to(torch.float32), | ||
| ) | ||
| self.assertGreater(sqnr_q1_q3_preshuffle, 32) | ||
|
|
||
|
|
||
| instantiate_parametrized_tests(TestQAT) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -219,6 +219,20 @@ def backward(ctx, gy: torch.Tensor) -> torch.Tensor: | |
| return gy | ||
|
|
||
|
|
||
| class _RoundToFloat8(torch.autograd.Function): | ||
| """ | ||
| Implementation of `tensor.to(float8_dtype)` with backward STE. | ||
| """ | ||
|
|
||
| @staticmethod | ||
| def forward(ctx, x: torch.Tensor, float8_dtype: torch.dtype) -> torch.Tensor: | ||
| return x.to(float8_dtype) | ||
|
|
||
| @staticmethod | ||
| def backward(ctx, gy: torch.Tensor) -> torch.Tensor: | ||
| return gy, None | ||
|
|
||
|
|
||
| # TODO: decide on if we want to allow custom quant_min/quant_max here | ||
| def _get_and_check_qmin_qmax(dtype, quant_min, quant_max): | ||
| """Get quant_min and quant_max args based on dtype and also verify bounds. | ||
|
|
@@ -2275,7 +2289,6 @@ def _quantize_affine_float8( | |
| tensor: torch.Tensor, | ||
| scale: torch.Tensor, | ||
| float8_dtype: torch.dtype = torch.float8_e4m3fn, | ||
| cast_to_float8_dtype: bool = True, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor. | ||
|
|
@@ -2288,9 +2301,7 @@ def _quantize_affine_float8( | |
| tensor_scaled = tensor_fp32 / scale_expanded | ||
| max_value = torch.finfo(float8_dtype).max | ||
| tensor_clamped = tensor_scaled.clamp(min=-max_value, max=max_value) | ||
| if cast_to_float8_dtype: | ||
| tensor_clamped = tensor_clamped.to(float8_dtype) | ||
| return tensor_clamped | ||
| return _RoundToFloat8.apply(tensor_clamped, float8_dtype) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so fake quant also do casting? I thought we don't do casting during fake quant There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think fp8 numerics is close enough with bf16/fp32 that we can do this during QAT. This helps mimic the convert numerics a bit closer There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fp8-int4 prepare vs convert sqnr drops from 22.375 to 19.25 without this |
||
|
|
||
|
|
||
| # TODO: don't register as custom op? | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trying to understand why this is not inf, is this because fake quant does not do dtype conversion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure, but this is enough to recover significant accuracy degradation in most cases. I did verify the qparams dtypes are also matching but haven't investigated further, may continue this later separately
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like this test is not related to the changes, since it's
Int8DynamicActivationInt4WeightConfig