Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Layernorm optimizations #8

Merged
merged 1 commit into from
Mar 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 163 additions & 15 deletions csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "dispatch_utils.h"
#include "reduction_utils.cuh"
#include "attention/dtype_float16.cuh"

namespace vllm {

Expand Down Expand Up @@ -35,8 +36,119 @@ __global__ void rms_norm_kernel(
}
}

// TODO: Further optimize this kernel.
template<typename scalar_t>
/* Helper struct to generate vectorized and packed FP16 ops
for appropriate overloads of fused_add_rms_norm_kernel.
Only special member functions and functions that are necessary
in that kernel are implemented.
*/
template<int width>
struct _half2Vec {
/* Not theoretically necessary that width is a power of 2 but should
almost always be the case for optimization purposes */
static_assert(width > 0 && (width & (width - 1)) == 0,
"Width is not a positive power of 2!");
__half2 data[width];

__device__ _half2Vec() = default;
__device__ ~_half2Vec() = default;
__device__ _half2Vec(const _half2Vec<width>&) = default;
__device__ _half2Vec& operator=(const _half2Vec<width>&) = default;
__device__ _half2Vec(_half2Vec<width>&&) = default;
__device__ _half2Vec& operator=(_half2Vec<width>&&) = default;

__device__ inline _half2Vec& operator+=(const _half2Vec<width>& other) {
#pragma unroll
for (int i = 0; i < width; ++i)
data[i] += other.data[i];
return *this;
}

__device__ inline _half2Vec& operator*=(const _half2Vec<width>& other) {
#pragma unroll
for (int i = 0; i < width; ++i)
data[i] *= other.data[i];
return *this;
}

__device__ inline _half2Vec& operator*=(const float scale) {
#pragma unroll
for (int i = 0; i < width; ++i)
data[i] = __float22half2_rn(__half22float2(data[i]) * scale);
return *this;
}

__device__ inline float sum_squares() const {
float result = 0.0f;
#pragma unroll
for (int i = 0; i < width; ++i) {
float2 z = __half22float2(data[i]);
result += z.x * z.x + z.y * z.y;
}
return result;
}
};

/* Max blockSize to use for fused_add_rms_norm_kernel
This kernel is memory-latency bound in many scenarios, so a smaller
block size allows for increased block occupancy on CUs and better
latency hiding on global mem ops. */
#define _FUSED_RMS_MAX_BLOCKSIZE 256

/* Function overload in the case of FP16 tensors.
Additional optimizations we can make in this case are packed and
vectorized operations, which help with the aforementioned memory
latency bottleneck. */
template<typename scalar_t, int width>
__global__ typename std::enable_if<
(width > 0) && std::is_same<scalar_t, c10::Half>::value,
void>::type
fused_add_rms_norm_kernel(
c10::Half* __restrict__ input, // [..., hidden_size]
c10::Half* __restrict__ residual, // [..., hidden_size]
const c10::Half* __restrict__ weight, // [hidden_size]
const float epsilon,
const int num_tokens,
const int hidden_size)
{
static_assert(sizeof(_half2Vec<width>) == sizeof(c10::Half) * width * 2);
const int vec_hidden_size = hidden_size / (width * 2);
__shared__ float s_variance;
float variance = 0.0f;
/* These and the argument pointers are all declared `restrict` as they are
not aliased in practice */
auto* __restrict__ input_v = reinterpret_cast<_half2Vec<width>*>(input);
auto* __restrict__ residual_v = reinterpret_cast<_half2Vec<width>*>(residual);
auto* __restrict__ weight_v = reinterpret_cast<const _half2Vec<width>*>(weight);

for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
_half2Vec<width> temp = residual_v[id];
temp += input_v[id];
residual_v[id] = temp;
variance += temp.sum_squares();
}
variance = blockReduceSum<float, _FUSED_RMS_MAX_BLOCKSIZE>(variance);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();

for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
_half2Vec<width> temp = residual_v[id];
temp *= s_variance;
temp *= weight_v[idx];
input_v[id] = temp;
}
}


/* Generic fused_add_rms_norm_kernel
No optimizations in this case, the width field is not used
but necessary for the correct overloading to occur in the
FP16 case.
*/
template<typename scalar_t, int width> // width is not used in this overload
__global__ void fused_add_rms_norm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size]
Expand Down Expand Up @@ -93,6 +205,21 @@ void rms_norm(
});
}

#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \
"fused_add_rms_norm_kernel", \
[&] { \
vllm::fused_add_rms_norm_kernel \
<scalar_t, width><<<grid, block, 0, stream>>>( \
input.data_ptr<scalar_t>(), \
residual.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), \
epsilon, \
num_tokens, \
hidden_size); \
});

