Skip to content

Commit

Permalink
Move casting scratch_ptr
Browse files Browse the repository at this point in the history
  • Loading branch information
masterleinad committed Dec 9, 2021
1 parent daf8ff9 commit 0d64ed3
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions core/src/SYCL/Kokkos_SYCL_Parallel_Team.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ class ParallelFor<FunctorType, Kokkos::TeamPolicy<Properties...>,
size_type const m_vector_size;
int m_shmem_begin;
int m_shmem_size;
void* m_scratch_ptr[2];
char* m_scratch_ptr[2];
int m_scratch_size[2];
// Only let one ParallelFor/Reduce modify the team scratch memory. The
// constructor acquires the mutex which is released in the destructor.
Expand All @@ -413,7 +413,7 @@ class ParallelFor<FunctorType, Kokkos::TeamPolicy<Properties...>,
// Avoid capturing *this since it might not be trivially copyable
const auto shmem_begin = m_shmem_begin;
const int scratch_size[2] = {m_scratch_size[0], m_scratch_size[1]};
void* const scratch_ptr[2] = {m_scratch_ptr[0], m_scratch_ptr[1]};
char* const scratch_ptr[2] = {m_scratch_ptr[0], m_scratch_ptr[1]};

// FIXME_SYCL accessors seem to need a size greater than zero at least for
// host queues
Expand All @@ -424,11 +424,10 @@ class ParallelFor<FunctorType, Kokkos::TeamPolicy<Properties...>,
cgh);

auto lambda = [=](sycl::nd_item<2> item) {
const member_type team_member(team_scratch_memory_L0.get_pointer(),
shmem_begin, scratch_size[0],
static_cast<char*>(scratch_ptr[1]) +
item.get_group(1) * scratch_size[1],
scratch_size[1], item);
const member_type team_member(
team_scratch_memory_L0.get_pointer(), shmem_begin, scratch_size[0],
scratch_ptr[1] + item.get_group(1) * scratch_size[1],
scratch_size[1], item);
if constexpr (std::is_same<work_tag, void>::value)
functor(team_member);
else
Expand Down Expand Up @@ -512,8 +511,8 @@ class ParallelFor<FunctorType, Kokkos::TeamPolicy<Properties...>,
// upon team size.
auto& space = *m_policy.space().impl_internal_space_instance();
m_scratch_ptr[0] = nullptr;
m_scratch_ptr[1] = space.resize_team_scratch_space(
static_cast<ptrdiff_t>(m_scratch_size[1]) * m_league_size);
m_scratch_ptr[1] = static_cast<char*>(space.resize_team_scratch_space(
static_cast<ptrdiff_t>(m_scratch_size[1]) * m_league_size));

if (static_cast<int>(space.m_maxShmemPerBlock) <
m_shmem_size - m_shmem_begin) {
Expand Down

0 comments on commit 0d64ed3

Please sign in to comment.