Skip to content

Commit

Permalink
intx weight only linear quantizer for mps (pytorch#1192)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#1192

Differential Revision: D65079774
  • Loading branch information
manuelcandales authored and facebook-github-bot committed Nov 13, 2024
1 parent c546c5c commit b1d27e1
Show file tree
Hide file tree
Showing 2 changed files with 255 additions and 4 deletions.
135 changes: 135 additions & 0 deletions torchao/experimental/ops/mps/test/test_quantizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional
import copy
import os
import sys

import torch
import torchao_mps_ops
import unittest

from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer
from torchao.experimental.quant_api import _quantize


def parameterized(test_cases):
def decorator(func):
def wrapper(self):
for case in test_cases:
with self.subTest(case=case):
func(self, *case)

return wrapper

return decorator


class TestUIntxWeightOnlyLinearQuantizer(unittest.TestCase):
cases = [(nbit,) for nbit in range(1, 8)]

def _model_setup(self):
group_size = 32
k0 = 96
k1 = 224
k2 = 160
n = 47
layers = [
torch.nn.Linear(k0, k1, bias=False),
torch.nn.Linear(k1, k2, bias=False),
torch.nn.Linear(k2, n, bias=False),
]
model = torch.nn.Sequential(*layers)
return model, group_size, k0, n

def _quantize_model(self, model, precision, nbit, group_size):
quantizer = UIntxWeightOnlyLinearQuantizer(
device="mps",
precision=precision,
bitwidth=nbit,
groupsize=group_size,
)
quantized_model = copy.deepcopy(model)
quantized_model = quantizer.quantize(quantized_model)
return quantized_model

@parameterized(cases)
def test_export(self, nbit):
model, group_size, k0, n = self._model_setup()
m = 3
activations = torch.randn(m, k0, dtype=torch.float32, device="mps")

quantized_model = self._quantize_model(model, torch.float32, nbit, group_size)
exported = torch.export.export(quantized_model, (activations,))

for node in exported.graph.nodes:
if node.op == "call_function":
self.assertTrue(
str(node.target)
== f"torchao._linear_fp_act_{nbit}bit_weight.default"
)

@parameterized(cases)
def test_2d_output_device_and_shape(self, nbit):
model, group_size, k0, n = self._model_setup()
m = 3
activations = torch.randn(m, k0, dtype=torch.float32, device="mps")

quantized_model = self._quantize_model(model, torch.float32, nbit, group_size)
result = quantized_model(activations)
self.assertTrue(result.is_mps)
self.assertTrue(result.shape == (m, n))

@parameterized(cases)
def test_3d_output_device_and_shape(self, nbit):
model, group_size, k0, n = self._model_setup()
leading_shape = (3, 5)
activations = torch.randn(
*leading_shape, k0, dtype=torch.float32, device="mps"
)

quantized_model = self._quantize_model(model, torch.float32, nbit, group_size)
result = quantized_model(activations)
self.assertTrue(result.is_mps)
self.assertTrue(result.shape == (*leading_shape, n))

def _reference_linear_lowbit_quant_weights(self, A, W, group_size, S, Z):
N = W.shape[0]
K = W.shape[1]
W = W.to(torch.float32)
scales = S.t().unsqueeze(2).repeat(1, 1, group_size).view(N, -1)[:, :K]
zeros = Z.t().unsqueeze(2).repeat(1, 1, group_size).view(N, -1)[:, :K]
W = scales * W + zeros
return torch.mm(A, W.t())

@parameterized(cases)
def test_accuracy(self, nbit):
group_size = 32
m = 3
n = 7
k = 64
with torch.no_grad():
activations = torch.rand(m, k, dtype=torch.float32, device="mps")
model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)])
quantized_model = self._quantize_model(model, torch.float32, nbit, group_size)
result = quantized_model(activations)

# Compute expected result
weight_cpu = model[0].weight.data
weight_qvals_cpu, weight_scales_cpu, weight_zeros_cpu = _quantize(
weight_cpu, group_size, nbit, True, torch.uint8
)
weight_scales_cpu = weight_scales_cpu.t()
weight_zeros_cpu = -weight_zeros_cpu.t() * weight_scales_cpu
expected = self._reference_linear_lowbit_quant_weights(activations.cpu(), weight_qvals_cpu, group_size, weight_scales_cpu, weight_zeros_cpu)

# Compare results
torch.testing.assert_close(result.cpu(), expected, rtol=0.001, atol=0.001)


