Skip to content

Commit

Permalink
Simplify write-out kernels in topk implementation (avoid recomputing …
Browse files Browse the repository at this point in the history
…gid)
  • Loading branch information
oleksandr-pavlyk committed Dec 22, 2024
1 parent d4f5aa4 commit ce5914f
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -417,17 +417,17 @@ sycl::event topk_merge_impl(
topk_partial_merge_map_back_krn<argTy, IndexTy, ValueComp>;

cgh.parallel_for<KernelName>(iter_nelems * k, [=](sycl::id<1> id) {
std::size_t gid = id[0];
const std::size_t gid = id[0];

std::size_t iter_gid = gid / k;
std::size_t axis_gid = gid - (iter_gid * k);
const std::size_t iter_gid = gid / k;
const std::size_t axis_gid = gid - (iter_gid * k);

std::size_t src_idx = iter_gid * alloc_len + axis_gid;
std::size_t dst_idx = iter_gid * k + axis_gid;
const std::size_t src_idx = iter_gid * alloc_len + axis_gid;
const std::size_t dst_idx = gid;

auto res_ind = index_data[src_idx];
const auto res_ind = index_data[src_idx];
vals_tp[dst_idx] = arg_tp[res_ind];
inds_tp[dst_idx] = res_ind % axis_nelems;
inds_tp[dst_idx] = (res_ind % axis_nelems);
});
});

Expand Down Expand Up @@ -529,17 +529,17 @@ sycl::event topk_radix_impl(sycl::queue &exec_q,
using KernelName = topk_radix_map_back_krn<argTy, IndexTy>;

cgh.parallel_for<KernelName>(iter_nelems * k, [=](sycl::id<1> id) {
std::size_t gid = id[0];
const std::size_t gid = id[0];

std::size_t iter_gid = gid / k;
std::size_t axis_gid = gid - (iter_gid * k);
const std::size_t iter_gid = gid / k;
const std::size_t axis_gid = gid - (iter_gid * k);

std::size_t src_idx = iter_gid * axis_nelems + axis_gid;
std::size_t dst_idx = iter_gid * k + axis_gid;
const std::size_t src_idx = iter_gid * axis_nelems + axis_gid;
const std::size_t dst_idx = gid;

IndexTy res_ind = tmp_tp[src_idx];
const IndexTy res_ind = tmp_tp[src_idx];
vals_tp[dst_idx] = arg_tp[res_ind];
inds_tp[dst_idx] = res_ind % axis_nelems;
inds_tp[dst_idx] = (res_ind % axis_nelems);
});
});

Expand Down

0 comments on commit ce5914f

Please sign in to comment.