Skip to content

Commit

Permalink
Functors for masked extract/place changed to store typed pointers
Browse files Browse the repository at this point in the history
Also implement get_lws to choose local-work-group-size from
given choices I0 > I1 > I2 > ..., if n > I0, use I0, if n > I1
use I1, and so on.
  • Loading branch information
oleksandr-pavlyk committed Dec 7, 2024
1 parent ff93cfc commit f9abe3e
Showing 1 changed file with 83 additions and 48 deletions.
131 changes: 83 additions & 48 deletions dpctl/tensor/libtensor/include/kernels/boolean_advanced_indexing.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,15 @@ template <typename OrthogIndexerT,
typename LocalAccessorT>
struct MaskedExtractStridedFunctor
{
MaskedExtractStridedFunctor(const char *src_data_p,
const char *cumsum_data_p,
char *dst_data_p,
MaskedExtractStridedFunctor(const dataT *src_data_p,
const indT *cumsum_data_p,
dataT *dst_data_p,
size_t masked_iter_size,
const OrthogIndexerT &orthog_src_dst_indexer_,
const MaskedSrcIndexerT &masked_src_indexer_,
const MaskedDstIndexerT &masked_dst_indexer_,
const LocalAccessorT &lacc_)
: src_cp(src_data_p), cumsum_cp(cumsum_data_p), dst_cp(dst_data_p),
: src(src_data_p), cumsum(cumsum_data_p), dst(dst_data_p),
masked_nelems(masked_iter_size),
orthog_src_dst_indexer(orthog_src_dst_indexer_),
masked_src_indexer(masked_src_indexer_),
Expand All @@ -72,24 +72,20 @@ struct MaskedExtractStridedFunctor

void operator()(sycl::nd_item<2> ndit) const
{
const dataT *src_data = reinterpret_cast<const dataT *>(src_cp);
dataT *dst_data = reinterpret_cast<dataT *>(dst_cp);
const indT *cumsum_data = reinterpret_cast<const indT *>(cumsum_cp);

const size_t orthog_i = ndit.get_global_id(0);
const size_t group_i = ndit.get_group(1);
const std::size_t orthog_i = ndit.get_global_id(0);
const std::uint32_t l_i = ndit.get_local_id(1);
const std::uint32_t lws = ndit.get_local_range(1);

const size_t masked_block_start = group_i * lws;
const size_t masked_i = masked_block_start + l_i;
const std::size_t masked_i = ndit.get_global_id(1);
const std::size_t masked_block_start = masked_i - l_i;

const std::size_t max_offset = masked_nelems + 1;
for (std::uint32_t i = l_i; i < lacc.size(); i += lws) {
const size_t offset = masked_block_start + i;
lacc[i] = (offset == 0) ? indT(0)
: (offset - 1 < masked_nelems)
? cumsum_data[offset - 1]
: cumsum_data[masked_nelems - 1] + 1;
: (offset < max_offset)
? cumsum[offset - 1]
: cumsum[masked_nelems - 1] + 1;
}

sycl::group_barrier(ndit.get_group());
Expand All @@ -110,14 +106,14 @@ struct MaskedExtractStridedFunctor
masked_dst_indexer(current_running_count - 1) +
orthog_offsets.get_second_offset();

dst_data[total_dst_offset] = src_data[total_src_offset];
dst[total_dst_offset] = src[total_src_offset];
}
}

private:
const char *src_cp = nullptr;
const char *cumsum_cp = nullptr;
char *dst_cp = nullptr;
const dataT *src = nullptr;
const indT *cumsum = nullptr;
dataT *dst = nullptr;
const size_t masked_nelems = 0;
// has nd, shape, src_strides, dst_strides for
// dimensions that ARE NOT masked
Expand All @@ -138,15 +134,15 @@ template <typename OrthogIndexerT,
typename LocalAccessorT>
struct MaskedPlaceStridedFunctor
{
MaskedPlaceStridedFunctor(char *dst_data_p,
const char *cumsum_data_p,
const char *rhs_data_p,
MaskedPlaceStridedFunctor(dataT *dst_data_p,
const indT *cumsum_data_p,
const dataT *rhs_data_p,
size_t masked_iter_size,
const OrthogIndexerT &orthog_dst_rhs_indexer_,
const MaskedDstIndexerT &masked_dst_indexer_,
const MaskedRhsIndexerT &masked_rhs_indexer_,
const LocalAccessorT &lacc_)
: dst_cp(dst_data_p), cumsum_cp(cumsum_data_p), rhs_cp(rhs_data_p),
: dst(dst_data_p), cumsum(cumsum_data_p), rhs(rhs_data_p),
masked_nelems(masked_iter_size),
orthog_dst_rhs_indexer(orthog_dst_rhs_indexer_),
masked_dst_indexer(masked_dst_indexer_),
Expand All @@ -158,24 +154,20 @@ struct MaskedPlaceStridedFunctor

void operator()(sycl::nd_item<2> ndit) const
{
dataT *dst_data = reinterpret_cast<dataT *>(dst_cp);
const indT *cumsum_data = reinterpret_cast<const indT *>(cumsum_cp);
const dataT *rhs_data = reinterpret_cast<const dataT *>(rhs_cp);

const std::size_t orthog_i = ndit.get_global_id(0);
const std::size_t group_i = ndit.get_group(1);
const std::uint32_t l_i = ndit.get_local_id(1);
const std::uint32_t lws = ndit.get_local_range(1);

const size_t masked_block_start = group_i * lws;
const size_t masked_i = masked_block_start + l_i;
const size_t masked_i = ndit.get_global_id(1);
const size_t masked_block_start = masked_i - l_i;

const std::size_t max_offset = masked_nelems + 1;
for (std::uint32_t i = l_i; i < lacc.size(); i += lws) {
const size_t offset = masked_block_start + i;
lacc[i] = (offset == 0) ? indT(0)
: (offset - 1 < masked_nelems)
? cumsum_data[offset - 1]
: cumsum_data[masked_nelems - 1] + 1;
: (offset < max_offset)
? cumsum[offset - 1]
: cumsum[masked_nelems - 1] + 1;
}

sycl::group_barrier(ndit.get_group());
Expand All @@ -196,14 +188,14 @@ struct MaskedPlaceStridedFunctor
masked_rhs_indexer(current_running_count - 1) +
orthog_offsets.get_second_offset();

dst_data[total_dst_offset] = rhs_data[total_rhs_offset];
dst[total_dst_offset] = rhs[total_rhs_offset];
}
}

private:
char *dst_cp = nullptr;
const char *cumsum_cp = nullptr;
const char *rhs_cp = nullptr;
dataT *dst = nullptr;
const indT *cumsum = nullptr;
const dataT *rhs = nullptr;
const size_t masked_nelems = 0;
// has nd, shape, dst_strides, rhs_strides for
// dimensions that ARE NOT masked
Expand All @@ -218,6 +210,26 @@ struct MaskedPlaceStridedFunctor

// ======= Masked extraction ================================

namespace {

template <std::size_t I, std::size_t... IR>
std::size_t _get_lws_impl(std::size_t n) {
if constexpr (sizeof...(IR) == 0) {
return I;
} else {
return (n < I) ? _get_lws_impl<IR...>(n) : I;
}
}

std::size_t get_lws(std::size_t n) {
constexpr std::size_t lws0 = 256u;
constexpr std::size_t lws1 = 128u;
constexpr std::size_t lws2 = 64u;
return _get_lws_impl<lws0, lws1, lws2>(n);
}

} // end of anonymous namespace

template <typename MaskedDstIndexerT, typename dataT, typename indT>
class masked_extract_all_slices_contig_impl_krn;

Expand Down Expand Up @@ -258,16 +270,21 @@ sycl::event masked_extract_all_slices_contig_impl(
Strided1DIndexer, dataT, indT,
LocalAccessorT>;

constexpr std::size_t nominal_lws = 256;
const std::size_t masked_extent = iteration_size;
const std::size_t lws = std::min(masked_extent, nominal_lws);

const std::size_t lws = get_lws(masked_extent);

const std::size_t n_groups = (iteration_size + lws - 1) / lws;

sycl::range<2> gRange{1, n_groups * lws};
sycl::range<2> lRange{1, lws};

sycl::nd_range<2> ndRange(gRange, lRange);

const dataT *src_tp = reinterpret_cast<const dataT *>(src_p);
const indT *cumsum_tp = reinterpret_cast<const indT *>(cumsum_p);
dataT *dst_tp = reinterpret_cast<dataT *>(dst_p);

sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);

Expand All @@ -276,7 +293,7 @@ sycl::event masked_extract_all_slices_contig_impl(

cgh.parallel_for<KernelName>(
ndRange,
Impl(src_p, cumsum_p, dst_p, masked_extent, orthog_src_dst_indexer,
Impl(src_tp, cumsum_tp, dst_tp, masked_extent, orthog_src_dst_indexer,
masked_src_indexer, masked_dst_indexer, lacc));
});

Expand Down Expand Up @@ -332,16 +349,21 @@ sycl::event masked_extract_all_slices_strided_impl(
StridedIndexer, Strided1DIndexer,
dataT, indT, LocalAccessorT>;

constexpr std::size_t nominal_lws = 256;
const std::size_t masked_nelems = iteration_size;
const std::size_t lws = std::min(masked_nelems, nominal_lws);

const std::size_t lws = get_lws(masked_nelems);

const std::size_t n_groups = (masked_nelems + lws - 1) / lws;

sycl::range<2> gRange{1, n_groups * lws};
sycl::range<2> lRange{1, lws};

sycl::nd_range<2> ndRange(gRange, lRange);

const dataT *src_tp = reinterpret_cast<const dataT *>(src_p);
const indT *cumsum_tp = reinterpret_cast<const indT *>(cumsum_p);
dataT *dst_tp = reinterpret_cast<dataT *>(dst_p);

sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);

Expand All @@ -350,7 +372,7 @@ sycl::event masked_extract_all_slices_strided_impl(

cgh.parallel_for<KernelName>(
ndRange,
Impl(src_p, cumsum_p, dst_p, iteration_size, orthog_src_dst_indexer,
Impl(src_tp, cumsum_tp, dst_tp, iteration_size, orthog_src_dst_indexer,
masked_src_indexer, masked_dst_indexer, lacc));
});

Expand Down Expand Up @@ -422,9 +444,10 @@ sycl::event masked_extract_some_slices_strided_impl(
StridedIndexer, Strided1DIndexer,
dataT, indT, LocalAccessorT>;

const size_t nominal_lws = 256;
const std::size_t masked_extent = masked_nelems;
const size_t lws = std::min(masked_extent, nominal_lws);

const std::size_t lws = get_lws(masked_extent);

const size_t n_groups = ((masked_extent + lws - 1) / lws);
const size_t orthog_extent = static_cast<size_t>(orthog_nelems);

Expand All @@ -433,6 +456,10 @@ sycl::event masked_extract_some_slices_strided_impl(

sycl::nd_range<2> ndRange(gRange, lRange);

const dataT *src_tp = reinterpret_cast<const dataT *>(src_p);
const indT *cumsum_tp = reinterpret_cast<const indT *>(cumsum_p);
dataT *dst_tp = reinterpret_cast<dataT *>(dst_p);

sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);

Expand All @@ -442,7 +469,7 @@ sycl::event masked_extract_some_slices_strided_impl(

cgh.parallel_for<KernelName>(
ndRange,
Impl(src_p, cumsum_p, dst_p, masked_nelems, orthog_src_dst_indexer,
Impl(src_tp, cumsum_tp, dst_tp, masked_nelems, orthog_src_dst_indexer,
masked_src_indexer, masked_dst_indexer, lacc));
});

Expand Down Expand Up @@ -567,6 +594,10 @@ sycl::event masked_place_all_slices_strided_impl(

using LocalAccessorT = sycl::local_accessor<indT, 1>;

dataT *dst_tp = reinterpret_cast<dataT *>(dst_p);
const dataT *rhs_tp = reinterpret_cast<const dataT *>(rhs_p);
const indT *cumsum_tp = reinterpret_cast<const indT *>(cumsum_p);

sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);

Expand All @@ -578,7 +609,7 @@ sycl::event masked_place_all_slices_strided_impl(
MaskedPlaceStridedFunctor<TwoZeroOffsets_Indexer, StridedIndexer,
Strided1DCyclicIndexer, dataT, indT,
LocalAccessorT>(
dst_p, cumsum_p, rhs_p, iteration_size, orthog_dst_rhs_indexer,
dst_tp, cumsum_tp, rhs_tp, iteration_size, orthog_dst_rhs_indexer,
masked_dst_indexer, masked_rhs_indexer, lacc));
});

Expand Down Expand Up @@ -659,6 +690,10 @@ sycl::event masked_place_some_slices_strided_impl(

using LocalAccessorT = sycl::local_accessor<indT, 1>;

dataT *dst_tp = reinterpret_cast<dataT *>(dst_p);
const dataT *rhs_tp = reinterpret_cast<const dataT *>(rhs_p);
const indT* cumsum_tp = reinterpret_cast<const indT *>(cumsum_p);

sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);

Expand All @@ -670,7 +705,7 @@ sycl::event masked_place_some_slices_strided_impl(
MaskedPlaceStridedFunctor<TwoOffsets_StridedIndexer, StridedIndexer,
Strided1DCyclicIndexer, dataT, indT,
LocalAccessorT>(
dst_p, cumsum_p, rhs_p, masked_nelems, orthog_dst_rhs_indexer,
dst_tp, cumsum_tp, rhs_tp, masked_nelems, orthog_dst_rhs_indexer,
masked_dst_indexer, masked_rhs_indexer, lacc));
});

Expand Down

0 comments on commit f9abe3e

Please sign in to comment.