void fused_add_rms_norm(
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size]
Expand All @@ -102,19 +229,40 @@ void fused_add_rms_norm(
int num_tokens = input.numel() / hidden_size;

dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
dim3 block(std::min(hidden_size, _FUSED_RMS_MAX_BLOCKSIZE));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
"fused_add_rms_norm_kernel",
[&] {
vllm::fused_add_rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(),
residual.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);
});
/*If the tensor types are FP16, try to use the optimized kernel
with packed vectors. Max optimization is achieved with a width-4
vector of 2-packed-FP16s (equivalent to a vector of 8 FP16s)
since we can load at most 128 bits at once in a global memory op.
However, we have to narrow the vectors if the hidden_size does
not divide 8.

Specifically, assuming hidden-size does not divide 8:
If the hidden_size divides 4, we can use a width-2 packed vector
(equivalent to a vector of 4 FP16s).
If the hidden_size divides 2 or 6, we can use a width-1
packed vector (equiv. to vector of 2 FP16s).
If the hidden_size is odd, we cannot use packed vectors
=> cannot use the optimized kernel, which is signified
by setting (packed vector) width = 0.
*/
switch (hidden_size % 8) {
case 0:
LAUNCH_FUSED_ADD_RMS_NORM(4);
break;
case 2:
[[fallthrough]];
case 6:
LAUNCH_FUSED_ADD_RMS_NORM(1);
break;
case 4:
LAUNCH_FUSED_ADD_RMS_NORM(2);
break;
default:
LAUNCH_FUSED_ADD_RMS_NORM(0);
break;
}
}
#undef _FUSED_RMS_MAX_BLOCKSIZE
60 changes: 33 additions & 27 deletions csrc/reduction_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,45 +19,51 @@

#include "cuda_compat.h"

namespace vllm {
/* On ROCm, warpSize is a reserved keyword implemented as a macro.
On CUDA, warpSize is a reserved keyword but its value is read
from a special memory region at run time.
Thus, we have to define warpSize at compile time for CUDA.
*/
#ifndef USE_ROCM
// On CUDA, limit our macro's scope as much as possible
#pragma push_macro("warpSize")
#undef warpSize
#define warpSize 32
#endif

template<typename T>
namespace vllm {
template<typename T, int numLanes = warpSize>
__inline__ __device__ T warpReduceSum(T val) {
#pragma unroll
for (int mask = WARP_SIZE/2; mask > 0; mask >>= 1)
static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0,
"numLanes is not a positive power of 2!");
#pragma unroll
for (int mask = numLanes >> 1; mask > 0; mask >>= 1)
val += VLLM_SHFL_XOR_SYNC(val, mask);
return val;
}

__inline__ __device__ constexpr int _calculateLaneMask(int warp_size) {
return warp_size - 1;
}

__inline__ __device__ constexpr int _calculateWidShift(int warp_size) {
return 5 + (warp_size >> 6);
}

/* Calculate the sum of all elements in a block */
template<typename T>
template<typename T, int maxBlockSize = 1024>
__inline__ __device__ T blockReduceSum(T val) {
static __shared__ T shared[WARP_SIZE];
constexpr auto LANE_MASK = _calculateLaneMask(WARP_SIZE);
constexpr auto WID_SHIFT = _calculateWidShift(WARP_SIZE);
int lane = threadIdx.x & LANE_MASK;
int wid = threadIdx.x >> WID_SHIFT;

val = warpReduceSum<T>(val);
// If the block fits into a single warp, we are already done
if constexpr (maxBlockSize > warpSize) {
constexpr int maxActiveLanes = (maxBlockSize + warpSize - 1) / warpSize;
static __shared__ T shared[maxActiveLanes];
int lane = threadIdx.x % warpSize;
int wid = threadIdx.x / warpSize;
if (lane == 0)
shared[wid] = val;

if (lane == 0)
shared[wid] = val;
__syncthreads();

__syncthreads();

// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val = (threadIdx.x < (blockDim.x / (WARP_SIZE * 1.0f))) ? shared[lane] : (T)(0.0f);
val = warpReduceSum<T>(val);
val = (threadIdx.x < (blockDim.x / (float) warpSize)) ? shared[lane] : (T)(0.0f);
val = warpReduceSum<T, maxActiveLanes>(val);
}
return val;
}

} // namespace vllm
#ifndef USE_ROCM
#pragma pop_macro("warpSize")
#endif
Loading