Skip to content

Commit

Permalink
feat(cpp): add quantize_symmetric CPU kernel
Browse files Browse the repository at this point in the history
For now only per-tensor quantization is supported.
  • Loading branch information
dacorvo committed Feb 9, 2024
1 parent 89e365a commit d612b6b
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 1 deletion.
16 changes: 15 additions & 1 deletion bench/library/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,17 @@
from quanto.library import disable_extensions


def get_quantize_symmetric_bench(src_dtype, dst_dtype, per_axis, device):
a = torch.rand([10240, 10240], dtype=src_dtype).to(device)
scale = torch.fill((10240,), 0.5) if per_axis else torch.tensor(0.5)
scale = scale.to(src_dtype).to(device)

def bench_fn():
return torch.ops.quanto.quantize_symmetric(a, scale, dst_dtype)

return bench_fn


def get_unpack_bench(bits, device):
qmax = 2**bits
a = torch.randint(0, qmax, [10240, 10240], dtype=torch.uint8).to(device)
Expand Down Expand Up @@ -69,6 +80,9 @@ def elapsed_time(self, other):


GET_BENCH_FUNCTIONS = {
"quantize_symmetric_fp32_int8_per_tensor": lambda device: get_quantize_symmetric_bench(
torch.float32, torch.int8, False, device
),
"unpack_2bit": lambda device: get_unpack_bench(2, device),
"unpack_4bit": lambda device: get_unpack_bench(4, device),
}
Expand All @@ -89,7 +103,7 @@ def main():
device = torch.device("cpu")
else:
device = torch.device(args.device)
all_kernels = ["unpack_2bit", "unpack_4bit"]
all_kernels = GET_BENCH_FUNCTIONS.keys()
kernels = all_kernels if args.kernel is None else [args.kernel]
for kernel in kernels:
get_bench_fn = GET_BENCH_FUNCTIONS[kernel]
Expand Down
6 changes: 6 additions & 0 deletions quanto/library/ext/cpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def ext():
_ext = load(
name="quanto_cpp",
sources=[
f"{module_path}/quantize.cpp",
f"{module_path}/unpack.cpp",
f"{module_path}/pybind_module.cpp",
],
Expand All @@ -27,6 +28,11 @@ def ext():
return _ext


@torch.library.impl("quanto_ext::quantize_symmetric", ["CPU"])
def quantize_symmetric_cpp(t: torch.Tensor, scale: torch.Tensor, dtype: torch.Tensor.dtype):
return ext().quantize_symmetric(t, scale, dtype)


@impl("quanto_ext::unpack", ["CPU", "CUDA"])
def unpack_cpp(t: torch.Tensor, bits: int):
return ext().unpack(t, bits)
13 changes: 13 additions & 0 deletions quanto/library/ext/cpp/pybind_module.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
#include <torch/extension.h>
#include "quantize.h"
#include "unpack.h"

// !IMPORTANT! Some python objects such as dtype, device, are not mapped to C++ types,
// and need to be explicitly converted using dedicated helpers before calling a C++ method.
// As a consequence, when an operation takes such an object as parameter, instead
// of creating a binding directly to the C++ method, you must create a binding to a
// lambda method that converts the unmapped types and calls the C++ method.
// See the binding of quantize_symmetric for instance.

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("quantize_symmetric",
[](const torch::Tensor& t, const torch::Tensor& scale, py::object dtype) {
return quantize_symmetric(t,
scale,
torch::python::detail::py_object_to_dtype(dtype));
}, "quantize_symmetric");
m.def("unpack", &unpack, "unpack");
}
64 changes: 64 additions & 0 deletions quanto/library/ext/cpp/quantize.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#include "quantize.h"
#include <torch/extension.h>


template <typename T>
torch::Tensor quantize_symmetric_per_tensor(const torch::Tensor& input, const torch::Tensor& scale) {
torch::Tensor output = torch::empty_like(input, c10::TensorOptions(c10::kChar).dtype(), LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto qdata = reinterpret_cast<int8_t*>(output.data_ptr());
auto numel = input.numel();
const T* const data = input.data_ptr<T>();
float float_scale = scale.data_ptr<T>()[0];
float inv_scale = float_scale == 0 ? 1.0f : 1.0f / float_scale;
for (const auto i : c10::irange(numel)) {
int64_t qvalue = lrintf(std::nearbyint(data[i] * inv_scale));
qvalue = std::max(-127LL, std::min(qvalue, 127LL));
qdata[i] = static_cast<int8_t>(qvalue);
}
return output;
}


int get_scale_axis(const torch::Tensor& scale) {
int axis = -1;
auto scale_dims = scale.sizes();
for (int i = 0; i < scale_dims.size(); ++i) {
if (scale_dims[i] != 1) {
axis = i;
}
}
return axis;
}


torch::Tensor quantize_symmetric_char(const torch::Tensor& input,
const torch::Tensor& scale) {
int axis = get_scale_axis(scale);
if (axis == -1) {
auto scale_dtype = scale.dtype();
if (scale_dtype == at::ScalarType::Float) {
return quantize_symmetric_per_tensor<float>(input, scale);
}
if (scale_dtype == at::ScalarType::Half) {
return quantize_symmetric_per_tensor<at::Half>(input, scale);
}
TORCH_CHECK(false, "Unsupported scale dtype:", scale_dtype)
}
TORCH_CHECK(false, "symmetric per-axis is not supported")
}


torch::Tensor quantize_symmetric(const torch::Tensor& input,
const torch::Tensor& scale,
at::ScalarType dtype) {
bool scalar_scale = (scale.sizes().size() == 0);
bool broadcastable_scale = (input.sizes().size() == scale.sizes().size());
TORCH_CHECK(scalar_scale || broadcastable_scale,
"Quantization scale must be scalar or broadcastable to the base tensor.")
TORCH_CHECK((scale.dtype() == at::ScalarType::Float) || (scale.dtype() == at::ScalarType::Half),
"Quantization scale must be float or float16.")
if (dtype == at::ScalarType::Char) {
return quantize_symmetric_char(input, scale);
}
TORCH_CHECK_NOT_IMPLEMENTED(false, "quantize_symmetric not supported for ", dtype)
}
5 changes: 5 additions & 0 deletions quanto/library/ext/cpp/quantize.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include <torch/extension.h>

torch::Tensor quantize_symmetric(const torch::Tensor& input,
const torch::Tensor& scale,
at::ScalarType dtype);

0 comments on commit d612b6b

Please sign in to comment.