Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor QAT to use tensor subclasses #585

Merged
merged 1 commit into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 108 additions & 39 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,22 @@

import torch
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
from torchao.dtypes import (
TensorCoreTiledLayoutType,
)
from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import (
AffineFakeQuantizedTensor,
)
from torchao.quantization.prototype.qat.utils import (
_choose_qparams_per_token_asymmetric,
_fake_quantize_per_channel_group,
_fake_quantize_per_token,
_GenericFakeQuantize,
_QAT_LINEAR_SUBCLASS_INPUT_PREHOOK,
)
from torchao.quantization.quant_api import (
int4_weight_only,
quantize_,
)
from torchao.quantization.quant_primitives import (
fake_quantize_affine,
Expand Down Expand Up @@ -190,6 +201,7 @@ def test_qat_8da4w_linear(self):
ptq_out = ptq_linear(x2)
torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0)

# TODO: compare against quantize_ API instead
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
Expand Down Expand Up @@ -217,13 +229,6 @@ def test_qat_8da4w_quantizer(self):
converted_out = converted_model(*x)
torch.testing.assert_close(ptq_out, converted_out, atol=0, rtol=0)

# Compare converted state dict
ptq_state_dict = ptq_model.state_dict()
converted_state_dict = converted_model.state_dict()
self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys())
for k in ptq_state_dict.keys():
torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer_meta_weights(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
Expand All @@ -236,6 +241,20 @@ def test_qat_8da4w_quantizer_meta_weights(self):
qat_model = qat_quantizer.prepare(m)
self.assertTrue(all(v.is_meta for v in qat_model.state_dict().values()))

def _copy_subclass_weights(
self,
nn_linear: torch.nn.Linear,
subclass_linear: AffineFakeQuantizedTensor,
):
nn_linear.weight = torch.nn.Parameter(subclass_linear.weight.original_tensor)

def _assert_matches_subclass_weights(
self,
nn_linear: torch.nn.Linear,
subclass_linear: AffineFakeQuantizedTensor,
):
torch.testing.assert_close(nn_linear.weight, subclass_linear.weight.original_tensor, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer_disable_fake_quant(self):
"""
Expand All @@ -247,6 +266,16 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
enable_8da4w_fake_quant,
)

def assert_fake_quant_enabled(m: torch.nn.Linear, enabled: bool):
self.assertTrue(isinstance(m.weight, AffineFakeQuantizedTensor))
self.assertEqual(m.weight.fake_quant_enabled, enabled)
self.assertTrue(hasattr(m, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK))
(_, handle) = getattr(m, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK)
andrewor14 marked this conversation as resolved.
Show resolved Hide resolved
if enabled:
self.assertIsNotNone(handle)
else:
self.assertIsNone(handle)

group_size = 16
torch.manual_seed(self.SEED)
m = M()
Expand All @@ -255,14 +284,14 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
qat_model = quantizer.prepare(m)
qat_model.apply(disable_8da4w_fake_quant)
self.assertFalse(qat_model.linear1._fake_quant_enabled)
self.assertFalse(qat_model.linear2._fake_quant_enabled)
self.assertFalse(qat_model.sub.linear._fake_quant_enabled)
assert_fake_quant_enabled(qat_model.linear1, enabled=False)
assert_fake_quant_enabled(qat_model.linear2, enabled=False)
assert_fake_quant_enabled(qat_model.sub.linear, enabled=False)

# Disabled fake quant is just a normal linear
m2.linear1.weight = qat_model.linear1.weight
m2.linear2.weight = qat_model.linear2.weight
m2.sub.linear.weight = qat_model.sub.linear.weight
self._copy_subclass_weights(m2.linear1, qat_model.linear1)
self._copy_subclass_weights(m2.linear2, qat_model.linear2)
self._copy_subclass_weights(m2.sub.linear, qat_model.sub.linear)
torch.manual_seed(self.SEED)
x = m.example_inputs()
x2 = copy.deepcopy(x)
Expand All @@ -272,16 +301,16 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):

# Renable fake quant
qat_model.apply(enable_8da4w_fake_quant)
self.assertTrue(qat_model.linear1._fake_quant_enabled)
self.assertTrue(qat_model.linear2._fake_quant_enabled)
self.assertTrue(qat_model.sub.linear._fake_quant_enabled)
assert_fake_quant_enabled(qat_model.linear1, enabled=True)
assert_fake_quant_enabled(qat_model.linear2, enabled=True)
assert_fake_quant_enabled(qat_model.sub.linear, enabled=True)

# Fake quant should be applied as normal
quantizer2 = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
qat_model2 = quantizer2.prepare(m3)
qat_model2.linear1.weight = qat_model.linear1.weight
qat_model2.linear2.weight = qat_model.linear2.weight
qat_model2.sub.linear.weight = qat_model.sub.linear.weight
qat_model2.linear1.weight.original_tensor = qat_model.linear1.weight.original_tensor
qat_model2.linear2.weight.original_tensor = qat_model.linear2.weight.original_tensor
qat_model2.sub.linear.weight.original_tensor = qat_model.sub.linear.weight.original_tensor
torch.manual_seed(self.SEED)
x = m.example_inputs()
x2 = copy.deepcopy(x)
Expand All @@ -306,9 +335,9 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
qat_model = quantizer.prepare(m)
qat_model.apply(disable_8da4w_fake_quant)
nn_model.linear1.weight = qat_model.linear1.weight
nn_model.linear2.weight = qat_model.linear2.weight
nn_model.sub.linear.weight = qat_model.sub.linear.weight
self._copy_subclass_weights(nn_model.linear1, qat_model.linear1)
self._copy_subclass_weights(nn_model.linear2, qat_model.linear2)
self._copy_subclass_weights(nn_model.sub.linear, qat_model.sub.linear)

# Simulate training for both models
optimizer1 = torch.optim.SGD(nn_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
Expand All @@ -330,9 +359,55 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
optimizer2.step()

# After 1 training step, weights should match exactly
torch.testing.assert_close(nn_model.linear1.weight, qat_model.linear1.weight, atol=0, rtol=0)
torch.testing.assert_close(nn_model.linear2.weight, qat_model.linear2.weight, atol=0, rtol=0)
torch.testing.assert_close(nn_model.sub.linear.weight, qat_model.sub.linear.weight, atol=0, rtol=0)
self._assert_matches_subclass_weights(nn_model.linear1, qat_model.linear1)
self._assert_matches_subclass_weights(nn_model.linear2, qat_model.linear2)
self._assert_matches_subclass_weights(nn_model.sub.linear, qat_model.sub.linear)

def _test_qat_quantized_gradients(self, quantizer):
"""
Test that QAT produces gradients in the backward pass.
"""
num_steps = 10
torch.manual_seed(self.SEED)
m = M()
model = quantizer.prepare(m)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
loss_fn = torch.nn.CrossEntropyLoss()

# Simulate training
current_step = 0
last_linear1_grad = None
last_linear2_grad = None
last_sub_linear_grad = None
while current_step < num_steps:
example_inputs = model.example_inputs()
target = torch.randn(1, 512).float()
output = model(*example_inputs)
loss = loss_fn(output, target)
loss.backward()
# assert each linear grad is updated
new_linear1_grad = model.linear1.weight.grad
new_linear2_grad = model.linear2.weight.grad
new_sub_linear_grad = model.sub.linear.weight.grad
self.assertIsNotNone(new_linear1_grad)
self.assertIsNotNone(new_linear2_grad)
self.assertIsNotNone(new_sub_linear_grad)
if current_step > 0:
self.assertFalse(torch.equal(last_linear1_grad, new_linear1_grad))
self.assertFalse(torch.equal(last_linear2_grad, new_linear2_grad))
self.assertFalse(torch.equal(last_sub_linear_grad, new_sub_linear_grad))
last_linear1_grad = new_linear1_grad
last_linear2_grad = new_linear2_grad
last_sub_linear_grad = new_sub_linear_grad
optimizer.zero_grad()
optimizer.step()
current_step += 1

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer_gradients(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=16)
self._test_qat_quantized_gradients(quantizer)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_generic_fake_quantize(self):
Expand All @@ -353,7 +428,7 @@ def test_qat_generic_fake_quantize(self):
block_size = (1, ao_input.shape[-1])
ao_s = copy.deepcopy(py_s)
ao_zp = copy.deepcopy(py_zp)
ao_out = _GenericFakeQuantize.apply(ao_input, ao_s, ao_zp, qmin, qmax, block_size)
ao_out = _GenericFakeQuantize.apply(ao_input, block_size, ao_s, ao_zp, qmin, qmax)
ao_out.sum().backward()

torch.testing.assert_close(py_out, ao_out, atol=0, rtol=0)
Expand All @@ -373,10 +448,7 @@ def _assert_close_4w(self, val, ref):
print(mean_err)
self.assertTrue(mean_err < 0.05)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
# TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517
@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 doesn't work for 2.5+ right now")
def test_qat_4w_primitives(self):
n_bit = 4
group_size = 32
Expand Down Expand Up @@ -464,11 +536,9 @@ def test_qat_4w_quantizer(self):
qat_quantizer = Int4WeightOnlyQATQuantizer(
groupsize=group_size, inner_k_tiles=inner_k_tiles,
)
ptq_quantizer = Int4WeightOnlyQuantizer(
groupsize=group_size, inner_k_tiles=inner_k_tiles,
)
qat_model = qat_quantizer.prepare(m)
ptq_model = ptq_quantizer.quantize(m2)
ptq_model = m2
quantize_(ptq_model, int4_weight_only(group_size, TensorCoreTiledLayoutType(inner_k_tiles)))

# Compare model values
torch.manual_seed(self.SEED)
Expand All @@ -483,12 +553,11 @@ def test_qat_4w_quantizer(self):
converted_out = converted_model(*x)
torch.testing.assert_close(converted_out, ptq_out, atol=0, rtol=0)

# Compare converted state dict
ptq_state_dict = ptq_model.state_dict()
converted_state_dict = converted_model.state_dict()
self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys())
for k in ptq_state_dict.keys():
torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_4w_quantizer_gradients(self):
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
quantizer = Int4WeightOnlyQATQuantizer(groupsize=32, inner_k_tiles=8)
self._test_qat_quantized_gradients(quantizer)


if __name__ == "__main__":
Expand Down
4 changes: 4 additions & 0 deletions torchao/quantization/prototype/qat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
disable_8da4w_fake_quant,
enable_4w_fake_quant,
enable_8da4w_fake_quant,
int4_weight_only_fake_quantize,
int8_dynamic_activation_int4_weight_fake_quantize,
Int4WeightOnlyQATQuantizer,
Int8DynActInt4WeightQATQuantizer,
Int8DynActInt4WeightQATLinear,
Expand All @@ -13,6 +15,8 @@
"disable_8da4w_fake_quant",
"enable_4w_fake_quant",
"enable_8da4w_fake_quant",
"int4_weight_only_fake_quantize",
"int8_dynamic_activation_int4_weight_fake_quantize",
"Int4WeightOnlyQATQuantizer",
"Int8DynActInt4WeightQATQuantizer",
"Int8DynActInt4WeightQATLinear",
Expand Down
Loading
Loading