Skip to content

Commit

Permalink
Merge pull request #1299 from e10harvey/issue1274
Browse files Browse the repository at this point in the history
src/batched/dense: Add Gemm_DblBuf LayoutLeft operator
  • Loading branch information
e10harvey authored Feb 9, 2022
2 parents 43a0398 + 57b2cf3 commit 8f79037
Showing 1 changed file with 219 additions and 13 deletions.
232 changes: 219 additions & 13 deletions src/batched/dense/impl/KokkosBatched_Gemm_DblBuf_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class BatchedDblBufGemm {

private:
void __run() {
using policy_type = Kokkos::TeamPolicy<execution_space_type>;
using policy_type = Kokkos::TeamPolicy<layout_type, execution_space_type>;
using member_type = typename policy_type::member_type;

// Compile-time expressions required for functor-level register allocations:
Expand Down Expand Up @@ -299,7 +299,44 @@ class BatchedDblBufGemm {
}

KOKKOS_INLINE_FUNCTION
void operator()(const MemberType &member) const {
void __rshmem_and_mul_ll(const int &thread_id, const int &vlane_id,
const unsigned &nk, view_value_type reg_a[REG_M],
view_value_type reg_b[REG_N],
view_value_type reg_c[REG_M][REG_N],
view_type_2d_scratch &svA_scr,
view_type_2d_scratch &svB_scr) const {
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif // KOKKOS_ENABLE_PRAGMA_UNROLL
for (unsigned k = 0; k < nk; ++k) {
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif // KOKKOS_ENABLE_PRAGMA_UNROLL
for (int m = 0; m < REG_M; ++m)
reg_a[m] = svA_scr(k, vlane_id + m * STRIDE_M);

#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif // KOKKOS_ENABLE_PRAGMA_UNROLL
for (int n = 0; n < REG_N; ++n)
reg_b[n] = svB_scr(thread_id + n * STRIDE_N, k);

#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif // KOKKOS_ENABLE_PRAGMA_UNROLL
for (int m = 0; m < REG_M; ++m) {
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif // KOKKOS_ENABLE_PRAGMA_UNROLL
for (int n = 0; n < REG_N; ++n)
__mul(reg_a[m], reg_b[n], reg_c[m][n], __ei.__alpha_mul_tag);
}
}
}

KOKKOS_INLINE_FUNCTION
void operator()(const Kokkos::LayoutRight &,
const MemberType &member) const {
// TODO: use Kokkos view with compile-time size to allocating register??
// Then we can use local deep copy for prefetch_reg population.
// Allocate registers used for prefetching
Expand Down Expand Up @@ -336,12 +373,12 @@ class BatchedDblBufGemm {
Kokkos::parallel_for(
Kokkos::TeamThreadRange(member, 0, STRIDE_M),
[&](const int &thread_id) {
int thread_offset = thread_id + start_m;
int m_offset = thread_id + start_m;

Kokkos::parallel_for(
Kokkos::ThreadVectorRange(member, 0, STRIDE_N),
[&](const int &vlane_id) {
int vlane_offset = vlane_id + start_n;
int n_offset = vlane_id + start_n;

// Here we populate scratch memory with one or more "k" tiles for
// every thread of the team!
Expand All @@ -351,7 +388,7 @@ class BatchedDblBufGemm {
for (int i = 0; i < REG_N * STRIDE_N; i += STRIDE_N)
svB_scr(thread_id, vlane_id + i) =
access_view_bounds_check<view_value_type>(
svB, thread_id, vlane_offset + i,
svB, thread_id, n_offset + i,
__ei.__bounds_check_tag);

#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
Expand All @@ -360,7 +397,7 @@ class BatchedDblBufGemm {
for (int i = 0; i < REG_M * STRIDE_M; i += STRIDE_M)
svA_scr(thread_id + i, vlane_id) =
access_view_bounds_check<view_value_type>(
svA, thread_offset + i, vlane_id,
svA, m_offset + i, vlane_id,
__ei.__bounds_check_tag);

// Wait for A, B to reside in scratch memory
Expand All @@ -380,16 +417,15 @@ class BatchedDblBufGemm {
prefetch_reg_b[i] =
access_view_bounds_check<view_value_type>(
svB, thread_id + k_tile_offset,
vlane_offset + i * STRIDE_N,
__ei.__bounds_check_tag);
n_offset + i * STRIDE_N, __ei.__bounds_check_tag);

#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif // KOKKOS_ENABLE_PRAGMA_UNROLL
for (int i = 0; i < REG_M; ++i)
prefetch_reg_a[i] =
access_view_bounds_check<view_value_type>(
svA, thread_offset + i * STRIDE_M,
svA, m_offset + i * STRIDE_M,
vlane_id + k_tile_offset,
__ei.__bounds_check_tag);

Expand Down Expand Up @@ -434,12 +470,12 @@ class BatchedDblBufGemm {
#pragma unroll
#endif // KOKKOS_ENABLE_PRAGMA_UNROLL
for (int m = 0; m < REG_M; ++m) {
int cm = thread_offset + m * STRIDE_M;
int cm = m_offset + m * STRIDE_M;
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif // KOKKOS_ENABLE_PRAGMA_UNROLL
for (int n = 0; n < REG_N; ++n) {
int cn = vlane_offset + n * STRIDE_N;
int cn = n_offset + n * STRIDE_N;
fma_bounds_check(svC, cm, cn, reg_c[m][n], __alpha,
__ei.__alpha_fma_tag,
__ei.__bounds_check_tag);
Expand All @@ -450,12 +486,182 @@ class BatchedDblBufGemm {
#pragma unroll
#endif // KOKKOS_ENABLE_PRAGMA_UNROLL
for (int m = 0; m < REG_M; ++m) {
int cm = thread_offset + m * STRIDE_M;
int cm = m_offset + m * STRIDE_M;
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif // KOKKOS_ENABLE_PRAGMA_UNROLL
for (int n = 0; n < REG_N; ++n) {
int cn = vlane_offset + n * STRIDE_N;
int cn = n_offset + n * STRIDE_N;
fma_bounds_check(svC, cm, cn, reg_c[m][n], __alpha,
__beta, __ei.__alpha_fma_tag,
__ei.__bounds_check_tag);
}
}
}
});
});
}

KOKKOS_INLINE_FUNCTION
void operator()(const Kokkos::LayoutLeft &,
const MemberType &member) const {
// TODO: use Kokkos view with compile-time size to allocating register??
// Then we can use local deep copy for prefetch_reg population.
// Allocate registers used for prefetching
view_value_type prefetch_reg_a[REG_M] = {0}, prefetch_reg_b[REG_N] = {0};

// Allocate registers used for FMAs
view_value_type reg_a[REG_M] = {0}, reg_b[REG_N] = {0},
reg_c[REG_M][REG_N] = {{0}};
// TODO: look at local loads and stores via nvprof
// TODO: look at GPU trace in nvprof to find out how many registers are
// used.

unsigned batch_idx = member.league_rank() / __n_sub_tiles;

// Compute starting tile offsets for each team into svA, svB, svC
unsigned local_team_idx = member.league_rank() % __n_sub_tiles;
unsigned start_m = (local_team_idx % __tiles_per_row) * TILE_M;
unsigned start_n = (local_team_idx / __tiles_per_row) * TILE_N;

int kk;

// Fetch entire 2-rank sub-matrix
auto svA = subview_wrapper(__A, batch_idx, Kokkos::ALL(), Kokkos::ALL(),
__ei.__batch_layout_tag, __ei.__transA_tag);
auto svB = subview_wrapper(__B, batch_idx, Kokkos::ALL(), Kokkos::ALL(),
__ei.__batch_layout_tag, __ei.__transB_tag);
auto svC = subview_wrapper(__C, batch_idx, Kokkos::ALL(), Kokkos::ALL(),
__ei.__batch_layout_tag);

// Allocate scratch memory buffers used for prefetching
view_type_2d_scratch svA_scr(member.team_scratch(0), TILE_K, TILE_M);
view_type_2d_scratch svB_scr(member.team_scratch(0), TILE_N, TILE_K);

Kokkos::parallel_for(
Kokkos::TeamThreadRange(member, 0, STRIDE_N),
[&](const int &thread_id) {
int n_offset = thread_id + start_n;

Kokkos::parallel_for(
Kokkos::ThreadVectorRange(member, 0, STRIDE_M),
[&](const int &vlane_id) {
int m_offset = vlane_id + start_m;

// Here we populate scratch memory with one or more "k" tiles for
// every thread of the team!
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif // KOKKOS_ENABLE_PRAGMA_UNROLL
for (int i = 0; i < REG_N * STRIDE_N; i += STRIDE_N)
svB_scr(thread_id + i, vlane_id) =
access_view_bounds_check<view_value_type>(
svB, vlane_id, n_offset + i,
__ei.__bounds_check_tag);

#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif // KOKKOS_ENABLE_PRAGMA_UNROLL
for (int i = 0; i < REG_M * STRIDE_M; i += STRIDE_M)
svA_scr(thread_id, vlane_id + i) =
access_view_bounds_check<view_value_type>(
svA, m_offset + i, thread_id,
__ei.__bounds_check_tag);

// Wait for A, B to reside in scratch memory
member.team_barrier();

// Each thread calculates a single dot product in chunks of
// size TILE_K
for (kk = 0; kk < __k - TILE_K; kk += TILE_K) {
int k_tile_offset = kk + TILE_K;

// Get this threads next TILE_K entries from global memory
// Each thread has its own copy of prefetch_reg_b.
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif // KOKKOS_ENABLE_PRAGMA_UNROLL
for (int i = 0; i < REG_N; ++i)
prefetch_reg_b[i] =
access_view_bounds_check<view_value_type>(
svB, vlane_id + k_tile_offset,
n_offset + i * STRIDE_N, __ei.__bounds_check_tag);

#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif // KOKKOS_ENABLE_PRAGMA_UNROLL
for (int i = 0; i < REG_M; ++i)
prefetch_reg_a[i] =
access_view_bounds_check<view_value_type>(
svA, m_offset + i * STRIDE_M,
thread_id + k_tile_offset,
__ei.__bounds_check_tag);

__rshmem_and_mul_ll(thread_id, vlane_id, TILE_K, reg_a,
reg_b, reg_c, svA_scr, svB_scr);

// Wait for:
// 1. prefetch_regs to be populated
// 2. for shmem to no longer be read from
member.team_barrier();

// populate shmem from prefetch registers. Each thread has its own
// copy of prefetch_reg_b.
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif // KOKKOS_ENABLE_PRAGMA_UNROLL
for (int i = 0; i < REG_N; ++i)
svB_scr(thread_id + i * STRIDE_N, vlane_id) =
prefetch_reg_b[i];

// populate shmem from prefetch registers. Each thread has its own
// copy of prefetch_reg_a.
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif // KOKKOS_ENABLE_PRAGMA_UNROLL
for (int i = 0; i < REG_M; ++i)
svA_scr(thread_id, vlane_id + i * STRIDE_M) =
prefetch_reg_a[i];

// Wait for shmem stores to land before performing next
// TILE_K multiply
member.team_barrier();
} // end n_tile_k_tiles loop

// Multiply last tile, may be a partial tile
__rshmem_and_mul_ll(thread_id, vlane_id, __k - kk, reg_a,
reg_b, reg_c, svA_scr, svB_scr);

// store results back to global memory
if (__beta == 0.0F) {
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif // KOKKOS_ENABLE_PRAGMA_UNROLL
for (int n = 0; n < REG_N; ++n) {
int cn = n_offset + n * STRIDE_N;

#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif // KOKKOS_ENABLE_PRAGMA_UNROLL
for (int m = 0; m < REG_M; ++m) {
int cm = m_offset + m * STRIDE_M;
fma_bounds_check(svC, cm, cn, reg_c[m][n], __alpha,
__ei.__alpha_fma_tag,
__ei.__bounds_check_tag);
}
}
} else {
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif // KOKKOS_ENABLE_PRAGMA_UNROLL
for (int n = 0; n < REG_N; ++n) {
int cn = n_offset + n * STRIDE_N;

#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif // KOKKOS_ENABLE_PRAGMA_UNROLL
for (int m = 0; m < REG_M; ++m) {
int cm = m_offset + m * STRIDE_M;
fma_bounds_check(svC, cm, cn, reg_c[m][n], __alpha,
__beta, __ei.__alpha_fma_tag,
__ei.__bounds_check_tag);
Expand Down

0 comments on commit 8f79037

Please sign in to comment.