Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multm_prev_ layer and enhance gemm() function for PLANE_WISE operations #3020

Merged
merged 36 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
4d698fc
Fix Stride Indexing Bugs in `reorg` and `reorg_gradient` Functions (C…
Cydral Sep 16, 2024
1d73b6c
'add_to' parameter missing in cuda call reorg_gradient.launch_kernel()
Cydral Sep 20, 2024
c343779
Cleanup: remove using namespace std; (#3016)
arrufat Sep 23, 2024
724ec09
Merge branch 'refs/heads/master' into Cydral-master
davisking Sep 23, 2024
4dca9b2
fix build error
davisking Sep 23, 2024
2f68a11
Adjust comment formatting to be like other dlib comments
davisking Sep 23, 2024
64e3471
Merge branch 'davisking:master' into master
Cydral Sep 23, 2024
640c02f
Add positional encodings layer to Dlib
Cydral Sep 24, 2024
0f1e250
Add multm_prev layer and enhance gemm() function for PLANE_WISE opera…
Cydral Sep 26, 2024
e8e10ce
Updates
Cydral Sep 26, 2024
06a7f6a
Updates
Cydral Sep 26, 2024
d40171d
Merge branch 'master' into multm-prev-layer
Cydral Sep 30, 2024
0d60627
Resynchronization with tril_ class
Cydral Sep 30, 2024
ed39b2c
Delete .vscode/settings.json
Cydral Oct 6, 2024
8e2a48c
Merge branch 'master' into multm-prev-layer
Cydral Nov 4, 2024
300a8c6
Remove duplicates
Cydral Nov 4, 2024
d173fbd
Small improvements to PLANE_WISE in gemm() function
Cydral Nov 8, 2024
c81efb7
Same improvements for the CPU version
Cydral Nov 11, 2024
89746e2
Merge branch 'davisking:master' into multm-prev-layer
Cydral Nov 18, 2024
3d60227
Introducing a new enum for operation modes in tensor computations
Cydral Nov 18, 2024
a257f02
Remove a test duplicated call in dnn tests
Cydral Nov 18, 2024
21dc524
Remove duplicated declaration
Cydral Nov 18, 2024
439bb87
Comment fixed
Cydral Nov 18, 2024
ca01599
Fixing the Cuda compilation
Cydral Dec 7, 2024
2772dca
Merging with updated softmax_ layer
Cydral Dec 9, 2024
1ff436e
Fixing header for CPU compilation
Cydral Dec 9, 2024
274f32f
Adding a missing cast
Cydral Dec 9, 2024
8685ed8
Test fixed to use the new operation_mode enum
Cydral Dec 10, 2024
275bafc
softmaxm test fixed
Cydral Dec 10, 2024
6beab3b
Enum test removed
Cydral Dec 16, 2024
39b09d9
Enum test removed
Cydral Dec 16, 2024
caed8ff
Fixing indentation
Cydral Dec 16, 2024
fbaa299
Fixing indentation
Cydral Dec 16, 2024
f2dea1e
Test removed
Cydral Dec 16, 2024
c9cc82f
Move the operation_mode enumeration to its own header
Cydral Dec 17, 2024
efda8e7
Use operation_mode instead of unsigned long
davisking Dec 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 106 additions & 43 deletions dlib/cuda/cublas_dlibapi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,55 +101,118 @@ namespace dlib
const tensor& lhs,
bool trans_lhs,
const tensor& rhs,
bool trans_rhs
bool trans_rhs,
size_t g_mode
)
{
// Recall that BLAS uses column major order so to deal with that we flip the
// order of the lhs and rhs arguments.
const auto transa = trans_lhs ? CUBLAS_OP_T : CUBLAS_OP_N;
const auto transb = trans_rhs ? CUBLAS_OP_T : CUBLAS_OP_N;

const int dest_nr = dest.num_samples();
const int dest_nc = dest.size()/dest_nr;
const int lhs_nr = lhs.num_samples();
const int lhs_nc = lhs.size()/lhs_nr;
const int rhs_nr = rhs.num_samples();
const int rhs_nc = rhs.size()/rhs_nr;
if (trans_lhs && trans_rhs)
if (g_mode == 0) // gemm_mode::CHANNEL_WISE
{
DLIB_ASSERT( dest_nr == lhs_nc &&
dest_nc == rhs_nr &&
lhs_nr == rhs_nc)
}
else if (!trans_lhs && trans_rhs)
{
DLIB_ASSERT( dest_nr == lhs_nr &&
dest_nc == rhs_nr &&
lhs_nc == rhs_nc)
}
else if (trans_lhs && !trans_rhs)
{
DLIB_ASSERT( dest_nr == lhs_nc &&
dest_nc == rhs_nc &&
lhs_nr == rhs_nr)
// Recall that BLAS uses column major order so to deal with that we flip the
// order of the lhs and rhs arguments.
const auto transa = trans_lhs ? CUBLAS_OP_T : CUBLAS_OP_N;
const auto transb = trans_rhs ? CUBLAS_OP_T : CUBLAS_OP_N;

const int dest_nr = dest.num_samples();
const int dest_nc = dest.size() / dest_nr;
const int lhs_nr = lhs.num_samples();
const int lhs_nc = lhs.size() / lhs_nr;
const int rhs_nr = rhs.num_samples();
const int rhs_nc = rhs.size() / rhs_nr;
if (trans_lhs && trans_rhs)
{
DLIB_ASSERT(dest_nr == lhs_nc &&
dest_nc == rhs_nr &&
lhs_nr == rhs_nc)
}
else if (!trans_lhs && trans_rhs)
{
DLIB_ASSERT(dest_nr == lhs_nr &&
dest_nc == rhs_nr &&
lhs_nc == rhs_nc)
}
else if (trans_lhs && !trans_rhs)
{
DLIB_ASSERT(dest_nr == lhs_nc &&
dest_nc == rhs_nc &&
lhs_nr == rhs_nr)
}
else
{
DLIB_ASSERT(dest_nr == lhs_nr &&
dest_nc == rhs_nc &&
lhs_nc == rhs_nr)
}

const int k = trans_rhs ? rhs_nc : rhs_nr;
CHECK_CUBLAS(cublasSgemm(context(),
transb,
transa,
dest_nc, dest_nr, k,
&alpha,
rhs.device(), rhs_nc,
lhs.device(), lhs_nc,
&beta,
dest.device(), dest_nc));
}
else
else if (g_mode == 1) // gemm_mode::PLANE_WISE
{
DLIB_ASSERT( dest_nr == lhs_nr &&
dest_nc == rhs_nc &&
lhs_nc == rhs_nr)
}
const auto transa = trans_lhs ? CUBLAS_OP_T : CUBLAS_OP_N;
const auto transb = trans_rhs ? CUBLAS_OP_T : CUBLAS_OP_N;

long num_samples = std::min({ lhs.num_samples(), rhs.num_samples(), dest.num_samples() });
long num_channels = std::min({ lhs.k(), rhs.k(), dest.k() });

auto is_matrix = [](const auto& tensor) {
return ((tensor.num_samples() * tensor.k() == 1 && tensor.nr() * tensor.nc() > 1) ||
(tensor.num_samples() * tensor.k() > 1 && tensor.nr() * tensor.nc() == 1));
};
const bool lhs_is_matrix = is_matrix(lhs), rhs_is_matrix = is_matrix(rhs), dest_is_matrix = is_matrix(dest);

if (lhs_is_matrix && rhs_is_matrix && dest_is_matrix) num_samples = num_channels = 1;

size_t lhs_rows = lhs.nr();
size_t lhs_cols = lhs.nc();
if (lhs_is_matrix && (lhs.num_samples() > 1 || lhs.k() > 1)) {
lhs_rows = lhs.num_samples();
lhs_cols = lhs.k();
}
size_t rhs_rows = rhs.nr();
size_t rhs_cols = rhs.nc();
if (rhs_is_matrix && (rhs.num_samples() > 1 || rhs.k() > 1)) {
rhs_rows = rhs.num_samples();
rhs_cols = rhs.k();
}
size_t dest_rows = dest.nr();
size_t dest_cols = dest.nc();
if (dest_is_matrix && (dest.num_samples() > 1 || dest.k() > 1)) {
dest_rows = dest.num_samples();
dest_cols = dest.k();
}

const size_t lhs_plane_size = lhs_rows * lhs_cols;
const size_t rhs_plane_size = rhs_rows * rhs_cols;
const size_t dest_plane_size = dest_rows * dest_cols;

const int k = trans_rhs ? rhs_nc : rhs_nr;
CHECK_CUBLAS(cublasSgemm(context(),
transb,
transa,
dest_nc, dest_nr, k,
&alpha,
rhs.device(), rhs_nc,
lhs.device(), lhs_nc,
&beta,
dest.device(),dest_nc));
for (long b = 0; b < num_samples; ++b)
{
for (long c = 0; c < num_channels; ++c)
{
auto lhs_slice = lhs_is_matrix ? lhs.device() :
lhs.device() + (b * num_channels + c) * lhs_plane_size;
auto rhs_slice = rhs_is_matrix ? rhs.device() :
rhs.device() + (b * num_channels + c) * rhs_plane_size;
auto dest_slice = dest_is_matrix ? dest.device() :
dest.device() + (b * num_channels + c) * dest_plane_size;
const int k = trans_rhs ? rhs_cols : rhs_rows;

CHECK_CUBLAS(cublasSgemm(
context(), transb, transa, dest_cols, dest_rows, k,
&alpha, rhs_slice, rhs_cols, lhs_slice, lhs_cols,
&beta, dest_slice, dest_cols
));
}
}
}
}

