Skip to content

Fix CI after quantize op change in PyTorch core #244

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,7 +1124,7 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype):
m_c = torch.compile(m, mode="max-autotune")
y_wo, (code,) = run_and_get_code(m_c, x)
sqnr = compute_error(y_ref, y_wo)
self.assertGreater(sqnr, 43.0)
self.assertGreaterEqual(sqnr, 42.75)
if device == "cuda":
self.assertTrue("mixed_mm" in code)

Expand Down
16 changes: 8 additions & 8 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
fake_quantize_per_token,
)
from torchao.quantization.quant_primitives import get_group_qparams_symmetric
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4


# TODO: put this in a common test utils file
Expand Down Expand Up @@ -58,7 +58,7 @@ def _get_qmin_qmax(self, n_bit: int):
qmax = 2 ** (n_bit - 1) - 1
return (qmin, qmax)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
def test_fake_quantize_per_channel_group(self):
n_bit = 4
(qmin, qmax) = self._get_qmin_qmax(n_bit)
Expand All @@ -84,7 +84,7 @@ def test_fake_quantize_per_channel_group(self):
)
torch.testing.assert_close(out, out_ptq, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
def test_fake_quantize_per_token(self):
(qmin, qmax) = self._get_qmin_qmax(8)

Expand Down Expand Up @@ -130,7 +130,7 @@ def _set_ptq_weight(
ptq_linear.scales = s
ptq_linear.zeros = zp

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
def test_qat_8da4w_linear(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATLinear
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
Expand All @@ -155,7 +155,7 @@ def test_qat_8da4w_linear(self):
ptq_out = ptq_linear(x2)
torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
def test_qat_8da4w_quantizer(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer
Expand Down Expand Up @@ -189,7 +189,7 @@ def test_qat_8da4w_quantizer(self):
for k in ptq_state_dict.keys():
torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
def test_qat_8da4w_quantizer_meta_weights(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer

Expand All @@ -201,7 +201,7 @@ def test_qat_8da4w_quantizer_meta_weights(self):
qat_model = qat_quantizer.prepare(m)
self.assertTrue(all(v.is_meta for v in qat_model.state_dict().values()))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
def test_qat_8da4w_quantizer_disable_fake_quant(self):
"""
Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward.
Expand Down Expand Up @@ -254,7 +254,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
qat_out2 = qat_model2(*x2)
torch.testing.assert_close(qat_out, qat_out2, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
"""
Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward.
Expand Down
2 changes: 1 addition & 1 deletion test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def test_quantize_activation_per_token_abs_max_zero_input(self):
quantized_ref, scale_ref = quantize_activation_per_token_absmax(input)


@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
def test_quantize_dequantize_group_sym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.SYMMETRIC
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/prototype/qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def forward(ctx, input, scales, zero_points, quant_min, quant_max):
# which rounds first before adding the zero points. However, this
# is what `quantize_per_channel_group` and `quantize_per_token`
# do and here we try to match that behavior as closely as possible.
q = input.div(scales).add(zero_points).round()
q = input.mul(1.0 / scales).add(zero_points).round()
dq = q.clamp(quant_min, quant_max).sub(zero_points).mul(scales)
# TODO: do we need this mask?
mask = torch.logical_and((q >= quant_min), (q <= quant_max))
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def quantize_affine(

if zero_point_domain == ZeroPointDomain.INT:
quant = torch.clamp(
torch.round(input / scale) + zero_point, quant_min, quant_max
torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max
).to(output_dtype)
else:
assert zero_point_domain == ZeroPointDomain.FLOAT
Expand Down
Loading