Skip to content

Commit

Permalink
Use extended comparison operators to define weak order on real/comple…
Browse files Browse the repository at this point in the history
…x FP types

We use extended comparison operators compatible with NumPy's behavior:

https://numpy.org/devdocs/reference/generated/numpy.sort.html

Specifically, we use [R, nan] block ordering for reals, and
[(R, R), (R, nan), (nan, R), (nan, nan)] for complexes.
  • Loading branch information
oleksandr-pavlyk committed Jan 9, 2024
1 parent 08e5dac commit a3d0d08
Showing 1 changed file with 68 additions and 12 deletions.
80 changes: 68 additions & 12 deletions dpctl/tensor/libtensor/source/sorting/sorting_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@

#pragma once

#include "utils/math_utils.hpp"
#include "sycl/sycl.hpp"
#include <type_traits>

namespace dpctl
{
Expand All @@ -33,44 +34,99 @@ namespace tensor
namespace py_internal
{

template <typename cT> struct ComplexLess
namespace
{
bool operator()(const cT &v1, const cT &v2) const
template <typename fpT> struct ExtendedRealFPLess
{
/* [R, nan] */
bool operator()(const fpT v1, const fpT v2) const
{
using dpctl::tensor::math_utils::less_complex;
return (!sycl::isnan(v1) && (sycl::isnan(v2) || (v1 < v2)));
}
};

return less_complex(v1, v2);
template <typename fpT> struct ExtendedRealFPGreater
{
bool operator()(const fpT v1, const fpT v2) const
{
return (!sycl::isnan(v2) && (sycl::isnan(v1) || (v2 < v1)));
}
};

template <typename cT> struct ComplexGreater
template <typename cT> struct ExtendedComplexFPLess
{
/* [(R, R), (R, nan), (nan, R), (nan, nan)] */

bool operator()(const cT &v1, const cT &v2) const
{
using dpctl::tensor::math_utils::greater_complex;
using realT = typename cT::value_type;

const realT real1 = std::real(v1);
const realT real2 = std::real(v2);

const bool r1_nan = sycl::isnan(real1);
const bool r2_nan = sycl::isnan(real2);

const realT imag1 = std::imag(v1);
const realT imag2 = std::imag(v2);

const bool i1_nan = sycl::isnan(imag1);
const bool i2_nan = sycl::isnan(imag2);

return greater_complex(v1, v2);
const int idx1 = ((r1_nan) ? 2 : 0) + ((i1_nan) ? 1 : 0);
const int idx2 = ((r2_nan) ? 2 : 0) + ((i2_nan) ? 1 : 0);

const bool res =
!(r1_nan && i1_nan) &&
((idx1 < idx2) ||
((idx1 == idx2) &&
((r1_nan && !i1_nan && (imag1 < imag2)) ||
(!r1_nan && i1_nan && (real1 < real2)) ||
(!r1_nan && !i1_nan &&
((real1 < real2) || (!(real2 < real1) && (imag1 < imag2)))))));

return res;
}
};

template <typename cT> struct ExtendedComplexFPGreater
{
bool operator()(const cT &v1, const cT &v2) const
{
auto less_ = ExtendedComplexFPLess<cT>{};
return less_(v2, v1);
}
};

template <typename T>
inline constexpr bool is_fp_v = (std::is_same_v<T, sycl::half> ||
std::is_same_v<T, float> ||
std::is_same_v<T, double>);

} // end of anonymous namespace

template <typename argTy> struct AscendingSorter
{
using type = std::less<argTy>;
using type = std::conditional_t<is_fp_v<argTy>,
ExtendedRealFPLess<argTy>,
std::less<argTy>>;
};

template <typename T> struct AscendingSorter<std::complex<T>>
{
using type = ComplexLess<std::complex<T>>;
using type = ExtendedComplexFPLess<std::complex<T>>;
};

template <typename argTy> struct DescendingSorter
{
using type = std::greater<argTy>;
using type = std::conditional_t<is_fp_v<argTy>,
ExtendedRealFPGreater<argTy>,
std::greater<argTy>>;
};

template <typename T> struct DescendingSorter<std::complex<T>>
{
using type = ComplexGreater<std::complex<T>>;
using type = ExtendedComplexFPGreater<std::complex<T>>;
};

} // end of namespace py_internal
Expand Down

0 comments on commit a3d0d08

Please sign in to comment.