// ------------------------------------------------------------------------------------
Expand Down
57 changes: 44 additions & 13 deletions dlib/cuda/cublas_dlibapi.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,52 @@ namespace dlib
const tensor& lhs,
bool trans_lhs,
const tensor& rhs,
bool trans_rhs
bool trans_rhs,
size_t g_mode = 0
);
/*!
requires
- The dimensions of lhs and rhs must be compatible for matrix
multiplication. In particular:
- Let L == trans_lhs ? trans(mat(lhs)) : mat(lhs)
- Let R == trans_rhs ? trans(mat(rhs)) : mat(rhs)
- Let D == mat(dest)
- D.nr() == L.nr() && D.nc() == R.nc()
(i.e. dest must be preallocated and have the correct output dimensions)
- L.nc() == R.nr()
ensures
/*!
requires
- The dimensions of lhs and rhs must be compatible for matrix multiplication.
The specific requirements depend on the g_mode:

For g_mode == 0 (CHANNEL_WISE, default):
- Let L == trans_lhs ? trans(mat(lhs)) : mat(lhs)
- Let R == trans_rhs ? trans(mat(rhs)) : mat(rhs)
- Let D == mat(dest)
- D.nr() == L.nr() && D.nc() == R.nc()
(i.e. dest must be preallocated and have the correct output dimensions)
- L.nc() == R.nr()

For g_mode == 1 (PLANE_WISE):
- lhs.num_samples() == rhs.num_samples() && lhs.k() == rhs.k()
- If !trans_lhs && !trans_rhs:
lhs.nc() == rhs.nr()
dest.nr() == lhs.nr() && dest.nc() == rhs.nc()
- If trans_lhs && !trans_rhs:
lhs.nr() == rhs.nr()
dest.nr() == lhs.nc() && dest.nc() == rhs.nc()
- If !trans_lhs && trans_rhs:
lhs.nc() == rhs.nc()
dest.nr() == lhs.nr() && dest.nc() == rhs.nr()
- If trans_lhs && trans_rhs:
lhs.nr() == rhs.nc()
dest.nr() == lhs.nc() && dest.nc() == rhs.nr()

ensures
- Performs matrix multiplication based on the specified g_mode:

For g_mode == 0 (CHANNEL_WISE):
- performs: dest = alpha*L*R + beta*mat(dest)
!*/
Where L, R, and D are as defined above.

