Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] added rms norm residual kernel #125

Merged
merged 1 commit into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 70 additions & 8 deletions src/kernels/layernorm_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/torch.h>

#include "dispatch.h"
#include "reduce_kernel_utils.cuh"

Expand All @@ -23,7 +24,7 @@ __global__ void rms_norm_kernel(T* __restrict__ out,
float variance = 0.0f;

for (int i = tidx; i < n; i += blockDim.x) {
const float x = __ldg(&input[bidx * n + i]);
const float x = input[bidx * n + i];
variance += x * x;
}
variance = block_reduce_sum<float>(variance);
Expand All @@ -34,8 +35,8 @@ __global__ void rms_norm_kernel(T* __restrict__ out,

for (int i = tidx; i < n; i += blockDim.x) {
const int idx = bidx * n + i;
const float x = __ldg(&input[idx]);
out[idx] = (T)(x * s_variance * weight[i]);
const float x = input[idx];
out[idx] = (T)(x * s_variance) * weight[i];
}
}

Expand All @@ -61,6 +62,69 @@ void rms_norm(torch::Tensor& out,
});
}

// calculate the root mean square norm.
// equation: x -> w * x / sqrt(E[x^2] + eps)
// The mean is calculated over the last dimension
// equilvalent to layernorm module in the T5 style No bias and no subtraction of
// mean.
template <typename T>
__global__ void rms_norm_residual_kernel(T* __restrict__ out,
T* __restrict__ residual,
const T* __restrict__ input,
const T* __restrict__ weight,
const float epsilon,
int n) {
const int tidx = threadIdx.x;
const int bidx = blockIdx.x;

__shared__ float s_variance;
float variance = 0.0f;

for (int i = tidx; i < n; i += blockDim.x) {
const int idx = bidx * n + i;
const float r = residual[idx];
const float x = r + input[idx];
residual[idx] = x;
variance += x * x;
}
variance = block_reduce_sum<float>(variance);
if (tidx == 0) {
s_variance = rsqrtf(variance / n + epsilon);
}
__syncthreads();

for (int i = tidx; i < n; i += blockDim.x) {
const int idx = bidx * n + i;
const float x = residual[idx];
out[idx] = (T)(x * s_variance) * weight[i];
}
}

void rms_norm_residual(torch::Tensor& out,
torch::Tensor& residual,
torch::Tensor input,
torch::Tensor weight,
float epsilon) {
DCHECK(input.is_contiguous()) << "input tensor must be contiguous";
DCHECK(out.is_contiguous()) << "output tensor must be contiguous";
DCHECK(residual.is_contiguous()) << "residual tensor must be contiguous";

const int n = input.size(1);

dim3 grid(input.size(0));
dim3 block(std::min(n, 1024));
DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_residual_kernel", [&] {
rms_norm_residual_kernel<scalar_t>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
out.data_ptr<scalar_t>(),
residual.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
n);
});
}

