Skip to content

Commit

Permalink
Add threaded implementation
Browse files Browse the repository at this point in the history
[ci skip]
  • Loading branch information
mborland committed Apr 19, 2022
1 parent 3b99162 commit 66217f2
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 21 deletions.
17 changes: 0 additions & 17 deletions include/boost/math/statistics/bivariate_statistics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,23 +323,6 @@ ReturnType correlation_coefficient_parallel_impl(ForwardIterator u_begin, Forwar

#endif // BOOST_MATH_EXEC_COMPATIBLE

template<typename ReturnType, typename ForwardIterator>
ReturnType chatterjee_correlation(ForwardIterator u_begin, ForwardIterator u_end, ForwardIterator v_begin, ForwardIterator v_end)
{
BOOST_MATH_ASSERT_MSG(std::is_sorted(u_begin, u_end), "Data set must be sorted in order to calculate the chatterjee correlation.");

const auto rank_vector = rank(v_begin, v_end);

std::size_t sum = 0;
for (std::size_t i = 1; i < rank_vector.size(); ++i)
{
// avoids unsigned underflow even though the result will always be >= 0
sum += rank_vector[i] > rank_vector[i-1] ? rank_vector[i] - rank_vector[i-1] : rank_vector[i-1] - rank_vector[i];
}

return static_cast<ReturnType>(1) - static_cast<ReturnType>(3 * sum) / static_cast<ReturnType>(rank_vector.size() - 1);
}

} // namespace detail

#ifdef BOOST_MATH_EXEC_COMPATIBLE
Expand Down
77 changes: 75 additions & 2 deletions include/boost/math/statistics/chatterjee_correlation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,18 @@
#include <iterator>
#include <vector>
#include <limits>
#include <utility>
#include <type_traits>
#include <boost/math/tools/assert.hpp>
#include <boost/math/tools/config.hpp>
#include <boost/math/statistics/detail/rank.hpp>

#ifdef BOOST_MATH_EXEC_COMPATIBLE
#include <execution>
#include <future>
#include <thread>
#endif

namespace boost { namespace math { namespace statistics {

namespace detail {
Expand All @@ -26,7 +33,7 @@ ReturnType chatterjee_correlation_seq_impl(ForwardIterator u_begin, ForwardItera
{
using std::abs;

BOOST_MATH_ASSERT_MSG(std::is_sorted(u_begin, u_end), "The x values must be sorted in order to use this funtionality");
BOOST_MATH_ASSERT_MSG(std::is_sorted(u_begin, u_end), "The x values must be sorted in order to use this functionality");

const std::vector<std::size_t> rank_vector = rank(v_begin, v_end);

Expand All @@ -46,7 +53,7 @@ ReturnType chatterjee_correlation_seq_impl(ForwardIterator u_begin, ForwardItera

ReturnType result = static_cast<ReturnType>(1) - (static_cast<ReturnType>(3 * sum) / static_cast<ReturnType>(rank_vector.size() * rank_vector.size() - 1));

// If the result is 1 then Y is constant and all of the elements must be ties
// If the result is 1 then Y is constant and all the elements must be ties
if (abs(result - static_cast<ReturnType>(1)) < std::numeric_limits<ReturnType>::epsilon())
{
return std::numeric_limits<ReturnType>::quiet_NaN();
Expand All @@ -66,4 +73,70 @@ inline ReturnType chatterjee_correlation(const Container& u, const Container& v)

}}} // Namespace boost::math::statistics

#ifdef BOOST_MATH_EXEC_COMPATIBLE

namespace boost::math::statistics {

namespace detail {

template <typename Real>
struct rank_compare_value
{
Real operator()(Real val_i, Real val_im1)
{
if (val_i > val_im1)
{
return val_i - val_im1;
}
else
{
return val_im1 - val_i;
}
}
};

template <typename ReturnType, typename ExecutionPolicy, typename ForwardIterator>
ReturnType chatterjee_correlation_par_impl(ExecutionPolicy&& exec, ForwardIterator u_begin, ForwardIterator u_end,
ForwardIterator v_begin, ForwardIterator v_end)
{
using std::abs;
BOOST_MATH_ASSERT_MSG(std::is_sorted(std::forward<ExecutionPolicy>(exec), u_begin, u_end), "The x values must be sorted in order to use this functionality");

const auto rank_vector = rank(std::forward<ExecutionPolicy>(exec), v_begin, v_end);
std::size_t sum = std::reduce(std::forward<ExecutionPolicy>(exec), rank_vector.cbegin() + 1, rank_vector.cend(), rank_compare_value<ReturnType>());

ReturnType result = static_cast<ReturnType>(1) - (static_cast<ReturnType>(3 * sum) / static_cast<ReturnType>(rank_vector.size() * rank_vector.size() - 1));

// If the result is 1 then Y is constant and all the elements must be ties
if (abs(result - static_cast<ReturnType>(1)) < std::numeric_limits<ReturnType>::epsilon())
{
return std::numeric_limits<ReturnType>::quiet_NaN();
}

return result;
}

} // Namespace detail

template <typename ExecutionPolicy, typename Container, typename Real = typename Container::value_type,
typename ReturnType = std::conditional_t<std::is_integral_v<Real>, double, Real>>
inline ReturnType chatterjee_correlation(ExecutionPolicy&& exec, const Container& u, const Container& v)
{
if constexpr (std::is_same_v<std::remove_reference_t<decltype(exec)>, decltype(std::execution::seq)>)
{
return detail::chatterjee_correlation_seq_impl<ReturnType>(std::cbegin(u), std::cend(u),
std::cbegin(v), std::cend(v));
}
else
{
return detail::chatterjee_correlation_par_impl<ReturnType>(std::forward<ExecutionPolicy>(exec),
std::cbegin(u), std::cend(u),
std::cbegin(v), std::cend(v));
}
}

} // Namespace boost::math::statistics

#endif

#endif // BOOST_MATH_STATISTICS_CHATTERJEE_CORRELATION_HPP
6 changes: 4 additions & 2 deletions include/boost/math/statistics/detail/rank.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
#include <algorithm>
#include <boost/math/tools/config.hpp>

#ifdef BOOST_MATH_EXEC_COMPATIBLE
#include <execution>
#endif

namespace boost { namespace math { namespace statistics { namespace detail {

struct pair_equal
Expand Down Expand Up @@ -75,8 +79,6 @@ inline auto rank(const Container& c) -> std::vector<std::size_t>

#else

#include <execution>

namespace boost::math::statistics::detail {

template <typename ExecutionPolicy, typename ForwardIterator, typename T = typename std::iterator_traits<ForwardIterator>::value_type>
Expand Down

0 comments on commit 66217f2

Please sign in to comment.