diff --git a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt index ecffb579c..a13737d87 100644 --- a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt @@ -6,8 +6,8 @@ add_library( kernel_aarch64 - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp ) diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/CMakeLists.txt b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/CMakeLists.txt index 55bcdfbc2..10e44a79a 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/CMakeLists.txt @@ -13,12 +13,12 @@ set(CMAKE_BUILD_TYPE Release) add_compile_options("-Wall" "-Werror") include(CMakePrintHelpers) -message("TORCHAO_LIBRARIES: ${TORCHAO_LIBRARIES}") -include_directories(${TORCHAO_LIBRARIES}) +message("TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}") +include_directories(${TORCHAO_INCLUDE_DIRS}) -add_subdirectory(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/kernel_aarch64) +add_subdirectory(${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/kernel_aarch64) -include(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/Utils.cmake) +include(${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/Utils.cmake) set(PLATFORM "ATEN" CACHE STRING "Choose platform surface: ATEN, EXECUTORCH") string(TOUPPER ${PLATFORM} PLATFORM_TO_UPPER) diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/build_custom_op.sh b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/build_custom_op.sh index 94cb9587c..c657857fc 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/build_custom_op.sh +++ b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/build_custom_op.sh @@ -6,14 +6,14 @@ # LICENSE file in the root directory of this source tree. SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) -export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../../.. +export TORCHAO_INCLUDE_DIRS=${SCRIPT_DIR}/../../../../../../.. export CMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}" -export CMAKE_OUT=/tmp/cmake-out/torch_ao/examples/torch_custom_op -cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \ +export CMAKE_OUT=/tmp/cmake-out/torchao +cmake -DTORCHAO_INCLUDE_DIRS=${TORCHAO_INCLUDE_DIRS} \ -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \ -DPLATFORM="ATEN" \ - -S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op \ + -S ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op \ -B ${CMAKE_OUT} cmake --build ${CMAKE_OUT} diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/run_custom_op.py b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/run_custom_op.py index 0b85583f7..e3d96df63 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/run_custom_op.py +++ b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/run_custom_op.py @@ -5,12 +5,21 @@ # LICENSE file in the root directory of this source tree. import copy +import glob +import os + +import sys import torch -from torch_custom_op import ( - linear_a8sz_w_lowbit_reference_impl, - replace_linear_with_quantized_linear, + +sys.path.insert( + 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../..")) ) +from quant_api import Int8DynActIntxWeightQuantizer + +libs = glob.glob("/tmp/cmake-out/torchao/liblowbit_op_aten.*") +libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs)) +torch.ops.load_library(libs[0]) group_size = 256 m = 1 @@ -27,15 +36,15 @@ print("Quantizing random model") quantized_model = copy.deepcopy(model) -quantized_model = quantized_model.eval() -replace_linear_with_quantized_linear( - quantized_model, - kwargs={ - "group_size": group_size, - "nbit": nbit, - "has_weight_zeros": has_weight_zeros, - }, +quantizer = Int8DynActIntxWeightQuantizer( + device="cpu", + precision=torch.float32, + bitwidth=nbit, + groupsize=group_size, + has_weight_zeros=has_weight_zeros, ) +quantized_model = quantizer.quantize(quantized_model) +quantized_model = quantized_model.eval() print("Creating random activations") activations = torch.randn(m, k, dtype=torch.float32) @@ -58,44 +67,3 @@ print("Running AOTI") fn = torch._export.aot_load("/tmp/torch_custom_op_example_model.so", "cpu") fn(activations) - - -print("\nChecking correctness on layer 0") -linear = model[0] -quantized_linear = quantized_model[0] - -with torch.no_grad(): - result = quantized_linear(activations) - expected_result = linear_a8sz_w_lowbit_reference_impl( - linear.weight, activations, group_size, nbit, has_weight_zeros - ) - non_quantized_result = linear(activations) - - -# Check that entries in result match entries in expected_result -num_mismatch_at_low_tol = 0 -num_total = result.reshape(-1).shape[0] -for i in range(num_total): - actual_val = result.reshape(-1)[i] - expected_val = expected_result.reshape(-1)[i] - if not torch.allclose(actual_val, expected_val): - num_mismatch_at_low_tol += 1 - - # If results are not close at a relaxed tolerance, exit with failure - if not torch.allclose(actual_val, expected_val, atol=1e-6): - assert False, "Correctness check failed" - -# Assert at most 5% of entries are not close at a low tolerance -assert num_mismatch_at_low_tol / num_total <= 0.05, "Correctness check failed" -print( - "Correctness check passed. All results are close, and ", - (num_total - num_mismatch_at_low_tol), - "/", - num_total, - " entries are close at a low tolerance.", -) -print("Quantization errors:") -print("\tL1 error: ", torch.mean(torch.abs(result - non_quantized_result)).item()) -print("\tL2 error: ", torch.mean((result - non_quantized_result) ** 2).item()) -print("\tquantized_result[0:5]: ", result[0][0:5]) -print("\tnon_quantized_result[0:5]: ", non_quantized_result[0][0:5]) diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_custom_op.py b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_custom_op.py deleted file mode 100644 index e4e108b90..000000000 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_custom_op.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import unittest - -import torch -from torch_custom_op import ( - linear_a8sz_w_lowbit_reference_impl, - replace_linear_with_quantized_linear, -) -import copy - -class TestTorchCustomOp(unittest.TestCase): - def test_accuracy(self): - group_size = 128 - m = 1 - n = 1071 - k = 4096 - activations = torch.randn(m, k, dtype=torch.float32) - model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)]) - - for nbit in [2, 3, 4, 5]: - for has_weight_zeros in [False, True]: - quantized_model = copy.deepcopy(model) - replace_linear_with_quantized_linear( - quantized_model, - kwargs={ - "group_size": group_size, - "nbit": nbit, - "has_weight_zeros": has_weight_zeros, - }, - ) - - with torch.no_grad(): - result = quantized_model(activations) - expected_result = linear_a8sz_w_lowbit_reference_impl( - model[0].weight, activations, group_size, nbit, has_weight_zeros - ) - - num_mismatch_at_low_tol = 0 - num_total = result.reshape(-1).shape[0] - for i in range(num_total): - actual_val = result.reshape(-1)[i] - expected_val = expected_result.reshape(-1)[i] - self.assertTrue(torch.allclose(actual_val, expected_val, atol=1e-6)) - if not torch.allclose(actual_val, expected_val): - num_mismatch_at_low_tol += 1 - - # Assert at most 5% of entries are not close at a low tolerance - self.assertTrue(num_mismatch_at_low_tol / num_total <= 0.05) - -if __name__ == '__main__': - unittest.main() diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_int8_dyn_act_intx_weight_quantizer.py b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_int8_dyn_act_intx_weight_quantizer.py new file mode 100644 index 000000000..513088d2f --- /dev/null +++ b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_int8_dyn_act_intx_weight_quantizer.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import copy + +import glob +import os + +import sys +import unittest + +import torch + +sys.path.insert( + 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../..")) +) +from quant_api import ( + _Int8DynActIntxWeightQuantizedLinearFallback, + Int8DynActIntxWeightQuantizer, +) + +libs = glob.glob("/tmp/cmake-out/torchao/liblowbit_op_aten.*") +libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs)) +if len(libs) == 0: + print( + "Could not find library lowbit_op_aten; please run `sh build_custom_op.sh` to build the library. A slow fallback kernel will be used instaed." + ) +else: + torch.ops.load_library(libs[0]) + + +class TestInt8DynActIntxWeightQuantizer(unittest.TestCase): + def test_accuracy(self): + group_size = 128 + m = 1 + n = 1071 + k = 4096 + activations = torch.randn(m, k, dtype=torch.float32) + model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)]) + + for nbit in [1, 2, 3, 4, 5, 6, 7]: + for has_weight_zeros in [True, False]: + print(f"Testing nbit={nbit}, has_weight_zeros={has_weight_zeros}") + quantized_model = copy.deepcopy(model) + quantizer = Int8DynActIntxWeightQuantizer( + device="cpu", + precision=torch.float32, + bitwidth=nbit, + groupsize=group_size, + has_weight_zeros=has_weight_zeros, + ) + quantized_model = quantizer.quantize(quantized_model) + + with torch.no_grad(): + result = quantized_model(activations) + reference_impl = _Int8DynActIntxWeightQuantizedLinearFallback() + reference_impl.quantize_and_pack_weights( + model[0].weight, nbit, group_size, has_weight_zeros + ) + expected_result = reference_impl(activations) + + num_mismatch_at_low_tol = 0 + num_total = result.reshape(-1).shape[0] + for i in range(num_total): + actual_val = result.reshape(-1)[i] + expected_val = expected_result.reshape(-1)[i] + self.assertTrue(torch.allclose(actual_val, expected_val, atol=1e-6)) + if not torch.allclose(actual_val, expected_val): + num_mismatch_at_low_tol += 1 + + # Assert at most 5% of entries are not close at a low tolerance + self.assertTrue(num_mismatch_at_low_tol / num_total <= 0.05) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/torch_custom_op.py b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/torch_custom_op.py deleted file mode 100644 index 46117db15..000000000 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/torch_custom_op.py +++ /dev/null @@ -1,231 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Optional - -import torch -import torch.nn as nn - -import glob -libs = glob.glob("/tmp/cmake-out/torch_ao/examples/torch_custom_op/libtorch_custom_op.*") -libs = list(filter(lambda l:(l.endswith("so") or l.endswith("dylib")), libs)) -torch.ops.load_library(libs[0]) - -def quantize(vals: torch.Tensor, group_size: int, nbit: int, scale_only: bool): - assert nbit >= 2 and nbit <= 8 - qmin = -(1 << (nbit - 1)) - qmax = (1 << (nbit - 1)) - 1 - - n, k = vals.shape - vals = vals.reshape(-1, group_size) - vmins, _ = torch.min(vals, axis=1) - vmaxs, _ = torch.max(vals, axis=1) - group_scales = (vmaxs - vmins) / (qmax - qmin) - - if scale_only: - group_qvals = torch.round(vals / group_scales.reshape(-1, 1)) - else: - group_zeros = qmin - torch.round(vmins / group_scales) - group_qvals = torch.round( - group_zeros.reshape(-1, 1) + vals / group_scales.reshape(-1, 1) - ) - - group_qvals = torch.clip(group_qvals, qmin, qmax).reshape(n, k).to(torch.int8) - - if scale_only: - return group_qvals, group_scales - return group_qvals, group_scales, group_zeros - - -def linear_a8sz_w_lowbit_reference_impl( - weights, activations, group_size, nbit, has_weight_zeros -): - n, k = weights.shape - m, k = activations.shape - assert m == 1 - assert k % group_size == 0 - - if has_weight_zeros: - weight_qvals, weight_scales, weight_zeros = quantize( - weights, group_size, nbit, scale_only=False - ) - weights_dequantized = ( - weight_scales.reshape(-1, 1) - * (weight_qvals.reshape(-1, group_size) - weight_zeros.reshape(-1, 1)) - ).reshape(n, k) - else: - weight_qvals, weight_scales = quantize( - weights, group_size, nbit, scale_only=True - ) - weights_dequantized = ( - weight_scales.reshape(-1, 1) * (weight_qvals.reshape(-1, group_size)) - ).reshape(n, k) - - activation_qvals, activations_scales, activations_zeros = quantize( - activations, k, 8, False - ) - activations_dequantized = activations_scales * ( - activation_qvals - activations_zeros - ).reshape(m, k) - return torch.matmul(activations_dequantized, weights_dequantized.transpose(1, 0)) - - -class _quantized_linear(nn.Module): - def __init__( - self, - nbit, - has_weight_zeros, - pack_weight_op, - linear_op, - squeeze_unsqueeze_dim0=False, - ): - super().__init__() - self.squeeze_unsqueeze_dim0 = squeeze_unsqueeze_dim0 - self.nbit = nbit - - self._has_weight_zeros = has_weight_zeros - self._pack_weights_op = pack_weight_op - self._linear_op = linear_op - - def pack_weights(self, weight_qvals, weight_scales_and_zeros, group_size): - n, k = weight_qvals.shape - - # TODO(T200095131): convert self.n, self.k, self.group_size to - # int when supported by AOTI - self.n = torch.empty(n) - self.k = torch.empty(k) - self.group_size = torch.empty(group_size) - - if self._has_weight_zeros: - weight_scales, weight_zeros = weight_scales_and_zeros - self.packed_weights = self._pack_weights_op( - weight_qvals, weight_scales, weight_zeros, self.group_size - ) - else: - weight_scales = weight_scales_and_zeros - self.packed_weights = self._pack_weights_op( - weight_qvals, weight_scales, self.group_size - ) - - def forward(self, x): - if self.squeeze_unsqueeze_dim0: - x = x.squeeze(0) - - res = self._linear_op(self.packed_weights, self.n, self.k, self.group_size, x) - - if self.squeeze_unsqueeze_dim0: - res = res.unsqueeze(0) - return res - - -def replace_linear_with_quantized_linear(module: nn.Module, kwargs={}): - group_size = kwargs["group_size"] - nbit = kwargs["nbit"] - has_weight_zeros = kwargs["has_weight_zeros"] - squeeze_unsqueeze_dim0 = ( - kwargs["squeeze_unsqueeze_dim0"] - if "squeeze_unsqueeze_dim0" in kwargs - else False - ) - - for name, child in module.named_children(): - if isinstance(child, nn.Linear): - assert child.bias is None - - if not has_weight_zeros: - weight_qvals, weight_scales = quantize( - child.weight, group_size=group_size, nbit=nbit, scale_only=True - ) - weight_scales_and_zeros = weight_scales - else: - weight_qvals, weight_scales, weight_zeros = quantize( - child.weight, group_size=group_size, nbit=nbit, scale_only=False - ) - weight_scales_and_zeros = (weight_scales, weight_zeros.to(torch.int8)) - - qlinear = None - if nbit == 2: - if has_weight_zeros: - qlinear = _quantized_linear( - nbit=nbit, - has_weight_zeros=has_weight_zeros, - pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w2sz, - linear_op=torch.ops.torchao._linear_a8sz_w2sz, - squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, - ) - else: - qlinear = _quantized_linear( - nbit=nbit, - has_weight_zeros=has_weight_zeros, - pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w2s, - linear_op=torch.ops.torchao._linear_a8sz_w2s, - squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, - ) - elif nbit == 3: - if has_weight_zeros: - qlinear = _quantized_linear( - nbit=nbit, - has_weight_zeros=has_weight_zeros, - pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w3sz, - linear_op=torch.ops.torchao._linear_a8sz_w3sz, - squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, - ) - else: - qlinear = _quantized_linear( - nbit=nbit, - has_weight_zeros=has_weight_zeros, - pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w3s, - linear_op=torch.ops.torchao._linear_a8sz_w3s, - squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, - ) - elif nbit == 4: - if has_weight_zeros: - qlinear = _quantized_linear( - nbit=nbit, - has_weight_zeros=has_weight_zeros, - pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w4sz, - linear_op=torch.ops.torchao._linear_a8sz_w4sz, - squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, - ) - else: - qlinear = _quantized_linear( - nbit=nbit, - has_weight_zeros=has_weight_zeros, - pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w4s, - linear_op=torch.ops.torchao._linear_a8sz_w4s, - squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, - ) - elif nbit == 5: - if has_weight_zeros: - qlinear = _quantized_linear( - nbit=nbit, - has_weight_zeros=has_weight_zeros, - pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w5sz, - linear_op=torch.ops.torchao._linear_a8sz_w5sz, - squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, - ) - else: - qlinear = _quantized_linear( - nbit=nbit, - has_weight_zeros=has_weight_zeros, - pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w5s, - linear_op=torch.ops.torchao._linear_a8sz_w5s, - squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, - ) - else: - raise ValueError( - f"Unsupported nbit ({nbit}) and has_weight_zeros ({has_weight_zeros}) combination" - ) - - assert qlinear is not None - setattr(module, name, qlinear) - getattr(module, name).pack_weights( - weight_qvals, - weight_scales_and_zeros, - group_size, - ) - else: - replace_linear_with_quantized_linear(child, kwargs) diff --git a/torchao/experimental/quant_api.py b/torchao/experimental/quant_api.py new file mode 100644 index 000000000..26797bdb1 --- /dev/null +++ b/torchao/experimental/quant_api.py @@ -0,0 +1,321 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Optional + +import torch +import torch.nn as nn +from torch.ao.quantization.fx._decomposed import ( + dequantize_per_channel_group, + quantize_per_channel_group, +) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + +import sys + +handler = logging.StreamHandler(sys.stdout) +formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") +handler.setFormatter(formatter) +logger.addHandler(handler) + + +def _quantize(vals: torch.Tensor, group_size: int, nbit: int, has_weight_zeros: bool): + assert nbit >= 1 and nbit <= 8 + qmin = -(1 << (nbit - 1)) + qmax = (1 << (nbit - 1)) - 1 + + n, k = vals.shape + vals = vals.reshape(-1, group_size) + vmins, _ = torch.min(vals, axis=1) + vmaxs, _ = torch.max(vals, axis=1) + group_scales = (vmaxs - vmins) / (qmax - qmin) + + if not has_weight_zeros: + group_zeros = torch.zeros_like(group_scales) + else: + group_zeros = qmin - torch.round(vmins / group_scales) + + vals = vals.reshape(n, k) + group_scales = group_scales.reshape(n, -1) + group_zeros = group_zeros.reshape(n, -1) + + group_qvals = quantize_per_channel_group( + input=vals, + scales=group_scales, + zero_points=group_zeros, + quant_min=qmin, + quant_max=qmax, + dtype=torch.int8, + group_size=group_size, + ) + + if not has_weight_zeros: + group_zeros = None + + return group_qvals, group_scales, group_zeros + + +class _Int8DynActIntxWeightQuantizedLinearNative(nn.Module): + def __init__( + self, + pack_weight_op, + linear_op, + ): + super().__init__() + self._pack_weights_op = pack_weight_op + self._linear_op = linear_op + + def quantize_and_pack_weights(self, weights, nbit, group_size, has_weight_zeros): + self.nbit = nbit + self.group_size = group_size + self.has_weight_zeros = has_weight_zeros + + n, k = weights.shape + + # TODO(T200095131): convert self.n, self.k, self.group_size to + # int when supported by AOTI + self._n = torch.empty(n, dtype=torch.int8) + self._k = torch.empty(k, dtype=torch.int8) + self._group_size = torch.empty(self.group_size, dtype=torch.int8) + + weight_qvals, weight_scales, weight_zeros = _quantize( + weights, self.group_size, self.nbit, self.has_weight_zeros + ) + if self.has_weight_zeros: + self.packed_weights = self._pack_weights_op( + weight_qvals, + weight_scales.reshape(-1), + weight_zeros.to(torch.int8).reshape(-1), + self._group_size, + ) + else: + self.packed_weights = self._pack_weights_op( + weight_qvals, weight_scales.reshape(-1), self._group_size + ) + + def forward(self, x): + assert x.dim() >= 2 + if x.dim() == 2: + return self._linear_op( + self.packed_weights, self._n, self._k, self._group_size, x + ) + + assert x.dim() >= 3 + lead_shape = x.shape[0:-2] + m, k = x.shape[-2], x.shape[-1] + n = self._n.shape[0] + x = x.reshape(-1, m, k) + + res = [ + self._linear_op( + self.packed_weights, self._n, self._k, self._group_size, x[i, :, :] + ) + for i in range(x.shape[0]) + ] + res = torch.stack(res) + res = res.reshape(*lead_shape, m, n) + return res + + +# Python-based reference implementation of Int8DynActLowbitWeightQuantizedLinear +# It is arithmetically equivalent to Int8DynActLowbitWeightQuantizedLinear +# This is used to test Int8DynActLowbitWeightQuantizedLinear, and as a fallback when +# Int8DynActLowbitWeightQuantizedLinear is not available +class _Int8DynActIntxWeightQuantizedLinearFallback(nn.Module): + def __init__(self): + super().__init__() + + def quantize_and_pack_weights(self, weights, nbit, group_size, has_weight_zeros): + self.nbit = nbit + self.group_size = group_size + self.has_weight_zeros = has_weight_zeros + + self._n, self._k = weights.shape + assert self._k % group_size == 0, "group_size must divide k" + + self.weight_qvals, self.weight_scales, self.weight_zeros = _quantize( + weights, self.group_size, self.nbit, self.has_weight_zeros + ) + + def _forward_2d(self, x): + assert x.dim() == 2 + + n, k = self._n, self._k + m, k_ = x.shape + assert k_ == k + + weights_dequantized = dequantize_per_channel_group( + w_int8=self.weight_qvals, + scales=self.weight_scales, + zero_points=( + self.weight_zeros + if self.has_weight_zeros + else torch.zeros_like(self.weight_scales) + ), + quant_min=None, # TODO: why is this an arg for this function + quant_max=None, # TODO: why is this an arg for this function + dtype=None, # TODO: why is this an arg for this function + group_size=self.group_size, + output_dtype=torch.float32, + ) + + activation_qvals, activation_scales, activation_zeros = _quantize( + x, group_size=k, nbit=8, has_weight_zeros=True + ) + activations_dequantized = dequantize_per_channel_group( + w_int8=activation_qvals, + scales=activation_scales, + zero_points=activation_zeros, + quant_min=None, # TODO: why is this an arg for this function + quant_max=None, # TODO: why is this an arg for this function + dtype=None, # TODO: why is this an arg for this function + group_size=k, + output_dtype=torch.float32, + ) + + res = torch.matmul(activations_dequantized, weights_dequantized.transpose(1, 0)) + return res + + def forward(self, x): + assert x.dim() >= 2 + if x.dim() == 2: + return self._forward_2d(x) + + assert x.dim() >= 3 + lead_shape = x.shape[0:-2] + m, k = x.shape[-2], x.shape[-1] + n = self._n + x = x.reshape(-1, m, k) + + res = [self._forward_2d(x[i, :, :]) for i in range(x.shape[0])] + res = torch.stack(res) + res = res.reshape(*lead_shape, m, n) + return res + + +def _maybe_get_quantized_linear_native(nbit, has_weight_zeros): + try: + if nbit in [2, 3, 4, 5]: + wzp_suffix = "z" if has_weight_zeros else "" + return _Int8DynActIntxWeightQuantizedLinearNative( + pack_weight_op=getattr( + torch.ops.torchao, f"_pack_weights_a8sz_w{nbit}s{wzp_suffix}" + ), + linear_op=getattr( + torch.ops.torchao, f"_linear_a8sz_w{nbit}s{wzp_suffix}" + ), + ) + else: + logger.warning( + f"_Int8DynActIntxWeightQuantizedLinearNative does not support: nbit={nbit}, has_weight_zeros={has_weight_zeros}." + ) + except Exception as e: + logger.warning( + f"_Int8DynActIntxWeightQuantizedLinearNative raised an exception during initialization: {e}" + ) + + logger.warning( + "Falling back to **slow** implementation _Int8DynActIntxWeightQuantizedLinearFallback." + ) + return _Int8DynActIntxWeightQuantizedLinearFallback() + + +def _replace_linear_with_quantized_linear(module: nn.Module, kwargs={}): + group_size = kwargs["group_size"] + nbit = kwargs["nbit"] + has_weight_zeros = kwargs["has_weight_zeros"] + + assert not isinstance(module, nn.Linear) + assert nbit >= 1 and nbit <= 7 + + for name, child in module.named_children(): + if not isinstance(child, nn.Linear): + _replace_linear_with_quantized_linear(child, kwargs) + else: + assert child.bias is None + qlinear = _maybe_get_quantized_linear_native( + nbit=nbit, has_weight_zeros=has_weight_zeros + ) + try: + # The packing function may raise some error from the C++ layer (e.g., if group_size is unsupported) + # so calling quantize_and_pack_weights can fail. In this case, we still switch to fallback + # implementation + setattr(module, name, qlinear) + getattr(module, name).quantize_and_pack_weights( + child.weight, nbit, group_size, has_weight_zeros + ) + except Exception as e: + if not isinstance(qlinear, _Int8DynActIntxWeightQuantizedLinearNative): + raise e + logger.warning( + "_Int8DynActIntxWeightQuantizedLinearNative raised an exception during quantize_and_pack_weights: {e}\n" + + "Falling back to **slow** implementation _Int8DynActIntxWeightQuantizedLinearFallback." + ) + qlinear = _Int8DynActIntxWeightQuantizedLinearFallback() + setattr(module, name, qlinear) + getattr(module, name).quantize_and_pack_weights( + child.weight, nbit, group_size, has_weight_zeros + ) + + +class Int8DynActIntxWeightQuantizer: + def __init__( + self, + device, + precision, + *, + bitwidth: Optional[int] = None, + groupsize: Optional[int] = None, + has_weight_zeros: Optional[bool] = None, + ): + if device != "cpu": + raise NotImplementedError( + "Only device=cpu is currently supported in Int8DynActLowbitWeightQuantizer" + ) + else: + self.device = device + + if precision != torch.float32: + raise NotImplementedError( + "Only precision=torch.float32 is currently supported in Int8DynActLowbitWeightQuantizer" + ) + else: + self.precision = precision + + if bitwidth is None: + self.bitwidth = 4 + logger.warning(f"bitwidth not specified, defaulting to {self.bitwidth}.") + else: + self.bitwidth = bitwidth + + if groupsize is None: + self.groupsize = 128 + logger.warning(f"groupsize not specified, defaulting to {self.groupsize}.") + else: + self.groupsize = groupsize + + if has_weight_zeros is None: + self.has_weight_zeros = False + logger.warning( + f"has_weight_zeros not specified, defaulting to {self.has_weight_zeros}." + ) + else: + self.has_weight_zeros = has_weight_zeros + + def quantize(self, model: nn.Module) -> nn.Module: + model = model.to(self.device).to(self.precision) + _replace_linear_with_quantized_linear( + model, + kwargs={ + "group_size": self.groupsize, + "nbit": self.bitwidth, + "has_weight_zeros": self.has_weight_zeros, + }, + ) + return model