Skip to content

Commit

Permalink
General update of the rms_norm_ class
Browse files Browse the repository at this point in the history
  • Loading branch information
Cydral committed Sep 7, 2024
1 parent 03cdf93 commit 1083aab
Show file tree
Hide file tree
Showing 11 changed files with 494 additions and 316 deletions.
125 changes: 69 additions & 56 deletions dlib/cuda/cpu_dlib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1457,117 +1457,130 @@ namespace dlib
const tensor& gamma
)
{
const long num = src.k() * src.nr() * src.nc();
DLIB_CASSERT(
src.k() == gamma.k() &&
src.nr() == gamma.nr() &&
src.nc() == gamma.nc() &&
gamma.k() == src.k() &&
gamma.nr() == 1 &&
gamma.nc() == 1 &&
eps > 0,
"\nsrc.k(): " << src.k() <<
"\ngamma.k(): " << gamma.k() <<
"\ngamma.nr(): " << gamma.nr() <<
"\ngamma.nc(): " << gamma.nc() <<
"\nsrc.k(): " << src.k() <<
"\nsrc.nr(): " << src.nr() <<
"\nsrc.nc(): " << src.nc() <<
"\neps: " << eps
);

const long ns = src.num_samples();
const long ks = src.k();
const long num = src.nr() * src.nc();

dest.copy_size(src);
scale.set_size(src.num_samples());
scale.set_size(ns);

// Compute RMS
const auto p_scale = scale.host();
auto p_src = src.host();
for (long n = 0; n < src.num_samples(); ++n)
// Compute RMS values
scale = 0;
const float* p_src = src.host();
float* p_scale = scale.host();
for (long n = 0; n < ns; ++n)
{
float sum_squares = 0;
for (long i = 0; i < num; ++i)
for (long k = 0; k < ks; ++k)
{
float val = p_src[n * num + i];
sum_squares += val * val;
for (long i = 0; i < num; ++i)
{
p_scale[n] += (*p_src) * (*p_src);
++p_src;
}
}
p_scale[n] = sum_squares / num;
}
// Compute RMS inverse
for (long n = 0; n < src.num_samples(); ++n)
{
p_scale[n] = 1.0f / std::sqrt(p_scale[n] + eps);
p_scale[n] = 1.0f / std::sqrt(p_scale[n] / (ks * num) + static_cast<float>(eps));
}
scale.host();

// Apply RMS normalization
p_src = src.host();
auto p_dest = dest.host();
auto p_gamma = gamma.host();
for (long n = 0; n < src.num_samples(); ++n)
float* p_dest = dest.host();
const float* p_gamma = gamma.host();
for (long n = 0; n < ns; ++n)
{
for (long i = 0; i < num; ++i)
for (long k = 0; k < ks; ++k)
{
*p_dest = (*p_src) * p_scale[n] * p_gamma[i];
++p_src;
++p_dest;
for (long i = 0; i < num; ++i)
{
*p_dest = (*p_src) * p_scale[n] * p_gamma[k];
++p_src;
++p_dest;
}
}
}
}

void rms_normalize_gradient(
const double eps,
const tensor& gradient_input,
const tensor& scale,
const tensor& src,
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad,
tensor& dscale
resizable_tensor& dscale
)
{
const long num = src.k() * src.nr() * src.nc();
DLIB_CASSERT(src.num_samples() == scale.size());
DLIB_CASSERT(src.k() == gamma.k());
DLIB_CASSERT(src.nr() == gamma.nr());
DLIB_CASSERT(src.nc() == gamma.nc());
DLIB_CASSERT(have_same_dimensions(gamma, gamma_grad));
DLIB_CASSERT(gamma.k() == src.k());
DLIB_CASSERT(gamma.nr() == 1);
DLIB_CASSERT(gamma.nc() == 1);
DLIB_CASSERT(have_same_dimensions(gradient_input, src));
DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad));
DLIB_CASSERT(have_same_dimensions(gamma_grad, gamma));
DLIB_CASSERT(eps > 0);

const long ns = src.num_samples();
const long ks = src.k();
const long num = src.nr() * src.nc();

gamma_grad = 0;
dscale.copy_size(scale);
dscale = 0;

auto p_grad = gradient_input.host();
auto p_src = src.host();
const auto p_gamma = gamma.host();
const auto p_gamma_grad = gamma_grad.host();
const auto p_scale = scale.host();
auto p_dscale = dscale.host();

dscale = 0;
const auto p_dscale = dscale.host();

for (long n = 0; n < src.num_samples(); ++n)
for (long n = 0; n < ns; ++n)
{
for (long i = 0; i < num; ++i)
const float scale_pow = -0.5f * std::pow(p_scale[n], 3.0f);
for (long k = 0; k < ks; ++k)
{
const float x_hat = (*p_src) * p_scale[n];
p_gamma_grad[i] += (*p_grad) * x_hat;
for (long i = 0; i < num; ++i)
{
const float x_hat = *p_src * p_scale[n];
p_gamma_grad[k] += (*p_grad) * x_hat;

const float dx = *p_grad * p_gamma[i];
p_dscale[n] += dx * (*p_src) * (-0.5) * p_scale[n] * p_scale[n] * p_scale[n];
const float dx = *p_grad * p_gamma[k];
p_dscale[n] += dx * *p_src * scale_pow;

++p_grad;
++p_src;
++p_grad;
++p_src;
}
}
}

p_grad = gradient_input.host();
p_src = src.host();
auto p_src_grad = src_grad.host();
for (long n = 0; n < src.num_samples(); ++n)
const float invnum = 1.0f / (ks * num);
for (long n = 0; n < ns; ++n)
{
for (long i = 0; i < num; ++i)
for (long k = 0; k < ks; ++k)
{
const float dx = *p_grad * p_gamma[i];

*p_src_grad += dx * p_scale[n] + p_dscale[n] * 2 * (*p_src) / num;
for (long i = 0; i < num; ++i)
{
const float dx = *p_grad * p_gamma[k];
*p_src_grad += dx * p_scale[n] + p_dscale[n] * 2 * *p_src * invnum;

++p_grad;
++p_src;
++p_src_grad;
++p_grad;
++p_src;
++p_src_grad;
}
}
}
}
Expand Down
3 changes: 1 addition & 2 deletions dlib/cuda/cpu_dlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,14 +266,13 @@ namespace dlib
);

void rms_normalize_gradient(
const double eps,
const tensor& gradient_input,
const tensor& scale,
const tensor& src,
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad,
tensor& dscale
resizable_tensor& dscale
);

// -----------------------------------------------------------------------------------
Expand Down
Loading

0 comments on commit 1083aab

Please sign in to comment.