Skip to content

Commit

Permalink
intx weight only linear quantizer for mps (pytorch#1192)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#1192

Differential Revision: D65079774
  • Loading branch information
manuelcandales authored and facebook-github-bot committed Oct 30, 2024
1 parent 4f1fc4c commit ed83de7
Show file tree
Hide file tree
Showing 3 changed files with 280 additions and 18 deletions.
83 changes: 65 additions & 18 deletions torchao/experimental/ops/mps/register.mm
Original file line number Diff line number Diff line change
Expand Up @@ -17,54 +17,66 @@

// LowBit Quantized Linear on MPS Backend
template <int nbit>
Tensor linear_mps_kernel(
void check_linear_mps_args(
const Tensor& A,
const Tensor& B,
int64_t group_size,
const Tensor& SZ) {
auto M = A.size(0);
auto N = B.size(0);
auto K = A.size(1);

TORCH_CHECK(
A.is_mps(), __func__, "A is on ", A.device(), " but expected on mps");
TORCH_CHECK(
B.is_mps(), __func__, "B is on ", B.device(), " but expected on mps");
TORCH_CHECK(
SZ.is_mps(), __func__, "SZ is on ", SZ.device(), " but expected on mps");

TORCH_CHECK(
TORCHAO_CHECK(
A.dtype() == at::kBFloat16 || A.dtype() == at::kHalf ||
A.dtype() == at::kFloat,
__func__,
" : expect A to be either 32-bit or 16-bit float tensor.");
TORCH_CHECK(A.is_contiguous(), __func__, " : expect A to be contiguous.");
TORCH_CHECK(A.dim() == 2, __func__, " : expect A to be 2D tensor.");
TORCHAO_CHECK(A.is_contiguous(), __func__, " : expect A to be contiguous.");
TORCHAO_CHECK(A.dim() == 2, __func__, " : expect A to be 2D tensor.");

TORCH_CHECK(
TORCHAO_CHECK(
B.dtype() == at::kByte, __func__, " : expect B to be uint8 tensor.");
TORCH_CHECK(B.is_contiguous(), __func__, " : expect B to be contiguous.");
TORCH_CHECK(
TORCHAO_CHECK(B.is_contiguous(), __func__, " : expect B to be contiguous.");
TORCHAO_CHECK(
B.size(1) == (K / 8) * nbit,
__func__,
" : expect B.size(1) == ",
(K / 8) * nbit);

TORCH_CHECK(K % 8 == 0, __func__, ": expect K to be multiple of 8, got ", K);
TORCHAO_CHECK(K % 8 == 0, __func__, ": expect K to be multiple of 8, got ", K);

TORCH_CHECK(
TORCHAO_CHECK(
group_size == 32 || group_size == 64 || group_size == 128 ||
group_size == 256,
__func__,
": expect group_size to be 32, 64, 128 or 256, got ",
group_size);

TORCH_CHECK(
TORCHAO_CHECK(
SZ.dim() == 3 && SZ.size(1) == N && SZ.size(2) == 2,
__func__,
": expect SZ to be 3d tensor with sizes [:, ",
N,
", 2]");
}

template <int nbit>
Tensor linear_mps_kernel(
const Tensor& A,
const Tensor& B,
int64_t group_size,
const Tensor& SZ) {
TORCHAO_CHECK(
A.is_mps(), __func__, "A is on ", A.device(), " but expected on mps");
TORCHAO_CHECK(
B.is_mps(), __func__, "B is on ", B.device(), " but expected on mps");
TORCHAO_CHECK(
SZ.is_mps(), __func__, "SZ is on ", SZ.device(), " but expected on mps");

check_linear_mps_args<nbit>(A, B, group_size, SZ);

auto M = A.size(0);
auto N = B.size(0);
auto K = A.size(1);

auto C = at::empty({M, N}, A.options());

Expand All @@ -82,6 +94,31 @@ Tensor linear_mps_kernel(
return C;
}

template <int nbit>
Tensor linear_mps_kernel_meta(
const Tensor& A,
const Tensor& B,
int64_t group_size,
const Tensor& SZ) {
TORCHAO_CHECK(
A.is_meta(), __func__, "A is on ", A.device(), " but expected on meta");
TORCHAO_CHECK(
B.is_meta(), __func__, "B is on ", B.device(), " but expected on meta");
TORCHAO_CHECK(
SZ.is_meta(),
__func__,
"SZ is on ",
SZ.device(),
" but expected on meta");

check_linear_mps_args<nbit>(A, B, group_size, SZ);

auto M = A.size(0);
auto N = B.size(0);

return at::empty({M, N}, A.options()).to("meta");
}

// LowBit Packing on CPU Backend
template <int nbit>
Tensor pack_weights_cpu_kernel(const Tensor& W) {
Expand Down Expand Up @@ -144,4 +181,14 @@ Tensor pack_weights_cpu_kernel(const Tensor& W) {
m.impl("_linear_fp_act_7bit_weight", &linear_mps_kernel<7>);
}

TORCH_LIBRARY_IMPL(torchao, Meta, m) {
m.impl("_linear_fp_act_1bit_weight", &linear_mps_kernel_meta<1>);
m.impl("_linear_fp_act_2bit_weight", &linear_mps_kernel_meta<2>);
m.impl("_linear_fp_act_3bit_weight", &linear_mps_kernel_meta<3>);
m.impl("_linear_fp_act_4bit_weight", &linear_mps_kernel_meta<4>);
m.impl("_linear_fp_act_5bit_weight", &linear_mps_kernel_meta<5>);
m.impl("_linear_fp_act_6bit_weight", &linear_mps_kernel_meta<6>);
m.impl("_linear_fp_act_7bit_weight", &linear_mps_kernel_meta<7>);
}

} // namespace torchao::kernels::mps::lowbit::aten
104 changes: 104 additions & 0 deletions torchao/experimental/ops/mps/test/test_quantizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional
import copy
import os
import sys

import torch
import torchao_mps_ops
import unittest

torchao_root: Optional[str] = os.getenv("TORCHAO_ROOT")
assert torchao_root is not None, "TORCHAO_ROOT is not set"

sys.path.insert(0, torchao_root)
from torchao.experimental.quant_api import IntxWeightOnlyLinearQuantizer


def parameterized(test_cases):
def decorator(func):
def wrapper(self):
for case in test_cases:
with self.subTest(case=case):
func(self, *case)
return wrapper
return decorator


class TestIntxWeightOnlyLinearQuantizer(unittest.TestCase):
cases = [(nbit,) for nbit in range(1, 8)]

def _model_setup(self):
k0 = 512
k1 = 256
k2 = 128
k3 = 1024
layers = [
torch.nn.Linear(k0, k1, bias=False),
torch.nn.Linear(k1, k2, bias=False),
torch.nn.Linear(k2, k3, bias=False),
]
model = torch.nn.Sequential(*layers)
return model

def _quantize_model(self, model, precision, nbit, group_size):
quantizer = IntxWeightOnlyLinearQuantizer(
device="mps",
precision=precision,
bitwidth=nbit,
groupsize=group_size,
)
quantized_model = copy.deepcopy(model)
quantized_model = quantizer.quantize(quantized_model)
return quantized_model

@parameterized(cases)
def test_export(self, nbit):
model = self._model_setup()
group_size = 32
m = 3
k0 = 512
activations = torch.randn(m, k0, dtype=torch.float32, device="mps")

quantized_model = self._quantize_model(model, torch.float32, nbit, group_size)
exported = torch.export.export(quantized_model, (activations,))

for node in exported.graph.nodes:
if node.op == "call_function":
self.assertTrue(
str(node.target)
== f"torchao._linear_fp_act_{nbit}bit_weight.default"
)

@parameterized(cases)
def test_2d_output_device_and_shape(self, nbit):
model = self._model_setup()
group_size = 32
m = 3
activations = torch.randn(m, 512, dtype=torch.float32, device="mps")

quantized_model = self._quantize_model(model, torch.float32, nbit, group_size)
result = quantized_model(activations)
self.assertTrue(result.is_mps)
self.assertTrue(result.shape == (m, 1024))

@parameterized(cases)
def test_3d_output_device_and_shape(self, nbit):
model = self._model_setup()
group_size = 32
leading_shape = (3, 5)
activations = torch.randn(*leading_shape, 512, dtype=torch.float32, device="mps")

quantized_model = self._quantize_model(model, torch.float32, nbit, group_size)
result = quantized_model(activations)
self.assertTrue(result.is_mps)
self.assertTrue(result.shape == (*leading_shape, 1024))


if __name__ == "__main__":
unittest.main()
111 changes: 111 additions & 0 deletions torchao/experimental/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,3 +516,114 @@ def apply(weight):
)

return _get_linear_subclass_inserter(apply)


class IntxWeightOnlyQuantizedLinear(nn.Module):
def __init__(
self,
pack_weight_op,
linear_op,
):
super().__init__()
self._pack_weights_op = pack_weight_op
self._linear_op = linear_op

def quantize_and_pack_weights(self, weights, nbit, group_size):
self.nbit = nbit
self.group_size = group_size

weight_qvals, weight_scales, weight_zeros = _quantize(
weights, self.group_size, self.nbit, True
)
weight_qvals = (weight_qvals + (1 << (nbit - 1))).to(torch.uint8)

self.weight_scales_and_zeros = torch.stack(
(weight_scales.t(), weight_zeros.t()), dim=2
)

self.packed_weights = self._pack_weights_op(weight_qvals.cpu()).to(device="mps")

def forward(self, x):
assert x.dim() >= 2
if x.dim() == 2:
return self._linear_op(
x, self.packed_weights, self.group_size, self.weight_scales_and_zeros
)

lead_shape = x.shape[0:-1]
k = x.shape[-1]
n = self.weight_scales_and_zeros.shape[1]
res = self._linear_op(x.reshape(-1, k), self.packed_weights, self.group_size, self.weight_scales_and_zeros)
res = res.reshape(*lead_shape, n)
return res


def _replace_linear_with_quantized_linear_mps(module: nn.Module, kwargs={}):
group_size = kwargs["group_size"]
nbit = kwargs["nbit"]

assert not isinstance(module, nn.Linear)
assert nbit >= 1 and nbit <= 7

for name, child in module.named_children():
if not isinstance(child, nn.Linear):
_replace_linear_with_quantized_linear_mps(child, kwargs)
else:
assert child.bias is None
qlinear = IntxWeightOnlyQuantizedLinear(
pack_weight_op=getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit"),
linear_op=getattr(
torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight"
),
)
setattr(module, name, qlinear)
getattr(module, name).quantize_and_pack_weights(
child.weight, nbit, group_size
)


class IntxWeightOnlyLinearQuantizer:
def __init__(
self,
device,
precision,
*,
bitwidth: Optional[int] = None,
groupsize: Optional[int] = None,
):
if device != "mps":
raise NotImplementedError(
"Only device=mps is currently supported in IntxWeightOnlyLinearQuantizer"
)
else:
self.device = device

if precision not in [torch.float32, torch.float16, torch.bfloat16]:
raise NotImplementedError(
"Only precisions float32, float16 & bfloat16 are currently supported in IntxWeightOnlyLinearQuantizer"
)
else:
self.precision = precision

if bitwidth is None:
self.bitwidth = 4
logger.warning(f"bitwidth not specified, defaulting to {self.bitwidth}.")
else:
self.bitwidth = bitwidth

if groupsize is None:
self.groupsize = 128
logger.warning(f"groupsize not specified, defaulting to {self.groupsize}.")
else:
self.groupsize = groupsize

def quantize(self, model: nn.Module) -> nn.Module:
model = model.to(self.device).to(self.precision)
_replace_linear_with_quantized_linear_mps(
model,
kwargs={
"group_size": self.groupsize,
"nbit": self.bitwidth,
},
)
return model

0 comments on commit ed83de7

Please sign in to comment.