From 155dbfdccf84c8a993bbd175371636cb6a2b1436 Mon Sep 17 00:00:00 2001 From: SzymonOzog Date: Fri, 28 Feb 2025 10:58:46 +0000 Subject: [PATCH 1/7] Add more dtype support to matvec Signed-off-by: SzymonOzog --- csrc/quantization/gguf/gguf_kernel.cu | 205 ++++++++++++++------------ csrc/quantization/gguf/mmvq.cuh | 101 +++++++------ 2 files changed, 170 insertions(+), 136 deletions(-) diff --git a/csrc/quantization/gguf/gguf_kernel.cu b/csrc/quantization/gguf/gguf_kernel.cu index 5f0eaf5a973f..99171dc07c9c 100644 --- a/csrc/quantization/gguf/gguf_kernel.cu +++ b/csrc/quantization/gguf/gguf_kernel.cu @@ -5,6 +5,7 @@ #include #include "cuda_compat.h" +#include "dispatch_utils.h" #include "ggml-common.h" #include "vecdotq.cuh" @@ -13,7 +14,8 @@ #include "mmq.cuh" // Q8 gemv -static __global__ void quantize_q8_1(const half* __restrict__ x, +template +static __global__ void quantize_q8_1(const scalar_t* __restrict__ x, void* __restrict__ vy, const int kx, const int kx_padded) { const int ix = blockDim.x * blockIdx.x + threadIdx.x; @@ -28,7 +30,7 @@ static __global__ void quantize_q8_1(const half* __restrict__ x, const int ib = i_padded / QK8_1; // block index const int iqs = i_padded % QK8_1; // quant index - const float xi = ix < kx ? __half2float(x[iy * kx + ix]) : 0.0f; + const float xi = ix < kx ? static_cast(x[iy * kx + ix]) : 0.0f; float amax = fabsf(xi); float sum = xi; @@ -51,14 +53,16 @@ static __global__ void quantize_q8_1(const half* __restrict__ x, y[ib].ds.y = __float2half(sum); } -static void quantize_row_q8_1_cuda(const half* x, void* vy, const int kx, +template +static void quantize_row_q8_1_cuda(const scalar_t* x, void* vy, const int kx, const int ky, cudaStream_t stream) { const int64_t kx_padded = (kx + 512 - 1) / 512 * 512; const int block_num_x = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; const dim3 num_blocks(block_num_x, ky, 1); const dim3 block_size(CUDA_DEQUANTIZE_BLOCK_SIZE, 1, 1); - quantize_q8_1<<>>(x, vy, kx, kx_padded); + quantize_q8_1 + <<>>(x, vy, kx, kx_padded); } torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight @@ -79,101 +83,112 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight int col = X.sizes()[1]; const int padded = (col + 512 - 1) / 512 * 512; const at::cuda::OptionalCUDAGuard device_guard(device_of(X)); - auto options = - torch::TensorOptions().dtype(torch::kFloat16).device(W.device()); + auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device()); at::Tensor Y = torch::empty({1, row}, options); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); options = torch::TensorOptions().dtype(torch::kInt32).device(W.device()); at::Tensor quant_X = torch::empty({1, padded / 32 * 9}, options); - quantize_row_q8_1_cuda((half*)X.data_ptr(), (void*)quant_X.data_ptr(), col, 1, - stream); - switch (type) { - case 2: - mul_mat_vec_q4_0_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (half*)Y.data_ptr(), col, row, stream); - break; - case 3: - mul_mat_vec_q4_1_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (half*)Y.data_ptr(), col, row, stream); - break; - case 6: - mul_mat_vec_q5_0_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (half*)Y.data_ptr(), col, row, stream); - break; - case 7: - mul_mat_vec_q5_1_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (half*)Y.data_ptr(), col, row, stream); - break; - case 8: - mul_mat_vec_q8_0_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (half*)Y.data_ptr(), col, row, stream); - break; - case 10: - mul_mat_vec_q2_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (half*)Y.data_ptr(), col, row, stream); - break; - case 11: - mul_mat_vec_q3_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (half*)Y.data_ptr(), col, row, stream); - break; - case 12: - mul_mat_vec_q4_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (half*)Y.data_ptr(), col, row, stream); - break; - case 13: - mul_mat_vec_q5_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (half*)Y.data_ptr(), col, row, stream); - break; - case 14: - mul_mat_vec_q6_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (half*)Y.data_ptr(), col, row, stream); - break; - case 16: - mul_mat_vec_iq2_xxs_q8_1_cuda((void*)W.data_ptr(), - (void*)quant_X.data_ptr(), - (half*)Y.data_ptr(), col, row, stream); - break; - case 17: - mul_mat_vec_iq2_xs_q8_1_cuda((void*)W.data_ptr(), - (void*)quant_X.data_ptr(), - (half*)Y.data_ptr(), col, row, stream); - break; - case 18: - mul_mat_vec_iq3_xxs_q8_1_cuda((void*)W.data_ptr(), - (void*)quant_X.data_ptr(), - (half*)Y.data_ptr(), col, row, stream); - break; - case 19: - mul_mat_vec_iq1_s_q8_1_cuda((void*)W.data_ptr(), - (void*)quant_X.data_ptr(), - (half*)Y.data_ptr(), col, row, stream); - break; - case 20: - mul_mat_vec_iq4_nl_q8_1_cuda((void*)W.data_ptr(), - (void*)quant_X.data_ptr(), - (half*)Y.data_ptr(), col, row, stream); - break; - case 21: - mul_mat_vec_iq3_s_q8_1_cuda((void*)W.data_ptr(), - (void*)quant_X.data_ptr(), - (half*)Y.data_ptr(), col, row, stream); - break; - case 22: - mul_mat_vec_iq2_s_q8_1_cuda((void*)W.data_ptr(), - (void*)quant_X.data_ptr(), - (half*)Y.data_ptr(), col, row, stream); - break; - case 23: - mul_mat_vec_iq4_xs_q8_1_cuda((void*)W.data_ptr(), - (void*)quant_X.data_ptr(), - (half*)Y.data_ptr(), col, row, stream); - break; - case 29: - mul_mat_vec_iq1_m_q8_1_cuda((void*)W.data_ptr(), - (void*)quant_X.data_ptr(), - (half*)Y.data_ptr(), col, row, stream); - break; - } + VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_mul_mat_vec_a8", [&] { + quantize_row_q8_1_cuda((scalar_t*)X.data_ptr(), + (void*)quant_X.data_ptr(), col, 1, stream); + switch (type) { + case 2: + mul_mat_vec_q4_0_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, stream); + break; + case 3: + mul_mat_vec_q4_1_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, stream); + break; + case 6: + mul_mat_vec_q5_0_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, stream); + break; + case 7: + mul_mat_vec_q5_1_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, stream); + break; + case 8: + mul_mat_vec_q8_0_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, stream); + break; + case 10: + mul_mat_vec_q2_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, stream); + break; + case 11: + mul_mat_vec_q3_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, stream); + break; + case 12: + mul_mat_vec_q4_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, stream); + break; + case 13: + mul_mat_vec_q5_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, stream); + break; + case 14: + mul_mat_vec_q6_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, stream); + break; + case 16: + mul_mat_vec_iq2_xxs_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, stream); + break; + case 17: + mul_mat_vec_iq2_xs_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, stream); + break; + case 18: + mul_mat_vec_iq3_xxs_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, stream); + break; + case 19: + mul_mat_vec_iq1_s_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, stream); + break; + case 20: + mul_mat_vec_iq4_nl_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, stream); + break; + case 21: + mul_mat_vec_iq3_s_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, stream); + break; + case 22: + mul_mat_vec_iq2_s_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, stream); + break; + case 23: + mul_mat_vec_iq4_xs_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, stream); + break; + case 29: + mul_mat_vec_iq1_m_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, stream); + break; + } + }); return Y; } diff --git a/csrc/quantization/gguf/mmvq.cuh b/csrc/quantization/gguf/mmvq.cuh index b01e939808a3..d83f29745554 100644 --- a/csrc/quantization/gguf/mmvq.cuh +++ b/csrc/quantization/gguf/mmvq.cuh @@ -1,6 +1,6 @@ // copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmvq.cu -template -static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, const int ncols, const int nrows) { +template +static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst, const int ncols, const int nrows) { const int row = blockIdx.x*blockDim.y + threadIdx.y; if (row >= nrows) { @@ -33,158 +33,177 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * } if (threadIdx.x == 0) { - dst[row] = __float2half(tmp); + dst[row] = tmp; } } -static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { +template +static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } -static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { +template +static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } -static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { +template +static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } -static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { +template +static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } -static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { +template +static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } -static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { +template +static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } -static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { +template +static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } -static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { +template +static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } -static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { +template +static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } -static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { +template +static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } -static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { +template +static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } -static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { +template +static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } -static void mul_mat_vec_iq2_s_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { +template +static void mul_mat_vec_iq2_s_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } -static void mul_mat_vec_iq3_xxs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { +template +static void mul_mat_vec_iq3_xxs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } -static void mul_mat_vec_iq1_s_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { +template +static void mul_mat_vec_iq1_s_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } -static void mul_mat_vec_iq1_m_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { +template +static void mul_mat_vec_iq1_m_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } -static void mul_mat_vec_iq4_nl_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { +template +static void mul_mat_vec_iq4_nl_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } -static void mul_mat_vec_iq4_xs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { +template +static void mul_mat_vec_iq4_xs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } -static void mul_mat_vec_iq3_s_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { +template +static void mul_mat_vec_iq3_s_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } From f9bbd6f636f5e6d0b70253bdef94cecaec2e04a0 Mon Sep 17 00:00:00 2001 From: SzymonOzog Date: Fri, 28 Feb 2025 10:59:45 +0000 Subject: [PATCH 2/7] Add more dtype support to matmul Signed-off-by: SzymonOzog --- csrc/quantization/gguf/gguf_kernel.cu | 113 ++++++++++----------- csrc/quantization/gguf/mmq.cuh | 136 ++++++++++++++------------ 2 files changed, 130 insertions(+), 119 deletions(-) diff --git a/csrc/quantization/gguf/gguf_kernel.cu b/csrc/quantization/gguf/gguf_kernel.cu index 99171dc07c9c..1150bd8f2258 100644 --- a/csrc/quantization/gguf/gguf_kernel.cu +++ b/csrc/quantization/gguf/gguf_kernel.cu @@ -199,66 +199,67 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight int padded = (col + 512 - 1) / 512 * 512; int batch = X.sizes()[0]; const at::cuda::OptionalCUDAGuard device_guard(device_of(X)); - auto options = - torch::TensorOptions().dtype(torch::kFloat16).device(W.device()); + auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device()); at::Tensor Y = torch::empty({batch, row}, options); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); options = torch::TensorOptions().dtype(torch::kInt32).device(W.device()); at::Tensor quant_X = torch::empty({batch, padded / 32 * 9}, options); - quantize_row_q8_1_cuda((half*)X.data_ptr(), (void*)quant_X.data_ptr(), col, - batch, stream); + VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_mul_mat_a8", [&] { + quantize_row_q8_1_cuda((scalar_t*)X.data_ptr(), (void*)quant_X.data_ptr(), + col, batch, stream); - switch (type) { - case 2: - ggml_mul_mat_q4_0_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), - col, row, batch, padded, row, stream); - break; - case 3: - ggml_mul_mat_q4_1_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), - col, row, batch, padded, row, stream); - break; - case 6: - ggml_mul_mat_q5_0_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), - col, row, batch, padded, row, stream); - break; - case 7: - ggml_mul_mat_q5_1_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), - col, row, batch, padded, row, stream); - break; - case 8: - ggml_mul_mat_q8_0_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), - col, row, batch, padded, row, stream); - break; - case 10: - ggml_mul_mat_q2_K_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), - col, row, batch, padded, row, stream); - break; - case 11: - ggml_mul_mat_q3_K_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), - col, row, batch, padded, row, stream); - break; - case 12: - ggml_mul_mat_q4_K_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), - col, row, batch, padded, row, stream); - break; - case 13: - ggml_mul_mat_q5_K_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), - col, row, batch, padded, row, stream); - break; - case 14: - ggml_mul_mat_q6_K_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), - col, row, batch, padded, row, stream); - break; - } + switch (type) { + case 2: + ggml_mul_mat_q4_0_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream); + break; + case 3: + ggml_mul_mat_q4_1_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream); + break; + case 6: + ggml_mul_mat_q5_0_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream); + break; + case 7: + ggml_mul_mat_q5_1_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream); + break; + case 8: + ggml_mul_mat_q8_0_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream); + break; + case 10: + ggml_mul_mat_q2_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream); + break; + case 11: + ggml_mul_mat_q3_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream); + break; + case 12: + ggml_mul_mat_q4_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream); + break; + case 13: + ggml_mul_mat_q5_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream); + break; + case 14: + ggml_mul_mat_q6_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream); + break; + } + }); return Y; } diff --git a/csrc/quantization/gguf/mmq.cuh b/csrc/quantization/gguf/mmq.cuh index c935faa07df0..daa8b67dc60e 100644 --- a/csrc/quantization/gguf/mmq.cuh +++ b/csrc/quantization/gguf/mmq.cuh @@ -1,8 +1,8 @@ // copied from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmq.cu -template static __device__ __forceinline__ void mul_mat_q( - const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, + const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { const block_q_t * x = (const block_q_t *) vx; @@ -98,7 +98,7 @@ static __device__ __forceinline__ void mul_mat_q( if (row_dst >= nrows_dst) { continue; } - dst[col_dst*nrows_dst + row_dst] = __float2half(sum[i/WARP_SIZE_GGUF][j/nwarps]); + dst[col_dst*nrows_dst + row_dst] = sum[i/WARP_SIZE_GGUF][j/nwarps]; } } } @@ -113,24 +113,25 @@ static __device__ __forceinline__ void mul_mat_q( #define NWARPS_Q4_0 4 #endif -template static __global__ void +template static __global__ void #if defined(USE_ROCM) __launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q4_0, 2) #endif mul_mat_q4_0( - const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, + const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { const int mmq_x = MMQ_X_Q4_0; const int mmq_y = MMQ_Y_Q4_0; const int nwarps = NWARPS_Q4_0; - mul_mat_q, + mul_mat_q, load_tiles_q4_0, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } +template static void ggml_mul_mat_q4_0_q8_1_cuda( - const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x, + const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { int mmq_x = MMQ_X_Q4_0; @@ -144,11 +145,11 @@ static void ggml_mul_mat_q4_0_q8_1_cuda( if (nrows_x % mmq_y == 0) { const bool need_check = false; - mul_mat_q4_0<<>> + mul_mat_q4_0<<>> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } else { const bool need_check = true; - mul_mat_q4_0<<>> + mul_mat_q4_0<<>> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } } @@ -163,24 +164,25 @@ static void ggml_mul_mat_q4_0_q8_1_cuda( #define NWARPS_Q4_1 4 #endif -template static __global__ void +template static __global__ void #if defined(USE_ROCM) __launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q4_1, 2) #endif mul_mat_q4_1( - const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, + const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { const int mmq_x = MMQ_X_Q4_1; const int mmq_y = MMQ_Y_Q4_1; const int nwarps = NWARPS_Q4_1; - mul_mat_q, + mul_mat_q, load_tiles_q4_1, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } +template static void ggml_mul_mat_q4_1_q8_1_cuda( - const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x, + const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { int mmq_x = MMQ_X_Q4_1; @@ -194,11 +196,11 @@ static void ggml_mul_mat_q4_1_q8_1_cuda( if (nrows_x % mmq_y == 0) { const bool need_check = false; - mul_mat_q4_1<<>> + mul_mat_q4_1<<>> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } else { const bool need_check = true; - mul_mat_q4_1<<>> + mul_mat_q4_1<<>> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } } @@ -213,24 +215,25 @@ static void ggml_mul_mat_q4_1_q8_1_cuda( #define NWARPS_Q5_0 4 #endif -template static __global__ void +template static __global__ void #if defined(USE_ROCM) __launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q5_0, 2) #endif mul_mat_q5_0( - const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, + const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { const int mmq_x = MMQ_X_Q5_0; const int mmq_y = MMQ_Y_Q5_0; const int nwarps = NWARPS_Q5_0; - mul_mat_q, + mul_mat_q, load_tiles_q5_0, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } +template static void ggml_mul_mat_q5_0_q8_1_cuda( - const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x, + const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { const int mmq_x = MMQ_X_Q5_0; @@ -244,11 +247,11 @@ static void ggml_mul_mat_q5_0_q8_1_cuda( if (nrows_x % mmq_y == 0) { const bool need_check = false; - mul_mat_q5_0<<>> + mul_mat_q5_0<<>> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } else { const bool need_check = true; - mul_mat_q5_0<<>> + mul_mat_q5_0<<>> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } } @@ -263,24 +266,25 @@ static void ggml_mul_mat_q5_0_q8_1_cuda( #define NWARPS_Q5_1 4 #endif -template static __global__ void +template static __global__ void #if defined(USE_ROCM) __launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q5_1, 2) #endif mul_mat_q5_1( - const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, + const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { const int mmq_x = MMQ_X_Q5_1; const int mmq_y = MMQ_Y_Q5_1; const int nwarps = NWARPS_Q5_1; - mul_mat_q, + mul_mat_q, load_tiles_q5_1, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } +template static void ggml_mul_mat_q5_1_q8_1_cuda( - const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x, + const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { const int mmq_x = MMQ_X_Q5_1; const int mmq_y = MMQ_Y_Q5_1; @@ -293,11 +297,11 @@ static void ggml_mul_mat_q5_1_q8_1_cuda( if (nrows_x % mmq_y == 0) { const bool need_check = false; - mul_mat_q5_1<<>> + mul_mat_q5_1<<>> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } else { const bool need_check = true; - mul_mat_q5_1<<>> + mul_mat_q5_1<<>> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } } @@ -312,24 +316,25 @@ static void ggml_mul_mat_q5_1_q8_1_cuda( #define NWARPS_Q8_0 4 #endif -template static __global__ void +template static __global__ void #if defined(USE_ROCM) __launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q8_0, 2) #endif mul_mat_q8_0( - const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, + const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { const int mmq_x = MMQ_X_Q8_0; const int mmq_y = MMQ_Y_Q8_0; const int nwarps = NWARPS_Q8_0; - mul_mat_q, + mul_mat_q, load_tiles_q8_0, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } +template static void ggml_mul_mat_q8_0_q8_1_cuda( - const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x, + const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { const int mmq_x = MMQ_X_Q8_0; const int mmq_y = MMQ_Y_Q8_0; @@ -342,11 +347,11 @@ static void ggml_mul_mat_q8_0_q8_1_cuda( if (nrows_x % mmq_y == 0) { const bool need_check = false; - mul_mat_q8_0<<>> + mul_mat_q8_0<<>> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } else { const bool need_check = true; - mul_mat_q8_0<<>> + mul_mat_q8_0<<>> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } } @@ -361,24 +366,25 @@ static void ggml_mul_mat_q8_0_q8_1_cuda( #define NWARPS_Q2_K 4 #endif -template static __global__ void +template static __global__ void #if defined(USE_ROCM) __launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q2_K, 2) #endif mul_mat_q2_K( - const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, + const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { const int mmq_x = MMQ_X_Q2_K; const int mmq_y = MMQ_Y_Q2_K; const int nwarps = NWARPS_Q2_K; - mul_mat_q, + mul_mat_q, load_tiles_q2_K, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } +template static void ggml_mul_mat_q2_K_q8_1_cuda( - const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x, + const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { const int mmq_x = MMQ_X_Q2_K; const int mmq_y = MMQ_Y_Q2_K; @@ -391,11 +397,11 @@ static void ggml_mul_mat_q2_K_q8_1_cuda( if (nrows_x % mmq_y == 0) { const bool need_check = false; - mul_mat_q2_K<<>> + mul_mat_q2_K<<>> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } else { const bool need_check = true; - mul_mat_q2_K<<>> + mul_mat_q2_K<<>> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } } @@ -410,25 +416,26 @@ static void ggml_mul_mat_q2_K_q8_1_cuda( #define NWARPS_Q3_K 4 #endif -template static __global__ void +template static __global__ void #if defined(USE_ROCM) __launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q3_K, 2) #endif mul_mat_q3_K( - const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, + const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { const int mmq_x = MMQ_X_Q3_K; const int mmq_y = MMQ_Y_Q3_K; const int nwarps = NWARPS_Q3_K; - mul_mat_q, + mul_mat_q, load_tiles_q3_K, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } +template static void ggml_mul_mat_q3_K_q8_1_cuda( - const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x, + const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { const int mmq_x = MMQ_X_Q3_K; @@ -442,11 +449,11 @@ static void ggml_mul_mat_q3_K_q8_1_cuda( if (nrows_x % mmq_y == 0) { const bool need_check = false; - mul_mat_q3_K<<>> + mul_mat_q3_K<<>> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } else { const bool need_check = true; - mul_mat_q3_K<<>> + mul_mat_q3_K<<>> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } } @@ -461,24 +468,25 @@ static void ggml_mul_mat_q3_K_q8_1_cuda( #define NWARPS_Q4_K 4 #endif -template static __global__ void +template static __global__ void #if defined(USE_ROCM) __launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q4_K, 2) #endif mul_mat_q4_K( - const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, + const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { const int mmq_x = MMQ_X_Q4_K; const int mmq_y = MMQ_Y_Q4_K; const int nwarps = NWARPS_Q4_K; - mul_mat_q, + mul_mat_q, load_tiles_q4_K, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } +template static void ggml_mul_mat_q4_K_q8_1_cuda( - const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x, + const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { const int mmq_x = MMQ_X_Q4_K; const int mmq_y = MMQ_Y_Q4_K; @@ -491,11 +499,11 @@ static void ggml_mul_mat_q4_K_q8_1_cuda( if (nrows_x % mmq_y == 0) { const bool need_check = false; - mul_mat_q4_K<<>> + mul_mat_q4_K<<>> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } else { const bool need_check = true; - mul_mat_q4_K<<>> + mul_mat_q4_K<<>> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } } @@ -510,24 +518,25 @@ static void ggml_mul_mat_q4_K_q8_1_cuda( #define NWARPS_Q5_K 4 #endif -template static __global__ void +template static __global__ void #if defined(USE_ROCM) __launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q5_K, 2) #endif mul_mat_q5_K( - const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, + const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { const int mmq_x = MMQ_X_Q5_K; const int mmq_y = MMQ_Y_Q5_K; const int nwarps = NWARPS_Q5_K; - mul_mat_q, + mul_mat_q, load_tiles_q5_K, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } +template static void ggml_mul_mat_q5_K_q8_1_cuda( - const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x, + const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { const int mmq_x = MMQ_X_Q5_K; @@ -541,11 +550,11 @@ static void ggml_mul_mat_q5_K_q8_1_cuda( if (nrows_x % mmq_y == 0) { const bool need_check = false; - mul_mat_q5_K<<>> + mul_mat_q5_K<<>> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } else { const bool need_check = true; - mul_mat_q5_K<<>> + mul_mat_q5_K<<>> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } } @@ -560,24 +569,25 @@ static void ggml_mul_mat_q5_K_q8_1_cuda( #define NWARPS_Q6_K 4 #endif -template static __global__ void +template static __global__ void #if defined(USE_ROCM) __launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q6_K, 2) #endif mul_mat_q6_K( - const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, + const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { const int mmq_x = MMQ_X_Q6_K; const int mmq_y = MMQ_Y_Q6_K; const int nwarps = NWARPS_Q6_K; - mul_mat_q, + mul_mat_q, load_tiles_q6_K, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } +template static void ggml_mul_mat_q6_K_q8_1_cuda( - const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x, + const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { const int mmq_x = MMQ_X_Q6_K; const int mmq_y = MMQ_Y_Q6_K; @@ -590,11 +600,11 @@ static void ggml_mul_mat_q6_K_q8_1_cuda( if (nrows_x % mmq_y == 0) { const bool need_check = false; - mul_mat_q6_K<<>> + mul_mat_q6_K<<>> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } else { const bool need_check = true; - mul_mat_q6_K<<>> + mul_mat_q6_K<<>> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } } From 8d57355dd55a57a24787cb38b5c918932cc4c468 Mon Sep 17 00:00:00 2001 From: SzymonOzog Date: Fri, 28 Feb 2025 13:24:04 +0000 Subject: [PATCH 3/7] Add support python scripts Signed-off-by: SzymonOzog --- vllm/model_executor/layers/quantization/gguf.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index ba176e4a567c..d4e971776934 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -32,7 +32,7 @@ def get_name(self) -> str: return "gguf" def get_supported_act_dtypes(self) -> List[torch.dtype]: - return [torch.half] + return [torch.half, torch.bfloat16, torch.float32] @classmethod def get_min_capability(cls) -> int: @@ -134,6 +134,7 @@ def create_weights(self, layer: torch.nn.Module, output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): + self.params_dtype = params_dtype output_size_per_partition = sum(output_partition_sizes) tensor_shape = (output_size_per_partition, input_size_per_partition) @@ -326,7 +327,7 @@ def embedding(self, layer: torch.nn.Module, x_flat = x.flatten() quant = torch.index_select(qweight, dim=0, index=x_flat) dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size, - x_flat.shape[0]) + x_flat.shape[0]).to(self.params_dtype) return dequant.view(*x.shape, hidden_size) From 3a21d839d61e26c4e698627ad08f9a22d8b56c6f Mon Sep 17 00:00:00 2001 From: SzymonOzog Date: Fri, 28 Feb 2025 14:19:31 +0000 Subject: [PATCH 4/7] add tests Signed-off-by: SzymonOzog --- tests/kernels/test_gguf.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/tests/kernels/test_gguf.py b/tests/kernels/test_gguf.py index 847ca9f43105..e73c571ddc9b 100644 --- a/tests/kernels/test_gguf.py +++ b/tests/kernels/test_gguf.py @@ -23,8 +23,8 @@ def get_gguf_sample_tensors( return GGUFReader(sample_file).tensors -DTYPES = [torch.half] -# Hidden_size for testing, must match the sample file in HF repo, +DTYPES = [torch.half, torch.bfloat16, torch.float32 + ] # Hidden_size for testing, must match the sample file in HF repo, # we have `hidden_size = 256, 1024` for test in HF repo currently. HIDDEN_SIZES = [256, 1024] NUM_TOKENS = [7, 83, 128, 2048] # Arbitrary values for testing @@ -53,7 +53,7 @@ def get_gguf_sample_tensors( @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) -@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("dtype", [torch.half]) @pytest.mark.parametrize("quant_type", QUANT_TYPES) @torch.inference_mode() def test_dequantize(hidden_size: int, dtype: torch.dtype, @@ -123,7 +123,13 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype, ref_output = x @ weight.T qweight = torch.tensor(tensor.data, device="cuda") - output = ops.ggml_mul_mat_a8(qweight, x, quant_type, - qweight.shape[0]).to(dtype) - - torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1) + output = ops.ggml_mul_mat_a8(qweight, x, quant_type, qweight.shape[0]) + atols = {torch.half: 1, torch.bfloat16: 1.5, torch.float: 1.2} + # test matrix has inputs centered around 0 and lower precision from + # bfloat16 tends to accumulate and can greatly inflate rtol + # since outputs are also very close to 0 + rtols = {torch.half: 1e-1, torch.bfloat16: 1e4, torch.float: 2e1} + torch.testing.assert_close(output, + ref_output, + atol=atols[dtype], + rtol=rtols[dtype]) From 7e54164c38bee6f98a83854d0132f7fae69d1452 Mon Sep 17 00:00:00 2001 From: SzymonOzog Date: Fri, 28 Feb 2025 14:27:37 +0000 Subject: [PATCH 5/7] correct dtype for fakes Signed-off-by: SzymonOzog --- vllm/_custom_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 0e83bcaead94..a547892cbb69 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -419,7 +419,7 @@ def _ggml_mul_mat_vec_a8_fake( quant_type: int, row: torch.SymInt, ) -> torch.Tensor: - return torch.empty((1, row), dtype=torch.float16, device=W.device) + return torch.empty((1, row), dtype=X.dtype, device=W.device) @register_fake("_C::ggml_mul_mat_a8") def _ggml_mul_mat_a8_fake( @@ -429,7 +429,7 @@ def _ggml_mul_mat_a8_fake( row: torch.SymInt, ) -> torch.Tensor: batch = X.size(0) - return torch.empty((batch, row), dtype=torch.float16, device=W.device) + return torch.empty((batch, row), dtype=X.dtype, device=W.device) # cutlass From fa79d19fb7b4fde1fe96dd87ea1141f02e6f30e3 Mon Sep 17 00:00:00 2001 From: SzymonOzog Date: Mon, 3 Mar 2025 14:37:08 +0100 Subject: [PATCH 6/7] fix comment alignment Signed-off-by: SzymonOzog --- tests/kernels/test_gguf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/kernels/test_gguf.py b/tests/kernels/test_gguf.py index e73c571ddc9b..d3e3d1e62c6d 100644 --- a/tests/kernels/test_gguf.py +++ b/tests/kernels/test_gguf.py @@ -23,8 +23,8 @@ def get_gguf_sample_tensors( return GGUFReader(sample_file).tensors -DTYPES = [torch.half, torch.bfloat16, torch.float32 - ] # Hidden_size for testing, must match the sample file in HF repo, +DTYPES = [torch.half, torch.bfloat16, torch.float32] +# Hidden_size for testing, must match the sample file in HF repo, # we have `hidden_size = 256, 1024` for test in HF repo currently. HIDDEN_SIZES = [256, 1024] NUM_TOKENS = [7, 83, 128, 2048] # Arbitrary values for testing From 126d4011593979df31cfda47222ad4f7091a388a Mon Sep 17 00:00:00 2001 From: SzymonOzog Date: Mon, 10 Mar 2025 09:10:34 +0000 Subject: [PATCH 7/7] another bounds check Signed-off-by: SzymonOzog --- csrc/quantization/gguf/mmq.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/quantization/gguf/mmq.cuh b/csrc/quantization/gguf/mmq.cuh index daa8b67dc60e..e2b93680ffb5 100644 --- a/csrc/quantization/gguf/mmq.cuh +++ b/csrc/quantization/gguf/mmq.cuh @@ -38,7 +38,7 @@ static __device__ __forceinline__ void mul_mat_q( threadIdx.y, nrows_x-row_x_0-1, threadIdx.x, blocks_per_row_x); #pragma unroll - for (int ir = 0; ir < qr; ++ir) { + for (int ir = 0; ir < qr && ib0 + ir * blocks_per_warp/qr < blocks_per_row_x; ++ir) { const int kqs = ir*WARP_SIZE_GGUF + threadIdx.x; const int kbxd = kqs / QI8_1;