diff --git a/vllm/awq_quantization/kernels/csrc/layernorm/layernorm.cu b/vllm/awq_quantization/kernels/csrc/layernorm/layernorm.cu deleted file mode 100644 index 0a5d2d251c21c..0000000000000 --- a/vllm/awq_quantization/kernels/csrc/layernorm/layernorm.cu +++ /dev/null @@ -1,113 +0,0 @@ -/* - -Adapted from NVIDIA FasterTransformer: -https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/layernorm_kernels.cu - -*/ - -#include -#include -#include "reduction.cuh" -#include "layernorm.h" -#include -#include - -static inline __device__ float to_float(half src) -{ - return __half2float(src); -} - -static inline __device__ float to_float(float src) -{ - return src; -} - -template -__global__ void generalT5LayerNorm( - const T* __restrict input, const T* __restrict gamma, T* output, const float layernorm_eps, int m, int n) -{ - // layernorm module in the T5 style No bias and no subtraction of mean. - const int tid = threadIdx.x; - - __shared__ float s_variance; - float variance = 0.0f; - - float local_var_sum = 0.0f; - for (int i = tid; i < n; i += blockDim.x) { - float diff = to_float(__ldg(&input[blockIdx.x * n + i])); - local_var_sum += diff * diff; - } - variance = blockReduceSum(local_var_sum); - - if (threadIdx.x == 0) { - s_variance = rsqrtf(variance / (float)n + layernorm_eps); - } - __syncthreads(); - - for (int i = tid; i < n; i += blockDim.x) { - output[blockIdx.x * n + i] = - clamp_inf_for_half((to_float(input[blockIdx.x * n + i]) * s_variance) * to_float(__ldg(&gamma[i]))); - } -} - - -template -void invokeGeneralT5LayerNorm(T* out, - const T* input, - const T* gamma, - // const T* beta, - const float layernorm_eps, - const int m, - const int n) -{ - dim3 grid(m); - dim3 block(min(n, 1024)); - - /* For general cases, n is equal to hidden_units, e.g., 512/1024. - Since we have warp shuffle inside the code, block.x % 32 should be 0. - */ - if (n % 32 != 0) { - block.x = 1024; - } - - block.x = block.x / (4 / sizeof(T)); // if using half, only need half of block.x - - /* should pay attention to the rsqrt precision*/ - generalT5LayerNorm<<>>(input, gamma, out, layernorm_eps, m, n); // For gpt-3 -} - -template void invokeGeneralT5LayerNorm(half* out, - const half* input, - const half* gamma, - // const half* beta, - const float layernorm_eps, - const int m, - const int n); - -template void invokeGeneralT5LayerNorm(float* out, - const float* input, - const float* gamma, - // const half* beta, - const float layernorm_eps, - const int m, - const int n); - - - -// input b, n, c -void layernorm_forward_cuda( - torch::Tensor _input, - torch::Tensor _gamma, - torch::Tensor _out, - float eps) -{ - int m = _input.size(0) * _input.size(1); - int n = _input.size(2); - const at::cuda::OptionalCUDAGuard device_guard(device_of(_input)); - - auto input = reinterpret_cast(_input.data_ptr()); - auto gamma = reinterpret_cast(_gamma.data_ptr()); - auto out = reinterpret_cast(_out.data_ptr()); - - invokeGeneralT5LayerNorm(out, input, gamma, eps, m, n); -} diff --git a/vllm/awq_quantization/kernels/csrc/layernorm/layernorm.h b/vllm/awq_quantization/kernels/csrc/layernorm/layernorm.h deleted file mode 100644 index de43ccac688d6..0000000000000 --- a/vllm/awq_quantization/kernels/csrc/layernorm/layernorm.h +++ /dev/null @@ -1,3 +0,0 @@ -#include - -void layernorm_forward_cuda(torch::Tensor _input, torch::Tensor _gamma, torch::Tensor _out, float eps); diff --git a/vllm/awq_quantization/kernels/csrc/layernorm/reduction.cuh b/vllm/awq_quantization/kernels/csrc/layernorm/reduction.cuh deleted file mode 100644 index 678160e8fdf57..0000000000000 --- a/vllm/awq_quantization/kernels/csrc/layernorm/reduction.cuh +++ /dev/null @@ -1,82 +0,0 @@ -/* - -Adapted from NVIDIA FasterTransformer: -https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/reduce_kernel_utils.cuh -*/ - -#pragma once -#include -#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) -#include -#else -#include -#endif -#include -#include -#include -#include - -static const float HALF_FLT_MAX = 65504.F; -#define FINAL_MASK 0xffffffff - - -template -inline __device__ T add(T a, T b) { - return a + b; -} - -template<> -inline __device__ half2 add(half2 a, half2 b) { - return __hadd2(a, b); -} - -template<> -inline __device__ half add(half a, half b) { - return __hadd(a, b); -} - -template -__inline__ __device__ T warpReduceSum(T val) -{ -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val = add(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); //__shfl_sync bf16 return float when sm < 80 - return val; -} - -/* Calculate the sum of all elements in a block */ -template -__inline__ __device__ T blockReduceSum(T val) -{ - static __shared__ T shared[32]; - int lane = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; - - val = warpReduceSum(val); - - if (lane == 0) - shared[wid] = val; - - __syncthreads(); - - // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent - // blockDim.x is not divided by 32 - val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f); - val = warpReduceSum(val); - - return val; -} - - -template -__device__ __forceinline__ T clamp_inf_for_half(const float input) -{ - return input; -} - -template<> -__device__ __forceinline__ half clamp_inf_for_half(const float input) -{ - // clamp inf values to enable fp16 training - return input > 0.0f ? __float2half(min(input, HALF_FLT_MAX - 1000)) : __float2half(max(input, -HALF_FLT_MAX + 1000)); -} diff --git a/vllm/awq_quantization/kernels/csrc/position_embedding/pos_encoding.h b/vllm/awq_quantization/kernels/csrc/position_embedding/pos_encoding.h deleted file mode 100644 index fe2de2f92908b..0000000000000 --- a/vllm/awq_quantization/kernels/csrc/position_embedding/pos_encoding.h +++ /dev/null @@ -1,9 +0,0 @@ -#pragma once -#include - -void rotary_embedding_neox( - torch::Tensor& positions, - torch::Tensor& query, - torch::Tensor& key, - int head_size, - torch::Tensor& cos_sin_cache); diff --git a/vllm/awq_quantization/kernels/csrc/position_embedding/pos_encoding_kernels.cu b/vllm/awq_quantization/kernels/csrc/position_embedding/pos_encoding_kernels.cu deleted file mode 100644 index 883b59c41b74c..0000000000000 --- a/vllm/awq_quantization/kernels/csrc/position_embedding/pos_encoding_kernels.cu +++ /dev/null @@ -1,88 +0,0 @@ -/* - -Adapted from the VLLM project: -https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding_kernels.cu - -*/ - -#include -#include -#include "pos_encoding.h" - -template -__global__ void rotary_embedding_neox_kernel( - const int64_t* __restrict__ positions, // [num_tokens] - scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size] - scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] - const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] - const int rot_dim, - const int stride, - const int num_heads, - const int head_size) { - // Each thread block is responsible for one token. - const int token_idx = blockIdx.x; - int64_t pos = positions[token_idx]; - const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; - - const int embed_dim = rot_dim / 2; - const int n = num_heads * embed_dim; - for (int i = threadIdx.x; i < n; i += blockDim.x) { - const int head_idx = i / embed_dim; - const int token_head = token_idx * stride + head_idx * head_size; - - const int rot_offset = i % embed_dim; - const int x_index = rot_offset; - const int y_index = embed_dim + rot_offset; - - const int out_x = token_idx * stride + head_idx * head_size + x_index; - const int out_y = token_idx * stride + head_idx * head_size + y_index; - - const scalar_t cos = __ldg(cache_ptr + x_index); - const scalar_t sin = __ldg(cache_ptr + y_index); - - const scalar_t q_x = query[token_head + x_index]; - const scalar_t q_y = query[token_head + y_index]; - query[out_x] = q_x * cos - q_y * sin; - query[out_y] = q_y * cos + q_x * sin; - - const scalar_t k_x = key[token_head + x_index]; - const scalar_t k_y = key[token_head + y_index]; - key[out_x] = k_x * cos - k_y * sin; - key[out_y] = k_y * cos + k_x * sin; - } -} - -void rotary_embedding_neox( - torch::Tensor& positions, // [b, num_tokens] - torch::Tensor& query, // [b, num_tokens, 1, num_heads, head_size] - torch::Tensor& key, // [b, num_tokens, 1, num_heads, head_size] - int head_size, - torch::Tensor& cos_sin_cache) // [max_position, rot_dim] -{ - int num_tokens = query.size(0) * query.size(1); - int rot_dim = cos_sin_cache.size(1); - int num_heads = query.size(-2); - int stride = num_heads * head_size; - // TORCH_CHECK(stride == key.stride(0)); - - dim3 grid(num_tokens); - dim3 block(std::min(num_heads * rot_dim / 2, 512)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, - query.scalar_type(), - "rotary_embedding_neox", - [&] { - rotary_embedding_neox_kernel<<>>( - positions.data_ptr(), - query.data_ptr(), - key.data_ptr(), - cos_sin_cache.data_ptr(), - rot_dim, - stride, - num_heads, - head_size); - }); -} - diff --git a/vllm/awq_quantization/kernels/csrc/pybind.cpp b/vllm/awq_quantization/kernels/csrc/pybind.cpp index 8f67b8e3cccf4..ec5d051606350 100644 --- a/vllm/awq_quantization/kernels/csrc/pybind.cpp +++ b/vllm/awq_quantization/kernels/csrc/pybind.cpp @@ -2,13 +2,9 @@ #include #include -#include "layernorm/layernorm.h" #include "quantization/gemm_cuda.h" -#include "position_embedding/pos_encoding.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("layernorm_forward_cuda", &layernorm_forward_cuda, "FasterTransformer layernorm kernel"); m.def("gemm_forward_cuda", &gemm_forward_cuda, "Quantized GEMM kernel."); - m.def("rotary_embedding_neox", &rotary_embedding_neox, "Apply GPT-NeoX style rotary embedding to query and key"); } diff --git a/vllm/awq_quantization/kernels/setup.py b/vllm/awq_quantization/kernels/setup.py index b7f406e5d3cb0..e69a7c6acb95f 100644 --- a/vllm/awq_quantization/kernels/setup.py +++ b/vllm/awq_quantization/kernels/setup.py @@ -17,8 +17,6 @@ sources=[ "csrc/pybind.cpp", "csrc/quantization/gemm_cuda_gen.cu", - "csrc/layernorm/layernorm.cu", - "csrc/position_embedding/pos_encoding_kernels.cu" ], extra_compile_args=extra_compile_args, ),