Skip to content

Commit

Permalink
Add RMS Normalization Layer (davisking#2999)
Browse files Browse the repository at this point in the history
* Add RMS Normalization Layer

* Update dnn.cpp

* Missing entry in visitors.h to take into account the new rms_norm_ layer

* Fix test function name

* Fix dangling pointer issue in CUDA implementation of rms_normalize_gradient

* Fixing the dnn.cpp test program for the new rms_norm_ layer

* General update of the rms_norm_ class
  • Loading branch information
Cydral authored Sep 7, 2024
1 parent 253098e commit fafdac3
Show file tree
Hide file tree
Showing 11 changed files with 863 additions and 4 deletions.
138 changes: 138 additions & 0 deletions dlib/cuda/cpu_dlib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1447,6 +1447,144 @@ namespace dlib
}
}

// -----------------------------------------------------------------------------------

void rms_normalize(
const double eps,
resizable_tensor& dest,
resizable_tensor& scale,
const tensor& src,
const tensor& gamma
)
{
DLIB_CASSERT(
gamma.k() == src.k() &&
gamma.nr() == 1 &&
gamma.nc() == 1 &&
eps > 0,
"\nsrc.k(): " << src.k() <<
"\ngamma.k(): " << gamma.k() <<
"\ngamma.nr(): " << gamma.nr() <<
"\ngamma.nc(): " << gamma.nc() <<
"\neps: " << eps
);

const long ns = src.num_samples();
const long ks = src.k();
const long num = src.nr() * src.nc();

dest.copy_size(src);
scale.set_size(ns);

// Compute RMS values
scale = 0;
const float* p_src = src.host();
float* p_scale = scale.host();
for (long n = 0; n < ns; ++n)
{
for (long k = 0; k < ks; ++k)
{
for (long i = 0; i < num; ++i)
{
p_scale[n] += (*p_src) * (*p_src);
++p_src;
}
}
p_scale[n] = 1.0f / std::sqrt(p_scale[n] / (ks * num) + static_cast<float>(eps));
}
scale.host();

// Apply RMS normalization
p_src = src.host();
float* p_dest = dest.host();
const float* p_gamma = gamma.host();
for (long n = 0; n < ns; ++n)
{
for (long k = 0; k < ks; ++k)
{
for (long i = 0; i < num; ++i)
{
*p_dest = (*p_src) * p_scale[n] * p_gamma[k];
++p_src;
++p_dest;
}
}
}
}

void rms_normalize_gradient(
const tensor& gradient_input,
const tensor& scale,
const tensor& src,
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad,
resizable_tensor& dscale
)
{
DLIB_CASSERT(src.num_samples() == scale.size());
DLIB_CASSERT(have_same_dimensions(gamma, gamma_grad));
DLIB_CASSERT(gamma.k() == src.k());
DLIB_CASSERT(gamma.nr() == 1);
DLIB_CASSERT(gamma.nc() == 1);
DLIB_CASSERT(have_same_dimensions(gradient_input, src));
DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad));

const long ns = src.num_samples();
const long ks = src.k();
const long num = src.nr() * src.nc();

gamma_grad = 0;
dscale.copy_size(scale);
dscale = 0;

auto p_grad = gradient_input.host();
auto p_src = src.host();
const auto p_gamma = gamma.host();
const auto p_gamma_grad = gamma_grad.host();
const auto p_scale = scale.host();
auto p_dscale = dscale.host();

for (long n = 0; n < ns; ++n)
{
const float scale_pow = -0.5f * std::pow(p_scale[n], 3.0f);
for (long k = 0; k < ks; ++k)
{
for (long i = 0; i < num; ++i)
{
const float x_hat = *p_src * p_scale[n];
p_gamma_grad[k] += (*p_grad) * x_hat;

const float dx = *p_grad * p_gamma[k];
p_dscale[n] += dx * *p_src * scale_pow;

++p_grad;
++p_src;
}
}
}

p_grad = gradient_input.host();
p_src = src.host();
auto p_src_grad = src_grad.host();
const float invnum = 1.0f / (ks * num);
for (long n = 0; n < ns; ++n)
{
for (long k = 0; k < ks; ++k)
{
for (long i = 0; i < num; ++i)
{
const float dx = *p_grad * p_gamma[k];
*p_src_grad += dx * p_scale[n] + p_dscale[n] * 2 * *p_src * invnum;

++p_grad;
++p_src;
++p_src_grad;
}
}
}
}

// -----------------------------------------------------------------------------------

void threshold (
Expand Down
20 changes: 20 additions & 0 deletions dlib/cuda/cpu_dlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,26 @@ namespace dlib
resizable_tensor& dvars
);

// -----------------------------------------------------------------------------------

void rms_normalize(
const double eps,
resizable_tensor& dest,
resizable_tensor& scale,
const tensor& src,
const tensor& gamma
);

void rms_normalize_gradient(
const tensor& gradient_input,
const tensor& scale,
const tensor& src,
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad,
resizable_tensor& dscale
);

// -----------------------------------------------------------------------------------

void threshold (
Expand Down
160 changes: 160 additions & 0 deletions dlib/cuda/cuda_dlib.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2280,6 +2280,166 @@ namespace dlib
dmeans.device(), dvars.device(), eps, src.num_samples(), src.k(), num);
}

// ----------------------------------------------------------------------------------------

