Skip to content

Commit

Permalink
Add embeddings_ layer and supporting utility functions (#3021)
Browse files Browse the repository at this point in the history
* Fix Stride Indexing Bugs in `reorg` and `reorg_gradient` Functions (CPU & CUDA) and Add `add_to` Parameter

* 'add_to' parameter missing in cuda call reorg_gradient.launch_kernel()

* Cleanup: remove using namespace std; (#3016)

* remove using namespace std from headers

* more std::

* more std::

* more std:: on windows stuff

* remove uses of using namespace std::chrono

* do not use C++17 features

* Add Davis suggestion

* revert some more stuff

* revert removing include

* more std::chrono stuff

* fix build error

* Adjust comment formatting to be like other dlib comments

* Add positional encodings layer to Dlib

* Implement embeddings_ layer and add supporting utility functions to tensor_tools.h

* Updates

* Updates

* Updates

* Updates

* Update

* Update dlib/cuda/tensor_tools.h

---------

Co-authored-by: Adrià <1671644+arrufat@users.noreply.github.com>
Co-authored-by: Davis King <davis@dlib.net>
Co-authored-by: Davis E. King <davis685@gmail.com>
  • Loading branch information
4 people authored Oct 17, 2024
1 parent 488ee5c commit 8567075
Show file tree
Hide file tree
Showing 11 changed files with 982 additions and 10 deletions.
5 changes: 5 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"githubPullRequests.ignoredPullRequestBranches": [
"master"
]
}
115 changes: 115 additions & 0 deletions dlib/cuda/cpu_dlib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2421,6 +2421,121 @@ namespace dlib
}

// ------------------------------------------------------------------------------------

void embeddings(
resizable_tensor& dest,
const tensor& src,
const tensor& embs
)
{
DLIB_CASSERT(
src.nr() > 0 &&
embs.num_samples() > 0 &&
embs.k() > 0 &&
embs.nr() == 1 &&
embs.nc() == 1,
"\nsrc.num_samples(): " << src.num_samples() <<
"\nsrc.k(): " << src.k() <<
"\nsrc.nr(): " << src.nr() <<
"\nsrc.nc(): " << src.nc() <<
"\nembs.num_samples(): " << embs.num_samples() <<
"\nembs.k(): " << embs.k() <<
"\nembs.nr(): " << embs.nr() <<
"\nembs.nc(): " << embs.nc()
);

long ns = dest.num_samples(), nk = dest.k(), nr = dest.nr(), nc = dest.nc();
const float* src_data = src.host();
float* dest_data = dest.host();
const float* embs_data = embs.host();
for (long s = 0; s < ns; ++s)
{
for (long k = 0; k < nk; ++k)
{
for (long r = 0; r < nr; ++r)
{
const unsigned long token_idx = static_cast<unsigned long>(src_data[tensor_index(src, s, k, r, 0)]);
if (token_idx < embs.num_samples())
{
for (long c = 0; c < nc; ++c)
dest_data[tensor_index(dest, s, k, r, c)] = embs_data[tensor_index(embs, token_idx, c, 0, 0)];
}
else
{
for (long c = 0; c < nc; ++c)
dest_data[tensor_index(dest, s, k, r, c)] = 0;
}
}
}
}
}

