From cae76f8f67f5fb2af86155e25eb8d6c93e741e39 Mon Sep 17 00:00:00 2001 From: SzymonOzog Date: Tue, 27 May 2025 07:49:53 +0000 Subject: [PATCH 1/4] mmvq multiple vectors Signed-off-by: SzymonOzog --- csrc/quantization/gguf/gguf_kernel.cu | 48 +++---- csrc/quantization/gguf/mmvq.cuh | 125 +++++++++--------- .../layers/quantization/gguf.py | 6 +- 3 files changed, 93 insertions(+), 86 deletions(-) diff --git a/csrc/quantization/gguf/gguf_kernel.cu b/csrc/quantization/gguf/gguf_kernel.cu index 6c146c3fb6fd..f3d4f0b0388d 100644 --- a/csrc/quantization/gguf/gguf_kernel.cu +++ b/csrc/quantization/gguf/gguf_kernel.cu @@ -92,111 +92,113 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight torch::Tensor X, // input int64_t type, int64_t row) { int col = X.sizes()[1]; + int vecs = X.sizes()[0]; const int padded = (col + 512 - 1) / 512 * 512; const at::cuda::OptionalCUDAGuard device_guard(device_of(X)); auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device()); - at::Tensor Y = torch::empty({1, row}, options); + printf("rinning vecs %d row %d \n", vecs, (int)row); + at::Tensor Y = torch::empty({vecs, row}, options); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); options = torch::TensorOptions().dtype(torch::kInt32).device(W.device()); - at::Tensor quant_X = torch::empty({1, padded / 32 * 9}, options); + at::Tensor quant_X = torch::empty({vecs, padded / 32 * 9}, options); VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_mul_mat_vec_a8", [&] { - quantize_row_q8_1_cuda((scalar_t*)X.data_ptr(), - (void*)quant_X.data_ptr(), col, 1, stream); + quantize_row_q8_1_cuda( + (scalar_t*)X.data_ptr(), (void*)quant_X.data_ptr(), col, vecs, stream); switch (type) { case 2: mul_mat_vec_q4_0_q8_1_cuda( (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, stream); + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); break; case 3: mul_mat_vec_q4_1_q8_1_cuda( (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, stream); + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); break; case 6: mul_mat_vec_q5_0_q8_1_cuda( (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, stream); + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); break; case 7: mul_mat_vec_q5_1_q8_1_cuda( (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, stream); + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); break; case 8: mul_mat_vec_q8_0_q8_1_cuda( (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, stream); + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); break; case 10: mul_mat_vec_q2_K_q8_1_cuda( (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, stream); + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); break; case 11: mul_mat_vec_q3_K_q8_1_cuda( (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, stream); + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); break; case 12: mul_mat_vec_q4_K_q8_1_cuda( (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, stream); + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); break; case 13: mul_mat_vec_q5_K_q8_1_cuda( (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, stream); + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); break; case 14: mul_mat_vec_q6_K_q8_1_cuda( (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, stream); + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); break; case 16: mul_mat_vec_iq2_xxs_q8_1_cuda( (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, stream); + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); break; case 17: mul_mat_vec_iq2_xs_q8_1_cuda( (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, stream); + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); break; case 18: mul_mat_vec_iq3_xxs_q8_1_cuda( (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, stream); + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); break; case 19: mul_mat_vec_iq1_s_q8_1_cuda( (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, stream); + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); break; case 20: mul_mat_vec_iq4_nl_q8_1_cuda( (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, stream); + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); break; case 21: mul_mat_vec_iq3_s_q8_1_cuda( (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, stream); + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); break; case 22: mul_mat_vec_iq2_s_q8_1_cuda( (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, stream); + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); break; case 23: mul_mat_vec_iq4_xs_q8_1_cuda( (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, stream); + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); break; case 29: mul_mat_vec_iq1_m_q8_1_cuda( (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, stream); + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); break; } }); diff --git a/csrc/quantization/gguf/mmvq.cuh b/csrc/quantization/gguf/mmvq.cuh index 687cb0a37410..87cc1c02af6c 100644 --- a/csrc/quantization/gguf/mmvq.cuh +++ b/csrc/quantization/gguf/mmvq.cuh @@ -1,16 +1,17 @@ // copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmvq.cu template -static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst, const int ncols, const int nrows) { +static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst, const int ncols, const int nrows, const int nvecs) { const auto row = blockIdx.x*blockDim.y + threadIdx.y; + const auto vec = blockIdx.y; - if (row >= nrows) { + if (row >= nrows || vec >= nvecs) { return; } const int blocks_per_row = ncols / qk; const int blocks_per_warp = vdr * WARP_SIZE / qi; -// partial sum for each thread + // partial sum for each thread float tmp = 0.0f; const block_q_t * x = (const block_q_t *) vx; @@ -19,7 +20,7 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * for (auto i = threadIdx.x / (qi/vdr); i < blocks_per_row; i += blocks_per_warp) { const int ibx = row*blocks_per_row + i; // x block index - const int iby = i * (qk/QK8_1); // y block index that aligns with ibx + const int iby = vec*blocks_per_row + i * (qk/QK8_1); // y block index that aligns with ibx const int iqs = vdr * (threadIdx.x % (qi/vdr)); // x block quant index when casting the quants to int @@ -33,177 +34,177 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * } if (threadIdx.x == 0) { - dst[row] = tmp; + dst[vec*nrows + row] = tmp; } } template -static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_nums(block_num_y, nvecs, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); + <<>>(vx, vy, dst, ncols, nrows, nvecs); } template -static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_nums(block_num_y, nvecs, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); + <<>>(vx, vy, dst, ncols, nrows, nvecs); } template -static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_nums(block_num_y, nvecs, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); + <<>>(vx, vy, dst, ncols, nrows, nvecs); } template -static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_nums(block_num_y, nvecs, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); + <<>>(vx, vy, dst, ncols, nrows, nvecs); } template -static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_nums(block_num_y, nvecs, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); + <<>>(vx, vy, dst, ncols, nrows, nvecs); } template -static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_nums(block_num_y, nvecs, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); + <<>>(vx, vy, dst, ncols, nrows, nvecs); } template -static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_nums(block_num_y, nvecs, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); + <<>>(vx, vy, dst, ncols, nrows, nvecs); } template -static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_nums(block_num_y, nvecs, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); + <<>>(vx, vy, dst, ncols, nrows, nvecs); } template -static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_nums(block_num_y, nvecs, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); + <<>>(vx, vy, dst, ncols, nrows, nvecs); } template -static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_nums(block_num_y, nvecs, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); + <<>>(vx, vy, dst, ncols, nrows, nvecs); } template -static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_nums(block_num_y, nvecs, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); + <<>>(vx, vy, dst, ncols, nrows, nvecs); } template -static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_nums(block_num_y, nvecs, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); + <<>>(vx, vy, dst, ncols, nrows, nvecs); } template -static void mul_mat_vec_iq2_s_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void mul_mat_vec_iq2_s_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_nums(block_num_y, nvecs, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); + <<>>(vx, vy, dst, ncols, nrows, nvecs); } template -static void mul_mat_vec_iq3_xxs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void mul_mat_vec_iq3_xxs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_nums(block_num_y, nvecs, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); + <<>>(vx, vy, dst, ncols, nrows, nvecs); } template -static void mul_mat_vec_iq1_s_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void mul_mat_vec_iq1_s_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_nums(block_num_y, nvecs, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); + <<>>(vx, vy, dst, ncols, nrows, nvecs); } template -static void mul_mat_vec_iq1_m_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void mul_mat_vec_iq1_m_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_nums(block_num_y, nvecs, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); + <<>>(vx, vy, dst, ncols, nrows, nvecs); } template -static void mul_mat_vec_iq4_nl_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void mul_mat_vec_iq4_nl_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_nums(block_num_y, nvecs, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); + <<>>(vx, vy, dst, ncols, nrows, nvecs); } template -static void mul_mat_vec_iq4_xs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void mul_mat_vec_iq4_xs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_nums(block_num_y, nvecs, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); + <<>>(vx, vy, dst, ncols, nrows, nvecs); } template -static void mul_mat_vec_iq3_s_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void mul_mat_vec_iq3_s_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_nums(block_num_y, nvecs, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); + <<>>(vx, vy, dst, ncols, nrows, nvecs); } diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index d7d4a5d6acdb..49dcae4bafcb 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -98,6 +98,10 @@ def get_quant_method(self, layer: torch.nn.Module, def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor: + if qweight_type in MMQ_QUANT_TYPES: + mmvq_safe = 2 if qweight.shape[0] > 5120 else 6 + else: + mmvq_safe = 8 if qweight.shape[0] > 5120 else 16 # HACK: when doing chunked prefill we don't generate output tokens # so input to logits generator is empty which causes invalid parameter if x.shape[0] == 0: @@ -109,7 +113,7 @@ def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor, if qweight_type in UNQUANTIZED_TYPES: return x @ qweight.T # enable MMVQ in contiguous batching with batch_size=1 - if x.shape[0] == 1 and qweight_type in MMVQ_QUANT_TYPES: + if x.shape[0] <= mmvq_safe and qweight_type in MMVQ_QUANT_TYPES: y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0]) # Use MMQ Kernel if it's available (standard + k-quants) elif qweight_type in MMQ_QUANT_TYPES: From 4609aa06cf16907564385e195034c2175fd1ca97 Mon Sep 17 00:00:00 2001 From: SzymonOzog Date: Tue, 27 May 2025 08:05:20 +0000 Subject: [PATCH 2/4] cleanup Signed-off-by: SzymonOzog --- csrc/quantization/gguf/gguf_kernel.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/csrc/quantization/gguf/gguf_kernel.cu b/csrc/quantization/gguf/gguf_kernel.cu index f3d4f0b0388d..3b5180b51623 100644 --- a/csrc/quantization/gguf/gguf_kernel.cu +++ b/csrc/quantization/gguf/gguf_kernel.cu @@ -96,7 +96,6 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight const int padded = (col + 512 - 1) / 512 * 512; const at::cuda::OptionalCUDAGuard device_guard(device_of(X)); auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device()); - printf("rinning vecs %d row %d \n", vecs, (int)row); at::Tensor Y = torch::empty({vecs, row}, options); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); options = torch::TensorOptions().dtype(torch::kInt32).device(W.device()); From da62a6e78a0ca9218d6acab60a06f05e043e1b27 Mon Sep 17 00:00:00 2001 From: SzymonOzog Date: Tue, 27 May 2025 09:26:34 +0000 Subject: [PATCH 3/4] incorporate feedback Signed-off-by: SzymonOzog --- vllm/model_executor/layers/quantization/gguf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 15da5e161024..49d022369273 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -98,10 +98,10 @@ def get_quant_method(self, layer: torch.nn.Module, def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor: - if qweight_type in MMQ_QUANT_TYPES: - mmvq_safe = 2 if qweight.shape[0] > 5120 else 6 - else: + if qweight_type in IMATRIX_QUANT_TYPES: mmvq_safe = 8 if qweight.shape[0] > 5120 else 16 + else: + mmvq_safe = 2 if qweight.shape[0] > 5120 else 6 # HACK: when doing chunked prefill we don't generate output tokens # so input to logits generator is empty which causes invalid parameter if x.shape[0] == 0: From 251d0734ad0356338f54466b2146c723900684e4 Mon Sep 17 00:00:00 2001 From: SzymonOzog Date: Tue, 3 Jun 2025 02:51:51 -0700 Subject: [PATCH 4/4] fix fake shape and stride Signed-off-by: SzymonOzog --- csrc/quantization/gguf/mmvq.cuh | 4 +++- vllm/_custom_ops.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/csrc/quantization/gguf/mmvq.cuh b/csrc/quantization/gguf/mmvq.cuh index 87cc1c02af6c..e27bec7af5b7 100644 --- a/csrc/quantization/gguf/mmvq.cuh +++ b/csrc/quantization/gguf/mmvq.cuh @@ -10,6 +10,8 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * const int blocks_per_row = ncols / qk; const int blocks_per_warp = vdr * WARP_SIZE / qi; + const int nrows_y = (ncols + 512 - 1) / 512 * 512; + // partial sum for each thread float tmp = 0.0f; @@ -20,7 +22,7 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * for (auto i = threadIdx.x / (qi/vdr); i < blocks_per_row; i += blocks_per_warp) { const int ibx = row*blocks_per_row + i; // x block index - const int iby = vec*blocks_per_row + i * (qk/QK8_1); // y block index that aligns with ibx + const int iby = vec*(nrows_y/QK8_1) + i * (qk/QK8_1); // y block index that aligns with ibx const int iqs = vdr * (threadIdx.x % (qi/vdr)); // x block quant index when casting the quants to int diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 3c8e6b95ce76..aaba11947206 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -555,7 +555,7 @@ def _ggml_mul_mat_vec_a8_fake( quant_type: int, row: torch.SymInt, ) -> torch.Tensor: - return torch.empty((1, row), dtype=X.dtype, device=W.device) + return torch.empty((X.shape[0], row), dtype=X.dtype, device=W.device) @register_fake("_C::ggml_mul_mat_a8") def _ggml_mul_mat_a8_fake(