__global__ void _cuda_rms_normalize(
float* dest,
float* scale,
const float* src,
const float* gamma,
float eps,
size_t ns,
size_t ks,
size_t num
)
{
for (auto n : grid_stride_range_y(0, ns))
{
const auto ps = src + n * ks * num;
float sum_squares = 0.0f;
for (auto i : grid_stride_range(0, ks * num))
{
sum_squares += ps[i] * ps[i];
}
warp_reduce_atomic_add(scale[n], sum_squares / (ks * num));
}
__syncthreads();

for (auto n : grid_stride_range_y(0, ns))
{
for (auto i : grid_stride_range(0, 1))
{
scale[n] = 1.0f / std::sqrt(scale[n] + eps);
}
}
__syncthreads();

for (auto n : grid_stride_range_y(0, ns))
{
const auto ps = src + n * ks * num;
const auto pd = dest + n * ks * num;
for (auto i : grid_stride_range(0, ks * num))
{
pd[i] = ps[i] * scale[n] * gamma[i / num];
}
}
}

void rms_normalize(
const double eps,
resizable_tensor& dest,
resizable_tensor& scale,
const tensor& src,
const tensor& gamma
)
{
DLIB_CASSERT(
gamma.k() == src.k() &&
gamma.nr() == 1 &&
gamma.nc() == 1 &&
eps > 0,
"\nsrc.k(): " << src.k() <<
"\ngamma.k(): " << gamma.k() <<
"\ngamma.nr(): " << gamma.nr() <<
"\ngamma.nc(): " << gamma.nc() <<
"\neps: " << eps
);

const long ns = src.num_samples();
const long ks = src.k();
const long num = src.nr() * src.nc();

dest.copy_size(src);
scale.set_size(ns);
scale = 0;

launch_kernel(_cuda_rms_normalize, max_jobs(ks * num, ns),
dest.device(), scale.device(), src.device(), gamma.device(), eps, ns, ks, num);
}

// ----------------------------------------------------------------------------------------

__global__ void _cuda_rms_normalize_gradient(
float* src_grad,
float* gamma_grad,
float* dscale,
const float* src,
const float* gradient_input,
const float* scale,
const float* gamma,
size_t ns,
size_t ks,
size_t num
)
{
for (auto nk : grid_stride_range_y(0, ns * ks))
{
const auto n = nk / ks;
const auto k = nk % ks;
const auto ps = src + (n * ks + k) * num;
const auto pgi = gradient_input + (n * ks + k) * num;
const float scale_pow = -0.5f * std::pow(scale[n], 3.0f);
float temp_gg = 0.0f;
float temp_ds = 0.0f;
for (auto i : grid_stride_range(0, num))
{
const float x_hat = ps[i] * scale[n];
const float dx = pgi[i] * gamma[i / num];
temp_gg += pgi[i] * x_hat;
temp_ds += dx * ps[i] * scale_pow;
}
warp_reduce_atomic_add(gamma_grad[k], temp_gg);
warp_reduce_atomic_add(dscale[n], temp_ds);
}
__syncthreads();

const float invnum = 1.0f / (ks * num);
for (auto n : grid_stride_range_y(0, ns))
{
const auto ps = src + n * ks * num;
const auto pgi = gradient_input + n * ks * num;
const auto psg = src_grad + n * ks * num;
for (auto i : grid_stride_range(0, ks * num))
{
const float dx = pgi[i] * gamma[i / num];
psg[i] += dx * scale[n] + dscale[n] * 2 * ps[i] * invnum;
}
}
}

void rms_normalize_gradient(
const tensor& gradient_input,
const tensor& scale,
const tensor& src,
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad,
resizable_tensor& dscale
)
{
DLIB_CASSERT(src.num_samples() == scale.size());
DLIB_CASSERT(have_same_dimensions(gamma, gamma_grad));
DLIB_CASSERT(gamma.k() == src.k());
DLIB_CASSERT(gamma.nr() == 1);
DLIB_CASSERT(gamma.nc() == 1);
DLIB_CASSERT(have_same_dimensions(gradient_input, src));
DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad));

const long ns = src.num_samples();
const long ks = src.k();
const long num = src.nr() * src.nc();

gamma_grad = 0;
dscale.copy_size(scale);
dscale = 0;

// Lancement du kernel CUDA
launch_kernel(_cuda_rms_normalize_gradient, max_jobs(ks * num, ns),
src_grad.device(), gamma_grad.device(), dscale.device(),
src.device(), gradient_input.device(), scale.device(), gamma.device(),
ns, ks, num);
}

// ----------------------------------------------------------------------------------------

__global__ void _cuda_copy_tensor_add_to (float* dest, size_t size, const float* src, size_t dest_stride, size_t src_stride, size_t block_size)
Expand Down
20 changes: 20 additions & 0 deletions dlib/cuda/cuda_dlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,26 @@ namespace dlib
resizable_tensor& dvars
);

// -----------------------------------------------------------------------------------

void rms_normalize(
const double eps,
resizable_tensor& dest,
resizable_tensor& scale,
const tensor& src,
const tensor& gamma
);

void rms_normalize_gradient(
const tensor& gradient_input,
const tensor& scale,
const tensor& src,
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad,
resizable_tensor& dscale
);

// -----------------------------------------------------------------------------------

void threshold (
Expand Down
Loading

0 comments on commit fafdac3

Please sign in to comment.