Skip to content

Commit

Permalink
Refactor QAT to use tensor subclasses
Browse files Browse the repository at this point in the history
This commit refactors QAT to use tensor subclasses. This is motivated by the general move towards tensor subclasses in torchao for better composability with other subclasses like DTensors. To achieve this, we introduce `AffineFakeQuantizedTensor`, which is analogous to `AffineQuantizedTensor` but applies fake quantization instead and requires gradient updates.

`AffineFakeQuantizedTensor` wraps the original weight or input activation tensor and applies fake quantize dynamically only when the linear function is called. Gradients only flow to the outer tensor (`AffineFakeQuantizedTensor`) and never to the inner tensor. For weights, the outer tensor is also a `torch.nn.Parameter`, and gradient updates received by the outer tensor are then passed to the inner tensor through ops like `aten.add_` and `aten.mul_`.

An important difference between the PTQ and the QAT flows is how input activation subclasses are inserted. For QAT, we use the nn.module `forward_pre_hook` instead of relying on another subclass `LinearActivationQuantizedTensor` that wraps the weight subclass. The problem with the old PTQ approach is it can create subclasses under `__torch_dispatch__`, which runs below autograd and so the created subclasses cannot have gradients, so it was difficult to get the gradients to flow correctly in such cases. It's also not super intuitive because quantizing input activation needs to go through the weights. In the new approach used by QAT, we instead register a `forward_pre_hook` that wraps the input activations before each call to forward. This approach is also motivated by how [DTensor wraps their subclasses ](https://github.com/pytorch/pytorch/blob/844103197d3e8cf6b4b59176e473365113f4f962/torch/distributed/tensor/parallel/style.py#L521).

- [x] Add AffineFakeQuantizedTensor
- [x] Add support for int4 weight only fake quantize
- [x] Add support for int8 dynamic activations + int4 weight fake quantize (8da4w)
- [x] Add prepare and convert path to int4 QAT quantizer
- [x] Add prepare and convert path to 8da4w QAT quantizer
- [x] Support enabling and disabling fake quant dynamically
- [x] Support `__repr__` in AffineFakeQuantizedTensor
- [x] Fix backward pass for int4 weight only
- [x] Fix backward pass for int8 dynamic activations + int4 weight
  • Loading branch information
andrewor14 committed Aug 19, 2024
1 parent b523f9f commit 61bd8d3
Show file tree
Hide file tree
Showing 6 changed files with 670 additions and 142 deletions.
147 changes: 108 additions & 39 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,22 @@

import torch
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
from torchao.dtypes import (
TensorCoreTiledLayoutType,
)
from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import (
AffineFakeQuantizedTensor,
)
from torchao.quantization.prototype.qat.utils import (
_choose_qparams_per_token_asymmetric,
_fake_quantize_per_channel_group,
_fake_quantize_per_token,
_GenericFakeQuantize,
_QAT_LINEAR_SUBCLASS_INPUT_PREHOOK,
)
from torchao.quantization.quant_api import (
int4_weight_only,
quantize_,
)
from torchao.quantization.quant_primitives import (
fake_quantize_affine,
Expand Down Expand Up @@ -190,6 +201,7 @@ def test_qat_8da4w_linear(self):
ptq_out = ptq_linear(x2)
torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0)

# TODO: compare against quantize_ API instead
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
Expand Down Expand Up @@ -217,13 +229,6 @@ def test_qat_8da4w_quantizer(self):
converted_out = converted_model(*x)
torch.testing.assert_close(ptq_out, converted_out, atol=0, rtol=0)

