Skip to content

Commit

Permalink
support rocm (#63)
Browse files Browse the repository at this point in the history
  • Loading branch information
wejoncy authored Oct 12, 2024
1 parent abdc473 commit f169d8d
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 82 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,6 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
csrc/*_hip*
csrc/*.hip

37 changes: 15 additions & 22 deletions csrc/dequant_impl_packed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
// Licensed under the MIT License.

#include <cmath>
#include <math_constants.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "common.h"
#include "utils.cuh"

template <typename T>
struct C10ToNvType {
typedef __nv_bfloat16 type;
typedef __bfloat16 type;
};

template <>
Expand Down Expand Up @@ -40,11 +39,12 @@ __global__ void WqA16WithOutliers_PackIndice(
tidx += bidz * cuda::kBlockSize * Do_Reduce;
}
int in_y = bidx;
__shared__ float shared_output[GROUPSIZE][cuda::kBlockSize / 32 + 1];
__shared__ float shared_output[GROUPSIZE][cuda::kBlockSize / WARP_SIZE + 1];
scalar_t tmp_output[GROUPSIZE];
const scalar_t zero_value = ZERO_VALUE(scalar_t());
#pragma unroll
for (int i = 0; i < GROUPSIZE; i++) {
tmp_output[i] = scalar_t(0.0f);
tmp_output[i] = zero_value;
}
input_data = input_data + in_features * bidy;
out = out + out_features * bidy * gridDim.z;
Expand Down Expand Up @@ -140,14 +140,14 @@ __global__ void WqA16WithOutliers_PackIndice(
}
}

// warp_size = 32
int warpid = threadIdx.x / 32; // at most 8 warp= 256/32
int landid = threadIdx.x % 32;
// warp_size = WARP_SIZE
int warpid = threadIdx.x / WARP_SIZE; // at most 8 warp= 256/WARP_SIZE
int landid = threadIdx.x % WARP_SIZE;
#pragma unroll
for (int gi = 0; gi < GROUPSIZE; gi++) {
float reduce_out = 0.f;
reduce_out = cuda::ConvertToFloat(tmp_output[gi]);
reduce_out = cuda::warpReduceSum<32>(reduce_out);
reduce_out = cuda::warpReduceSum<WARP_SIZE>(reduce_out);
if (landid == 0) {
shared_output[gi][warpid] = reduce_out;
}
Expand All @@ -162,18 +162,17 @@ __global__ void WqA16WithOutliers_PackIndice(
}

__syncthreads();
if (landid < cuda::kBlockSize / 32) {
if (landid < cuda::kBlockSize / WARP_SIZE) {
#pragma unroll
for (int wid = warpid; wid < GROUPSIZE; wid += cuda::kBlockSize / 32) {
for (int wid = warpid; wid < GROUPSIZE; wid += cuda::kBlockSize / WARP_SIZE) {
float reduce_out = shared_output[wid][landid];
reduce_out = cuda::warpReduceSum<cuda::kBlockSize / 32>(reduce_out);
reduce_out = cuda::warpReduceSum<cuda::kBlockSize / WARP_SIZE>(reduce_out);
if (landid == 0 && (in_y * GROUPSIZE + wid) < out_features) {
if constexpr (Do_Reduce) {
out[(wid)*gridDim.z] = cuda::ConvertFromFloat<scalar_t>(reduce_out, scalar_t(0.0f)) +
((bidz == 0 && bias != 0) ? bias[wid] : scalar_t(0.0f));
out[(wid)*gridDim.z] = cuda::ConvertFromFloat<scalar_t>(reduce_out, zero_value) +
((bidz == 0 && bias != 0) ? bias[wid] : zero_value);
} else {
out[wid] =
cuda::ConvertFromFloat<scalar_t>(reduce_out, scalar_t(0.0f)) + ((bias != 0) ? bias[wid] : scalar_t(0.0f));
out[wid] = cuda::ConvertFromFloat<scalar_t>(reduce_out, zero_value) + ((bias != 0) ? bias[wid] : zero_value);
}
}
}
Expand Down Expand Up @@ -215,7 +214,7 @@ __global__ void DequantizeWithOutliers_PackIndice(scalar_t* out, const int32_t*
if ((gi + j) >= out_features) {
return;
}
out[(gi + j) * in_features + in_x] = outliers_centroids_start[j] * scale + bias;
out[(gi + j) * in_features + in_x] = FMA(outliers_centroids_start[j], scale, bias);
}
}
return;
Expand Down Expand Up @@ -335,9 +334,6 @@ torch::Tensor lauch_deqantize_outliers_cuda_packkernel(
if (centroids.dtype() == at::ScalarType::Half) { \
using scalar_t = c10::Half; \
callDequantWithOutliers(scalar_t, IDXBITS, BASEGROUP, OUT_OUF_INF, ResidualBits); \
} else if (centroids.dtype() == at::ScalarType::Float) { \
using scalar_t = float; \
callDequantWithOutliers(scalar_t, IDXBITS, BASEGROUP, OUT_OUF_INF, ResidualBits); \
} else { \
using scalar_t = c10::BFloat16; \
callDequantWithOutliers(scalar_t, IDXBITS, BASEGROUP, OUT_OUF_INF, ResidualBits); \
Expand Down Expand Up @@ -505,9 +501,6 @@ torch::Tensor lauch_gemv_outliers_cuda_packkernel(
if (input.dtype() == at::ScalarType::Half) { \
using scalar_t = c10::Half; \
CallWqA16kernel(scalar_t, out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits); \
} else if (input.dtype() == at::ScalarType::Float) { \
using scalar_t = float; \
CallWqA16kernel(scalar_t, out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits); \
} else { \
using scalar_t = c10::BFloat16; \
CallWqA16kernel(scalar_t, out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits); \
Expand Down
170 changes: 116 additions & 54 deletions csrc/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,25 @@

#pragma once

#include <cuda_fp16.h>
#include <cuda_bf16.h>
#if defined(USE_ROCM)
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>

#define VPTQ_LDG(arg) __ldg(arg)
#define SHFL_DOWN(val, offset) __shfl_down(val, offset)
#define WARP_SIZE warpSize
typedef __hip_bfloat162 __bfloat162;
typedef __hip_bfloat16 __bfloat16;
#else
#include <cuda_fp16.h>
#include <cuda_bf16.h>

#define WARP_SIZE 32
#define VPTQ_LDG(arg) *(arg)
#define SHFL_DOWN(val, offset) __shfl_down_sync(0xffffffff, val, offset)
typedef __nv_bfloat162 __bfloat162;
typedef __nv_bfloat16 __bfloat16;
#endif

namespace cuda {

Expand All @@ -16,8 +33,8 @@ struct TypeVec2 {
};

template <>
struct TypeVec2<__nv_bfloat16> {
typedef __nv_bfloat162 type;
struct TypeVec2<__bfloat16> {
typedef __bfloat162 type;
};

template <>
Expand All @@ -28,7 +45,7 @@ struct TypeVec2<float> {
template <typename T>
T __device__ __forceinline__ ConvertFromFloat(float v, T vv) {
(void)(vv);
if constexpr (std::is_same<T, __nv_bfloat16>::value) {
if constexpr (std::is_same<T, __bfloat16>::value) {
return vv = __float2bfloat16(v);
} else if constexpr (std::is_same<T, float>::value) {
return vv = v;
Expand All @@ -40,7 +57,7 @@ T __device__ __forceinline__ ConvertFromFloat(float v, T vv) {

template <typename T>
float __device__ __forceinline__ ConvertToFloat(T v) {
if constexpr (std::is_same<T, __nv_bfloat16>::value) {
if constexpr (std::is_same<T, __bfloat16>::value) {
return __bfloat162float(v);
} else if constexpr (std::is_same<T, float>::value) {
return v;
Expand All @@ -52,11 +69,12 @@ float __device__ __forceinline__ ConvertToFloat(T v) {

template <unsigned int WarpSize>
__device__ __forceinline__ float warpReduceSum(float sum) {
if constexpr (WarpSize >= 32) sum += __shfl_down_sync(0xffffffff, sum, 16); // 0-16, 1-17, 2-18, etc.
if constexpr (WarpSize >= 16) sum += __shfl_down_sync(0xffffffff, sum, 8); // 0-8, 1-9, 2-10, etc.
if constexpr (WarpSize >= 8) sum += __shfl_down_sync(0xffffffff, sum, 4); // 0-4, 1-5, 2-6, etc.
if constexpr (WarpSize >= 4) sum += __shfl_down_sync(0xffffffff, sum, 2); // 0-2, 1-3, 4-6, 5-7, etc.
if constexpr (WarpSize >= 2) sum += __shfl_down_sync(0xffffffff, sum, 1); // 0-1, 2-3, 4-5, etc.
if constexpr (WarpSize >= 64) sum += SHFL_DOWN(sum, 32); // 0-16, 1-17, 2-18, etc.
if constexpr (WarpSize >= 32) sum += SHFL_DOWN(sum, 16); // 0-16, 1-17, 2-18, etc.
if constexpr (WarpSize >= 16) sum += SHFL_DOWN(sum, 8); // 0-8, 1-9, 2-10, etc.
if constexpr (WarpSize >= 8) sum += SHFL_DOWN(sum, 4); // 0-4, 1-5, 2-6, etc.
if constexpr (WarpSize >= 4) sum += SHFL_DOWN(sum, 2); // 0-2, 1-3, 4-6, 5-7, etc.
if constexpr (WarpSize >= 2) sum += SHFL_DOWN(sum, 1); // 0-1, 2-3, 4-5, etc.
return sum;
}

Expand All @@ -69,49 +87,49 @@ __device__ __forceinline__ void ldg_vec_x(T* __restrict__ dst_t32, const uint32_
int2* dst = (int2*)dst_u32;
const int2* src = (const int2*)src_u32;
if constexpr (GROUPSIZE == 2) {
*dst_u32 = __ldg(src_u32);
*dst_u32 = VPTQ_LDG(src_u32);
// uint32_t* dec = (uint32_t*)dst;
// asm volatile (
// "ld.cg.global.v2.u32 {%0, %1}, [%2];"
// : "=r"(dec[0]), "=r"(dec[1])
// : "l"((const void*)src)
// );
} else if constexpr (GROUPSIZE == 4) {
*dst = __ldg(src);
*dst = VPTQ_LDG(src);
// uint32_t* dec = (uint32_t*)dst;
// asm volatile (
// "ld.cg.global.v2.u32 {%0, %1}, [%2];"
// : "=r"(dec[0]), "=r"(dec[1])
// : "l"((const void*)src)
// );
} else if constexpr (GROUPSIZE == 6) {
dst_u32[0] = __ldg(src_u32);
dst_u32[1] = __ldg(src_u32 + 1);
dst_u32[2] = __ldg(src_u32 + 2);
dst_u32[0] = VPTQ_LDG(src_u32);
dst_u32[1] = VPTQ_LDG(src_u32 + 1);
dst_u32[2] = VPTQ_LDG(src_u32 + 2);
} else if constexpr (GROUPSIZE == 8) {
*(int4*)dst = __ldg((const int4*)src);
*(int4*)dst = VPTQ_LDG((const int4*)src);
} else if constexpr (GROUPSIZE == 16) {
// *(int4*)dst = __ldg((const int4*)src);
// *(int4*)(dst+2) = __ldg((const int4*)(src+2));
asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
: "=r"(dst_u32[0]), "=r"(dst_u32[1]), "=r"(dst_u32[2]), "=r"(dst_u32[3])
: "l"((const void*)src_u32));
asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
: "=r"(dst_u32[4]), "=r"(dst_u32[5]), "=r"(dst_u32[6]), "=r"(dst_u32[7])
: "l"((const void*)(src_u32 + 4)));
*(int4*)dst = VPTQ_LDG((const int4*)src);
*(int4*)(dst + 2) = VPTQ_LDG((const int4*)(src + 2));
// asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
// : "=r"(dst_u32[0]), "=r"(dst_u32[1]), "=r"(dst_u32[2]), "=r"(dst_u32[3])
// : "l"((const void*)src_u32));
// asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
// : "=r"(dst_u32[4]), "=r"(dst_u32[5]), "=r"(dst_u32[6]), "=r"(dst_u32[7])
// : "l"((const void*)(src_u32 + 4)));
} else if constexpr (GROUPSIZE == 12) {
if (uint64_t(src) % 16) {
dst[0] = __ldg(src);
int4 b = __ldg((const int4*)(src + 1));
dst[0] = VPTQ_LDG(src);
int4 b = VPTQ_LDG((const int4*)(src + 1));
dst[1] = *((const int2*)&b);
dst[2] = *((const int2*)&b + 1);
} else {
*(int4*)dst = __ldg((int4*)(src));
dst[2] = __ldg((src + 2));
*(int4*)dst = VPTQ_LDG((int4*)(src));
dst[2] = VPTQ_LDG((src + 2));
}
// dst[0] = __ldg(src);
// dst[1] = __ldg((src+1));
// dst[2] = __ldg((src+2));
// dst[0] = VPTQ_LDG(src);
// dst[1] = VPTQ_LDG((src+1));
// dst[2] = VPTQ_LDG((src+2));

// uint32_t* dec = (uint32_t*)dst;
// asm volatile (
Expand All @@ -125,22 +143,26 @@ __device__ __forceinline__ void ldg_vec_x(T* __restrict__ dst_t32, const uint32_
// : "l"((const void*)src)
// );
} else if constexpr (GROUPSIZE == 24) {
*((int4*)(dst)) = __ldg((const int4*)(src));
*(((int4*)(dst)) + 1) = __ldg(((const int4*)(src)) + 1);
*(((int4*)(dst)) + 2) = __ldg(((const int4*)(src)) + 2);
*((int4*)(dst)) = VPTQ_LDG((const int4*)(src));
*(((int4*)(dst)) + 1) = VPTQ_LDG(((const int4*)(src)) + 1);
*(((int4*)(dst)) + 2) = VPTQ_LDG(((const int4*)(src)) + 2);
} else if constexpr (GROUPSIZE == 32) {
asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
: "=r"(dst_u32[0]), "=r"(dst_u32[1]), "=r"(dst_u32[2]), "=r"(dst_u32[3])
: "l"((const void*)src_u32));
asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
: "=r"(dst_u32[4]), "=r"(dst_u32[5]), "=r"(dst_u32[6]), "=r"(dst_u32[7])
: "l"((const void*)(src_u32 + 4)));
asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
: "=r"(dst_u32[8]), "=r"(dst_u32[9]), "=r"(dst_u32[10]), "=r"(dst_u32[11])
: "l"((const void*)(src_u32 + 8)));
asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
: "=r"(dst_u32[12]), "=r"(dst_u32[13]), "=r"(dst_u32[14]), "=r"(dst_u32[15])
: "l"((const void*)(src_u32 + 12)));
// asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
// : "=r"(dst_u32[0]), "=r"(dst_u32[1]), "=r"(dst_u32[2]), "=r"(dst_u32[3])
// : "l"((const void*)src_u32));
// asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
// : "=r"(dst_u32[4]), "=r"(dst_u32[5]), "=r"(dst_u32[6]), "=r"(dst_u32[7])
// : "l"((const void*)(src_u32 + 4)));
// asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
// : "=r"(dst_u32[8]), "=r"(dst_u32[9]), "=r"(dst_u32[10]), "=r"(dst_u32[11])
// : "l"((const void*)(src_u32 + 8)));
// asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
// : "=r"(dst_u32[12]), "=r"(dst_u32[13]), "=r"(dst_u32[14]), "=r"(dst_u32[15])
// : "l"((const void*)(src_u32 + 12)));
*((int4*)(dst)) = VPTQ_LDG((const int4*)(src));
*(((int4*)(dst)) + 1) = VPTQ_LDG(((const int4*)(src)) + 1);
*(((int4*)(dst)) + 2) = VPTQ_LDG(((const int4*)(src)) + 2);
*(((int4*)(dst)) + 3) = VPTQ_LDG(((const int4*)(src)) + 3);
} else {
assert(false);
}
Expand Down Expand Up @@ -180,11 +202,11 @@ __forceinline__ T ceil_div(T a, T b) {

template <typename T>
T __device__ __forceinline__ FMA2(T a, T b, T c) {
if constexpr (std::is_same<T, __nv_bfloat162>::value) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
if constexpr (std::is_same<T, __bfloat162>::value) {
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
float x = __bfloat162float(a.x) * __bfloat162float(b.x) + __bfloat162float(c.x);
float y = __bfloat162float(a.y) * __bfloat162float(b.y) + __bfloat162float(c.y);
return __nv_bfloat162{__float2bfloat16(x), __float2bfloat16(y)};
return __bfloat162{__float2bfloat16(x), __float2bfloat16(y)};
#else
return __hfma2(a, b, c);
#endif
Expand All @@ -196,13 +218,30 @@ T __device__ __forceinline__ FMA2(T a, T b, T c) {
__builtin_unreachable(); // Suppress missing return statement warning
}

template <typename T>
T __device__ __forceinline__ FMA(T a, T b, T c) {
if constexpr (std::is_same<T, __bfloat16>::value) {
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
float x = __bfloat16float(a) * __bfloat16float(b) + __bfloat16float(c);
return __bfloat16{__float2bfloat16(x)};
#else
return __hfma(a, b, c);
#endif
} else if constexpr (std::is_same<T, float>::value) {
return float{a.x * b.x + c.x};
} else {
return __hfma(a, b, c);
}
__builtin_unreachable(); // Suppress missing return statement warning
}

template <typename T>
T __device__ __forceinline__ ADD2(T a, T b) {
if constexpr (std::is_same<T, __nv_bfloat162>::value) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
if constexpr (std::is_same<T, __bfloat162>::value) {
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800) || defined(USE_ROCM)
float x = __bfloat162float(a.x) + __bfloat162float(b.x);
float y = __bfloat162float(a.y) + __bfloat162float(b.y);
return __nv_bfloat162{__float2bfloat16(x), __float2bfloat16(y)};
return __bfloat162{__float2bfloat16(x), __float2bfloat16(y)};
#else
return __hadd2(a, b);
#endif
Expand All @@ -212,4 +251,27 @@ T __device__ __forceinline__ ADD2(T a, T b) {
return __hadd2(a, b);
}
__builtin_unreachable(); // Suppress missing return statement warning
}
}

template <typename T>
T __device__ __forceinline__ ZERO_VALUE(T a) {
if constexpr (std::is_same<T, __bfloat16>::value) {
return __ushort_as_bfloat16((unsigned short)0x0000U);
} else if constexpr (std::is_same<T, float>::value) {
return 0.0f;
} else {
return __float2half(0.0f);
}
}

#if defined(USE_ROCM)
__device__ __half operator+(const __half& a, const __half& b) {
return __hadd(a, b); // Use HIP's intrinsic __hadd for half-precision addition
}

// Overload the * operator for __half
__device__ __half operator*(const __half& a, const __half& b) {
return __hmul(a, b); // Use HIP's intrinsic __hmul for half-precision multiplication
}

#endif
Loading

0 comments on commit f169d8d

Please sign in to comment.