-
Notifications
You must be signed in to change notification settings - Fork 171
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
intx weight only linear quantizer for mps #1192
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Optional | ||
import copy | ||
import os | ||
import sys | ||
|
||
import torch | ||
import torchao_mps_ops | ||
import unittest | ||
|
||
from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer | ||
from torchao.experimental.quant_api import _quantize | ||
|
||
|
||
def parameterized(test_cases): | ||
def decorator(func): | ||
def wrapper(self): | ||
for case in test_cases: | ||
with self.subTest(case=case): | ||
func(self, *case) | ||
|
||
return wrapper | ||
|
||
return decorator | ||
|
||
|
||
class TestUIntxWeightOnlyLinearQuantizer(unittest.TestCase): | ||
cases = [(nbit,) for nbit in range(1, 8)] | ||
|
||
# Currently, the quantization code in quant_api.py only supports K values | ||
# multiple of group_size. | ||
# TODO(mcandales): Generalize the code in quant_api.py and add tests to | ||
# cover values of K not multiple of group_size. | ||
def _model_setup(self): | ||
group_size = 32 | ||
k0 = 96 | ||
k1 = 224 | ||
k2 = 160 | ||
n = 47 | ||
layers = [ | ||
torch.nn.Linear(k0, k1, bias=False), | ||
torch.nn.Linear(k1, k2, bias=False), | ||
torch.nn.Linear(k2, n, bias=False), | ||
] | ||
model = torch.nn.Sequential(*layers) | ||
return model, group_size, k0, n | ||
|
||
def _quantize_model(self, model, precision, nbit, group_size): | ||
quantizer = UIntxWeightOnlyLinearQuantizer( | ||
device="mps", | ||
precision=precision, | ||
bitwidth=nbit, | ||
groupsize=group_size, | ||
) | ||
quantized_model = copy.deepcopy(model) | ||
quantized_model = quantizer.quantize(quantized_model) | ||
return quantized_model | ||
|
||
@parameterized(cases) | ||
def test_export(self, nbit): | ||
model, group_size, k0, n = self._model_setup() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are all tests doing only group size of 32? If so we should test other group sizes as well including those that should result in exception. |
||
m = 3 | ||
activations = torch.randn(m, k0, dtype=torch.float32, device="mps") | ||
|
||
quantized_model = self._quantize_model(model, torch.float32, nbit, group_size) | ||
exported = torch.export.export(quantized_model, (activations,)) | ||
|
||
for node in exported.graph.nodes: | ||
if node.op == "call_function": | ||
self.assertTrue( | ||
str(node.target) | ||
== f"torchao._linear_fp_act_{nbit}bit_weight.default" | ||
) | ||
|
||
@parameterized(cases) | ||
def test_2d_output_device_and_shape(self, nbit): | ||
model, group_size, k0, n = self._model_setup() | ||
m = 3 | ||
activations = torch.randn(m, k0, dtype=torch.float32, device="mps") | ||
|
||
quantized_model = self._quantize_model(model, torch.float32, nbit, group_size) | ||
result = quantized_model(activations) | ||
manuelcandales marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.assertTrue(result.is_mps) | ||
self.assertTrue(result.shape == (m, n)) | ||
|
||
@parameterized(cases) | ||
def test_3d_output_device_and_shape(self, nbit): | ||
model, group_size, k0, n = self._model_setup() | ||
leading_shape = (3, 5) | ||
activations = torch.randn(*leading_shape, k0, dtype=torch.float32, device="mps") | ||
|
||
quantized_model = self._quantize_model(model, torch.float32, nbit, group_size) | ||
result = quantized_model(activations) | ||
self.assertTrue(result.is_mps) | ||
self.assertTrue(result.shape == (*leading_shape, n)) | ||
|
||
# TODO(mcandales): Consolidate with the reference impl in test_lowbit.py | ||
def _reference_linear_lowbit_quant_weights(self, A, W, group_size, S, Z): | ||
N = W.shape[0] | ||
K = W.shape[1] | ||
W = W.to(torch.float32) | ||
scales = S.t().unsqueeze(2).repeat(1, 1, group_size).view(N, -1)[:, :K] | ||
zeros = Z.t().unsqueeze(2).repeat(1, 1, group_size).view(N, -1)[:, :K] | ||
W = scales * W + zeros | ||
return torch.mm(A, W.t()) | ||
|
||
@parameterized(cases) | ||
def test_accuracy(self, nbit): | ||
group_size = 32 | ||
m = 3 | ||
n = 7 | ||
k = 64 | ||
with torch.no_grad(): | ||
activations = torch.rand(m, k, dtype=torch.float32, device="mps") | ||
model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)]) | ||
quantized_model = self._quantize_model( | ||
model, torch.float32, nbit, group_size | ||
) | ||
result = quantized_model(activations) | ||
|
||
# Compute expected result | ||
weight_cpu = model[0].weight.data | ||
weight_qvals_cpu, weight_scales_cpu, weight_zeros_cpu = _quantize( | ||
weight_cpu, group_size, nbit, True, torch.uint8 | ||
) | ||
weight_scales_cpu = weight_scales_cpu.t() | ||
weight_zeros_cpu = -weight_zeros_cpu.t() * weight_scales_cpu | ||
expected = self._reference_linear_lowbit_quant_weights( | ||
activations.cpu(), | ||
weight_qvals_cpu, | ||
group_size, | ||
weight_scales_cpu, | ||
weight_zeros_cpu, | ||
) | ||
|
||
# Compare results | ||
torch.testing.assert_close(result.cpu(), expected, rtol=0.001, atol=0.001) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -25,10 +25,16 @@ | |||||
logger.addHandler(handler) | ||||||
|
||||||
|
||||||
def _quantize(vals: torch.Tensor, group_size: int, nbit: int, has_weight_zeros: bool): | ||||||
def _quantize(vals: torch.Tensor, group_size: int, nbit: int, has_weight_zeros: bool, dtype=torch.int8): | ||||||
assert nbit >= 1 and nbit <= 8 | ||||||
qmin = -(1 << (nbit - 1)) | ||||||
qmax = (1 << (nbit - 1)) - 1 | ||||||
if dtype == torch.int8: | ||||||
qmin = -(1 << (nbit - 1)) | ||||||
qmax = (1 << (nbit - 1)) - 1 | ||||||
elif dtype == torch.uint8: | ||||||
qmin = 0 | ||||||
qmax = (1 << nbit) - 1 | ||||||
Comment on lines
+30
to
+35
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of overloading dtype=int8 to convey signed vs. unsigned, can you just do |
||||||
else: | ||||||
raise ValueError(f"Unsupported dtype {dtype}") | ||||||
|
||||||
n, k = vals.shape | ||||||
vals = vals.reshape(-1, group_size) | ||||||
|
@@ -51,7 +57,7 @@ def _quantize(vals: torch.Tensor, group_size: int, nbit: int, has_weight_zeros: | |||||
zero_points=group_zeros, | ||||||
quant_min=qmin, | ||||||
quant_max=qmax, | ||||||
dtype=torch.int8, | ||||||
dtype=dtype, | ||||||
group_size=group_size, | ||||||
) | ||||||
|
||||||
|
@@ -516,3 +522,113 @@ def apply(weight): | |||||
) | ||||||
|
||||||
return _get_linear_subclass_inserter(apply) | ||||||
|
||||||
|
||||||
class UIntxWeightOnlyQuantizedLinear(nn.Module): | ||||||
def __init__( | ||||||
self, | ||||||
pack_weight_op, | ||||||
linear_op, | ||||||
): | ||||||
super().__init__() | ||||||
self._pack_weights_op = pack_weight_op | ||||||
self._linear_op = linear_op | ||||||
|
||||||
def quantize_and_pack_weights(self, weights, nbit, group_size): | ||||||
self.nbit = nbit | ||||||
self.group_size = group_size | ||||||
|
||||||
weight_qvals, weight_scales, weight_zeros = _quantize( | ||||||
manuelcandales marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
weights, self.group_size, self.nbit, True, torch.uint8 | ||||||
) | ||||||
weight_scales = torch.transpose_copy(weight_scales, 1, 0) | ||||||
weight_zeros = torch.transpose_copy(weight_zeros, 1, 0) | ||||||
self.weight_scales = weight_scales | ||||||
self.weight_zeros = -weight_zeros * weight_scales | ||||||
|
||||||
self.packed_weights = self._pack_weights_op(weight_qvals.cpu()).to(device="mps") | ||||||
|
||||||
def forward(self, x): | ||||||
manuelcandales marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
assert x.dim() >= 2 | ||||||
if x.dim() == 2: | ||||||
return self._linear_op( | ||||||
x, self.packed_weights, self.group_size, self.weight_scales, self.weight_zeros | ||||||
) | ||||||
|
||||||
lead_shape = x.shape[0:-1] | ||||||
k = x.shape[-1] | ||||||
n = self.weight_scales.shape[1] | ||||||
return self._linear_op( | ||||||
x.reshape(-1, k), self.packed_weights, self.group_size, self.weight_scales, self.weight_zeros | ||||||
).reshape(*lead_shape, n) | ||||||
|
||||||
|
||||||
def _replace_linear_with_quantized_linear_mps(module: nn.Module, kwargs={}): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These function can be factored out in a way that is agnostic of IntxWeightOnlyQuantizedLinear vs the other one @metascroy added. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't want to touch code used by the lowbit cpu kernels in this PR There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you at least add todo to the effect |
||||||
group_size = kwargs["group_size"] | ||||||
nbit = kwargs["nbit"] | ||||||
|
||||||
assert not isinstance(module, nn.Linear) | ||||||
assert nbit >= 1 and nbit <= 7 | ||||||
|
||||||
for name, child in module.named_children(): | ||||||
if not isinstance(child, nn.Linear): | ||||||
_replace_linear_with_quantized_linear_mps(child, kwargs) | ||||||
else: | ||||||
assert child.bias is None | ||||||
qlinear = UIntxWeightOnlyQuantizedLinear( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you have groupsize restriction? If so where is that asserted? I would have expected that groupsize will be constructor arg so that constructor can check and throw if the quantized linear supports it or not. I dont exactly like exceptions in this scenario but maybe thats a better choice because you cannot create an instance of quantized linear for invaild group size |
||||||
pack_weight_op=getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit"), | ||||||
linear_op=getattr( | ||||||
torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight" | ||||||
), | ||||||
) | ||||||
setattr(module, name, qlinear) | ||||||
getattr(module, name).quantize_and_pack_weights( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
child.weight, nbit, group_size | ||||||
) | ||||||
|
||||||
|
||||||
class UIntxWeightOnlyLinearQuantizer: | ||||||
def __init__( | ||||||
self, | ||||||
device, | ||||||
precision, | ||||||
*, | ||||||
bitwidth: Optional[int] = None, | ||||||
groupsize: Optional[int] = None, | ||||||
): | ||||||
if device != "mps": | ||||||
raise NotImplementedError( | ||||||
"Only device=mps is currently supported in UIntxWeightOnlyLinearQuantizer" | ||||||
) | ||||||
else: | ||||||
self.device = device | ||||||
|
||||||
if precision not in [torch.float32, torch.float16, torch.bfloat16]: | ||||||
raise NotImplementedError( | ||||||
"Only precisions float32, float16 & bfloat16 are currently supported in UIntxWeightOnlyLinearQuantizer" | ||||||
) | ||||||
else: | ||||||
self.precision = precision | ||||||
|
||||||
if bitwidth is None: | ||||||
self.bitwidth = 4 | ||||||
logger.warning(f"bitwidth not specified, defaulting to {self.bitwidth}.") | ||||||
else: | ||||||
self.bitwidth = bitwidth | ||||||
|
||||||
if groupsize is None: | ||||||
self.groupsize = 128 | ||||||
logger.warning(f"groupsize not specified, defaulting to {self.groupsize}.") | ||||||
else: | ||||||
self.groupsize = groupsize | ||||||
|
||||||
def quantize(self, model: nn.Module) -> nn.Module: | ||||||
model = model.to(self.device).to(self.precision) | ||||||
_replace_linear_with_quantized_linear_mps( | ||||||
model, | ||||||
kwargs={ | ||||||
"group_size": self.groupsize, | ||||||
"nbit": self.bitwidth, | ||||||
}, | ||||||
) | ||||||
return model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isnt there something equivalent within torchao?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well, there is something in torch.testing._internal, but then I ran into dependency issues with 'expecttest'.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we use this in other places:
ao/test/integration/test_integration.py
Line 647 in c546c5c
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we please use what Jerry is pointing to? Less code is better.
Unresolving the comment