From ac55621687b39b385d740575f7f673fae9c20a56 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Mon, 15 Apr 2024 13:57:57 -0700 Subject: [PATCH] [feat] added rms norm residual kernel --- src/kernels/layernorm_kernels.cu | 78 +++++++++++++++++++++++++++---- src/kernels/layernorm_kernels.h | 8 +++- src/layers/normalization.h | 35 ++++++++++---- src/layers/normalization_test.cpp | 38 +++++++++++++-- src/models/huggingface/baichuan.h | 10 +--- src/models/huggingface/gemma.h | 19 ++++---- 6 files changed, 150 insertions(+), 38 deletions(-) diff --git a/src/kernels/layernorm_kernels.cu b/src/kernels/layernorm_kernels.cu index 4a41fa81..eefd7b9e 100644 --- a/src/kernels/layernorm_kernels.cu +++ b/src/kernels/layernorm_kernels.cu @@ -1,5 +1,6 @@ #include #include + #include "dispatch.h" #include "reduce_kernel_utils.cuh" @@ -23,7 +24,7 @@ __global__ void rms_norm_kernel(T* __restrict__ out, float variance = 0.0f; for (int i = tidx; i < n; i += blockDim.x) { - const float x = __ldg(&input[bidx * n + i]); + const float x = input[bidx * n + i]; variance += x * x; } variance = block_reduce_sum(variance); @@ -34,8 +35,8 @@ __global__ void rms_norm_kernel(T* __restrict__ out, for (int i = tidx; i < n; i += blockDim.x) { const int idx = bidx * n + i; - const float x = __ldg(&input[idx]); - out[idx] = (T)(x * s_variance * weight[i]); + const float x = input[idx]; + out[idx] = (T)(x * s_variance) * weight[i]; } } @@ -61,6 +62,69 @@ void rms_norm(torch::Tensor& out, }); } +// calculate the root mean square norm. +// equation: x -> w * x / sqrt(E[x^2] + eps) +// The mean is calculated over the last dimension +// equilvalent to layernorm module in the T5 style No bias and no subtraction of +// mean. +template +__global__ void rms_norm_residual_kernel(T* __restrict__ out, + T* __restrict__ residual, + const T* __restrict__ input, + const T* __restrict__ weight, + const float epsilon, + int n) { + const int tidx = threadIdx.x; + const int bidx = blockIdx.x; + + __shared__ float s_variance; + float variance = 0.0f; + + for (int i = tidx; i < n; i += blockDim.x) { + const int idx = bidx * n + i; + const float r = residual[idx]; + const float x = r + input[idx]; + residual[idx] = x; + variance += x * x; + } + variance = block_reduce_sum(variance); + if (tidx == 0) { + s_variance = rsqrtf(variance / n + epsilon); + } + __syncthreads(); + + for (int i = tidx; i < n; i += blockDim.x) { + const int idx = bidx * n + i; + const float x = residual[idx]; + out[idx] = (T)(x * s_variance) * weight[i]; + } +} + +void rms_norm_residual(torch::Tensor& out, + torch::Tensor& residual, + torch::Tensor input, + torch::Tensor weight, + float epsilon) { + DCHECK(input.is_contiguous()) << "input tensor must be contiguous"; + DCHECK(out.is_contiguous()) << "output tensor must be contiguous"; + DCHECK(residual.is_contiguous()) << "residual tensor must be contiguous"; + + const int n = input.size(1); + + dim3 grid(input.size(0)); + dim3 block(std::min(n, 1024)); + DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_residual_kernel", [&] { + rms_norm_residual_kernel + <<>>( + out.data_ptr(), + residual.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + epsilon, + n); + }); +} + // equation: x -> (x - E[x]) / sqrt(Var[x] + eps) * w + b // The mean and standard-deviation are calculated over the last dimension template @@ -80,8 +144,7 @@ __global__ void layer_norm_kernel(T* __restrict__ out, // calculate mean of the input. for (int i = tidx; i < n; i += blockDim.x) { - const int idx = bidx * n + i; - mean += __ldg(&input[idx]); + mean += input[bidx * n + i]; } mean = block_reduce_sum(mean); if (tidx == 0) { @@ -102,10 +165,9 @@ __global__ void layer_norm_kernel(T* __restrict__ out, for (int i = tidx; i < n; i += blockDim.x) { const int idx = bidx * n + i; - float local_out = - (__ldg(&input[idx]) - s_mean) * s_variance * __ldg(&weight[i]); + float local_out = (input[idx] - s_mean) * s_variance * weight[i]; if (bias != nullptr) { - local_out += __ldg(&bias[i]); + local_out += bias[i]; } out[idx] = (T)(local_out); } diff --git a/src/kernels/layernorm_kernels.h b/src/kernels/layernorm_kernels.h index 96df3852..496622bb 100644 --- a/src/kernels/layernorm_kernels.h +++ b/src/kernels/layernorm_kernels.h @@ -7,7 +7,13 @@ void rms_norm(torch::Tensor& out, torch::Tensor input, torch::Tensor weight, float epsilon); - + +void rms_norm_residual(torch::Tensor& out, + torch::Tensor& residual, + torch::Tensor input, + torch::Tensor weight, + float epsilon); + void layer_norm(torch::Tensor& out, torch::Tensor input, torch::Tensor weight, diff --git a/src/layers/normalization.h b/src/layers/normalization.h index b15b756e..84211f34 100644 --- a/src/layers/normalization.h +++ b/src/layers/normalization.h @@ -10,15 +10,30 @@ DECLARE_bool(disable_custom_kernels); namespace llm { namespace detail { -inline torch::Tensor rms_norm(torch::Tensor input, +inline torch::Tensor rms_norm(const torch::Tensor& input, const torch::Tensor& weight, float eps) { // it is important to use float to calculate the mean and std const auto x = input.to(torch::kFloat); const auto mean = x.pow(/*exponent=*/2).mean(/*dim=*/-1, /*keepdim=*/true); - const auto output = x * torch::rsqrt(mean + eps) * weight; + const auto output = x * torch::rsqrt(mean + eps); // convert back to the original dtype - return output.to(input); + return output.to(input) * weight; +} + +inline torch::Tensor rms_norm_residual(const torch::Tensor& input, + torch::Tensor& residual, + const torch::Tensor& weight, + float eps) { + // it is important to use float for the residual + auto x = input.to(torch::kFloat) + residual.to(torch::kFloat); + residual = x.to(input); + + // it is important to use float to calculate the mean and std + const auto mean = x.pow(/*exponent=*/2).mean(/*dim=*/-1, /*keepdim=*/true); + const auto output = x * torch::rsqrt(mean + eps); + // convert back to the original dtype + return output.to(input) * weight; } inline torch::Tensor layer_norm(torch::Tensor input, @@ -124,7 +139,7 @@ class RMSNormImpl : public torch::nn::Module { /*requires_grad=*/false); } - torch::Tensor forward(torch::Tensor input) { + torch::Tensor forward(const torch::Tensor& input) { if (input.is_cuda() && !FLAGS_disable_custom_kernels) { auto output = torch::empty_like(input); kernel::rms_norm(output, input, weight_, eps_); @@ -177,20 +192,22 @@ class RMSNormResidualImpl : public torch::nn::Module { /*requires_grad=*/false); } - torch::Tensor forward(torch::Tensor input, torch::Tensor& residual) { + torch::Tensor forward(const torch::Tensor& input, torch::Tensor& residual) { if (input.is_cuda() && !FLAGS_disable_custom_kernels) { auto output = torch::empty_like(input); if (residual.defined()) { - input = input + residual; + kernel::rms_norm_residual(output, residual, input, weight_, eps_); + } else { residual = input; + kernel::rms_norm(output, input, weight_, eps_); } - kernel::rms_norm(output, input, weight_, eps_); return output; } + if (residual.defined()) { - input = input + residual; - residual = input; + return detail::rms_norm_residual(input, residual, weight_, eps_); } + residual = input; return detail::rms_norm(input, weight_, eps_); } diff --git a/src/layers/normalization_test.cpp b/src/layers/normalization_test.cpp index fa6c5687..1b6bebd4 100644 --- a/src/layers/normalization_test.cpp +++ b/src/layers/normalization_test.cpp @@ -116,11 +116,43 @@ TEST(NormalizationTest, RMSNormKernel) { kernel::rms_norm(output, input, weight, eps); // use float result as baseline - auto desired_output = - detail::rms_norm(input.to(torch::kFloat32), weight, eps).to(dtype); + auto output_ref = detail::rms_norm(input, weight, eps); EXPECT_TRUE(torch::allclose(output, - desired_output, + output_ref, + /*rtol=*/1e-03, + /*atol=*/1e-05)); +} + +TEST(NormalizationTest, RMSNormResidualKernel) { + const auto dtype = torch::kHalf; + const auto device = torch::kCUDA; + const auto options = torch::dtype(dtype).device(device); + + const int64_t dim = 1024; + const float eps = 1e-5; + + // generate weight + const auto weight = torch::rand({dim}, options); + + // verify output + const auto input = torch::randn({100, dim}, options); + auto residual = torch::randn({100, dim}, options); + auto residual_ref = residual.clone(); + + auto output = torch::empty_like(input); + kernel::rms_norm_residual(output, residual, input, weight, eps); + + // use float result as baseline + auto output_ref = detail::rms_norm_residual(input, residual_ref, weight, eps); + + EXPECT_TRUE(torch::allclose(output, + output_ref, + /*rtol=*/1e-02, + /*atol=*/1e-03)); + + EXPECT_TRUE(torch::allclose(residual, + residual_ref, /*rtol=*/1e-03, /*atol=*/1e-05)); } diff --git a/src/models/huggingface/baichuan.h b/src/models/huggingface/baichuan.h index 6bb9ee9f..3e205de5 100644 --- a/src/models/huggingface/baichuan.h +++ b/src/models/huggingface/baichuan.h @@ -4,6 +4,7 @@ #include #include +#include "chat_template/coded_chat_template.h" #include "layers/activation.h" #include "layers/attention/attention.h" #include "layers/attention/handler.h" @@ -204,14 +205,7 @@ class BaichuanDecoderLayerImpl : public torch::nn::Module { torch::Tensor& residual, KVCache& kv_cache, const InputParameters& input_params) { - torch::Tensor hidden_states; - if (!residual.defined()) { - residual = x; - torch::Tensor placeholder; - hidden_states = input_layernorm_(x, placeholder); - } else { - hidden_states = input_layernorm_(x, residual); - } + auto hidden_states = input_layernorm_(x, residual); hidden_states = self_attn_(hidden_states, positions, kv_cache, input_params); diff --git a/src/models/huggingface/gemma.h b/src/models/huggingface/gemma.h index 6226feb6..30382b35 100644 --- a/src/models/huggingface/gemma.h +++ b/src/models/huggingface/gemma.h @@ -2,9 +2,18 @@ #include +#include "chat_template/coded_chat_template.h" +#include "layers/activation.h" +#include "layers/attention/attention.h" #include "layers/attention/handler.h" #include "layers/embedding.h" +#include "layers/linear.h" #include "layers/normalization.h" +#include "memory/kv_cache.h" +#include "models/model_args.h" +#include "models/model_registry.h" +#include "models/parameters.h" + // gemma model compatible with huggingface weight namespace llm::hf { @@ -191,15 +200,7 @@ class GemmaDecoderLayerImpl : public torch::nn::Module { KVCache& kv_cache, const InputParameters& input_params, torch::Tensor& residual) { - torch::Tensor hidden_states; - - if (!residual.defined()) { - residual = x; - torch::Tensor placeholder; - hidden_states = input_layernorm_(x, placeholder); - } else { - hidden_states = input_layernorm_(x, residual); - } + auto hidden_states = input_layernorm_(x, residual); hidden_states = self_attn_(hidden_states, positions, kv_cache, input_params);