# Compare converted state dict
ptq_state_dict = ptq_model.state_dict()
converted_state_dict = converted_model.state_dict()
self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys())
for k in ptq_state_dict.keys():
torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer_meta_weights(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
Expand All @@ -236,6 +241,20 @@ def test_qat_8da4w_quantizer_meta_weights(self):
qat_model = qat_quantizer.prepare(m)
self.assertTrue(all(v.is_meta for v in qat_model.state_dict().values()))

def _copy_subclass_weights(
self,
nn_linear: torch.nn.Linear,
subclass_linear: AffineFakeQuantizedTensor,
):
nn_linear.weight = torch.nn.Parameter(subclass_linear.weight.original_tensor)

def _assert_matches_subclass_weights(
self,
nn_linear: torch.nn.Linear,
subclass_linear: AffineFakeQuantizedTensor,
):
torch.testing.assert_close(nn_linear.weight, subclass_linear.weight.original_tensor, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer_disable_fake_quant(self):
"""
Expand All @@ -247,6 +266,16 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
enable_8da4w_fake_quant,
)

def assert_fake_quant_enabled(m: torch.nn.Linear, enabled: bool):
self.assertTrue(isinstance(m.weight, AffineFakeQuantizedTensor))
self.assertEqual(m.weight.fake_quant_enabled, enabled)
self.assertTrue(hasattr(m, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK))
(_, handle) = getattr(m, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK)
if enabled:
self.assertIsNotNone(handle)
else:
self.assertIsNone(handle)

group_size = 16
torch.manual_seed(self.SEED)
m = M()
Expand All @@ -255,14 +284,14 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
qat_model = quantizer.prepare(m)
qat_model.apply(disable_8da4w_fake_quant)
self.assertFalse(qat_model.linear1._fake_quant_enabled)
self.assertFalse(qat_model.linear2._fake_quant_enabled)
self.assertFalse(qat_model.sub.linear._fake_quant_enabled)
assert_fake_quant_enabled(qat_model.linear1, enabled=False)
assert_fake_quant_enabled(qat_model.linear2, enabled=False)
assert_fake_quant_enabled(qat_model.sub.linear, enabled=False)

# Disabled fake quant is just a normal linear
m2.linear1.weight = qat_model.linear1.weight
m2.linear2.weight = qat_model.linear2.weight
m2.sub.linear.weight = qat_model.sub.linear.weight
self._copy_subclass_weights(m2.linear1, qat_model.linear1)
self._copy_subclass_weights(m2.linear2, qat_model.linear2)
self._copy_subclass_weights(m2.sub.linear, qat_model.sub.linear)
torch.manual_seed(self.SEED)
x = m.example_inputs()
x2 = copy.deepcopy(x)
Expand All @@ -272,16 +301,16 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):

# Renable fake quant
qat_model.apply(enable_8da4w_fake_quant)
self.assertTrue(qat_model.linear1._fake_quant_enabled)
self.assertTrue(qat_model.linear2._fake_quant_enabled)
self.assertTrue(qat_model.sub.linear._fake_quant_enabled)
assert_fake_quant_enabled(qat_model.linear1, enabled=True)
assert_fake_quant_enabled(qat_model.linear2, enabled=True)
assert_fake_quant_enabled(qat_model.sub.linear, enabled=True)

# Fake quant should be applied as normal
quantizer2 = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
qat_model2 = quantizer2.prepare(m3)
qat_model2.linear1.weight = qat_model.linear1.weight
qat_model2.linear2.weight = qat_model.linear2.weight
qat_model2.sub.linear.weight = qat_model.sub.linear.weight
qat_model2.linear1.weight.original_tensor = qat_model.linear1.weight.original_tensor
qat_model2.linear2.weight.original_tensor = qat_model.linear2.weight.original_tensor
qat_model2.sub.linear.weight.original_tensor = qat_model.sub.linear.weight.original_tensor
torch.manual_seed(self.SEED)
x = m.example_inputs()
x2 = copy.deepcopy(x)
Expand All @@ -306,9 +335,9 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
qat_model = quantizer.prepare(m)
qat_model.apply(disable_8da4w_fake_quant)
nn_model.linear1.weight = qat_model.linear1.weight
nn_model.linear2.weight = qat_model.linear2.weight
nn_model.sub.linear.weight = qat_model.sub.linear.weight
self._copy_subclass_weights(nn_model.linear1, qat_model.linear1)
self._copy_subclass_weights(nn_model.linear2, qat_model.linear2)
self._copy_subclass_weights(nn_model.sub.linear, qat_model.sub.linear)

# Simulate training for both models
optimizer1 = torch.optim.SGD(nn_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
Expand All @@ -330,9 +359,55 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
optimizer2.step()

# After 1 training step, weights should match exactly
torch.testing.assert_close(nn_model.linear1.weight, qat_model.linear1.weight, atol=0, rtol=0)
torch.testing.assert_close(nn_model.linear2.weight, qat_model.linear2.weight, atol=0, rtol=0)
torch.testing.assert_close(nn_model.sub.linear.weight, qat_model.sub.linear.weight, atol=0, rtol=0)
self._assert_matches_subclass_weights(nn_model.linear1, qat_model.linear1)
self._assert_matches_subclass_weights(nn_model.linear2, qat_model.linear2)
self._assert_matches_subclass_weights(nn_model.sub.linear, qat_model.sub.linear)

def _test_qat_quantized_gradients(self, quantizer):
"""
Test that QAT produces gradients in the backward pass.
"""
num_steps = 10
torch.manual_seed(self.SEED)
m = M()
model = quantizer.prepare(m)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
loss_fn = torch.nn.CrossEntropyLoss()

# Simulate training
current_step = 0
last_linear1_grad = None
last_linear2_grad = None
last_sub_linear_grad = None
while current_step < num_steps:
example_inputs = model.example_inputs()
target = torch.randn(1, 512).float()
output = model(*example_inputs)
loss = loss_fn(output, target)
loss.backward()
# assert each linear grad is updated
new_linear1_grad = model.linear1.weight.grad
new_linear2_grad = model.linear2.weight.grad
new_sub_linear_grad = model.sub.linear.weight.grad
self.assertIsNotNone(new_linear1_grad)
self.assertIsNotNone(new_linear2_grad)
self.assertIsNotNone(new_sub_linear_grad)
if current_step > 0:
self.assertFalse(torch.equal(last_linear1_grad, new_linear1_grad))
self.assertFalse(torch.equal(last_linear2_grad, new_linear2_grad))
self.assertFalse(torch.equal(last_sub_linear_grad, new_sub_linear_grad))
last_linear1_grad = new_linear1_grad
last_linear2_grad = new_linear2_grad
last_sub_linear_grad = new_sub_linear_grad
optimizer.zero_grad()
optimizer.step()
current_step += 1

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer_gradients(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=16)
self._test_qat_quantized_gradients(quantizer)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_generic_fake_quantize(self):
Expand All @@ -353,7 +428,7 @@ def test_qat_generic_fake_quantize(self):
block_size = (1, ao_input.shape[-1])
ao_s = copy.deepcopy(py_s)
ao_zp = copy.deepcopy(py_zp)
ao_out = _GenericFakeQuantize.apply(ao_input, ao_s, ao_zp, qmin, qmax, block_size)
ao_out = _GenericFakeQuantize.apply(ao_input, block_size, ao_s, ao_zp, qmin, qmax)
ao_out.sum().backward()

torch.testing.assert_close(py_out, ao_out, atol=0, rtol=0)
Expand All @@ -373,10 +448,7 @@ def _assert_close_4w(self, val, ref):
print(mean_err)
self.assertTrue(mean_err < 0.05)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
# TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517
@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 doesn't work for 2.5+ right now")
def test_qat_4w_primitives(self):
n_bit = 4
group_size = 32
Expand Down Expand Up @@ -464,11 +536,9 @@ def test_qat_4w_quantizer(self):
qat_quantizer = Int4WeightOnlyQATQuantizer(
groupsize=group_size, inner_k_tiles=inner_k_tiles,
)
ptq_quantizer = Int4WeightOnlyQuantizer(
groupsize=group_size, inner_k_tiles=inner_k_tiles,
)
qat_model = qat_quantizer.prepare(m)
ptq_model = ptq_quantizer.quantize(m2)
ptq_model = m2
quantize_(ptq_model, int4_weight_only(group_size, TensorCoreTiledLayoutType(inner_k_tiles)))

# Compare model values
torch.manual_seed(self.SEED)
Expand All @@ -483,12 +553,11 @@ def test_qat_4w_quantizer(self):
converted_out = converted_model(*x)
torch.testing.assert_close(converted_out, ptq_out, atol=0, rtol=0)

# Compare converted state dict
ptq_state_dict = ptq_model.state_dict()
converted_state_dict = converted_model.state_dict()
self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys())
for k in ptq_state_dict.keys():
torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_4w_quantizer_gradients(self):
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
quantizer = Int4WeightOnlyQATQuantizer(groupsize=32, inner_k_tiles=8)
self._test_qat_quantized_gradients(quantizer)


if __name__ == "__main__":
Expand Down
4 changes: 4 additions & 0 deletions torchao/quantization/prototype/qat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
disable_8da4w_fake_quant,
enable_4w_fake_quant,
enable_8da4w_fake_quant,
int4_weight_only_fake_quantize,
int8_dynamic_activation_int4_weight_fake_quantize,
Int4WeightOnlyQATQuantizer,
Int8DynActInt4WeightQATQuantizer,
Int8DynActInt4WeightQATLinear,
Expand All @@ -13,6 +15,8 @@
"disable_8da4w_fake_quant",
"enable_4w_fake_quant",
"enable_8da4w_fake_quant",
"int4_weight_only_fake_quantize",
"int8_dynamic_activation_int4_weight_fake_quantize",
"Int4WeightOnlyQATQuantizer",
"Int8DynActInt4WeightQATQuantizer",
"Int8DynActInt4WeightQATLinear",
Expand Down
Loading

0 comments on commit 61bd8d3

Please sign in to comment.