diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp index 90b7997a37..6a187da6f4 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp @@ -31,6 +31,7 @@ #include #include +#include "utils/math_utils.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" #include "utils/type_utils.hpp" @@ -61,7 +62,8 @@ template struct LogAddExpFunctor resT operator()(const argT1 &in1, const argT2 &in2) const { - return impl(in1, in2); + using dpctl::tensor::math_utils::logaddexp; + return logaddexp(in1, in2); } template @@ -79,7 +81,8 @@ template struct LogAddExpFunctor impl_finite(-std::abs(diff[i])); } else { - res[i] = impl(in1[i], in2[i]); + using dpctl::tensor::math_utils::logaddexp; + res[i] = logaddexp(in1[i], in2[i]); } } @@ -87,26 +90,6 @@ template struct LogAddExpFunctor } private: - template T impl(T const &in1, T const &in2) const - { - if (in1 == in2) { // handle signed infinities - const T log2 = std::log(T(2)); - return in1 + log2; - } - else { - const T tmp = in1 - in2; - if (tmp > 0) { - return in1 + std::log1p(std::exp(-tmp)); - } - else if (tmp <= 0) { - return in2 + std::log1p(std::exp(tmp)); - } - else { - return std::numeric_limits::quiet_NaN(); - } - } - } - template T impl_finite(T const &in) const { return (in > 0) ? (in + std::log1p(std::exp(-in))) diff --git a/dpctl/tensor/libtensor/include/utils/math_utils.hpp b/dpctl/tensor/libtensor/include/utils/math_utils.hpp index d724e03e35..120a14d536 100644 --- a/dpctl/tensor/libtensor/include/utils/math_utils.hpp +++ b/dpctl/tensor/libtensor/include/utils/math_utils.hpp @@ -115,6 +115,26 @@ template T min_complex(const T &x1, const T &x2) return (std::isnan(real1) || isnan_imag1 || lt) ? x1 : x2; } +template T logaddexp(T x, T y) +{ + if (x == y) { // handle signed infinities + const T log2 = std::log(T(2)); + return x + log2; + } + else { + const T tmp = x - y; + if (tmp > 0) { + return x + std::log1p(std::exp(-tmp)); + } + else if (tmp <= 0) { + return y + std::log1p(std::exp(tmp)); + } + else { + return std::numeric_limits::quiet_NaN(); + } + } +} + } // namespace math_utils } // namespace tensor } // namespace dpctl diff --git a/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp b/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp index 6e8a68a8b5..c0165b0ecc 100644 --- a/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp +++ b/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp @@ -292,22 +292,8 @@ template struct LogSumExp { T operator()(const T &x, const T &y) const { - if (x == y) { - const T log2 = std::log(T(2)); - return x + log2; - } - else { - const T tmp = x - y; - if (tmp > 0) { - return x + std::log1p(std::exp(-tmp)); - } - else if (tmp <= 0) { - return y + std::log1p(std::exp(tmp)); - } - else { - return std::numeric_limits::quiet_NaN(); - } - } + using dpctl::tensor::math_utils::logaddexp; + return logaddexp(x, y); } };