Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 137 additions & 4 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
)
from torchao.quantization.qat.fake_quantize_config import (
Float8FakeQuantizeConfig,
Int4WeightPreshuffledFakeQuantizeConfig,
IntxFakeQuantizeConfig,
)
from torchao.quantization.qat.fake_quantizer import (
Expand Down Expand Up @@ -1929,7 +1930,7 @@ def test_quantize_api_fp8_int4(self):
"""
self._test_quantize_api_against_ptq(
Float8DynamicActivationInt4WeightConfig(),
target_prepare_sqnr=12,
target_prepare_sqnr=22,
target_convert_sqnr=float("inf"),
)

Expand All @@ -1950,6 +1951,19 @@ def test_quantize_api_int4(self, version: int):
target_convert_sqnr=float("inf"),
)

@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
def test_quantize_api_int8_int4(self):
"""
Test the following:
quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="prepare"))
quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="convert"))
"""
self._test_quantize_api_against_ptq(
Int8DynamicActivationInt4WeightConfig(group_size=32),
target_prepare_sqnr=30,
Copy link
Contributor

@jerryzh168 jerryzh168 Sep 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trying to understand why this is not inf, is this because fake quant does not do dtype conversion?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure, but this is enough to recover significant accuracy degradation in most cases. I did verify the qparams dtypes are also matching but haven't investigated further, may continue this later separately

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like this test is not related to the changes, since it's Int8DynamicActivationInt4WeightConfig

target_convert_sqnr=float("inf"),
)

def test_infer_fp8_int4_config(self):
"""
Test that fake quantize configs are correctly inferred from
Expand All @@ -1964,10 +1978,9 @@ def test_infer_fp8_int4_config(self):
self.assertIsInstance(act_config, Float8FakeQuantizeConfig)
self.assertEqual(act_config.dtype, torch.float8_e4m3fn)
self.assertIsInstance(act_config.granularity, PerRow)
self.assertIsInstance(weight_config, IntxFakeQuantizeConfig)
self.assertEqual(weight_config.dtype, torch.int4)
self.assertIsInstance(weight_config, Int4WeightPreshuffledFakeQuantizeConfig)
self.assertEqual(weight_config.group_size, 128)
self.assertTrue(weight_config.is_symmetric)
self.assertEqual(weight_config.activation_dtype, torch.float8_e4m3fn)