// equation: x -> (x - E[x]) / sqrt(Var[x] + eps) * w + b
// The mean and standard-deviation are calculated over the last dimension
template <typename T>
Expand All @@ -80,8 +144,7 @@ __global__ void layer_norm_kernel(T* __restrict__ out,

// calculate mean of the input.
for (int i = tidx; i < n; i += blockDim.x) {
const int idx = bidx * n + i;
mean += __ldg(&input[idx]);
mean += input[bidx * n + i];
}
mean = block_reduce_sum<float>(mean);
if (tidx == 0) {
Expand All @@ -102,10 +165,9 @@ __global__ void layer_norm_kernel(T* __restrict__ out,

for (int i = tidx; i < n; i += blockDim.x) {
const int idx = bidx * n + i;
float local_out =
(__ldg(&input[idx]) - s_mean) * s_variance * __ldg(&weight[i]);
float local_out = (input[idx] - s_mean) * s_variance * weight[i];
if (bias != nullptr) {
local_out += __ldg(&bias[i]);
local_out += bias[i];
}
out[idx] = (T)(local_out);
}
Expand Down
8 changes: 7 additions & 1 deletion src/kernels/layernorm_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@ void rms_norm(torch::Tensor& out,
torch::Tensor input,
torch::Tensor weight,
float epsilon);


void rms_norm_residual(torch::Tensor& out,
torch::Tensor& residual,
torch::Tensor input,
torch::Tensor weight,
float epsilon);

void layer_norm(torch::Tensor& out,
torch::Tensor input,
torch::Tensor weight,
Expand Down
35 changes: 26 additions & 9 deletions src/layers/normalization.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,30 @@
DECLARE_bool(disable_custom_kernels);
namespace llm {
namespace detail {
inline torch::Tensor rms_norm(torch::Tensor input,
inline torch::Tensor rms_norm(const torch::Tensor& input,
const torch::Tensor& weight,
float eps) {
// it is important to use float to calculate the mean and std
const auto x = input.to(torch::kFloat);
const auto mean = x.pow(/*exponent=*/2).mean(/*dim=*/-1, /*keepdim=*/true);
const auto output = x * torch::rsqrt(mean + eps) * weight;
const auto output = x * torch::rsqrt(mean + eps);
// convert back to the original dtype
return output.to(input);
return output.to(input) * weight;
}

inline torch::Tensor rms_norm_residual(const torch::Tensor& input,
torch::Tensor& residual,
const torch::Tensor& weight,
float eps) {
// it is important to use float for the residual
auto x = input.to(torch::kFloat) + residual.to(torch::kFloat);
residual = x.to(input);

// it is important to use float to calculate the mean and std
const auto mean = x.pow(/*exponent=*/2).mean(/*dim=*/-1, /*keepdim=*/true);
const auto output = x * torch::rsqrt(mean + eps);
// convert back to the original dtype
return output.to(input) * weight;
}

inline torch::Tensor layer_norm(torch::Tensor input,
Expand Down Expand Up @@ -124,7 +139,7 @@ class RMSNormImpl : public torch::nn::Module {
/*requires_grad=*/false);
}

torch::Tensor forward(torch::Tensor input) {
torch::Tensor forward(const torch::Tensor& input) {
if (input.is_cuda() && !FLAGS_disable_custom_kernels) {
auto output = torch::empty_like(input);
kernel::rms_norm(output, input, weight_, eps_);
Expand Down Expand Up @@ -177,20 +192,22 @@ class RMSNormResidualImpl : public torch::nn::Module {
/*requires_grad=*/false);
}

torch::Tensor forward(torch::Tensor input, torch::Tensor& residual) {
torch::Tensor forward(const torch::Tensor& input, torch::Tensor& residual) {
if (input.is_cuda() && !FLAGS_disable_custom_kernels) {
auto output = torch::empty_like(input);
if (residual.defined()) {
input = input + residual;
kernel::rms_norm_residual(output, residual, input, weight_, eps_);
} else {
residual = input;
kernel::rms_norm(output, input, weight_, eps_);
}
kernel::rms_norm(output, input, weight_, eps_);
return output;
}

if (residual.defined()) {
input = input + residual;
residual = input;
return detail::rms_norm_residual(input, residual, weight_, eps_);
}
residual = input;
return detail::rms_norm(input, weight_, eps_);
}

Expand Down
38 changes: 35 additions & 3 deletions src/layers/normalization_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,43 @@ TEST(NormalizationTest, RMSNormKernel) {
kernel::rms_norm(output, input, weight, eps);

// use float result as baseline
auto desired_output =
detail::rms_norm(input.to(torch::kFloat32), weight, eps).to(dtype);
auto output_ref = detail::rms_norm(input, weight, eps);

EXPECT_TRUE(torch::allclose(output,
desired_output,
output_ref,
/*rtol=*/1e-03,
/*atol=*/1e-05));
}

TEST(NormalizationTest, RMSNormResidualKernel) {
const auto dtype = torch::kHalf;
const auto device = torch::kCUDA;
const auto options = torch::dtype(dtype).device(device);

const int64_t dim = 1024;
const float eps = 1e-5;

// generate weight
const auto weight = torch::rand({dim}, options);

// verify output
const auto input = torch::randn({100, dim}, options);
auto residual = torch::randn({100, dim}, options);
auto residual_ref = residual.clone();

auto output = torch::empty_like(input);
kernel::rms_norm_residual(output, residual, input, weight, eps);

// use float result as baseline
auto output_ref = detail::rms_norm_residual(input, residual_ref, weight, eps);

EXPECT_TRUE(torch::allclose(output,
output_ref,
/*rtol=*/1e-02,
/*atol=*/1e-03));

EXPECT_TRUE(torch::allclose(residual,
residual_ref,
/*rtol=*/1e-03,
/*atol=*/1e-05));
}
Expand Down
10 changes: 2 additions & 8 deletions src/models/huggingface/baichuan.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <torch/torch.h>
#include <torch/types.h>

#include "chat_template/coded_chat_template.h"
#include "layers/activation.h"
#include "layers/attention/attention.h"
#include "layers/attention/handler.h"
Expand Down Expand Up @@ -204,14 +205,7 @@ class BaichuanDecoderLayerImpl : public torch::nn::Module {
torch::Tensor& residual,
KVCache& kv_cache,
const InputParameters& input_params) {
torch::Tensor hidden_states;
if (!residual.defined()) {
residual = x;
torch::Tensor placeholder;
hidden_states = input_layernorm_(x, placeholder);
} else {
hidden_states = input_layernorm_(x, residual);
}
auto hidden_states = input_layernorm_(x, residual);

hidden_states =
self_attn_(hidden_states, positions, kv_cache, input_params);
Expand Down
19 changes: 10 additions & 9 deletions src/models/huggingface/gemma.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,18 @@

#include <torch/torch.h>

#include "chat_template/coded_chat_template.h"
#include "layers/activation.h"
#include "layers/attention/attention.h"
#include "layers/attention/handler.h"
#include "layers/embedding.h"
#include "layers/linear.h"
#include "layers/normalization.h"
#include "memory/kv_cache.h"
#include "models/model_args.h"
#include "models/model_registry.h"
#include "models/parameters.h"

// gemma model compatible with huggingface weight
namespace llm::hf {

Expand Down Expand Up @@ -191,15 +200,7 @@ class GemmaDecoderLayerImpl : public torch::nn::Module {
KVCache& kv_cache,
const InputParameters& input_params,
torch::Tensor& residual) {
torch::Tensor hidden_states;

if (!residual.defined()) {
residual = x;
torch::Tensor placeholder;
hidden_states = input_layernorm_(x, placeholder);
} else {
hidden_states = input_layernorm_(x, residual);
}
auto hidden_states = input_layernorm_(x, residual);

hidden_states =
self_attn_(hidden_states, positions, kv_cache, input_params);
Expand Down
Loading