Skip to content

Commit

Permalink
feat(library): add CUDA unpack kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed Mar 19, 2024
1 parent 1606e1b commit e5dc839
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 1 deletion.
3 changes: 3 additions & 0 deletions quanto/library/ext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,8 @@
from .cpp import *


if torch.cuda.is_available():
from .cuda import *

if torch.backends.mps.is_available():
from .mps import *
2 changes: 1 addition & 1 deletion quanto/library/ext/cpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ def dqmm_cpp(input: torch.Tensor, other: torch.Tensor, other_scale: torch.Tensor
return ext().dqmm(input, other, other_scale)


@torch.library.impl("quanto_ext::unpack", ["CPU", "CUDA"])
@torch.library.impl("quanto_ext::unpack", ["CPU"])
def unpack_cpp(t: torch.Tensor, bits: int):
return ext().unpack(t, bits)
11 changes: 11 additions & 0 deletions quanto/library/ext/cuda/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Quanto generic CUDA extension

Kernels in this extension can use both the C++ and CUDA syntax.

They can use any pytorch operation defined under `aten::` or `c10::`.

To add a new implementation for an operation defined in `library./ops.py`:

- add the corresponding `.cpp` or `.cu` file to the list of sources in `__init__.py`,
- add a binding to `pybind_module.cpp`,
- provide an implementation calling the binding in `__init__.py`.
30 changes: 30 additions & 0 deletions quanto/library/ext/cuda/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import os

import torch
from torch.utils.cpp_extension import load


__all__ = []


_ext = None


def ext():
"""Helper to load the CUDA ext only when it is required"""
global _ext
if _ext is None:
module_path = os.path.dirname(__file__)
_ext = load(
name="quanto_cuda",
sources=[
f"{module_path}/unpack.cu",
f"{module_path}/pybind_module.cpp",
],
)
return _ext


@torch.library.impl("quanto_ext::unpack", ["CUDA"])
def unpack_cuda(t: torch.Tensor, bits: int):
return ext().unpack(t, bits)
13 changes: 13 additions & 0 deletions quanto/library/ext/cuda/pybind_module.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#include <torch/extension.h>
#include "unpack.h"

// !IMPORTANT! Some python objects such as dtype, device, are not mapped to C++ types,
// and need to be explicitly converted using dedicated helpers before calling a C++ method.
// As a consequence, when an operation takes such an object as parameter, instead
// of creating a binding directly to the C++ method, you must create a binding to a
// lambda method that converts the unmapped types and calls the C++ method.
// See the binding of quantize_symmetric for instance.

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("unpack", &unpack, "unpack");
}
83 changes: 83 additions & 0 deletions quanto/library/ext/cuda/unpack.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <c10/cuda/CUDAException.h>

inline unsigned int cdiv(unsigned int a, unsigned int b) { return (a + b - 1) / b;}
#define BLOCK_SIZE 256

using namespace at;


static torch::Tensor allocate_output(const torch::Tensor& input, int bits) {
int n_packed = 8 / bits;
auto output_shape = input.sizes().vec();
output_shape[0] = output_shape[0] * n_packed;
return torch::empty(output_shape, input.options());
}

__global__ void unpack_4bit_kernel(unsigned char* input, unsigned char* output, int n) {
int i = blockIdx.x*blockDim.x + threadIdx.x;
if(i>=n) return;

output[i] = (input[i] & 0x0F);
output[i + n] = (input[i] & 0xF0) >> 4;
}

static torch::Tensor unpack_4bit(const torch::Tensor& input){

auto output = allocate_output(input, 4);

const auto numel = input.numel();
int blocks = cdiv(numel, BLOCK_SIZE);
unpack_4bit_kernel<<<blocks, BLOCK_SIZE>>>(
input.data_ptr<unsigned char>(),
output.data_ptr<unsigned char>(),
numel
);

C10_CUDA_KERNEL_LAUNCH_CHECK();

return output;
}

__global__ void unpack_2bit_kernel(unsigned char* input, unsigned char* output, int n) {
int i = blockIdx.x*blockDim.x + threadIdx.x;
if(i>=n) return;

output[i] = (input[i] & 0x03);
output[i + n] = (input[i] & 0x0C) >> 2;
output[i + n*2] = (input[i] & 0x30) >> 4;
output[i + n*3] = (input[i] & 0xC0) >> 6;
}

static torch::Tensor unpack_2bit(const torch::Tensor& input){

auto output = allocate_output(input, 2);

const auto numel = input.numel();
int blocks = cdiv(numel, BLOCK_SIZE);
unpack_2bit_kernel<<<blocks, BLOCK_SIZE>>>(
input.data_ptr<unsigned char>(),
output.data_ptr<unsigned char>(),
numel
);

C10_CUDA_KERNEL_LAUNCH_CHECK();

return output;
}

torch::Tensor unpack(torch::Tensor &t, int bits) {
TORCH_CHECK(t.scalar_type() == torch::kUInt8, "Unsupported data type: ", t.scalar_type());
TORCH_CHECK(t.device().is_cuda(), "t must be a CUDA tensor.");
TORCH_CHECK(t.is_contiguous(), "t must be contiguous.");
switch(bits) {
case 4:
return unpack_4bit(t);
case 2:
return unpack_2bit(t);
default:
throw std::invalid_argument("Can only unpack 2-bit or 4-bit tensors.");
}
}
3 changes: 3 additions & 0 deletions quanto/library/ext/cuda/unpack.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#include <torch/extension.h>

torch::Tensor unpack(torch::Tensor &t, int bits);

0 comments on commit e5dc839

Please sign in to comment.