Skip to content

Commit 182827b

Browse files
committed
[Bugfix][Kernel]: Fix AllSpark kernel compilation errors and enable for CUDA version <12.0
Signed-off-by: wyj371990 <wyj371990@alibaba-inc.com>
1 parent e22ee1e commit 182827b

File tree

3 files changed

+15
-10
lines changed

3 files changed

+15
-10
lines changed

CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
319319

320320
# Only build AllSpark kernels if we are building for at least some compatible archs.
321321
cuda_archs_loose_intersection(ALLSPARK_ARCHS "8.0;8.6;8.7;8.9" "${CUDA_ARCHS}")
322-
if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND ALLSPARK_ARCHS)
322+
if (ALLSPARK_ARCHS)
323323
set(ALLSPARK_SRCS
324324
"csrc/quantization/gptq_allspark/allspark_repack.cu"
325325
"csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu")
@@ -330,7 +330,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
330330
message(STATUS "Building AllSpark kernels for archs: ${ALLSPARK_ARCHS}")
331331
else()
332332
message(STATUS "Not building AllSpark kernels as no compatible archs found"
333-
" in CUDA target architectures, or CUDA not >= 12.0")
333+
" in CUDA target architectures")
334334
endif()
335335

336336

csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -437,9 +437,10 @@ struct ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
437437
for (int n_idx = 0; n_idx < WARP_NITER; ++n_idx) {
438438
#pragma unroll
439439
for (int k_idx = 0; k_idx < 2; ++k_idx) {
440-
FType low16 = static_cast<FType>(C_frag[m_idx][n_idx][k_idx * 2]);
440+
FType low16 =
441+
ScalarType<FType>::float2num(C_frag[m_idx][n_idx][k_idx * 2]);
441442
FType high16 =
442-
static_cast<FType>(C_frag[m_idx][n_idx][k_idx * 2 + 1]);
443+
ScalarType<FType>::float2num(C_frag[m_idx][n_idx][k_idx * 2 + 1]);
443444
uint32_t tmp = (reinterpret_cast<uint32_t&>(low16) & 0xffff) |
444445
(reinterpret_cast<uint32_t&>(high16) << 16);
445446
int sts_offset =
@@ -793,7 +794,7 @@ __global__ void restore_N32_K16_dequantize_rhs_w8a16_perc_kernel(
793794
FT scale_reg[4];
794795
*(reinterpret_cast<uint2*>(scale_reg)) =
795796
*(reinterpret_cast<const uint2*>(scales + params_nidx));
796-
FT zero_reg[4] = {0};
797+
FT zero_reg[4];
797798
if (zeros != nullptr) {
798799
*(reinterpret_cast<uint2*>(zero_reg)) =
799800
*(reinterpret_cast<const uint2*>(zeros + params_nidx));
@@ -809,8 +810,10 @@ __global__ void restore_N32_K16_dequantize_rhs_w8a16_perc_kernel(
809810
reinterpret_cast<typename HalfType<FT>::T2*>(&(fval_reg[ni * 4])));
810811
#pragma unroll
811812
for (int ki = 0; ki < 4; ++ki) {
812-
fval_reg[ni * 4 + ki] =
813-
(fval_reg[ni * 4 + ki] - zero_reg[ni]) * scale_reg[ni];
813+
if (zeros != nullptr) {
814+
fval_reg[ni * 4 + ki] = __hsub(fval_reg[ni * 4 + ki], zero_reg[ni]);
815+
}
816+
fval_reg[ni * 4 + ki] = __hmul(fval_reg[ni * 4 + ki], scale_reg[ni]);
814817
int sts_offset = sts_base_offset + ((ki / 2) * 8 + (ki % 2)) * 32 +
815818
((ni + lane_id % 4) % 4) * 8;
816819
smem[sts_offset] = fval_reg[ni * 4 + ki];

csrc/quantization/gptq_allspark/allspark_utils.cuh

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#include <cuda_fp16.h>
88
#include <cuda_bf16.h>
99
#include <iostream>
10+
#include "../gptq_marlin/marlin_dtypes.cuh"
11+
using marlin::ScalarType;
1012

1113
namespace allspark {
1214

@@ -66,14 +68,14 @@ __global__ void f16_gemm_splitk_reduce_kernel(const FType* C_split, FType* C,
6668
return;
6769
}
6870

69-
FType sum(0);
71+
float sum = 0.f;
7072

7173
int n_mat = N_MATRIX > 0 ? N_MATRIX : (int)n_matrix;
7274
for (int i = 0; i < n_mat; ++i) {
73-
sum += C_split[idx + i * matrix_size];
75+
sum += ScalarType<FType>::num2float(C_split[idx + i * matrix_size]);
7476
}
7577

76-
C[idx] = sum;
78+
C[idx] = ScalarType<FType>::float2num(sum);
7779
}
7880

7981
template <typename FType>

0 commit comments

Comments
 (0)