def test_infer_int4_weight_only_config(self):
"""
Expand Down Expand Up @@ -2033,6 +2046,126 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool):
sqnr = compute_error(out, baseline_out).item()
self.assertGreater(sqnr, 24)

@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
@unittest.skipIf(
not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0"
)
def test_fbgemm_fp8_primitives(self):
"""
Compare numerics between:
(1) fbgemm_gpu.experimental.gen_ai.quantize.quantize_fp8_row
(2) Our reference QAT version in `Float8FakeQuantizer`
"""
from fbgemm_gpu.experimental.gen_ai.quantize import quantize_fp8_row

from torchao.quantization.quant_primitives import (
_choose_scale_float8,
_quantize_affine_float8,
)

x1 = torch.randn([128, 256], dtype=torch.bfloat16).cuda()
x2 = copy.deepcopy(x1)

# (1) Just call `quantize_fp8_row`
(q1, scale1) = quantize_fp8_row(x1)

# (2) Our reference implementation for QAT without the dequantize
scale2 = _choose_scale_float8(
x2,
(1, x2.shape[-1]),
torch.float8_e4m3fn,
hp_value_lb=1e-12,
)
q2 = _quantize_affine_float8(x2, scale2, torch.float8_e4m3fn)
sqnr = compute_error(q1.to(torch.float32), q2.to(torch.float32))
scale_sqnr = compute_error(
scale1.to(torch.float32).flatten(),
scale2.to(torch.float32).flatten(),
)
self.assertGreater(sqnr, 40)
self.assertGreater(scale_sqnr, 50)

@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
@unittest.skipIf(
not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0"
)
def test_fbgemm_int4_preshuffled_primitives(self):
"""
Compare numerics between:
(1) fbgemm_gpu.experimental.gen_ai.quantize.quantize_int4_preshuffle
(2) Our reference QAT version in `Int4WeightPreshuffledFakeQuantizer`
"""
from fbgemm_gpu.experimental.gen_ai.quantize import (
int4_row_quantize,
pack_int4,
quantize_fp8_row,
quantize_int4_preshuffle,
)

from torchao.quantization.quant_primitives import (
_choose_scale_float8,
_quantize_affine_float8,
_quantize_affine_no_dtype_cast,
)

group_size = 128
x1 = torch.randn([128, 256], dtype=torch.bfloat16).cuda()
x2 = copy.deepcopy(x1)
x3 = copy.deepcopy(x1)

# (1) Just call `quantize_int4_preshuffle`
(q1, (scale1, _)) = quantize_int4_preshuffle(x1, group_size, dtype="fp8")

# (2) Call `quantize_int4_preshuffle` but skip packing and shuffling
(q2, _) = quantize_fp8_row(x2)
(q2, scale2) = int4_row_quantize(q2, group_size)
Comment on lines +2120 to +2121
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this what happens in quantize_int4_preshuffle? why do we first quantize to fp8 then to int4 instead of just quantizing to int4?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


# (3) Reference implementation for QAT without the dequantize
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for this, should we just initialize Int4WeightPreshuffledFakeQuantizer can call it here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried doing that but it's tricky because we don't want the dequantize steps here (but they're needed in the fake quantizer)

fp8_scale = _choose_scale_float8(
x3,
(1, x3.shape[-1]),
torch.float8_e4m3fn,
hp_value_lb=1e-12,
)
x3_fp8 = _quantize_affine_float8(x3, fp8_scale, torch.float8_e4m3fn)
x3_fp8 = x3_fp8.to(torch.float32)
x3_fp8_grouped = x3_fp8.view(x3_fp8.shape[0], -1, group_size)
max_abs = torch.amax(torch.abs(x3_fp8_grouped), dim=-1, keepdim=False)
scale = torch.clamp(max_abs / 8, min=1e-6)
zero_point = torch.zeros_like(scale)
q3 = _quantize_affine_no_dtype_cast(
x3_fp8,
(1, group_size),
scale,
zero_point,
quant_min=-8,
quant_max=7,
)
scale3 = scale

def shuffle_and_pack(t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
t = pack_int4(t.to(torch.int8))
return torch.ops.fbgemm.preshuffle_i4(t, scale.to(torch.float8_e4m3fn))[0]

# First, sanity check that shuffle_and_pack(q2) == q1
torch.testing.assert_close(q1, shuffle_and_pack(q2, scale2), atol=0, rtol=0)

# Now check q2 vs q3 with and without shuffle
sqnr_q2_q3 = compute_error(q2.to(torch.float32), q3.to(torch.float32))
sqnr_q2_q3_preshuffle = compute_error(
shuffle_and_pack(q2, scale2).to(torch.float32),
shuffle_and_pack(q3, scale3).to(torch.float32),
)
self.assertGreater(sqnr_q2_q3, 32)
self.assertGreater(sqnr_q2_q3_preshuffle, 32)

# Now check shuffle_and_pack(q3) vs q1
sqnr_q1_q3_preshuffle = compute_error(
q1.to(torch.float32),
shuffle_and_pack(q3, scale3).to(torch.float32),
)
self.assertGreater(sqnr_q1_q3_preshuffle, 32)


instantiate_parametrized_tests(TestQAT)

Expand Down
24 changes: 21 additions & 3 deletions torchao/quantization/qat/fake_quantize_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,25 @@ def __post_init__(self):
)


@dataclass
class Int4WeightPreshuffledFakeQuantizeConfig(FakeQuantizeConfigBase):
"""
Config for pint4 weight fake quantization that targets the numerics in the following preshuffled kernel:
torch.ops.fbgemm.f8i4bf16_shuffled