if __name__ == "__main__":
unittest.main()
124 changes: 120 additions & 4 deletions torchao/experimental/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,16 @@
logger.addHandler(handler)


def _quantize(vals: torch.Tensor, group_size: int, nbit: int, has_weight_zeros: bool):
def _quantize(vals: torch.Tensor, group_size: int, nbit: int, has_weight_zeros: bool, dtype=torch.int8):
assert nbit >= 1 and nbit <= 8
qmin = -(1 << (nbit - 1))
qmax = (1 << (nbit - 1)) - 1
if dtype == torch.int8:
qmin = -(1 << (nbit - 1))
qmax = (1 << (nbit - 1)) - 1
elif dtype == torch.uint8:
qmin = 0
qmax = (1 << nbit) - 1
else:
raise ValueError(f"Unsupported dtype {dtype}")

n, k = vals.shape
vals = vals.reshape(-1, group_size)
Expand All @@ -51,7 +57,7 @@ def _quantize(vals: torch.Tensor, group_size: int, nbit: int, has_weight_zeros:
zero_points=group_zeros,
quant_min=qmin,
quant_max=qmax,
dtype=torch.int8,
dtype=dtype,
group_size=group_size,
)

Expand Down Expand Up @@ -516,3 +522,113 @@ def apply(weight):
)

return _get_linear_subclass_inserter(apply)


class UIntxWeightOnlyQuantizedLinear(nn.Module):
def __init__(
self,
pack_weight_op,
linear_op,
):
super().__init__()
self._pack_weights_op = pack_weight_op
self._linear_op = linear_op

def quantize_and_pack_weights(self, weights, nbit, group_size):
self.nbit = nbit
self.group_size = group_size

weight_qvals, weight_scales, weight_zeros = _quantize(
weights, self.group_size, self.nbit, True, torch.uint8
)
weight_scales = torch.transpose_copy(weight_scales, 1, 0)
weight_zeros = torch.transpose_copy(weight_zeros, 1, 0)
self.weight_scales = weight_scales
self.weight_zeros = -weight_zeros * weight_scales

self.packed_weights = self._pack_weights_op(weight_qvals.cpu()).to(device="mps")

def forward(self, x):
assert x.dim() >= 2
if x.dim() == 2:
return self._linear_op(
x, self.packed_weights, self.group_size, self.weight_scales, self.weight_zeros
)

lead_shape = x.shape[0:-1]
k = x.shape[-1]
n = self.weight_scales.shape[1]
return self._linear_op(
x.reshape(-1, k), self.packed_weights, self.group_size, self.weight_scales, self.weight_zeros
).reshape(*lead_shape, n)


def _replace_linear_with_quantized_linear_mps(module: nn.Module, kwargs={}):
group_size = kwargs["group_size"]
nbit = kwargs["nbit"]

assert not isinstance(module, nn.Linear)
assert nbit >= 1 and nbit <= 7

for name, child in module.named_children():
if not isinstance(child, nn.Linear):
_replace_linear_with_quantized_linear_mps(child, kwargs)
else:
assert child.bias is None
qlinear = UIntxWeightOnlyQuantizedLinear(
pack_weight_op=getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit"),
linear_op=getattr(
torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight"
),
)
setattr(module, name, qlinear)
getattr(module, name).quantize_and_pack_weights(
child.weight, nbit, group_size
)


class UIntxWeightOnlyLinearQuantizer:
def __init__(
self,
device,
precision,
*,
bitwidth: Optional[int] = None,
groupsize: Optional[int] = None,
):
if device != "mps":
raise NotImplementedError(
"Only device=mps is currently supported in UIntxWeightOnlyLinearQuantizer"
)
else:
self.device = device

if precision not in [torch.float32, torch.float16, torch.bfloat16]:
raise NotImplementedError(
"Only precisions float32, float16 & bfloat16 are currently supported in UIntxWeightOnlyLinearQuantizer"
)
else:
self.precision = precision

if bitwidth is None:
self.bitwidth = 4
logger.warning(f"bitwidth not specified, defaulting to {self.bitwidth}.")
else:
self.bitwidth = bitwidth

if groupsize is None:
self.groupsize = 128
logger.warning(f"groupsize not specified, defaulting to {self.groupsize}.")
else:
self.groupsize = groupsize

def quantize(self, model: nn.Module) -> nn.Module:
model = model.to(self.device).to(self.precision)
_replace_linear_with_quantized_linear_mps(
model,
kwargs={
"group_size": self.groupsize,
"nbit": self.bitwidth,
},
)
return model

0 comments on commit b1d27e1

Please sign in to comment.