Skip to content

Commit

Permalink
logaddexp implementation moved to math_utils
Browse files Browse the repository at this point in the history
Reduces code repetition between logsumexp and logaddexp
  • Loading branch information
ndgrigorian committed Oct 25, 2023
1 parent 448a7f1 commit d88e78f
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <limits>
#include <type_traits>

#include "utils/math_utils.hpp"
#include "utils/offset_utils.hpp"
#include "utils/type_dispatch.hpp"
#include "utils/type_utils.hpp"
Expand Down Expand Up @@ -61,7 +62,8 @@ template <typename argT1, typename argT2, typename resT> struct LogAddExpFunctor

resT operator()(const argT1 &in1, const argT2 &in2) const
{
return impl<resT>(in1, in2);
using dpctl::tensor::math_utils::logaddexp;
return logaddexp<resT>(in1, in2);
}

template <int vec_sz>
Expand All @@ -79,34 +81,15 @@ template <typename argT1, typename argT2, typename resT> struct LogAddExpFunctor
impl_finite<resT>(-std::abs(diff[i]));
}
else {
res[i] = impl<resT>(in1[i], in2[i]);
using dpctl::tensor::math_utils::logaddexp;
res[i] = logaddexp<resT>(in1[i], in2[i]);
}
}

return res;
}

private:
template <typename T> T impl(T const &in1, T const &in2) const
{
if (in1 == in2) { // handle signed infinities
const T log2 = std::log(T(2));
return in1 + log2;
}
else {
const T tmp = in1 - in2;
if (tmp > 0) {
return in1 + std::log1p(std::exp(-tmp));
}
else if (tmp <= 0) {
return in2 + std::log1p(std::exp(tmp));
}
else {
return std::numeric_limits<T>::quiet_NaN();
}
}
}

template <typename T> T impl_finite(T const &in) const
{
return (in > 0) ? (in + std::log1p(std::exp(-in)))
Expand Down
20 changes: 20 additions & 0 deletions dpctl/tensor/libtensor/include/utils/math_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,26 @@ template <typename T> T min_complex(const T &x1, const T &x2)
return (std::isnan(real1) || isnan_imag1 || lt) ? x1 : x2;
}

template <typename T> T logaddexp(T x, T y)
{
if (x == y) { // handle signed infinities
const T log2 = std::log(T(2));
return x + log2;
}
else {
const T tmp = x - y;
if (tmp > 0) {
return x + std::log1p(std::exp(-tmp));
}
else if (tmp <= 0) {
return y + std::log1p(std::exp(tmp));
}
else {
return std::numeric_limits<T>::quiet_NaN();
}
}
}

} // namespace math_utils
} // namespace tensor
} // namespace dpctl
18 changes: 2 additions & 16 deletions dpctl/tensor/libtensor/include/utils/sycl_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,22 +292,8 @@ template <typename T> struct LogSumExp
{
T operator()(const T &x, const T &y) const
{
if (x == y) {
const T log2 = std::log(T(2));
return x + log2;
}
else {
const T tmp = x - y;
if (tmp > 0) {
return x + std::log1p(std::exp(-tmp));
}
else if (tmp <= 0) {
return y + std::log1p(std::exp(tmp));
}
else {
return std::numeric_limits<T>::quiet_NaN();
}
}
using dpctl::tensor::math_utils::logaddexp;
return logaddexp<T>(x, y);
}
};

Expand Down

0 comments on commit d88e78f

Please sign in to comment.