diff --git a/torchao/experimental/ops/mps/test/test_quantizer.py b/torchao/experimental/ops/mps/test/test_quantizer.py new file mode 100644 index 000000000..5640a5493 --- /dev/null +++ b/torchao/experimental/ops/mps/test/test_quantizer.py @@ -0,0 +1,146 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional +import copy +import os +import sys + +import torch +import torchao_mps_ops +import unittest + +from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer +from torchao.experimental.quant_api import _quantize + + +def parameterized(test_cases): + def decorator(func): + def wrapper(self): + for case in test_cases: + with self.subTest(case=case): + func(self, *case) + + return wrapper + + return decorator + + +class TestUIntxWeightOnlyLinearQuantizer(unittest.TestCase): + cases = [(nbit,) for nbit in range(1, 8)] + + # Currently, the quantization code in quant_api.py only supports K values + # multiple of group_size. + # TODO(mcandales): Generalize the code in quant_api.py and add tests to + # cover values of K not multiple of group_size. + def _model_setup(self): + group_size = 32 + k0 = 96 + k1 = 224 + k2 = 160 + n = 47 + layers = [ + torch.nn.Linear(k0, k1, bias=False), + torch.nn.Linear(k1, k2, bias=False), + torch.nn.Linear(k2, n, bias=False), + ] + model = torch.nn.Sequential(*layers) + return model, group_size, k0, n + + def _quantize_model(self, model, precision, nbit, group_size): + quantizer = UIntxWeightOnlyLinearQuantizer( + device="mps", + precision=precision, + bitwidth=nbit, + groupsize=group_size, + ) + quantized_model = copy.deepcopy(model) + quantized_model = quantizer.quantize(quantized_model) + return quantized_model + + @parameterized(cases) + def test_export(self, nbit): + model, group_size, k0, n = self._model_setup() + m = 3 + activations = torch.randn(m, k0, dtype=torch.float32, device="mps") + + quantized_model = self._quantize_model(model, torch.float32, nbit, group_size) + exported = torch.export.export(quantized_model, (activations,)) + + for node in exported.graph.nodes: + if node.op == "call_function": + self.assertTrue( + str(node.target) + == f"torchao._linear_fp_act_{nbit}bit_weight.default" + ) + + @parameterized(cases) + def test_2d_output_device_and_shape(self, nbit): + model, group_size, k0, n = self._model_setup() + m = 3 + activations = torch.randn(m, k0, dtype=torch.float32, device="mps") + + quantized_model = self._quantize_model(model, torch.float32, nbit, group_size) + result = quantized_model(activations) + self.assertTrue(result.is_mps) + self.assertTrue(result.shape == (m, n)) + + @parameterized(cases) + def test_3d_output_device_and_shape(self, nbit): + model, group_size, k0, n = self._model_setup() + leading_shape = (3, 5) + activations = torch.randn(*leading_shape, k0, dtype=torch.float32, device="mps") + + quantized_model = self._quantize_model(model, torch.float32, nbit, group_size) + result = quantized_model(activations) + self.assertTrue(result.is_mps) + self.assertTrue(result.shape == (*leading_shape, n)) + + # TODO(mcandales): Consolidate with the reference impl in test_lowbit.py + def _reference_linear_lowbit_quant_weights(self, A, W, group_size, S, Z): + N = W.shape[0] + K = W.shape[1] + W = W.to(torch.float32) + scales = S.t().unsqueeze(2).repeat(1, 1, group_size).view(N, -1)[:, :K] + zeros = Z.t().unsqueeze(2).repeat(1, 1, group_size).view(N, -1)[:, :K] + W = scales * W + zeros + return torch.mm(A, W.t()) + + @parameterized(cases) + def test_accuracy(self, nbit): + group_size = 32 + m = 3 + n = 7 + k = 64 + with torch.no_grad(): + activations = torch.rand(m, k, dtype=torch.float32, device="mps") + model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)]) + quantized_model = self._quantize_model( + model, torch.float32, nbit, group_size + ) + result = quantized_model(activations) + + # Compute expected result + weight_cpu = model[0].weight.data + weight_qvals_cpu, weight_scales_cpu, weight_zeros_cpu = _quantize( + weight_cpu, group_size, nbit, True, torch.uint8 + ) + weight_scales_cpu = weight_scales_cpu.t() + weight_zeros_cpu = -weight_zeros_cpu.t() * weight_scales_cpu + expected = self._reference_linear_lowbit_quant_weights( + activations.cpu(), + weight_qvals_cpu, + group_size, + weight_scales_cpu, + weight_zeros_cpu, + ) + + # Compare results + torch.testing.assert_close(result.cpu(), expected, rtol=0.001, atol=0.001) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/experimental/quant_api.py b/torchao/experimental/quant_api.py index 1c04305d3..640acfa0b 100644 --- a/torchao/experimental/quant_api.py +++ b/torchao/experimental/quant_api.py @@ -25,10 +25,16 @@ logger.addHandler(handler) -def _quantize(vals: torch.Tensor, group_size: int, nbit: int, has_weight_zeros: bool): +def _quantize(vals: torch.Tensor, group_size: int, nbit: int, has_weight_zeros: bool, dtype=torch.int8): assert nbit >= 1 and nbit <= 8 - qmin = -(1 << (nbit - 1)) - qmax = (1 << (nbit - 1)) - 1 + if dtype == torch.int8: + qmin = -(1 << (nbit - 1)) + qmax = (1 << (nbit - 1)) - 1 + elif dtype == torch.uint8: + qmin = 0 + qmax = (1 << nbit) - 1 + else: + raise ValueError(f"Unsupported dtype {dtype}") n, k = vals.shape vals = vals.reshape(-1, group_size) @@ -51,7 +57,7 @@ def _quantize(vals: torch.Tensor, group_size: int, nbit: int, has_weight_zeros: zero_points=group_zeros, quant_min=qmin, quant_max=qmax, - dtype=torch.int8, + dtype=dtype, group_size=group_size, ) @@ -516,3 +522,113 @@ def apply(weight): ) return _get_linear_subclass_inserter(apply) + + +class UIntxWeightOnlyQuantizedLinear(nn.Module): + def __init__( + self, + pack_weight_op, + linear_op, + ): + super().__init__() + self._pack_weights_op = pack_weight_op + self._linear_op = linear_op + + def quantize_and_pack_weights(self, weights, nbit, group_size): + self.nbit = nbit + self.group_size = group_size + + weight_qvals, weight_scales, weight_zeros = _quantize( + weights, self.group_size, self.nbit, True, torch.uint8 + ) + weight_scales = torch.transpose_copy(weight_scales, 1, 0) + weight_zeros = torch.transpose_copy(weight_zeros, 1, 0) + self.weight_scales = weight_scales + self.weight_zeros = -weight_zeros * weight_scales + + self.packed_weights = self._pack_weights_op(weight_qvals.cpu()).to(device="mps") + + def forward(self, x): + assert x.dim() >= 2 + if x.dim() == 2: + return self._linear_op( + x, self.packed_weights, self.group_size, self.weight_scales, self.weight_zeros + ) + + lead_shape = x.shape[0:-1] + k = x.shape[-1] + n = self.weight_scales.shape[1] + return self._linear_op( + x.reshape(-1, k), self.packed_weights, self.group_size, self.weight_scales, self.weight_zeros + ).reshape(*lead_shape, n) + + +def _replace_linear_with_quantized_linear_mps(module: nn.Module, kwargs={}): + group_size = kwargs["group_size"] + nbit = kwargs["nbit"] + + assert not isinstance(module, nn.Linear) + assert nbit >= 1 and nbit <= 7 + + for name, child in module.named_children(): + if not isinstance(child, nn.Linear): + _replace_linear_with_quantized_linear_mps(child, kwargs) + else: + assert child.bias is None + qlinear = UIntxWeightOnlyQuantizedLinear( + pack_weight_op=getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit"), + linear_op=getattr( + torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight" + ), + ) + setattr(module, name, qlinear) + getattr(module, name).quantize_and_pack_weights( + child.weight, nbit, group_size + ) + + +class UIntxWeightOnlyLinearQuantizer: + def __init__( + self, + device, + precision, + *, + bitwidth: Optional[int] = None, + groupsize: Optional[int] = None, + ): + if device != "mps": + raise NotImplementedError( + "Only device=mps is currently supported in UIntxWeightOnlyLinearQuantizer" + ) + else: + self.device = device + + if precision not in [torch.float32, torch.float16, torch.bfloat16]: + raise NotImplementedError( + "Only precisions float32, float16 & bfloat16 are currently supported in UIntxWeightOnlyLinearQuantizer" + ) + else: + self.precision = precision + + if bitwidth is None: + self.bitwidth = 4 + logger.warning(f"bitwidth not specified, defaulting to {self.bitwidth}.") + else: + self.bitwidth = bitwidth + + if groupsize is None: + self.groupsize = 128 + logger.warning(f"groupsize not specified, defaulting to {self.groupsize}.") + else: + self.groupsize = groupsize + + def quantize(self, model: nn.Module) -> nn.Module: + model = model.to(self.device).to(self.precision) + _replace_linear_with_quantized_linear_mps( + model, + kwargs={ + "group_size": self.groupsize, + "nbit": self.bitwidth, + }, + ) + return model