Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrating fast_hadamard_transform on C++ level #17

Merged
merged 13 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "fast-hadamard-transform"]
path = fast-hadamard-transform
url = https://github.com/Dao-AILab/fast-hadamard-transform
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -531,8 +531,8 @@ index 9fd91a2..80782ea 100644
--- a/flute/ops.py
+++ b/flute/ops.py
@@ -124 +124 @@ def _qgemm_simple_89_abstract(
-# @torch.library.impl_abstract("flute::qgemm_raw_simple_80")
+@torch.library.impl_abstract("flute::qgemm_raw_simple_80")
-# @torch.library.register_fake("flute::qgemm_raw_simple_80")
+@torch.library.register_fake("flute::qgemm_raw_simple_80")
```

</details>
Expand Down
1 change: 1 addition & 0 deletions fast-hadamard-transform
22 changes: 22 additions & 0 deletions flute/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,21 @@
None,
]

QGEMM_HADAMARD_TYPE = Callable[
[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
int,
int,
int,
HanGuo97 marked this conversation as resolved.
Show resolved Hide resolved
],
torch.Tensor,
]


# we use this instead of `torch.cuda.get_device_capability()` so that
# it works better with multiprocessing (which vLLM uses)
Expand Down Expand Up @@ -72,8 +87,15 @@
# 128: cast(QGEMM_RAW_SIMPLE_TYPE, torch.ops.flute.qgemm_raw_simple_89),
# }

QGEMM_HADAMARD_DICT = {
84 : cast(QGEMM_HADAMARD_TYPE, torch.ops.flute.qgemm_hadamard_86),
108: cast(QGEMM_HADAMARD_TYPE, torch.ops.flute.qgemm_hadamard_80),
128: cast(QGEMM_HADAMARD_TYPE, torch.ops.flute.qgemm_hadamard_89),
}

qgemm_simple = QGEMM_SIMPLE_DICT[NUM_SMS]
qgemm_raw_simple = None # QGEMM_RAW_SIMPLE_DICT[NUM_SMS]
qgemm_hadamard = QGEMM_HADAMARD_DICT[NUM_SMS]


# Load the template configs
Expand Down
253 changes: 253 additions & 0 deletions flute/csrc/hadamard.cpp
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a file a copy of that in the original "fast-hadamard-transform"? (Looks slightly different to me, and if not a copy, what's the difference?)

Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
// Based on code by Tri Dao, 2023

#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/extension.h>
#include <vector>


struct HadamardParamsBase {
using index_t = int64_t;

int batch, dim, log_N;

index_t x_batch_stride;
index_t out_batch_stride;

float scale;

// Common data pointers.
void *__restrict__ x_ptr;
void *__restrict__ out_ptr;
};


#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")

#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
if (ITYPE == at::ScalarType::Half) { \
using input_t = at::Half; \
__VA_ARGS__(); \
} else if (ITYPE == at::ScalarType::BFloat16) { \
using input_t = at::BFloat16; \
__VA_ARGS__(); \
} else if (ITYPE == at::ScalarType::Float) { \
using input_t = float; \
__VA_ARGS__(); \
} else { \
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
}

template<typename input_t>
void fast_hadamard_transform_cuda(HadamardParamsBase &params, cudaStream_t stream);

template<typename input_t>
void fast_hadamard_transform_12N_cuda(HadamardParamsBase &params, cudaStream_t stream);

template<typename input_t>
void fast_hadamard_transform_20N_cuda(HadamardParamsBase &params, cudaStream_t stream);

template<typename input_t>
void fast_hadamard_transform_28N_cuda(HadamardParamsBase &params, cudaStream_t stream);

void set_hadamard_params(HadamardParamsBase &params,
// sizes
const size_t batch,
const size_t dim,
const size_t multiple,
// device pointers
const at::Tensor x,
const at::Tensor out,
float scale
) {

// Reset the parameters
memset(&params, 0, sizeof(params));

params.batch = batch;
params.dim = dim;
params.log_N = int(ceil(std::log2(dim / multiple)));

// Set the pointers and strides.
params.x_ptr = x.data_ptr();
params.out_ptr = out.data_ptr();
// All stride are in elements, not bytes.
params.x_batch_stride = x.stride(0);
params.out_batch_stride = out.stride(0);

params.scale = scale;
}


at::Tensor
fast_hadamard_transform(at::Tensor &x, float scale) {
auto input_type = x.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);

TORCH_CHECK(x.is_cuda());

const auto shapes_og = x.sizes();
const int dim_og = x.size(-1);
x = x.reshape({-1, dim_og});
if (x.stride(-1) != 1) { x = x.contiguous(); }
const auto sizes = x.sizes();
const int batch_size = sizes[0];

CHECK_SHAPE(x, batch_size, dim_og);
TORCH_CHECK(x.stride(1) == 1);

if (dim_og % 8 != 0) {
x = torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, 8 - dim_og % 8}));
}
const int dim = x.size(1);

TORCH_CHECK(dim % 8 == 0, "fast_hadamard_transform only supports hidden dimension divisible by 8 for now");
TORCH_CHECK(dim <= 32768, "fast_hadamard_transform only supports hidden dimension at most 32768 for now");

at::Tensor out = torch::empty_like(x);

HadamardParamsBase params;
set_hadamard_params(params, batch_size, dim, 1, x, out, scale);

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
auto stream = at::cuda::getCurrentCUDAStream().stream();
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "fast_hadamard_transform", [&] {
fast_hadamard_transform_cuda<input_t>(params, stream);
});
if (dim_og % 8 != 0) {
out = out.index({torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, dim_og)});
}
return out.reshape(shapes_og);
}

at::Tensor
fast_hadamard_transform_12N(at::Tensor &x, float scale) {
auto input_type = x.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);

TORCH_CHECK(x.is_cuda());

const auto shapes_og = x.sizes();
const int dim_og = x.size(-1);
x = x.reshape({-1, dim_og});
if (x.stride(-1) != 1) { x = x.contiguous(); }
const auto sizes = x.sizes();
const int batch_size = sizes[0];

CHECK_SHAPE(x, batch_size, dim_og);
TORCH_CHECK(x.stride(1) == 1);

if (dim_og % (4 * 12) != 0) {
x = torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, (4 * 12) - dim_og % (4 * 12)}));
}
const int dim = x.size(1);

TORCH_CHECK(dim % (4 * 12) == 0, "fast_hadamard_transform_12N only supports hidden dimension divisible by 48 for now");
TORCH_CHECK(dim <= 12 * 1024, "fast_hadamard_transform_12N only supports hidden dimension at most 12288 for now");

at::Tensor out = torch::empty_like(x);

HadamardParamsBase params;
set_hadamard_params(params, batch_size, dim, 12, x, out, scale);

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
auto stream = at::cuda::getCurrentCUDAStream().stream();
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "fast_hadamard_transform", [&] {
fast_hadamard_transform_12N_cuda<input_t>(params, stream);
});
if (dim_og % (4 * 12) != 0) {
out = out.index({torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, dim_og)});
}
return out.reshape(shapes_og);
}

at::Tensor
fast_hadamard_transform_20N(at::Tensor &x, float scale) {
auto input_type = x.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);

TORCH_CHECK(x.is_cuda());

const auto shapes_og = x.sizes();
const int dim_og = x.size(-1);
x = x.reshape({-1, dim_og});
if (x.stride(-1) != 1) { x = x.contiguous(); }
const auto sizes = x.sizes();
const int batch_size = sizes[0];

CHECK_SHAPE(x, batch_size, dim_og);
TORCH_CHECK(x.stride(1) == 1);

if (dim_og % (4 * 20) != 0) {
x = torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, (4 * 20) - dim_og % (4 * 20)}));
}
const int dim = x.size(1);

TORCH_CHECK(dim % (4 * 20) == 0, "fast_hadamard_transform_20N only supports hidden dimension divisible by 80 for now");
TORCH_CHECK(dim <= 20 * 1024, "fast_hadamard_transform_20N only supports hidden dimension at most 20480 for now");

at::Tensor out = torch::empty_like(x);

HadamardParamsBase params;
set_hadamard_params(params, batch_size, dim, 20, x, out, scale);

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
auto stream = at::cuda::getCurrentCUDAStream().stream();
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "fast_hadamard_transform", [&] {
fast_hadamard_transform_20N_cuda<input_t>(params, stream);
});
if (dim_og % (4 * 20) != 0) {
out = out.index({torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, dim_og)});
}
return out.reshape(shapes_og);
}

at::Tensor
fast_hadamard_transform_28N(at::Tensor &x, float scale) {
auto input_type = x.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);

TORCH_CHECK(x.is_cuda());

const auto shapes_og = x.sizes();
const int dim_og = x.size(-1);
x = x.reshape({-1, dim_og});
if (x.stride(-1) != 1) { x = x.contiguous(); }
const auto sizes = x.sizes();
const int batch_size = sizes[0];

CHECK_SHAPE(x, batch_size, dim_og);
TORCH_CHECK(x.stride(1) == 1);

if (dim_og % (4 * 28) != 0) {
x = torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, (4 * 28) - dim_og % (4 * 28)}));
}
const int dim = x.size(1);

TORCH_CHECK(dim % (4 * 28) == 0, "fast_hadamard_transform_28N only supports hidden dimension divisible by 112 for now");
// TORCH_CHECK(dim <= 28 * 1024, "fast_hadamard_transform_28N only supports hidden dimension at most 28672 for now");
TORCH_CHECK(dim <= 28 * 2048, "fast_hadamard_transform_28N only supports hidden dimension at most 28672 for now");

at::Tensor out = torch::empty_like(x);

HadamardParamsBase params;
set_hadamard_params(params, batch_size, dim, 28, x, out, scale);

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
auto stream = at::cuda::getCurrentCUDAStream().stream();
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "fast_hadamard_transform", [&] {
fast_hadamard_transform_28N_cuda<input_t>(params, stream);
});
if (dim_og % (8 * 28) != 0) {
out = out.index({torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, dim_og)});
}
return out.reshape(shapes_og);
}
48 changes: 48 additions & 0 deletions flute/csrc/qgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
#include "cute/numeric/integral_constant.hpp"


at::Tensor fast_hadamard_transform(at::Tensor &x, float scale);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we make the signature style similar to others (one line for output type, and one argument per line) --- this is the style used in CUTLASS



template <
typename SMs,
typename T,
Expand Down Expand Up @@ -369,6 +372,45 @@ qgemm_raw_simple(const at::Tensor& input,
}


at::Tensor apply_hadamard(const at::Tensor& input, const cute::int64_t hadamard_size) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we make the signature style similar to others (one line for output type, and one argument per line) --- this is the style used in CUTLASS

auto input_sizes = input.sizes();
auto flat_input = input.reshape({-1, hadamard_size}).contiguous();
return fast_hadamard_transform(
flat_input,
1.0 / sqrt(static_cast<float>(hadamard_size))
).reshape(input_sizes).contiguous();
}


template <
typename SMs
>
at::Tensor
qgemm_hadamard(const at::Tensor& input,
const at::Tensor& weight,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor indentation inconsistency :)

const at::Tensor& scales,
const at::Tensor& table,
const at::Tensor& table2,
at::Tensor& workspace,
const cute::int64_t num_bits,
const cute::int64_t group_size,
const cute::int64_t hadamard_size)
{
auto had_input = apply_hadamard(input, hadamard_size);

return qgemm_simple<SMs>(
had_input,
weight,
scales,
table,
table2,
workspace,
num_bits,
group_size
);
}


// Registers _C as an extension module.
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}

Expand All @@ -381,6 +423,9 @@ TORCH_LIBRARY(flute, m) {
// m.def("qgemm_raw_simple_80(Tensor input, Tensor weight, Tensor(a!) output, Tensor scales, Tensor table, Tensor table2, Tensor(b!) workspace, int num_bits, int group_size, int template_id) -> ()");
// m.def("qgemm_raw_simple_86(Tensor input, Tensor weight, Tensor(a!) output, Tensor scales, Tensor table, Tensor table2, Tensor(b!) workspace, int num_bits, int group_size, int template_id) -> ()");
// m.def("qgemm_raw_simple_89(Tensor input, Tensor weight, Tensor(a!) output, Tensor scales, Tensor table, Tensor table2, Tensor(b!) workspace, int num_bits, int group_size, int template_id) -> ()");
m.def("qgemm_hadamard_80(Tensor input, Tensor weight, Tensor scales, Tensor table, Tensor table2, Tensor(a!) workspace, int num_bits, int group_size, int hadamard_size) -> Tensor");
m.def("qgemm_hadamard_86(Tensor input, Tensor weight, Tensor scales, Tensor table, Tensor table2, Tensor(a!) workspace, int num_bits, int group_size, int hadamard_size) -> Tensor");
m.def("qgemm_hadamard_89(Tensor input, Tensor weight, Tensor scales, Tensor table, Tensor table2, Tensor(a!) workspace, int num_bits, int group_size, int hadamard_size) -> Tensor");
}


Expand All @@ -391,4 +436,7 @@ TORCH_LIBRARY_IMPL(flute, CUDA, m) {
// m.impl("qgemm_raw_simple_80", &qgemm_raw_simple<cute::Int<108>>);
// m.impl("qgemm_raw_simple_86", &qgemm_raw_simple<cute::Int<84>>);
// m.impl("qgemm_raw_simple_89", &qgemm_raw_simple<cute::Int<128>>);
m.impl("qgemm_hadamard_80", &qgemm_hadamard<cute::Int<108>>);
m.impl("qgemm_hadamard_86", &qgemm_hadamard<cute::Int<84>>);
m.impl("qgemm_hadamard_89", &qgemm_hadamard<cute::Int<128>>);
}
Loading