From 62fa8a8e6af4d187ae89ef7f3cda4248dec9df85 Mon Sep 17 00:00:00 2001 From: arlo-phoenix Date: Fri, 12 Jan 2024 17:55:05 +0100 Subject: [PATCH] Adjust kQuantizeBlockwise to work with WARP size 64 --- csrc/kernels.cu | 37 ++++++++++++++++++++++--------------- csrc/ops.cuh | 6 +----- 2 files changed, 23 insertions(+), 20 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 89af37378..5874d4979 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -740,21 +740,28 @@ template 0) ? NUM_PER_TH/2 : NUM_PER_TH; + const int n_full = gridDim.x * BLOCK_SIZE; int valid_items = 0; const int base_idx = (blockIdx.x * BLOCK_SIZE); - T vals[NUM_PER_TH]; - float rand_vals[NUM_PER_TH]; - unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH]; + T vals[CUB_NUM_PER_TH]; + float rand_vals[CUB_NUM_PER_TH]; + unsigned char qvals[DATA_NUM_PER_TH]; //float local_abs_max = -FLT_MAX; float local_abs_max = 0.0f; int local_rand_idx = 0; - typedef cub::BlockLoad LoadT; - typedef cub::BlockStore 0) ? NUM_PER_TH/2 : NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar; - typedef cub::BlockReduce BlockReduce; - typedef cub::BlockLoad LoadFloat; + typedef cub::BlockLoad LoadT; + typedef cub::BlockStore StoreChar; + typedef cub::BlockReduce BlockReduce; + typedef cub::BlockLoad LoadFloat; __shared__ typename LoadT::TempStorage loadt; __shared__ typename LoadFloat::TempStorage loadf; @@ -779,8 +786,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float // 2. broadcast local max // 3. normalize inputs and quantize - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH; j++) + #pragma unroll CUB_NUM_PER_TH + for(int j = 0; j < CUB_NUM_PER_TH; j++) local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j])); local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, cub::Max(), valid_items); @@ -809,8 +816,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float switch(DATA_TYPE) { case General8bit: - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH; j++) + #pragma unroll CUB_NUM_PER_TH + for(int j = 0; j < CUB_NUM_PER_TH; j++) { if(!STOCHASTIC) qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max); @@ -819,8 +826,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float } break; case FP4: - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH/2; j++) + #pragma unroll CUB_NUM_PER_TH + for(int j = 0; j < DATA_NUM_PER_TH; j++) { packed_4bit |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4; packed_4bit |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max); @@ -828,8 +835,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float } break; case NF4: - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH/2; j++) + #pragma unroll CUB_NUM_PER_TH + for(int j = 0; j < DATA_NUM_PER_TH; j++) { packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4; packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 3584e5982..87203edae 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -14,10 +14,6 @@ #ifdef BNB_USE_HIP -// check rocminfo | grep "Wavefront Size". Should be supported on all new GPU's -// dirty hack to force wavefront_size 32 so this compiles -// RDNA 2 defaults to 64 which conflicts with kQuantizeBlockwise -#define __AMDGCN_WAVEFRONT_SIZE 32 #include #include @@ -58,7 +54,7 @@ #define cublasLtHandle_t hipblasLtHandle_t #define cublasLtCreate hipblasLtCreate #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT -#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT //TODO: HIP didn't have the right one, might cause issues +#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT #else #include