diff --git a/csrc/dequant_impl_packed.cu b/csrc/dequant_impl_packed.cu index f379064..8e07847 100644 --- a/csrc/dequant_impl_packed.cu +++ b/csrc/dequant_impl_packed.cu @@ -18,6 +18,11 @@ struct C10ToNvType { typedef __half type; }; +template <> +struct C10ToNvType { + typedef float type; +}; + template __global__ void WqA16WithOutliers_PackIndice( scalar_t* out, const scalar_t* input_data, const int32_t* q_indice, const uint16_t* q_indice_outliers, @@ -25,6 +30,7 @@ __global__ void WqA16WithOutliers_PackIndice( const scalar_t* outliers_centroids, const uint16_t* invert_perm, const scalar_t* weight_scale, const scalar_t* weight_bias, const scalar_t* bias, int out_features, int in_features, int outliers_infeatures, const int index_stride_0, const int index_stride_1, const int centroids_stride_0, const int group_nums) { + static_assert((GROUPSIZE & 1) == 0, "GROUPSIZE must be even "); int bidx = blockIdx.x; // out_features//base_groupsize int bidy = blockIdx.y; // batch int bidz = blockIdx.z; // segment in_features @@ -34,10 +40,6 @@ __global__ void WqA16WithOutliers_PackIndice( tidx += bidz * cuda::kBlockSize * Do_Reduce; } int in_y = bidx; - __shared__ scalar_t shared_memory[1]; // 3xin_features, dynamic - scalar_t* shared_input = shared_memory; // in_features, dynamic - // scalar_t* shared_w_scales = shared_memory+in_features;// in_features, dynamic - scalar_t* shared_w_bias = shared_memory + in_features; // in_features, dynamic __shared__ float shared_output[GROUPSIZE][cuda::kBlockSize / 32 + 1]; scalar_t tmp_output[GROUPSIZE]; #pragma unroll @@ -46,14 +48,6 @@ __global__ void WqA16WithOutliers_PackIndice( } input_data = input_data + in_features * bidy; out = out + out_features * bidy * gridDim.z; - if constexpr (Do_Reduce == 0) { - for (int i = tidx; i < in_features; i += cuda::kBlockSize) { - int w_col = invert_perm ? invert_perm[i] : i; - shared_input[i] = input_data[w_col] * weight_scale[w_col]; - shared_w_bias[i] = input_data[w_col] * weight_bias[w_col]; - } - __syncthreads(); - } if (tidx >= in_features) { return; } @@ -64,10 +58,10 @@ __global__ void WqA16WithOutliers_PackIndice( // const scalar_t scale = shared_w_scales[col]; const int w_col = Do_Reduce ? (invert_perm ? invert_perm[col] : col) : 0; const scalar_t input_col_v = input_data[w_col]; - const scalar_t bias = Do_Reduce ? input_col_v * weight_bias[w_col] : shared_w_bias[col]; - scalar_t input_v = Do_Reduce ? input_col_v * weight_scale[w_col] : shared_input[col]; - VecType input_v2 = VecType(input_v, input_v); - VecType bias2 = VecType(bias, bias); + const scalar_t bias = input_col_v * weight_bias[w_col]; + scalar_t input_v = input_col_v * weight_scale[w_col]; + VecType input_v2 = VecType{input_v, input_v}; + VecType bias2 = VecType{bias, bias}; int32_t mapped_index_x = col; if (mapped_index_x < outliers_infeatures) { @@ -84,14 +78,13 @@ __global__ void WqA16WithOutliers_PackIndice( scalar_t* tmp_output_off_p = tmp_output + gi; scalar_t scalar_weight[OL_GroupSize]; if (out_y < out_features) { - cuda::ldg_vec_x(reinterpret_cast(scalar_weight), - (const uint32_t*)outliers_centroids_start); + cuda::ldg_vec_x((scalar_weight), (const uint32_t*)outliers_centroids_start); VecType* weight_h2 = (VecType*)scalar_weight; VecType* tmp_output_off_h2 = (VecType*)tmp_output_off_p; - tmp_output_off_h2[0] = __hfma2(weight_h2[0], input_v2, tmp_output_off_h2[0]); - tmp_output_off_h2[1] = __hfma2(weight_h2[1], input_v2, tmp_output_off_h2[1]); - tmp_output_off_h2[0] = __hadd2(tmp_output_off_h2[0], bias2); - tmp_output_off_h2[1] = __hadd2(tmp_output_off_h2[1], bias2); + tmp_output_off_h2[0] = FMA2(weight_h2[0], input_v2, tmp_output_off_h2[0]); + tmp_output_off_h2[1] = FMA2(weight_h2[1], input_v2, tmp_output_off_h2[1]); + tmp_output_off_h2[0] = ADD2(tmp_output_off_h2[0], bias2); + tmp_output_off_h2[1] = ADD2(tmp_output_off_h2[1], bias2); } } } else { @@ -113,21 +106,21 @@ __global__ void WqA16WithOutliers_PackIndice( const uint32_t base_ind = merged_ind & ((1 << IDXBITS) - 1); const scalar_t* centroids_start = (centroids_cb) + base_ind * GROUPSIZE; - cuda::ldg_vec_x(reinterpret_cast(base), (const uint32_t*)(centroids_start)); + cuda::ldg_vec_x((base), (const uint32_t*)(centroids_start)); VecType* hres_ptr = nullptr; if constexpr (ResidualBits > 0) { scalar_t residual[GROUPSIZE]; const uint32_t res_ind = (merged_ind >> IDXBITS) & ((1 << ResidualBits) - 1); const scalar_t* residual_centroids_start = (residual_centroids_cb) + res_ind * GROUPSIZE; - cuda::ldg_vec_x(reinterpret_cast(residual), (const uint32_t*)(residual_centroids_start)); + cuda::ldg_vec_x((residual), (const uint32_t*)(residual_centroids_start)); VecType hres[GROUPSIZE / 2]; hres_ptr = hres; #pragma unroll for (int i = 0; i < GROUPSIZE / 2; i++) { - hres[i] = __hadd2(*(((VecType*)base) + i), *(((VecType*)residual) + i)); - // hres[i] = __hfma2(hres[i], scale2, bias2); + hres[i] = ADD2(*(((VecType*)base) + i), *(((VecType*)residual) + i)); + // hres[i] = FMA2(hres[i], scale2, bias2); } } else { hres_ptr = (VecType*)base; @@ -141,8 +134,8 @@ __global__ void WqA16WithOutliers_PackIndice( VecType* h2_tmp_output = (VecType*)tmp_output; #pragma unroll for (int gi = 0; gi < GROUPSIZE / 2; gi++) { - h2_tmp_output[gi] = __hfma2(hres_ptr[gi], input_v2, h2_tmp_output[gi]); - h2_tmp_output[gi] = __hadd2(h2_tmp_output[gi], bias2); + h2_tmp_output[gi] = FMA2(hres_ptr[gi], input_v2, h2_tmp_output[gi]); + h2_tmp_output[gi] = ADD2(h2_tmp_output[gi], bias2); } } } @@ -246,26 +239,26 @@ __global__ void DequantizeWithOutliers_PackIndice(scalar_t* out, const int32_t* const uint16_t base_ind = merged_ind & ((1 << IDXBITS) - 1); VecType base[GROUPSIZE / 2]; const scalar_t* centroids_start = centroids + base_ind * GROUPSIZE; - cuda::ldg_vec_x((uint32_t*)(base), (const uint32_t*)(centroids_start)); + cuda::ldg_vec_x((base), (const uint32_t*)(centroids_start)); if constexpr (ResidualBits > 0) { VecType residual[GROUPSIZE / 2]; merged_ind >>= IDXBITS; const uint16_t res_ind = merged_ind & ((1 << ResidualBits) - 1); const scalar_t* residual_centroids_start = residual_centroids + res_ind * GROUPSIZE; - cuda::ldg_vec_x((uint32_t*)(residual), (const uint32_t*)(residual_centroids_start)); + cuda::ldg_vec_x((residual), (const uint32_t*)(residual_centroids_start)); #pragma unroll for (int i = 0; i < GROUPSIZE / 2; i++) { - base[i] = __hadd2(*(((VecType*)base) + i), *(((VecType*)residual) + i)); + base[i] = ADD2(*(((VecType*)base) + i), *(((VecType*)residual) + i)); } } VecType hres[GROUPSIZE / 2]; - VecType scale2 = VecType(scale, scale); - VecType bias2 = VecType(bias, bias); + VecType scale2 = VecType{scale, scale}; + VecType bias2 = VecType{bias, bias}; #pragma unroll for (int i = 0; i < GROUPSIZE / 2; i++) { - hres[i] = __hfma2(base[i], scale2, bias2); + hres[i] = FMA2(base[i], scale2, bias2); } scalar_t* res = (scalar_t*)hres; const int group_step = in_y * GROUPSIZE; @@ -298,8 +291,8 @@ torch::Tensor lauch_deqantize_outliers_cuda_packkernel( OptionalCUDAGuard cudaguard(q_indice.device().index()); int base_groupsize = centroids.size(-1); // how many elements in a vector int res_groupsize = residual_centroids.has_value() ? residual_centroids.value().size(-1) : 0; - // TORCH_CHECK((res_groupsize===base_groupsize||res_groupsize==0), "res_groupsize===base_groupsize is false, must be - // true"); + TORCH_CHECK(((res_groupsize == base_groupsize) || (res_groupsize == 0)), + "res_groupsize==base_groupsize is false, must be true"); int index_bits = log2(centroids.size(1)); // how many bits to index quantization vector int res_index_bits = residual_centroids.has_value() ? log2(residual_centroids.value().size(1)) : 0; auto out_size = outf_x_inf; @@ -337,26 +330,18 @@ torch::Tensor lauch_deqantize_outliers_cuda_packkernel( outliers_indices_size_n1, outliers_centroids_size_n1, q_indice.stride(0), q_indice.stride(1), \ centroids.stride(0), q_indice.size(0)); \ } -#if __CUDA_ARCH__ < 800 - #define callDequantWithOutliers_dtype(IDXBITS, BASEGROUP, OUT_OUF_INF, ResidualBits) \ - if (centroids.dtype() == at::ScalarType::Half) { \ - using scalar_t = c10::Half; \ - callDequantWithOutliers(scalar_t, IDXBITS, BASEGROUP, OUT_OUF_INF, ResidualBits); \ - } else { \ - TORCH_CHECK(false, "un-supported dtype: bfloat16"); \ - } -#else - #define callDequantWithOutliers_dtype(IDXBITS, BASEGROUP, OUT_OUF_INF, ResidualBits) \ - if (centroids.dtype() == at::ScalarType::Half) { \ - using scalar_t = c10::Half; \ - callDequantWithOutliers(scalar_t, IDXBITS, BASEGROUP, OUT_OUF_INF, ResidualBits); \ - } else { \ - using scalar_t = c10::BFloat16; \ - callDequantWithOutliers(scalar_t, IDXBITS, BASEGROUP, OUT_OUF_INF, ResidualBits); \ - } - -#endif +#define callDequantWithOutliers_dtype(IDXBITS, BASEGROUP, OUT_OUF_INF, ResidualBits) \ + if (centroids.dtype() == at::ScalarType::Half) { \ + using scalar_t = c10::Half; \ + callDequantWithOutliers(scalar_t, IDXBITS, BASEGROUP, OUT_OUF_INF, ResidualBits); \ + } else if (centroids.dtype() == at::ScalarType::Float) { \ + using scalar_t = float; \ + callDequantWithOutliers(scalar_t, IDXBITS, BASEGROUP, OUT_OUF_INF, ResidualBits); \ + } else { \ + using scalar_t = c10::BFloat16; \ + callDequantWithOutliers(scalar_t, IDXBITS, BASEGROUP, OUT_OUF_INF, ResidualBits); \ + } #define callDequantWithOutliers_bits(BASEGROUP, OUT_OUF_INF, ResidualBits) \ switch (index_bits) { \ @@ -516,24 +501,17 @@ torch::Tensor lauch_gemv_outliers_cuda_packkernel( centroids.stride(0), q_indice.size(0)); \ } -#if __CUDA_ARCH__ < 800 - #define CallWqA16kernel_dtype(out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits) \ - if (input.dtype() == at::ScalarType::Half) { \ - using scalar_t = c10::Half; \ - CallWqA16kernel(scalar_t, out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits); \ - } else { \ - TORCH_CHECK(false, "un-supported dtype: bfloat16"); \ - } -#else - #define CallWqA16kernel_dtype(out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits) \ - if (input.dtype() == at::ScalarType::Half) { \ - using scalar_t = c10::Half; \ - CallWqA16kernel(scalar_t, out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits); \ - } else { \ - using scalar_t = c10::BFloat16; \ - CallWqA16kernel(scalar_t, out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits); \ - } -#endif +#define CallWqA16kernel_dtype(out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits) \ + if (input.dtype() == at::ScalarType::Half) { \ + using scalar_t = c10::Half; \ + CallWqA16kernel(scalar_t, out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits); \ + } else if (input.dtype() == at::ScalarType::Float) { \ + using scalar_t = float; \ + CallWqA16kernel(scalar_t, out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits); \ + } else { \ + using scalar_t = c10::BFloat16; \ + CallWqA16kernel(scalar_t, out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits); \ + } #define CallWqA16kernel_bits(out_buf, BASEGROUP, Do_Reduce, ResidualBits) \ switch (index_bits) { \ diff --git a/csrc/utils.cuh b/csrc/utils.cuh index 3155810..8b94493 100644 --- a/csrc/utils.cuh +++ b/csrc/utils.cuh @@ -20,10 +20,18 @@ struct TypeVec2<__nv_bfloat16> { typedef __nv_bfloat162 type; }; +template <> +struct TypeVec2 { + typedef float2 type; +}; + template T __device__ __forceinline__ ConvertFromFloat(float v, T vv) { + (void)(vv); if constexpr (std::is_same::value) { return vv = __float2bfloat16(v); + } else if constexpr (std::is_same::value) { + return vv = v; } else { static_assert(std::is_same::value); return vv = __float2half(v); @@ -34,6 +42,8 @@ template float __device__ __forceinline__ ConvertToFloat(T v) { if constexpr (std::is_same::value) { return __bfloat162float(v); + } else if constexpr (std::is_same::value) { + return v; } else { static_assert(std::is_same::value); return __half2float(v); @@ -50,8 +60,12 @@ __device__ __forceinline__ float warpReduceSum(float sum) { return sum; } -template -__device__ __forceinline__ void ldg_vec_x(uint32_t* __restrict__ dst_u32, const uint32_t* __restrict__ src_u32) { +template +__device__ __forceinline__ void ldg_vec_x(T* __restrict__ dst_t32, const uint32_t* __restrict__ src_u32) { + uint32_t* dst_u32 = (uint32_t*)dst_t32; + if constexpr (std::is_same::value || std::is_same::value) { + return ldg_vec_x(dst_u32, src_u32); + } int2* dst = (int2*)dst_u32; const int2* src = (const int2*)src_u32; if constexpr (GROUPSIZE == 2) { @@ -144,3 +158,39 @@ __forceinline__ T ceil_div(T a, T b) { } } // namespace cuda + +template +T __device__ __forceinline__ FMA2(T a, T b, T c) { + if constexpr (std::is_same::value) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float x = __bfloat162float(a.x) * __bfloat162float(b.x) + __bfloat162float(c.x); + float y = __bfloat162float(a.y) * __bfloat162float(b.y) + __bfloat162float(c.y); + return __nv_bfloat162{__float2bfloat16(x), __float2bfloat16(y)}; +#else + return __hfma2(a, b, c); +#endif + } else if constexpr (std::is_same::value) { + return float2{a.x * b.x + c.x, a.y * b.y + c.y}; + } else { + return __hfma2(a, b, c); + } + __builtin_unreachable(); // Suppress missing return statement warning +} + +template +T __device__ __forceinline__ ADD2(T a, T b) { + if constexpr (std::is_same::value) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float x = __bfloat162float(a.x) + __bfloat162float(b.x); + float y = __bfloat162float(a.y) + __bfloat162float(b.y); + return __nv_bfloat162{__float2bfloat16(x), __float2bfloat16(y)}; +#else + return __hadd2(a, b); +#endif + } else if constexpr (std::is_same::value) { + return float2{a.x + b.x, a.y + b.y}; + } else { + return __hadd2(a, b); + } + __builtin_unreachable(); // Suppress missing return statement warning +} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index a475439..b04a5ef 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ datasets torch -transformers>=4.44 +transformers>=4.45 safetensors psutil accelerate \ No newline at end of file diff --git a/setup.py b/setup.py index ec84302..eeeed0e 100644 --- a/setup.py +++ b/setup.py @@ -30,6 +30,8 @@ def build_cuda_extensions(): delimiter = ' ' if ';' not in TORCH_CUDA_ARCH_LIST else ' ' TORCH_CUDA_ARCH_LIST = TORCH_CUDA_ARCH_LIST.split(delimiter) compute_capabilities = [int(10 * float(arch)) for arch in TORCH_CUDA_ARCH_LIST if '+' not in arch] + + print(" build for compute capabilities: ==============", compute_capabilities) for cap in compute_capabilities: arch_flags += ["-gencode", f"arch=compute_{cap},code=sm_{cap}"] extra_compile_args = { diff --git a/vptq/__init__.py b/vptq/__init__.py index e08d8c0..3cda50f 100644 --- a/vptq/__init__.py +++ b/vptq/__init__.py @@ -3,5 +3,5 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -__version__ = "0.0.2" +__version__ = "0.0.2.post1" from .layers import AutoModelForCausalLM as AutoModelForCausalLM