Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CUDA unpack kernel #119

Merged
merged 1 commit into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions quanto/library/ext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,8 @@
from .cpp import *


if torch.cuda.is_available():
from .cuda import *

if torch.backends.mps.is_available():
from .mps import *
2 changes: 1 addition & 1 deletion quanto/library/ext/cpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ def dqmm_cpp(input: torch.Tensor, other: torch.Tensor, other_scale: torch.Tensor
return ext().dqmm(input, other, other_scale)


@torch.library.impl("quanto_ext::unpack", ["CPU", "CUDA"])
@torch.library.impl("quanto_ext::unpack", ["CPU"])
def unpack_cpp(t: torch.Tensor, bits: int):
return ext().unpack(t, bits)
11 changes: 11 additions & 0 deletions quanto/library/ext/cuda/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Quanto generic CUDA extension

Kernels in this extension can use both the C++ and CUDA syntax.

They can use any pytorch operation defined under `aten::` or `c10::`.

To add a new implementation for an operation defined in `library./ops.py`:

- add the corresponding `.cpp` or `.cu` file to the list of sources in `__init__.py`,
- add a binding to `pybind_module.cpp`,
- provide an implementation calling the binding in `__init__.py`.
30 changes: 30 additions & 0 deletions quanto/library/ext/cuda/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import os

import torch
from torch.utils.cpp_extension import load


__all__ = []


_ext = None


def ext():
"""Helper to load the CUDA ext only when it is required"""
global _ext
if _ext is None:
module_path = os.path.dirname(__file__)
_ext = load(
name="quanto_cuda",
sources=[
f"{module_path}/unpack.cu",
f"{module_path}/pybind_module.cpp",
],
)
return _ext


@torch.library.impl("quanto_ext::unpack", ["CUDA"])
def unpack_cuda(t: torch.Tensor, bits: int):
return ext().unpack(t, bits)
13 changes: 13 additions & 0 deletions quanto/library/ext/cuda/pybind_module.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#include <torch/extension.h>
#include "unpack.h"

// !IMPORTANT! Some python objects such as dtype, device, are not mapped to C++ types,
// and need to be explicitly converted using dedicated helpers before calling a C++ method.
// As a consequence, when an operation takes such an object as parameter, instead
// of creating a binding directly to the C++ method, you must create a binding to a
// lambda method that converts the unmapped types and calls the C++ method.
// See the binding of quantize_symmetric for instance.

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("unpack", &unpack, "unpack");
}
83 changes: 83 additions & 0 deletions quanto/library/ext/cuda/unpack.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <c10/cuda/CUDAException.h>

inline unsigned int cdiv(unsigned int a, unsigned int b) { return (a + b - 1) / b;}
#define BLOCK_SIZE 256

using namespace at;


static torch::Tensor allocate_output(const torch::Tensor& input, int bits) {
int n_packed = 8 / bits;
auto output_shape = input.sizes().vec();
output_shape[0] = output_shape[0] * n_packed;
return torch::empty(output_shape, input.options());
}

__global__ void unpack_4bit_kernel(unsigned char* input, unsigned char* output, int n) {
int i = blockIdx.x*blockDim.x + threadIdx.x;
if(i>=n) return;

output[i] = (input[i] & 0x0F);
output[i + n] = (input[i] & 0xF0) >> 4;
}

static torch::Tensor unpack_4bit(const torch::Tensor& input){

auto output = allocate_output(input, 4);

const auto numel = input.numel();
int blocks = cdiv(numel, BLOCK_SIZE);
unpack_4bit_kernel<<<blocks, BLOCK_SIZE>>>(
input.data_ptr<unsigned char>(),
output.data_ptr<unsigned char>(),
numel
);

C10_CUDA_KERNEL_LAUNCH_CHECK();

return output;
}

__global__ void unpack_2bit_kernel(unsigned char* input, unsigned char* output, int n) {
int i = blockIdx.x*blockDim.x + threadIdx.x;
if(i>=n) return;

output[i] = (input[i] & 0x03);
output[i + n] = (input[i] & 0x0C) >> 2;
output[i + n*2] = (input[i] & 0x30) >> 4;
output[i + n*3] = (input[i] & 0xC0) >> 6;
}

static torch::Tensor unpack_2bit(const torch::Tensor& input){

auto output = allocate_output(input, 2);

const auto numel = input.numel();
int blocks = cdiv(numel, BLOCK_SIZE);
unpack_2bit_kernel<<<blocks, BLOCK_SIZE>>>(
input.data_ptr<unsigned char>(),
output.data_ptr<unsigned char>(),
numel
);

C10_CUDA_KERNEL_LAUNCH_CHECK();

return output;
}

torch::Tensor unpack(torch::Tensor &t, int bits) {
TORCH_CHECK(t.scalar_type() == torch::kUInt8, "Unsupported data type: ", t.scalar_type());
TORCH_CHECK(t.device().is_cuda(), "t must be a CUDA tensor.");
TORCH_CHECK(t.is_contiguous(), "t must be contiguous.");
switch(bits) {
case 4:
return unpack_4bit(t);
case 2:
return unpack_2bit(t);
default:
throw std::invalid_argument("Can only unpack 2-bit or 4-bit tensors.");
}
}
3 changes: 3 additions & 0 deletions quanto/library/ext/cuda/unpack.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#include <torch/extension.h>

torch::Tensor unpack(torch::Tensor &t, int bits);
Loading