void embeddings_gradient(
const tensor& prev,
const tensor& gradient_input,
tensor& grads,
const tensor& freqs,
float learning_rate,
bool scale
)
{
DLIB_CASSERT(
prev.nr() > 0 &&
gradient_input.num_samples() == prev.num_samples() &&
gradient_input.k() == prev.k() &&
gradient_input.nr() == prev.nr() &&
gradient_input.nc() == grads.k() &&
grads.num_samples() > 0 &&
grads.k() > 0 &&
grads.nr() == 1 &&
grads.nc() == 1,
"\ngradient_input.num_samples(): " << gradient_input.num_samples() <<
"\ngradient_input.k(): " << gradient_input.k() <<
"\ngradient_input.nr(): " << gradient_input.nr() <<
"\ngradient_input.nc(): " << gradient_input.nc() <<
"\nprev.num_samples(): " << prev.num_samples() <<
"\nprev.k(): " << prev.k() <<
"\nprev.nr(): " << prev.nr() <<
"\nprev.nc(): " << prev.nc() <<
"\ngrads.num_samples(): " << grads.num_samples() <<
"\ngrads.k(): " << grads.k() <<
"\ngrads.nr(): " << grads.nr() <<
"\ngrads.nc(): " << grads.nc()
);

const float* prev_data = prev.host();
const float* gradient_input_data = gradient_input.host();
const float* freqs_data = freqs.host();
float* grads_data = grads.host();
long ns = gradient_input.num_samples(), nk = gradient_input.k();
long nr = gradient_input.nr(), nc = gradient_input.nc();

std::vector<dlib::mutex> embedding_mutexes(grads.num_samples());
parallel_for(0, ns * nk, [&](long i)
{
long s = i / nk;
long k = i % nk;

for (long r = 0; r < nr; ++r)
{
const unsigned long token_idx = static_cast<unsigned long>(prev_data[tensor_index(prev, s, k, r, 0)]);
if (token_idx < grads.num_samples())
{
const float freg_token = freqs_data[token_idx];
float freq_scale = 1.0f;

if (scale && freg_token != 0.0f) freq_scale = std::min(0.15f, std::max(1.0f / freg_token, 1.0f));
auto_mutex locker(embedding_mutexes[token_idx]);
for (long c = 0; c < nc; ++c)
{
const float gradient = gradient_input_data[tensor_index(gradient_input, s, k, r, c)];
grads_data[tensor_index(grads, token_idx, c, 0, 0)] -= (gradient * learning_rate * freq_scale);
}
}
}
});
}

// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------

Expand Down
17 changes: 17 additions & 0 deletions dlib/cuda/cpu_dlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,23 @@ namespace dlib
const tensor& gradient_input
);

// -----------------------------------------------------------------------------------

void embeddings(
resizable_tensor& dest,
const tensor& src,
const tensor& embs
);

void embeddings_gradient(
const tensor& prev,
const tensor& gradient_input,
tensor& grads,
const tensor& freqs,
float learning_rate,
bool scale
);

// -----------------------------------------------------------------------------------

class pooling
Expand Down
120 changes: 120 additions & 0 deletions dlib/cuda/cuda_dlib.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2088,6 +2088,126 @@ namespace dlib
row_stride, col_stride, add_to);
}

// ----------------------------------------------------------------------------------------

__global__ void _cuda_embeddings(size_t dsize, size_t dk, size_t dr, size_t dc,
float* d, const float* s, const float* e, size_t es
)
{
for (auto i : grid_stride_range(0, dsize))
{
const auto n = i / (dk * dr * dc);
const auto s_idx = i % (dk * dr * dc);
const auto k = (s_idx / (dr * dc)) % dk;
const auto r = (s_idx / dc) % dr;
const auto c = s_idx % dc;

const unsigned long t_idx = static_cast<unsigned long>(s[(n * dk + k) * dr + r]);

if (t_idx < es)
d[i] = e[t_idx * dc + c];
else
d[i] = 0.0f;
}
}

void embeddings(
resizable_tensor& dest,
const tensor& src,
const tensor& embs
)
{
DLIB_CASSERT(
src.nr() > 0 &&
embs.num_samples() > 0 &&
embs.k() > 0 &&
embs.nr() == 1 &&
embs.nc() == 1,
"\nsrc.num_samples(): " << src.num_samples() <<
"\nsrc.k(): " << src.k() <<
"\nsrc.nr(): " << src.nr() <<
"\nsrc.nc(): " << src.nc() <<
"\nembs.num_samples(): " << embs.num_samples() <<
"\nembs.k(): " << embs.k() <<
"\nembs.nr(): " << embs.nr() <<
"\nembs.nc(): " << embs.nc()
);

const long dk = dest.k();
const long dr = dest.nr();
const long dc = dest.nc();

launch_kernel(_cuda_embeddings, dest.size(), dk, dr, dc,
dest.device(), src.device(), embs.device(), embs.num_samples());
}