Currently this only supports float8 input activations. It is expected to be used in conjunction with
:class:`~torchao.quantization.Float8DynamicActivationInt4WeightConfig`. In the future, we may extend
this to support bfloat16 as well.
"""

group_size: int = 128
activation_dtype: torch.dtype = e4m3_dtype

def __post_init__(self):
if self.activation_dtype != e4m3_dtype:
raise ValueError(f"Only {e4m3_dtype} activation is supported currently")


@dataclass
class IntxFakeQuantizeConfig(FakeQuantizeConfigBase):
"""
Expand Down Expand Up @@ -404,10 +423,9 @@ def _infer_fake_quantize_configs(
dtype=torch.float8_e4m3fn,
granularity=PerRow(),
)
weight_config = IntxFakeQuantizeConfig(
dtype=torch.int4,
weight_config = Int4WeightPreshuffledFakeQuantizeConfig(
group_size=128,
is_symmetric=True,
activation_dtype=e4m3_dtype,
)
elif isinstance(base_config, NVFP4InferenceConfig):
# Note: today the PTQ config does not allow the user to specify
Expand Down
66 changes: 63 additions & 3 deletions torchao/quantization/qat/fake_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torchao.quantization.granularity import (
PerAxis,
PerGroup,
PerRow,
PerToken,
)
from torchao.quantization.observer import get_block_size
Expand All @@ -20,6 +21,7 @@
MappingType,
_choose_scale_float8,
_dequantize_affine_float8,
_fake_quantize_affine,
_quantize_affine_float8,
_Round,
choose_qparams_affine,
Expand All @@ -33,6 +35,7 @@
from .fake_quantize_config import (
FakeQuantizeConfigBase,
Float8FakeQuantizeConfig,
Int4WeightPreshuffledFakeQuantizeConfig,
IntxFakeQuantizeConfig,
)
from .utils import (
Expand Down Expand Up @@ -65,6 +68,8 @@ def from_config(config: FakeQuantizeConfigBase) -> "FakeQuantizerBase":

if isinstance(config, IntxFakeQuantizeConfig):
return IntxFakeQuantizer(config)
elif isinstance(config, Int4WeightPreshuffledFakeQuantizeConfig):
return Int4WeightPreshuffledFakeQuantizer(config)
elif isinstance(config, Float8FakeQuantizeConfig):
return Float8FakeQuantizer(config)
elif isinstance(config, NVFP4FakeQuantizeConfig):
Expand Down Expand Up @@ -93,13 +98,68 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
hp_value_lb=self.config.hp_value_lb,
hp_value_ub=self.config.hp_value_ub,
)
q = _quantize_affine_float8(
x, scale, self.config.dtype, cast_to_float8_dtype=False
)
q = _quantize_affine_float8(x, scale, self.config.dtype)
dq = _dequantize_affine_float8(q, scale, original_dtype)
return dq


class Int4WeightPreshuffledFakeQuantizer(FakeQuantizerBase):
"""
Generic module for applying int4 fake quantization to a weight tensor,
targeting the following FBGEMM kernel:
torch.ops.fbgemm.f8i4bf16_shuffled
"""

def __init__(self, config: Int4WeightPreshuffledFakeQuantizeConfig):
super().__init__()
self.config = config
torch._C._log_api_usage_once(
"torchao.quantization.qat.Int4WeightPreshuffledFakeQuantizer"
)

def forward(self, w: torch.Tensor) -> torch.Tensor:
"""
Apply int4 fake quantization to the weight tensor, using the following as a reference:
https://github.com/pytorch/FBGEMM/blob/80cc48c4b2b7fcc579e53211fc8715a8592cbd2c/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py#L112

Currently, we expect the activations to always be rowwise float8.
"""
assert w.dim() == 2
assert self.config.activation_dtype == torch.float8_e4m3fn

# First quantize weights to fp8 per row
# This simulates the numerics of fbgemm_gpu.experimental.gen_ai.quantize.quantize_fp8_row
per_row_block_size = get_block_size(w.shape, PerRow())
fp8_scale = _choose_scale_float8(
w,
per_row_block_size,
torch.float8_e4m3fn,
hp_value_lb=1e-12,
)
w_fp8 = _quantize_affine_float8(w, fp8_scale, torch.float8_e4m3fn)
w_fp8 = _dequantize_affine_float8(w_fp8, fp8_scale, w.dtype)

# Now quantize to int4 per group
# This simulates the numerics of fbgemm_gpu.experimental.gen_ai.quantize.int4_row_quantize
eps = 1e-6
fbgemm_scale_quant_max = 8
w_fp8_grouped = w_fp8.view(w_fp8.shape[0], -1, self.config.group_size)
max_abs = torch.amax(torch.abs(w_fp8_grouped), dim=-1, keepdim=False)
scale = torch.clamp(max_abs / fbgemm_scale_quant_max, min=eps)
zero_point = torch.zeros_like(scale)
per_group_block_size = (1, self.config.group_size)
fq = _fake_quantize_affine(
w_fp8,
per_group_block_size,
scale,
zero_point,
quant_dtype=torch.int8,
quant_min=-8,
quant_max=7,
)
return fq.to(w.dtype)


class IntxFakeQuantizer(FakeQuantizerBase):
"""
Generic module for applying integer fake quantization to a tensor, as specified in the config.
Expand Down
19 changes: 15 additions & 4 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,20 @@ def backward(ctx, gy: torch.Tensor) -> torch.Tensor:
return gy


class _RoundToFloat8(torch.autograd.Function):
"""
Implementation of `tensor.to(float8_dtype)` with backward STE.
"""

@staticmethod
def forward(ctx, x: torch.Tensor, float8_dtype: torch.dtype) -> torch.Tensor:
return x.to(float8_dtype)

@staticmethod
def backward(ctx, gy: torch.Tensor) -> torch.Tensor:
return gy, None


# TODO: decide on if we want to allow custom quant_min/quant_max here
def _get_and_check_qmin_qmax(dtype, quant_min, quant_max):
"""Get quant_min and quant_max args based on dtype and also verify bounds.
Expand Down Expand Up @@ -2275,7 +2289,6 @@ def _quantize_affine_float8(
tensor: torch.Tensor,
scale: torch.Tensor,
float8_dtype: torch.dtype = torch.float8_e4m3fn,
cast_to_float8_dtype: bool = True,
) -> torch.Tensor:
"""
Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor.
Expand All @@ -2288,9 +2301,7 @@ def _quantize_affine_float8(
tensor_scaled = tensor_fp32 / scale_expanded
max_value = torch.finfo(float8_dtype).max
tensor_clamped = tensor_scaled.clamp(min=-max_value, max=max_value)
if cast_to_float8_dtype:
tensor_clamped = tensor_clamped.to(float8_dtype)
return tensor_clamped
return _RoundToFloat8.apply(tensor_clamped, float8_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so fake quant also do casting? I thought we don't do casting during fake quant

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think fp8 numerics is close enough with bf16/fp32 that we can do this during QAT. This helps mimic the convert numerics a bit closer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fp8-int4 prepare vs convert sqnr drops from 22.375 to 19.25 without this



# TODO: don't register as custom op?
Expand Down
Loading