Skip to content

Commit

Permalink
Merge pull request #1895 from IntelPython/contribution-to-1894
Browse files Browse the repository at this point in the history
Contribution to 1894
  • Loading branch information
oleksandr-pavlyk authored Nov 15, 2024
2 parents 9e40df8 + 0493fdd commit 6dc298c
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class TakeFunctor
ssize_t src_offset = orthog_offsets.get_first_offset();
ssize_t dst_offset = orthog_offsets.get_second_offset();

const ProjectorT proj{};
constexpr ProjectorT proj{};
for (int axis_idx = 0; axis_idx < k_; ++axis_idx) {
indT *ind_data = reinterpret_cast<indT *>(ind_[axis_idx]);

Expand Down Expand Up @@ -239,7 +239,7 @@ class PutFunctor
ssize_t dst_offset = orthog_offsets.get_first_offset();
ssize_t val_offset = orthog_offsets.get_second_offset();

const ProjectorT proj{};
constexpr ProjectorT proj{};
for (int axis_idx = 0; axis_idx < k_; ++axis_idx) {
indT *ind_data = reinterpret_cast<indT *>(ind_[axis_idx]);

Expand Down
81 changes: 39 additions & 42 deletions dpctl/tensor/libtensor/include/utils/indexing_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,41 +49,40 @@ template <typename IndT> struct WrapIndex
ssize_t operator()(ssize_t max_item, IndT ind) const
{
ssize_t projected;
max_item = sycl::max<ssize_t>(max_item, 1);
constexpr ssize_t unit(1);
max_item = sycl::max(max_item, unit);

constexpr std::uintmax_t ind_max = std::numeric_limits<IndT>::max();
constexpr std::uintmax_t ssize_max =
std::numeric_limits<ssize_t>::max();

if constexpr (std::is_signed_v<IndT>) {
static constexpr std::uintmax_t ind_max =
std::numeric_limits<IndT>::max();
static constexpr std::uintmax_t ssize_max =
std::numeric_limits<ssize_t>::max();
static constexpr std::intmax_t ind_min =
std::numeric_limits<IndT>::min();
static constexpr std::intmax_t ssize_min =
constexpr std::intmax_t ind_min = std::numeric_limits<IndT>::min();
constexpr std::intmax_t ssize_min =
std::numeric_limits<ssize_t>::min();

if constexpr (ind_max <= ssize_max && ind_min >= ssize_min) {
projected = sycl::clamp<ssize_t>(static_cast<ssize_t>(ind),
-max_item, max_item - 1);
const ssize_t ind_ = static_cast<ssize_t>(ind);
const ssize_t lb = -max_item;
const ssize_t ub = max_item - 1;
projected = sycl::clamp(ind_, lb, ub);
}
else {
projected = sycl::clamp<IndT>(ind, static_cast<IndT>(-max_item),
static_cast<IndT>(max_item - 1));
const IndT lb = static_cast<IndT>(-max_item);
const IndT ub = static_cast<IndT>(max_item - 1);
projected = static_cast<ssize_t>(sycl::clamp(ind, lb, ub));
}
return (projected < 0) ? projected + max_item : projected;
}
else {
static constexpr std::uintmax_t ind_max =
std::numeric_limits<IndT>::max();
static constexpr std::uintmax_t ssize_max =
std::numeric_limits<ssize_t>::max();

if constexpr (ind_max <= ssize_max) {
projected =
sycl::min<ssize_t>(static_cast<ssize_t>(ind), max_item - 1);
const ssize_t ind_ = static_cast<ssize_t>(ind);
const ssize_t ub = max_item - 1;
projected = sycl::min(ind_, ub);
}
else {
projected =
sycl::min<IndT>(ind, static_cast<IndT>(max_item - 1));
const IndT ub = static_cast<IndT>(max_item - 1);
projected = static_cast<ssize_t>(sycl::min(ind, ub));
}
return projected;
}
Expand All @@ -95,40 +94,38 @@ template <typename IndT> struct ClipIndex
ssize_t operator()(ssize_t max_item, IndT ind) const
{
ssize_t projected;
max_item = sycl::max<ssize_t>(max_item, 1);
constexpr ssize_t unit(1);
max_item = sycl::max<ssize_t>(max_item, unit);

constexpr std::uintmax_t ind_max = std::numeric_limits<IndT>::max();
constexpr std::uintmax_t ssize_max =
std::numeric_limits<ssize_t>::max();
if constexpr (std::is_signed_v<IndT>) {
static constexpr std::uintmax_t ind_max =
std::numeric_limits<IndT>::max();
static constexpr std::uintmax_t ssize_max =
std::numeric_limits<ssize_t>::max();
static constexpr std::intmax_t ind_min =
std::numeric_limits<IndT>::min();
static constexpr std::intmax_t ssize_min =
constexpr std::intmax_t ind_min = std::numeric_limits<IndT>::min();
constexpr std::intmax_t ssize_min =
std::numeric_limits<ssize_t>::min();

if constexpr (ind_max <= ssize_max && ind_min >= ssize_min) {
projected = sycl::clamp<ssize_t>(static_cast<ssize_t>(ind),
ssize_t(0), max_item - 1);
const ssize_t ind_ = static_cast<ssize_t>(ind);
constexpr ssize_t lb(0);
const ssize_t ub = max_item - 1;
projected = sycl::clamp(ind_, lb, ub);
}
else {
projected = sycl::clamp<IndT>(ind, IndT(0),
static_cast<IndT>(max_item - 1));
constexpr IndT lb(0);
const IndT ub = static_cast<IndT>(max_item - 1);
projected = static_cast<size_t>(sycl::clamp(ind, lb, ub));
}
}
else {
static constexpr std::uintmax_t ind_max =
std::numeric_limits<IndT>::max();
static constexpr std::uintmax_t ssize_max =
std::numeric_limits<ssize_t>::max();

if constexpr (ind_max <= ssize_max) {
projected =
sycl::min<ssize_t>(static_cast<ssize_t>(ind), max_item - 1);
const ssize_t ind_ = static_cast<ssize_t>(ind);
const ssize_t ub = max_item - 1;
projected = sycl::min(ind_, ub);
}
else {
projected =
sycl::min<IndT>(ind, static_cast<IndT>(max_item - 1));
const IndT ub = static_cast<IndT>(max_item - 1);
projected = static_cast<ssize_t>(sycl::min(ind, ub));
}
}
return projected;
Expand Down

0 comments on commit 6dc298c

Please sign in to comment.