__global__ void _cuda_embeddings_gradient(size_t ssize, size_t sk, size_t sr, size_t sc,
const float* o, const float* gi, float* g, const float* f, float lr, bool sl, size_t es
)
{
for (auto i : grid_stride_range(0, ssize))
{
const auto n = i / (sk * sr * sc);
const auto s_idx = i % (sk * sr * sc);
const auto k = (s_idx / (sr * sc)) % sk;
const auto r = (s_idx / sc) % sr;
const auto c = s_idx % sc;

const unsigned long t_idx = static_cast<unsigned long>(o[(n * sk + k) * sr + r]);
if (t_idx < es)
{
const float f_t = f[t_idx];
float f_s = 1.0f;

if (sl && f_t != 0.0f) f_s = fminf(0.15f, fmaxf(1.0f / f_t, 1.0f));
if (f_t > 1) atomicAdd(&g[t_idx * sc + c], -gi[i] * lr * f_s);
else g[t_idx * sc + c] -= gi[i] * lr * f_s;
}
}
}

void embeddings_gradient(
const tensor& prev,
const tensor& gradient_input,
tensor& grads,
const tensor& freqs,
float learning_rate,
bool scale
)
{
DLIB_CASSERT(
prev.nr() > 0 &&
gradient_input.num_samples() == prev.num_samples() &&
gradient_input.k() == prev.k() &&
gradient_input.nr() == prev.nr() &&
gradient_input.nc() == grads.k() &&
grads.num_samples() > 0 &&
grads.k() > 0 &&
grads.nr() == 1 &&
grads.nc() == 1,
"\ngradient_input.num_samples(): " << gradient_input.num_samples() <<
"\ngradient_input.k(): " << gradient_input.k() <<
"\ngradient_input.nr(): " << gradient_input.nr() <<
"\ngradient_input.nc(): " << gradient_input.nc() <<
"\nprev.num_samples(): " << prev.num_samples() <<
"\nprev.k(): " << prev.k() <<
"\nprev.nr(): " << prev.nr() <<
"\nprev.nc(): " << prev.nc() <<
"\ngrads.num_samples(): " << grads.num_samples() <<
"\ngrads.k(): " << grads.k() <<
"\ngrads.nr(): " << grads.nr() <<
"\ngrads.nc(): " << grads.nc()
);

const long sk = gradient_input.k();
const long sr = gradient_input.nr();
const long sc = gradient_input.nc();

launch_kernel(_cuda_embeddings_gradient, gradient_input.size(), sk, sr, sc,
prev.device(), gradient_input.device(), grads.device(), freqs.device(),
learning_rate, scale, grads.num_samples());
}

// ----------------------------------------------------------------------------------------

__global__ void _cuda_layer_normalize(
Expand Down
17 changes: 17 additions & 0 deletions dlib/cuda/cuda_dlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,23 @@ namespace dlib
const tensor& gradient_input
);

// -----------------------------------------------------------------------------------

void embeddings(
resizable_tensor& dest,
const tensor& src,
const tensor& embs
);

void embeddings_gradient(
const tensor& prev,
const tensor& gradient_input,
tensor& grads,
const tensor& freqs,
float learning_rate,
bool scale
);

// ----------------------------------------------------------------------------------------

void copy_tensor(
Expand Down
31 changes: 31 additions & 0 deletions dlib/cuda/tensor_tools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,37 @@ namespace dlib { namespace tt
#endif
}

// ----------------------------------------------------------------------------------------

void embeddings(
resizable_tensor& dest,
const tensor& src,
const tensor& embs
)
{
#ifdef DLIB_USE_CUDA
cuda::embeddings(dest, src, embs);
#else
cpu::embeddings(dest, src, embs);
#endif
}

void embeddings_gradient(
const tensor& prev,
const tensor& gradient_input,
tensor& grads,
const tensor& freqs,
float learning_rate,
bool scale
)
{
#ifdef DLIB_USE_CUDA
cuda::embeddings_gradient(prev, gradient_input, grads, freqs, learning_rate, scale);
#else
cpu::embeddings_gradient(prev, gradient_input, grads, freqs, learning_rate, scale);
#endif
}

// ----------------------------------------------------------------------------------------

}}
Expand Down
Loading

0 comments on commit 8567075

Please sign in to comment.