For g_mode == 1 (PLANE_WISE):
- Performs matrix multiplication for each corresponding 2D plane (nr x nc)
in lhs and rhs across all samples and channels.
- The operation is equivalent to performing the following for each sample
and channel:
dest[s][k] = alpha * (lhs[s][k] * rhs[s][k]) + beta * dest[s][k]
Where [s][k] represents the 2D plane for sample s and channel k.
!*/

// ------------------------------------------------------------------------------------

Expand Down
104 changes: 85 additions & 19 deletions dlib/cuda/tensor_tools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,33 +208,99 @@ namespace dlib { namespace tt
const tensor& lhs,
bool trans_lhs,
const tensor& rhs,
bool trans_rhs
bool trans_rhs,
gemm_mode g_mode
)
{
#ifdef DLIB_USE_CUDA
cuda::gemm(beta, dest, alpha, lhs, trans_lhs, rhs, trans_rhs);
cuda::gemm(beta, dest, alpha, lhs, trans_lhs, rhs, trans_rhs, g_mode);
#else
if (beta != 0)
if (g_mode == CHANNEL_WISE)
{
if (trans_lhs && trans_rhs)
dest = alpha*trans(mat(lhs))*trans(mat(rhs)) + beta*mat(dest);
else if (!trans_lhs && trans_rhs)
dest = alpha*mat(lhs)*trans(mat(rhs)) + beta*mat(dest);
else if (trans_lhs && !trans_rhs)
dest = alpha*trans(mat(lhs))*mat(rhs) + beta*mat(dest);
if (beta != 0)
{
if (trans_lhs && trans_rhs)
dest = alpha * trans(mat(lhs)) * trans(mat(rhs)) + beta * mat(dest);
else if (!trans_lhs && trans_rhs)
dest = alpha * mat(lhs) * trans(mat(rhs)) + beta * mat(dest);
else if (trans_lhs && !trans_rhs)
dest = alpha * trans(mat(lhs)) * mat(rhs) + beta * mat(dest);
else
dest = alpha * mat(lhs) * mat(rhs) + beta * mat(dest);
}
else
dest = alpha*mat(lhs)*mat(rhs) + beta*mat(dest);
{
if (trans_lhs && trans_rhs)
dest = alpha * trans(mat(lhs)) * trans(mat(rhs));
else if (!trans_lhs && trans_rhs)
dest = alpha * mat(lhs) * trans(mat(rhs));
else if (trans_lhs && !trans_rhs)
dest = alpha * trans(mat(lhs)) * mat(rhs);
else
dest = alpha * mat(lhs) * mat(rhs);
}
}
else
else if (g_mode == PLANE_WISE)
{
if (trans_lhs && trans_rhs)
dest = alpha*trans(mat(lhs))*trans(mat(rhs));
else if (!trans_lhs && trans_rhs)
dest = alpha*mat(lhs)*trans(mat(rhs));
else if (trans_lhs && !trans_rhs)
dest = alpha*trans(mat(lhs))*mat(rhs);
else
dest = alpha*mat(lhs)*mat(rhs);
auto is_matrix = [](const auto& tensor) {
return ((tensor.num_samples() * tensor.k() == 1 && tensor.nr() * tensor.nc() > 1) ||
(tensor.num_samples() * tensor.k() > 1 && tensor.nr() * tensor.nc() == 1));
};

long num_samples = std::min({ lhs.num_samples(), rhs.num_samples(), dest.num_samples() });
long num_channels = std::min({ lhs.k(), rhs.k(), dest.k() });
const bool lhs_is_matrix = is_matrix(lhs), rhs_is_matrix = is_matrix(rhs), dest_is_matrix = is_matrix(dest);

if (lhs_is_matrix && rhs_is_matrix && dest_is_matrix) {
num_samples = num_channels = 1;
}

long lhs_rows = (lhs_is_matrix && lhs.num_samples() > 1) ? lhs.num_samples() : lhs.nr();
long lhs_cols = (lhs_is_matrix && lhs.k() > 1) ? lhs.k() : lhs.nc();
long rhs_rows = (rhs_is_matrix && rhs.num_samples() > 1) ? rhs.num_samples() : rhs.nr();
long rhs_cols = (rhs_is_matrix && rhs.k() > 1) ? rhs.k() : rhs.nc();
long dest_rows = (dest_is_matrix && dest.num_samples() > 1) ? dest.num_samples() : dest.nr();
long dest_cols = (dest_is_matrix && dest.k() > 1) ? dest.k() : dest.nc();

const size_t lhs_plane_size = lhs_rows * lhs_cols;
const size_t rhs_plane_size = rhs_rows * rhs_cols;
const size_t dest_plane_size = dest_rows * dest_cols;

for (long b = 0; b < num_samples; ++b)
{
for (long c = 0; c < num_channels; ++c)
{
auto lhs_slice = lhs_is_matrix ? alias_tensor(lhs_rows, lhs_cols)(lhs, 0) :
alias_tensor(lhs_rows, lhs_cols)(lhs, (b * num_channels + c) * lhs_plane_size);
auto rhs_slice = rhs_is_matrix ? alias_tensor(rhs_rows, rhs_cols)(rhs, 0) :
alias_tensor(rhs_rows, rhs_cols)(rhs, (b * num_channels + c) * rhs_plane_size);
auto dest_slice = dest_is_matrix ? alias_tensor(dest_rows, dest_cols)(dest, 0) :
alias_tensor(dest_rows, dest_cols)(dest, (b * num_channels + c) * dest_plane_size);

if (beta != 0)
{
if (trans_lhs && trans_rhs)
dest_slice = alpha * trans(mat(lhs_slice)) * trans(mat(rhs_slice)) + beta * mat(dest_slice);
else if (!trans_lhs && trans_rhs)
dest_slice = alpha * mat(lhs_slice) * trans(mat(rhs_slice)) + beta * mat(dest_slice);
else if (trans_lhs && !trans_rhs)
dest_slice = alpha * trans(mat(lhs_slice)) * mat(rhs_slice) + beta * mat(dest_slice);
else
dest_slice = alpha * mat(lhs_slice) * mat(rhs_slice) + beta * mat(dest_slice);
}
else
{
if (trans_lhs && trans_rhs)
dest_slice = alpha * trans(mat(lhs_slice)) * trans(mat(rhs_slice));
else if (!trans_lhs && trans_rhs)
dest_slice = alpha * mat(lhs_slice) * trans(mat(rhs_slice));
else if (trans_lhs && !trans_rhs)
dest_slice = alpha * trans(mat(lhs_slice)) * mat(rhs_slice);
else
dest_slice = alpha * mat(lhs_slice) * mat(rhs_slice);
}
}
}
}
#endif
}
Expand Down
Loading
Loading