Skip to content

Commit

Permalink
Reduce code duplications in matrixMulWithMdSpan (#2326)
Browse files Browse the repository at this point in the history
  • Loading branch information
psychocoderHPC authored Jul 31, 2024
1 parent d1cc2e0 commit f612f97
Showing 1 changed file with 20 additions and 19 deletions.
39 changes: 20 additions & 19 deletions example/matrixMulWithMdspan/src/matrixMulMdSpan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,21 @@ struct MatrixMulKernel
}
};

// initialize the matrix
template<typename TMdSpan>
inline void initializeMatrx(TMdSpan& span)
{
auto const numColumns = span.extent(1);
for(Idx i = 0; i < span.extent(0); ++i)
{
for(Idx j = 0; j < numColumns; ++j)
{
// fill with some data
span(i, j) = static_cast<DataType>(i * numColumns + j);
}
}
}

// In standard projects, you typically do not execute the code with any available accelerator.
// Instead, a single accelerator is selected once from the active accelerators and the kernels are executed with the
// selected accelerator only. If you use the example as the starting point for your project, you can rename the
Expand All @@ -82,9 +97,9 @@ auto example(TAccTag const&) -> int
using Dim = alpaka::DimInt<2>;

// Define matrix dimensions, A is MxK and B is KxN
const Idx M = 1024;
const Idx N = 512;
const Idx K = 1024;
Idx const M = 1024;
Idx const N = 512;
Idx const K = 1024;

// Define device and queue
using Acc = alpaka::AccCpuSerial<Dim, Idx>;
Expand Down Expand Up @@ -113,22 +128,8 @@ auto example(TAccTag const&) -> int
auto mdHostB = alpaka::experimental::getMdSpan(bufHostB);

// Initialize host matrices
for(Idx i = 0; i < M; ++i)
{
for(Idx j = 0; j < K; ++j)
{
// fill with some data
mdHostA(i, j) = static_cast<DataType>(i * K + j);
}
}
for(Idx i = 0; i < K; ++i)
{
for(Idx j = 0; j < N; ++j)
{
// fill with some data
mdHostB(i, j) = static_cast<DataType>(i * N + j);
}
}
initializeMatrx(mdHostA);
initializeMatrx(mdHostB);

// Allocate device memory
auto bufDevA = alpaka::allocBuf<DataType, Idx>(devAcc, extentA);
Expand Down

0 comments on commit f612f97

Please sign in to comment.