Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Compression config cutlass #205

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,6 @@ set(VLLM_EXT_SRC
"csrc/layernorm_kernels.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/smoothquant/int8gemm/cublasAlgoMap.cc"
"csrc/quantization/smoothquant/int8gemm/cublasINT8MMWrapper.cc"
"csrc/quantization/smoothquant/int8gemm/cuda_utils.cc"
"csrc/quantization/smoothquant/fused_kernels.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"
Expand Down
22 changes: 0 additions & 22 deletions csrc/pybind.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include "cache.h"
#include "cuda_utils.h"
#include "ops.h"
#include "quantization/smoothquant/int8gemm/int8_gemm.h"
#include <torch/extension.h>

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
Expand Down Expand Up @@ -50,21 +49,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"fused_add_rms_norm",
&fused_add_rms_norm,
"In-place fused Add and RMS Normalization");
ops.def(
"dequant",
py::overload_cast<
torch::Tensor&,
torch::Tensor&,
float>(&dequant),
"Dequant.");
ops.def(
"dequant",
py::overload_cast<
torch::Tensor&,
torch::Tensor&,
torch::Tensor&,
float>(&dequant),
"Per-token dequant.");
ops.def(
"quant",
py::overload_cast<
Expand Down Expand Up @@ -102,12 +86,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
pybind11::class_<I8CUGEMM>(ops, "I8CUGEMM")
.def(pybind11::init<>())
.def("linear_a8_w8_o32", &I8CUGEMM::linear_a8_w8_o32)
.def("linear_a8_w8_o8", &I8CUGEMM::linear_a8_w8_o8)
.def("linear_a8_w8_o8_", &I8CUGEMM::linear_a8_w8_o8_)
.def("linear_a8_w8_o32_", &I8CUGEMM::linear_a8_w8_o32_);
ops.def(
"moe_align_block_size",
&moe_align_block_size,
Expand Down
73 changes: 1 addition & 72 deletions csrc/quantization/smoothquant/fused_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,6 @@
#include "quant_utils.cuh"

namespace vllm {
template <typename scalar_t, bool use_per_token_dequant>
__global__ void dequant_kernel(
const int32_t* __restrict__ input,
scalar_t* __restrict__ out,
const float scale,
const int m,
const int hidden_size,
const int input_stride,
const int out_stride,
const float* __restrict__ act_scale = nullptr) {
const int tid = threadIdx.x;
const int token_idx = blockIdx.x;
float scale_ = scale;
if constexpr (use_per_token_dequant) {
scale_ = scale * act_scale[token_idx];
}
for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * out_stride + i] =
(scalar_t)(((float)input[token_idx * input_stride + i]) * scale_);
}
}

template <typename scalar_t, typename scale_type, bool use_per_token_quant>
__global__ void quant_kernel(
Expand Down Expand Up @@ -71,56 +50,6 @@ __global__ void quant_kernel(
}
} // namespace vllm

void dequant(
torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
float scale) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
int input_stride = input.stride(-2);
int out_stride = out.stride(-2);

const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(out.scalar_type(), "dequant_kernel", [&] {
vllm::dequant_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
input.data_ptr<int32_t>(),
out.data_ptr<scalar_t>(),
scale,
num_tokens,
hidden_size,
input_stride,
out_stride);
});
}

void dequant(
torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& scale,
float weight_dequant_scale) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
int input_stride = input.stride(-2);
int out_stride = out.stride(-2);

const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(out.scalar_type(), "dequant_kernel", [&] {
vllm::dequant_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
input.data_ptr<int32_t>(),
out.data_ptr<scalar_t>(),
weight_dequant_scale,
num_tokens,
hidden_size,
input_stride,
out_stride,
scale.data_ptr<float>());
});
}

void quant(
torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
Expand Down Expand Up @@ -159,4 +88,4 @@ void quant(
scale.data_ptr<float>(),
hidden_size);
});
}
}
232 changes: 0 additions & 232 deletions csrc/quantization/smoothquant/int8gemm/allocator.h

This file was deleted.

Loading