diff --git a/amalgamation/xgboost-all0.cc b/amalgamation/xgboost-all0.cc index cf19a0d73aff..c61135353ff2 100644 --- a/amalgamation/xgboost-all0.cc +++ b/amalgamation/xgboost-all0.cc @@ -14,6 +14,7 @@ #include "../src/metric/elementwise_metric.cc" #include "../src/metric/multiclass_metric.cc" #include "../src/metric/rank_metric.cc" +#include "../src/metric/auc.cc" #include "../src/metric/survival_metric.cc" // objectives diff --git a/doc/parameter.rst b/doc/parameter.rst index 05da8de30278..ad45f31ec832 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -400,7 +400,15 @@ Specify the learning task and the corresponding learning objective. The objectiv - ``error@t``: a different than 0.5 binary classification threshold value could be specified by providing a numerical value through 't'. - ``merror``: Multiclass classification error rate. It is calculated as ``#(wrong cases)/#(all cases)``. - ``mlogloss``: `Multiclass logloss `_. - - ``auc``: `Area under the curve `_. Available for binary classification and learning-to-rank tasks. + - ``auc``: `Receiver Operating Characteristic Area under the Curve `_. + Available for classification and learning-to-rank tasks. + + - When used with binary classification, the objective should be ``binary:logistic`` or similar functions that work on probability. + - When used with multi-class classification, objective should be ``multi:softprob`` instead of ``multi:softmax``, as the latter doesn't output probability. Also the AUC is calculated by 1-vs-rest with reference class weighted by class prevalence. + - When used with LTR task, the AUC is computed by comparing pairs of documents to count correctly sorted pairs. This corresponds to pairwise learning to rank. The implementation has some issues with average AUC around groups and distributed workers not being well-defined. + - On a single machine the AUC calculation is exact. In a distributed environment the AUC is a weighted average over the AUC of training rows on each node - therefore, distributed AUC is an approximation sensitive to the distribution of data across workers. Use another metric in distributed environments if precision and reproducibility are important. + - If input dataset contains only negative or positive samples the output is `NaN`. + - ``aucpr``: `Area under the PR curve `_. Available for binary classification and learning-to-rank tasks. - ``ndcg``: `Normalized Discounted Cumulative Gain `_ - ``map``: `Mean Average Precision `_ diff --git a/src/common/common.h b/src/common/common.h index a4397d1c89aa..1f9f23e9884a 100644 --- a/src/common/common.h +++ b/src/common/common.h @@ -8,6 +8,7 @@ #include #include +#include #include #include @@ -163,13 +164,14 @@ inline void AssertOneAPISupport() { #endif // XGBOOST_USE_ONEAPI } -template > -std::vector ArgSort(std::vector const &array, Comp comp = std::less{}) { +template > +std::vector ArgSort(Container const &array, Comp comp = std::less{}) { std::vector result(array.size()); std::iota(result.begin(), result.end(), 0); - std::stable_sort( - result.begin(), result.end(), - [&array, comp](Idx const &l, Idx const &r) { return comp(array[l], array[r]); }); + auto op = [&array, comp](Idx const &l, Idx const &r) { return comp(array[l], array[r]); }; + XGBOOST_PARALLEL_STABLE_SORT(result.begin(), result.end(), op); return result; } } // namespace common diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 1da0a3be6ae4..b1ddfdb20731 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -1198,6 +1198,62 @@ size_t SegmentedUnique(Inputs &&...inputs) { return SegmentedUnique(thrust::cuda::par(alloc), std::forward(inputs)...); } +/** + * \brief Unique by key for many groups of data. Has same constraint as `SegmentedUnique`. + * + * \tparam exec thrust execution policy + * \tparam key_segments_first start iter to segment pointer + * \tparam key_segments_last end iter to segment pointer + * \tparam key_first start iter to key for comparison + * \tparam key_last end iter to key for comparison + * \tparam val_first start iter to values + * \tparam key_segments_out output iterator for new segment pointer + * \tparam val_out output iterator for values + * \tparam comp binary comparison operator + */ +template +size_t SegmentedUniqueByKey( + const thrust::detail::execution_policy_base &exec, + SegInIt key_segments_first, SegInIt key_segments_last, KeyInIt key_first, + KeyInIt key_last, ValInIt val_first, SegOutIt key_segments_out, + ValOutIt val_out, Comp comp) { + using Key = + thrust::pair::value_type>; + + auto unique_key_it = dh::MakeTransformIterator( + thrust::make_counting_iterator(static_cast(0)), + [=] __device__(size_t i) { + size_t seg = dh::SegmentId(key_segments_first, key_segments_last, i); + return thrust::make_pair(seg, *(key_first + i)); + }); + size_t segments_len = key_segments_last - key_segments_first; + thrust::fill(thrust::device, key_segments_out, + key_segments_out + segments_len, 0); + size_t n_inputs = std::distance(key_first, key_last); + // Reduce the number of uniques elements per segment, avoid creating an + // intermediate array for `reduce_by_key`. It's limited by the types that + // atomicAdd supports. For example, size_t is not supported as of CUDA 10.2. + auto reduce_it = thrust::make_transform_output_iterator( + thrust::make_discard_iterator(), + detail::SegmentedUniqueReduceOp{key_segments_out}); + auto uniques_ret = thrust::unique_by_key_copy( + exec, unique_key_it, unique_key_it + n_inputs, val_first, reduce_it, + val_out, [=] __device__(Key const &l, Key const &r) { + if (l.first == r.first) { + // In the same segment. + return comp(thrust::get<1>(l), thrust::get<1>(r)); + } + return false; + }); + auto n_uniques = uniques_ret.second - val_out; + CHECK_LE(n_uniques, n_inputs); + thrust::exclusive_scan(exec, key_segments_out, + key_segments_out + segments_len, key_segments_out, 0); + return n_uniques; +} + template auto Reduce(Policy policy, InputIt first, InputIt second, Init init, Func reduce_op) { size_t constexpr kLimit = std::numeric_limits::max() / 2; @@ -1215,36 +1271,73 @@ auto Reduce(Policy policy, InputIt first, InputIt second, Init init, Func reduce return aggregate; } +// wrapper to avoid integer `num_items`. +template +void InclusiveScan(InputIteratorT d_in, OutputIteratorT d_out, ScanOpT scan_op, + OffsetT num_items) { + size_t bytes = 0; + safe_cuda(( + cub::DispatchScan::Dispatch(nullptr, bytes, d_in, d_out, scan_op, + cub::NullType(), num_items, nullptr, + false))); + dh::TemporaryArray storage(bytes); + safe_cuda(( + cub::DispatchScan::Dispatch(storage.data().get(), bytes, d_in, + d_out, scan_op, cub::NullType(), + num_items, nullptr, false))); +} + +template +void InclusiveSum(InputIteratorT d_in, OutputIteratorT d_out, OffsetT num_items) { + InclusiveScan(d_in, d_out, cub::Sum(), num_items); +} + template -void ArgSort(xgboost::common::Span values, xgboost::common::Span sorted_idx) { +void ArgSort(xgboost::common::Span keys, xgboost::common::Span sorted_idx) { size_t bytes = 0; Iota(sorted_idx); - CHECK_LT(sorted_idx.size(), 1 << 31); - TemporaryArray out(values.size()); + + using KeyT = typename decltype(keys)::value_type; + using ValueT = std::remove_const_t; + + TemporaryArray out(keys.size()); + cub::DoubleBuffer d_keys(const_cast(keys.data()), + out.data().get()); + cub::DoubleBuffer d_values(const_cast(sorted_idx.data()), + sorted_idx.data()); + if (accending) { - cub::DeviceRadixSort::SortPairs(nullptr, bytes, values.data(), - out.data().get(), sorted_idx.data(), - sorted_idx.data(), sorted_idx.size()); + void *d_temp_storage = nullptr; + cub::DispatchRadixSort::Dispatch( + d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, + sizeof(KeyT) * 8, false, nullptr, false); dh::TemporaryArray storage(bytes); - cub::DeviceRadixSort::SortPairs(storage.data().get(), bytes, values.data(), - out.data().get(), sorted_idx.data(), - sorted_idx.data(), sorted_idx.size()); + d_temp_storage = storage.data().get(); + cub::DispatchRadixSort::Dispatch( + d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, + sizeof(KeyT) * 8, false, nullptr, false); } else { - cub::DeviceRadixSort::SortPairsDescending( - nullptr, bytes, values.data(), out.data().get(), sorted_idx.data(), - sorted_idx.data(), sorted_idx.size()); + void *d_temp_storage = nullptr; + safe_cuda((cub::DispatchRadixSort::Dispatch( + d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, + sizeof(KeyT) * 8, false, nullptr, false))); dh::TemporaryArray storage(bytes); - cub::DeviceRadixSort::SortPairsDescending( - storage.data().get(), bytes, values.data(), out.data().get(), - sorted_idx.data(), sorted_idx.data(), sorted_idx.size()); + d_temp_storage = storage.data().get(); + safe_cuda((cub::DispatchRadixSort::Dispatch( + d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, + sizeof(KeyT) * 8, false, nullptr, false))); } } namespace detail { -// Wrapper around cub sort for easier `descending` sort -template +// Wrapper around cub sort for easier `descending` sort and `size_t num_items`. +template void DeviceSegmentedRadixSortPair( - void *d_temp_storage, size_t &temp_storage_bytes, const KeyT *d_keys_in, // NOLINT + void *d_temp_storage, size_t &temp_storage_bytes, const KeyT *d_keys_in, // NOLINT KeyT *d_keys_out, const ValueT *d_values_in, ValueT *d_values_out, size_t num_items, size_t num_segments, OffsetIteratorT d_begin_offsets, OffsetIteratorT d_end_offsets, int begin_bit = 0, @@ -1253,12 +1346,12 @@ void DeviceSegmentedRadixSortPair( cub::DoubleBuffer d_values(const_cast(d_values_in), d_values_out); using OffsetT = size_t; - dh::safe_cuda((cub::DispatchSegmentedRadixSort< - descending, KeyT, ValueT, OffsetIteratorT, - OffsetT>::Dispatch(d_temp_storage, temp_storage_bytes, d_keys, - d_values, num_items, num_segments, - d_begin_offsets, d_end_offsets, begin_bit, - end_bit, false, nullptr, false))); + safe_cuda((cub::DispatchSegmentedRadixSort< + descending, KeyT, ValueT, OffsetIteratorT, + OffsetT>::Dispatch(d_temp_storage, temp_storage_bytes, d_keys, + d_values, num_items, num_segments, + d_begin_offsets, d_end_offsets, begin_bit, + end_bit, false, nullptr, false))); } } // namespace detail @@ -1270,12 +1363,11 @@ void SegmentedArgSort(xgboost::common::Span values, size_t n_groups = group_ptr.size() - 1; size_t bytes = 0; Iota(sorted_idx); - CHECK_LT(sorted_idx.size(), 1 << 31); - TemporaryArray values_out(values.size()); + TemporaryArray> values_out(values.size()); detail::DeviceSegmentedRadixSortPair( - nullptr, bytes, values.data(), values_out.data().get(), - sorted_idx.data(), sorted_idx.data(), sorted_idx.size(), n_groups, - group_ptr.data(), group_ptr.data() + 1); + nullptr, bytes, values.data(), values_out.data().get(), sorted_idx.data(), + sorted_idx.data(), sorted_idx.size(), n_groups, group_ptr.data(), + group_ptr.data() + 1); dh::TemporaryArray temp_storage(bytes); detail::DeviceSegmentedRadixSortPair( temp_storage.data().get(), bytes, values.data(), values_out.data().get(), diff --git a/src/common/math.h b/src/common/math.h index 41905310e70d..c189babee954 100644 --- a/src/common/math.h +++ b/src/common/math.h @@ -26,6 +26,9 @@ XGBOOST_DEVICE inline float Sigmoid(float x) { return 1.0f / (1.0f + expf(-x)); } +template +XGBOOST_DEVICE inline static T Sqr(T a) { return a * a; } + /*! * \brief Equality test for both integer and floating point. */ diff --git a/src/common/random.h b/src/common/random.h index 626800597a12..d0ddf06ec830 100644 --- a/src/common/random.h +++ b/src/common/random.h @@ -99,7 +99,7 @@ std::vector WeightedSamplingWithoutReplacement( auto k = std::log(u) / w; keys[i] = k; } - auto ind = ArgSort(keys, std::greater<>{}); + auto ind = ArgSort(Span{keys}, std::greater<>{}); ind.resize(n); std::vector results(ind.size()); diff --git a/src/common/ranking_utils.cuh b/src/common/ranking_utils.cuh new file mode 100644 index 000000000000..c9b71c154919 --- /dev/null +++ b/src/common/ranking_utils.cuh @@ -0,0 +1,84 @@ +/*! + * Copyright 2021 by XGBoost Contributors + */ +#ifndef XGBOOST_COMMON_RANKING_UTILS_H_ +#define XGBOOST_COMMON_RANKING_UTILS_H_ + +#include +#include "xgboost/base.h" +#include "device_helpers.cuh" +#include "./math.h" + +namespace xgboost { +namespace common { +/** + * \param n Number of items (length of the base) + * \param h hight + */ +XGBOOST_DEVICE inline size_t DiscreteTrapezoidArea(size_t n, size_t h) { + n -= 1; // without diagonal entries + h = std::min(n, h); // Specific for ranking. + size_t total = ((n - (h - 1)) + n) * h / 2; + return total; +} + +/** + * Used for mapping many groups of trapezoid shaped computation onto CUDA blocks. The + * trapezoid must be on upper right corner. + * + * Equivalent to loops like: + * + * \code + * for (size i = 0; i < h; ++i) { + * for (size_t j = i + 1; j < n; ++j) { + * do_something(); + * } + * } + * \endcode + */ +template +inline size_t +SegmentedTrapezoidThreads(xgboost::common::Span group_ptr, + xgboost::common::Span out_group_threads_ptr, + size_t h) { + CHECK_GE(group_ptr.size(), 1); + CHECK_EQ(group_ptr.size(), out_group_threads_ptr.size()); + dh::LaunchN( + dh::CurrentDevice(), group_ptr.size(), [=] XGBOOST_DEVICE(size_t idx) { + if (idx == 0) { + out_group_threads_ptr[0] = 0; + return; + } + + size_t cnt = static_cast(group_ptr[idx] - group_ptr[idx - 1]); + out_group_threads_ptr[idx] = DiscreteTrapezoidArea(cnt, h); + }); + dh::InclusiveSum(out_group_threads_ptr.data(), out_group_threads_ptr.data(), + out_group_threads_ptr.size()); + size_t total = 0; + dh::safe_cuda(cudaMemcpy( + &total, out_group_threads_ptr.data() + out_group_threads_ptr.size() - 1, + sizeof(total), cudaMemcpyDeviceToHost)); + return total; +} + +/** + * Called inside kernel to obtain coordinate from trapezoid grid. + */ +XGBOOST_DEVICE inline void UnravelTrapeziodIdx(size_t i_idx, size_t n, + size_t *out_i, size_t *out_j) { + auto &i = *out_i; + auto &j = *out_j; + double idx = static_cast(i_idx); + double N = static_cast(n); + + i = std::ceil(-(0.5 - N + std::sqrt(common::Sqr(N - 0.5) + 2.0 * (-idx - 1.0)))) - 1.0; + + auto I = static_cast(i); + size_t n_elems = -0.5 * common::Sqr(I) + (N - 0.5) * I; + + j = idx - n_elems + i + 1; +} +} // namespace common +} // namespace xgboost +#endif // XGBOOST_COMMON_RANKING_UTILS_H_ diff --git a/src/data/data.cc b/src/data/data.cc index f99d3368e92a..f2ac7bf4a9e2 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -400,7 +400,9 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t group_ptr_.push_back(i); } } - group_ptr_.push_back(query_ids.size()); + if (group_ptr_.back() != query_ids.size()) { + group_ptr_.push_back(query_ids.size()); + } } else if (!std::strcmp(key, "label_lower_bound")) { auto& labels = labels_lower_bound_.HostVector(); labels.resize(num); diff --git a/src/metric/auc.cc b/src/metric/auc.cc new file mode 100644 index 000000000000..9184223da7d8 --- /dev/null +++ b/src/metric/auc.cc @@ -0,0 +1,340 @@ +/*! + * Copyright 2021 by XGBoost Contributors + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "rabit/rabit.h" +#include "xgboost/host_device_vector.h" +#include "xgboost/metric.h" +#include "auc.h" +#include "../common/common.h" +#include "../common/math.h" + +namespace xgboost { +namespace metric { + +namespace detail { +template +constexpr auto UnpackArr(std::array &&arr, std::index_sequence) { + return std::make_tuple(std::forward>(arr)[Idx]...); +} +} // namespace detail + +template +constexpr auto UnpackArr(std::array &&arr) { + return detail::UnpackArr(std::forward>(arr), + std::make_index_sequence{}); +} + +/** + * Calculate AUC for binary classification problem. This function does not normalize the + * AUC by 1 / (num_positive * num_negative), instead it returns a tuple for caller to + * handle the normalization. + */ +std::tuple BinaryAUC(std::vector const &predts, + std::vector const &labels, + std::vector const &weights) { + CHECK(!labels.empty()); + CHECK_EQ(labels.size(), predts.size()); + + float auc {0}; + auto const sorted_idx = common::ArgSort( + common::Span(predts), std::greater<>{}); + + auto get_weight = [&](size_t i) { + return weights.empty() ? 1.0f : weights[sorted_idx[i]]; + }; + float label = labels[sorted_idx.front()]; + float w = get_weight(0); + float fp = (1.0 - label) * w, tp = label * w; + float tp_prev = 0, fp_prev = 0; + // TODO(jiaming): We can parallize this if we have a parallel scan for CPU. + for (size_t i = 1; i < sorted_idx.size(); ++i) { + if (predts[sorted_idx[i]] != predts[sorted_idx[i-1]]) { + auc += TrapesoidArea(fp_prev, fp, tp_prev, tp); + tp_prev = tp; + fp_prev = fp; + } + label = labels[sorted_idx[i]]; + float w = get_weight(i); + fp += (1.0f - label) * w; + tp += label * w; + } + + auc += TrapesoidArea(fp_prev, fp, tp_prev, tp); + if (fp <= 0.0f || tp <= 0.0f) { + auc = 0; + fp = 0; + tp = 0; + } + + return std::make_tuple(fp, tp, auc); +} + +/** + * Calculate AUC for multi-class classification problem using 1-vs-rest approach. + * + * TODO(jiaming): Use better algorithms like: + * + * - Kleiman, Ross and Page, David. $AUC_{\mu}$: A Performance Metric for Multi-Class + * Machine Learning Models + */ +float MultiClassOVR(std::vector const& predts, MetaInfo const& info) { + auto n_classes = predts.size() / info.labels_.Size(); + CHECK_NE(n_classes, 0); + auto const& labels = info.labels_.ConstHostVector(); + + std::vector results(n_classes * 3, 0); + auto s_results = common::Span(results); + auto local_area = s_results.subspan(0, n_classes); + auto tp = s_results.subspan(n_classes, n_classes); + auto auc = s_results.subspan(2 * n_classes, n_classes); + + if (!info.labels_.Empty()) { + dmlc::OMPException omp_handler; +#pragma omp parallel for + for (omp_ulong c = 0; c < n_classes; ++c) { + omp_handler.Run([&]() { + std::vector proba(info.labels_.Size()); + std::vector response(info.labels_.Size()); + for (size_t i = 0; i < proba.size(); ++i) { + proba[i] = predts[i * n_classes + c]; + response[i] = labels[i] == c ? 1.0f : 0.0; + } + float fp; + std::tie(fp, tp[c], auc[c]) = + BinaryAUC(proba, response, info.weights_.ConstHostVector()); + local_area[c] = fp * tp[c]; + }); + } + omp_handler.Rethrow(); + } + + // we have 2 averages going in here, first is among workers, second is among classes. + // allreduce sums up fp/tp auc for each class. + rabit::Allreduce(results.data(), results.size()); + float auc_sum{0}; + float tp_sum{0}; + for (size_t c = 0; c < n_classes; ++c) { + if (local_area[c] != 0) { + // normalize and weight it by prevalence. After allreduce, `local_area` means the + // total covered area (not area under curve, rather it's the accessible are for each + // worker) for each class. + auc_sum += auc[c] / local_area[c] * tp[c]; + tp_sum += tp[c]; + } else { + auc_sum = std::numeric_limits::quiet_NaN(); + break; + } + } + if (tp_sum == 0 || std::isnan(auc_sum)) { + auc_sum = std::numeric_limits::quiet_NaN(); + } else { + auc_sum /= tp_sum; + } + return auc_sum; +} + +/** + * Calculate AUC for 1 ranking group; + */ +float GroupRankingAUC(common::Span predts, + common::Span labels, float w) { + // on ranking, we just count all pairs. + float auc{0}; + auto const sorted_idx = common::ArgSort(labels, std::greater<>{}); + w = common::Sqr(w); + + float sum_w = 0.0f; + for (size_t i = 0; i < labels.size(); ++i) { + for (size_t j = i + 1; j < labels.size(); ++j) { + auto predt = predts[sorted_idx[i]] - predts[sorted_idx[j]]; + if (predt > 0) { + predt = 1.0; + } else if (predt == 0) { + predt = 0.5; + } else { + predt = 0; + } + auc += predt * w; + sum_w += w; + } + } + if (sum_w != 0) { + auc /= sum_w; + } + CHECK_LE(auc, 1.0f); + return auc; +} + +/** + * Cast LTR problem to binary classification problem by comparing pairs. + */ +std::pair RankingAUC(std::vector const &predts, + MetaInfo const &info) { + CHECK_GE(info.group_ptr_.size(), 2); + uint32_t n_groups = info.group_ptr_.size() - 1; + float sum_auc = 0; + auto s_predts = common::Span{predts}; + auto s_labels = info.labels_.ConstHostSpan(); + auto s_weights = info.weights_.ConstHostSpan(); + + std::atomic invalid_groups{0}; + dmlc::OMPException omp_handler; + +#pragma omp parallel for reduction(+:sum_auc) + for (omp_ulong g = 1; g < info.group_ptr_.size(); ++g) { + omp_handler.Run([&]() { + size_t cnt = info.group_ptr_[g] - info.group_ptr_[g - 1]; + float w = s_weights.empty() ? 1.0f : s_weights[g - 1]; + auto g_predts = s_predts.subspan(info.group_ptr_[g - 1], cnt); + auto g_labels = s_labels.subspan(info.group_ptr_[g - 1], cnt); + float auc; + if (g_labels.size() < 3) { + // With 2 documents, there's only 1 comparison can be made. So either + // TP or FP will be zero. + invalid_groups++; + auc = 0; + } else { + auc = GroupRankingAUC(g_predts, g_labels, w); + } + sum_auc += auc; + }); + } + omp_handler.Rethrow(); + + if (invalid_groups != 0) { + InvalidGroupAUC(); + } + + return std::make_pair(sum_auc, n_groups - invalid_groups); +} + +class EvalAUC : public Metric { + std::shared_ptr d_cache_; + + public: + float Eval(const HostDeviceVector &preds, const MetaInfo &info, + bool distributed) override { + float auc {0}; + if (tparam_->gpu_id != GenericParameter::kCpuId) { + preds.SetDevice(tparam_->gpu_id); + info.labels_.SetDevice(tparam_->gpu_id); + info.weights_.SetDevice(tparam_->gpu_id); + } + if (!info.group_ptr_.empty()) { + /** + * learning to rank + */ + if (!info.weights_.Empty()) { + CHECK_EQ(info.weights_.Size(), info.group_ptr_.size() - 1); + } + uint32_t valid_groups = 0; + if (!info.labels_.Empty()) { + CHECK_EQ(info.group_ptr_.back(), info.labels_.Size()); + if (tparam_->gpu_id == GenericParameter::kCpuId) { + std::tie(auc, valid_groups) = + RankingAUC(preds.ConstHostVector(), info); + } else { + std::tie(auc, valid_groups) = GPURankingAUC( + preds.ConstDeviceSpan(), info, tparam_->gpu_id, &this->d_cache_); + } + } + + std::array results{auc, static_cast(valid_groups)}; + rabit::Allreduce(results.data(), results.size()); + auc = results[0]; + valid_groups = static_cast(results[1]); + + if (valid_groups <= 0) { + auc = std::numeric_limits::quiet_NaN(); + } else { + auc /= valid_groups; + CHECK_LE(auc, 1) << "Total AUC across groups: " << auc * valid_groups + << ", valid groups: " << valid_groups; + } + } else if (info.labels_.Size() != preds.Size() && + preds.Size() % info.labels_.Size() == 0) { + /** + * multi class + */ + if (tparam_->gpu_id == GenericParameter::kCpuId) { + auc = MultiClassOVR(preds.ConstHostVector(), info); + } else { + auc = GPUMultiClassAUCOVR(preds.ConstDeviceSpan(), info, tparam_->gpu_id, + &this->d_cache_); + } + } else { + /** + * binary classification + */ + float fp{0}, tp{0}; + if (!(preds.Empty() || info.labels_.Empty())) { + if (tparam_->gpu_id == GenericParameter::kCpuId) { + std::tie(fp, tp, auc) = + BinaryAUC(preds.ConstHostVector(), info.labels_.ConstHostVector(), + info.weights_.ConstHostVector()); + } else { + std::tie(fp, tp, auc) = GPUBinaryAUC( + preds.ConstDeviceSpan(), info, tparam_->gpu_id, &this->d_cache_); + } + } + float local_area = fp * tp; + std::array result{auc, local_area}; + rabit::Allreduce(result.data(), result.size()); + std::tie(auc, local_area) = UnpackArr(std::move(result)); + if (local_area <= 0) { + // the dataset across all workers have only positive or negative sample + auc = std::numeric_limits::quiet_NaN(); + } else { + // normalization + auc = auc / local_area; + } + } + if (std::isnan(auc)) { + LOG(WARNING) << "Dataset contains only positive or negative samples."; + } + return auc; + } + + char const* Name() const override { + return "auc"; + } +}; + +XGBOOST_REGISTER_METRIC(EvalBinaryAUC, "auc") +.describe("Receiver Operating Characteristic Area Under the Curve.") +.set_body([](const char*) { return new EvalAUC(); }); + +#if !defined(XGBOOST_USE_CUDA) +std::tuple +GPUBinaryAUC(common::Span predts, MetaInfo const &info, + int32_t device, std::shared_ptr *p_cache) { + common::AssertGPUSupport(); + return std::make_tuple(0.0f, 0.0f, 0.0f); +} + +float GPUMultiClassAUCOVR(common::Span predts, MetaInfo const &info, + int32_t device, std::shared_ptr* cache) { + common::AssertGPUSupport(); + return 0; +} + +std::pair +GPURankingAUC(common::Span predts, MetaInfo const &info, + int32_t device, std::shared_ptr *p_cache) { + common::AssertGPUSupport(); + return std::make_pair(0.0f, 0u); +} +struct DeviceAUCCache {}; +#endif // !defined(XGBOOST_USE_CUDA) +} // namespace metric +} // namespace xgboost diff --git a/src/metric/auc.cu b/src/metric/auc.cu new file mode 100644 index 000000000000..9570e31cb29d --- /dev/null +++ b/src/metric/auc.cu @@ -0,0 +1,540 @@ +/*! + * Copyright 2021 by XGBoost Contributors + */ +#include +#include +#include +#include +#include +#include +#include + +#include "rabit/rabit.h" +#include "xgboost/span.h" +#include "xgboost/data.h" +#include "auc.h" +#include "../common/device_helpers.cuh" +#include "../common/ranking_utils.cuh" + +namespace xgboost { +namespace metric { +namespace { +template +class Discard : public thrust::discard_iterator { + public: + using value_type = T; // NOLINT +}; + +struct GetWeightOp { + common::Span weights; + common::Span sorted_idx; + + __device__ float operator()(size_t i) const { + return weights.empty() ? 1.0f : weights[sorted_idx[i]]; + } +}; +} // namespace + +/** + * A cache to GPU data to avoid reallocating memory. + */ +struct DeviceAUCCache { + // Pair of FP/TP + using Pair = thrust::pair; + // index sorted by prediction value + dh::device_vector sorted_idx; + // track FP/TP for computation on trapesoid area + dh::device_vector fptp; + // track FP_PREV/TP_PREV for computation on trapesoid area + dh::device_vector neg_pos; + // index of unique prediction values. + dh::device_vector unique_idx; + // p^T: transposed prediction matrix, used by MultiClassAUC + dh::device_vector predts_t; + std::unique_ptr reducer; + + void Init(common::Span predts, bool is_multi, int32_t device) { + if (sorted_idx.size() != predts.size()) { + sorted_idx.resize(predts.size()); + fptp.resize(sorted_idx.size()); + unique_idx.resize(sorted_idx.size()); + neg_pos.resize(sorted_idx.size()); + if (is_multi) { + predts_t.resize(sorted_idx.size()); + reducer.reset(new dh::AllReducer); + reducer->Init(rabit::GetRank()); + } + } + } +}; + +/** + * The GPU implementation uses same calculation as CPU with a few more steps to distribute + * work across threads: + * + * - Run scan to obtain TP/FP values, which are right coordinates of trapesoid. + * - Find distinct prediction values and get the corresponding FP_PREV/TP_PREV value, + * which are left coordinates of trapesoid. + * - Reduce the scan array into 1 AUC value. + */ +std::tuple +GPUBinaryAUC(common::Span predts, MetaInfo const &info, + int32_t device, std::shared_ptr *p_cache) { + auto& cache = *p_cache; + if (!cache) { + cache.reset(new DeviceAUCCache); + } + cache->Init(predts, false, device); + + auto labels = info.labels_.ConstDeviceSpan(); + auto weights = info.weights_.ConstDeviceSpan(); + dh::safe_cuda(cudaSetDevice(device)); + + CHECK(!labels.empty()); + CHECK_EQ(labels.size(), predts.size()); + + /** + * Create sorted index for each class + */ + auto d_sorted_idx = dh::ToSpan(cache->sorted_idx); + dh::ArgSort(predts, d_sorted_idx); + + /** + * Linear scan + */ + auto get_weight = GetWeightOp{weights, d_sorted_idx}; + using Pair = thrust::pair; + auto get_fp_tp = [=]__device__(size_t i) { + size_t idx = d_sorted_idx[i]; + + float label = labels[idx]; + float w = get_weight(i); + + float fp = (1.0 - label) * w; + float tp = label * w; + + return thrust::make_pair(fp, tp); + }; // NOLINT + auto d_fptp = dh::ToSpan(cache->fptp); + dh::LaunchN(device, d_sorted_idx.size(), + [=] __device__(size_t i) { d_fptp[i] = get_fp_tp(i); }); + + dh::XGBDeviceAllocator alloc; + auto d_unique_idx = dh::ToSpan(cache->unique_idx); + dh::Iota(d_unique_idx, device); + + auto uni_key = dh::MakeTransformIterator( + thrust::make_counting_iterator(0), + [=] __device__(size_t i) { return predts[d_sorted_idx[i]]; }); + auto end_unique = thrust::unique_by_key_copy( + thrust::cuda::par(alloc), uni_key, uni_key + d_sorted_idx.size(), + dh::tbegin(d_unique_idx), thrust::make_discard_iterator(), + dh::tbegin(d_unique_idx)); + d_unique_idx = d_unique_idx.subspan(0, end_unique.second - dh::tbegin(d_unique_idx)); + + dh::InclusiveScan( + dh::tbegin(d_fptp), dh::tbegin(d_fptp), + [=] __device__(Pair const &l, Pair const &r) { + return thrust::make_pair(l.first + r.first, l.second + r.second); + }, + d_fptp.size()); + + auto d_neg_pos = dh::ToSpan(cache->neg_pos); + // scatter unique negaive/positive values + // shift to right by 1 with initial value being 0 + dh::LaunchN(device, d_unique_idx.size(), [=] __device__(size_t i) { + if (d_unique_idx[i] == 0) { // first unique index is 0 + assert(i == 0); + d_neg_pos[0] = {0, 0}; + return; + } + d_neg_pos[d_unique_idx[i]] = d_fptp[d_unique_idx[i] - 1]; + if (i == d_unique_idx.size() - 1) { + // last one needs to be included, may override above assignment if the last + // prediction value is district from previous one. + d_neg_pos.back() = d_fptp[d_unique_idx[i] - 1]; + return; + } + }); + + auto in = dh::MakeTransformIterator( + thrust::make_counting_iterator(0), [=] __device__(size_t i) { + float fp, tp; + float fp_prev, tp_prev; + if (i == 0) { + // handle the last element + thrust::tie(fp, tp) = d_fptp.back(); + thrust::tie(fp_prev, tp_prev) = d_neg_pos[d_unique_idx.back()]; + } else { + thrust::tie(fp, tp) = d_fptp[d_unique_idx[i] - 1]; + thrust::tie(fp_prev, tp_prev) = d_neg_pos[d_unique_idx[i - 1]]; + } + return TrapesoidArea(fp_prev, fp, tp_prev, tp); + }); + + Pair last = cache->fptp.back(); + float auc = thrust::reduce(thrust::cuda::par(alloc), in, in + d_unique_idx.size()); + return std::make_tuple(last.first, last.second, auc); +} + +void Transpose(common::Span in, common::Span out, size_t m, + size_t n, int32_t device) { + CHECK_EQ(in.size(), out.size()); + CHECK_EQ(in.size(), m * n); + dh::LaunchN(device, in.size(), [=] __device__(size_t i) { + size_t col = i / m; + size_t row = i % m; + size_t idx = row * n + col; + out[i] = in[idx]; + }); +} + +/** + * Last index of a group in a CSR style of index pointer. + */ +template +XGBOOST_DEVICE size_t LastOf(size_t group, common::Span indptr) { + return indptr[group + 1] - 1; +} + +/** + * MultiClass implementation is similar to binary classification, except we need to split + * up each class in all kernels. + */ +float GPUMultiClassAUCOVR(common::Span predts, MetaInfo const &info, + int32_t device, std::shared_ptr* p_cache) { + auto& cache = *p_cache; + if (!cache) { + cache.reset(new DeviceAUCCache); + } + cache->Init(predts, true, device); + + auto labels = info.labels_.ConstDeviceSpan(); + auto weights = info.weights_.ConstDeviceSpan(); + + size_t n_samples = labels.size(); + size_t n_classes = predts.size() / labels.size(); + CHECK_NE(n_classes, 0); + + /** + * Create sorted index for each class + */ + auto d_sorted_idx = dh::ToSpan(cache->sorted_idx); + dh::Iota(d_sorted_idx, device); + auto d_predts_t = dh::ToSpan(cache->predts_t); + Transpose(predts, d_predts_t, n_samples, n_classes, device); + + dh::TemporaryArray class_ptr(n_classes + 1, 0); + auto d_class_ptr = dh::ToSpan(class_ptr); + dh::LaunchN(device, n_classes + 1, [=]__device__(size_t i) { + d_class_ptr[i] = i * n_samples; + }); + // no out-of-place sort for thrust, cub sort doesn't accept general iterator. So can't + // use transform iterator in sorting. + dh::SegmentedArgSort(d_predts_t, d_class_ptr, d_sorted_idx); + + /** + * Linear scan + */ + dh::caching_device_vector d_auc(n_classes, 0); + auto s_d_auc = dh::ToSpan(d_auc); + auto get_weight = GetWeightOp{weights, d_sorted_idx}; + using Pair = thrust::pair; + auto d_fptp = dh::ToSpan(cache->fptp); + auto get_fp_tp = [=]__device__(size_t i) { + size_t idx = d_sorted_idx[i]; + + size_t class_id = i / n_samples; + // labels is a vector of size n_samples. + float label = labels[idx % n_samples] == class_id; + + float w = get_weight(i % n_samples); + float fp = (1.0 - label) * w; + float tp = label * w; + return thrust::make_pair(fp, tp); + }; // NOLINT + dh::LaunchN(device, d_sorted_idx.size(), + [=] __device__(size_t i) { d_fptp[i] = get_fp_tp(i); }); + + /** + * Handle duplicated predictions + */ + dh::XGBDeviceAllocator alloc; + auto d_unique_idx = dh::ToSpan(cache->unique_idx); + dh::Iota(d_unique_idx, device); + auto uni_key = dh::MakeTransformIterator>( + thrust::make_counting_iterator(0), [=] __device__(size_t i) { + uint32_t class_id = i / n_samples; + float predt = d_predts_t[d_sorted_idx[i]]; + return thrust::make_pair(class_id, predt); + }); + + // unique values are sparse, so we need a CSR style indptr + dh::TemporaryArray unique_class_ptr(class_ptr.size() + 1); + auto d_unique_class_ptr = dh::ToSpan(unique_class_ptr); + auto n_uniques = dh::SegmentedUniqueByKey( + thrust::cuda::par(alloc), + dh::tbegin(d_class_ptr), + dh::tend(d_class_ptr), + uni_key, + uni_key + d_sorted_idx.size(), + dh::tbegin(d_unique_idx), + d_unique_class_ptr.data(), + dh::tbegin(d_unique_idx), + thrust::equal_to>{}); + d_unique_idx = d_unique_idx.subspan(0, n_uniques); + + using Triple = thrust::tuple; + // expand to tuple to include class id + auto fptp_it_in = dh::MakeTransformIterator( + thrust::make_counting_iterator(0), [=] __device__(size_t i) { + uint32_t class_id = i / n_samples; + return thrust::make_tuple(class_id, d_fptp[i].first, d_fptp[i].second); + }); + // shrink down to pair + auto fptp_it_out = thrust::make_transform_output_iterator( + dh::tbegin(d_fptp), [=] __device__(Triple const &t) { + return thrust::make_pair(thrust::get<1>(t), thrust::get<2>(t)); + }); + dh::InclusiveScan( + fptp_it_in, fptp_it_out, + [=] __device__(Triple const &l, Triple const &r) { + uint32_t l_cid = thrust::get<0>(l); + uint32_t r_cid = thrust::get<0>(r); + if (l_cid != r_cid) { + return r; + } + + return Triple(r_cid, // class_id + thrust::get<1>(l) + thrust::get<1>(r), // fp + thrust::get<2>(l) + thrust::get<2>(r)); // tp + }, + d_fptp.size()); + + // scatter unique FP_PREV/TP_PREV values + auto d_neg_pos = dh::ToSpan(cache->neg_pos); + // When dataset is not empty, each class must have at least 1 (unique) sample + // prediction, so no need to handle special case. + dh::LaunchN(device, d_unique_idx.size(), [=]__device__(size_t i) { + if (d_unique_idx[i] % n_samples == 0) { // first unique index is 0 + assert(d_unique_idx[i] % n_samples == 0); + d_neg_pos[d_unique_idx[i]] = {0, 0}; // class_id * n_samples = i + return; + } + uint32_t class_id = d_unique_idx[i] / n_samples; + d_neg_pos[d_unique_idx[i]] = d_fptp[d_unique_idx[i] - 1]; + if (i == LastOf(class_id, d_unique_class_ptr)) { + // last one needs to be included. + size_t last = d_unique_idx[LastOf(class_id, d_unique_class_ptr)]; + d_neg_pos[LastOf(class_id, d_class_ptr)] = d_fptp[last - 1]; + return; + } + }); + + /** + * Reduce the result for each class + */ + auto key_in = dh::MakeTransformIterator( + thrust::make_counting_iterator(0), [=] __device__(size_t i) { + size_t class_id = d_unique_idx[i] / n_samples; + return class_id; + }); + auto val_in = dh::MakeTransformIterator( + thrust::make_counting_iterator(0), [=] __device__(size_t i) { + size_t class_id = d_unique_idx[i] / n_samples; + float fp, tp; + float fp_prev, tp_prev; + if (i == d_unique_class_ptr[class_id]) { + // first item is ignored, we use this thread to calculate the last item + thrust::tie(fp, tp) = d_fptp[class_id * n_samples + (n_samples - 1)]; + thrust::tie(fp_prev, tp_prev) = + d_neg_pos[d_unique_idx[LastOf(class_id, d_unique_class_ptr)]]; + } else { + thrust::tie(fp, tp) = d_fptp[d_unique_idx[i] - 1]; + thrust::tie(fp_prev, tp_prev) = d_neg_pos[d_unique_idx[i - 1]]; + } + float auc = TrapesoidArea(fp_prev, fp, tp_prev, tp); + return auc; + }); + + thrust::reduce_by_key(thrust::cuda::par(alloc), key_in, + key_in + d_unique_idx.size(), val_in, + thrust::make_discard_iterator(), d_auc.begin()); + + /** + * Scale the classes with number of samples for each class. + */ + dh::TemporaryArray resutls(n_classes * 4); + auto d_results = dh::ToSpan(resutls); + auto local_area = d_results.subspan(0, n_classes); + auto fp = d_results.subspan(n_classes, n_classes); + auto tp = d_results.subspan(2 * n_classes, n_classes); + auto auc = d_results.subspan(3 * n_classes, n_classes); + + dh::LaunchN(device, n_classes, [=] __device__(size_t c) { + auc[c] = s_d_auc[c]; + auto last = d_fptp[n_samples * c + (n_samples - 1)]; + fp[c] = last.first; + tp[c] = last.second; + local_area[c] = last.first * last.second; + }); + if (rabit::IsDistributed()) { + cache->reducer->AllReduceSum(resutls.data().get(), resutls.data().get(), + resutls.size()); + } + auto reduce_in = dh::MakeTransformIterator>( + thrust::make_counting_iterator(0), [=] __device__(size_t i) { + if (local_area[i] > 0) { + return thrust::make_pair(auc[i] / local_area[i] * tp[i], tp[i]); + } + return thrust::make_pair(std::numeric_limits::quiet_NaN(), 0.0f); + }); + + float tp_sum; + float auc_sum; + thrust::tie(auc_sum, tp_sum) = thrust::reduce( + thrust::cuda::par(alloc), reduce_in, reduce_in + n_classes, + thrust::make_pair(0.0f, 0.0f), + [=] __device__(auto const &l, auto const &r) { + return thrust::make_pair(l.first + r.first, l.second + r.second); + }); + if (tp_sum != 0 && !std::isnan(auc_sum)) { + auc_sum /= tp_sum; + } else { + return std::numeric_limits::quiet_NaN(); + } + return auc_sum; +} + +namespace { +struct RankScanItem { + size_t idx; + float predt; + float w; + bst_group_t group_id; +}; +} // anonymous namespace + +std::pair +GPURankingAUC(common::Span predts, MetaInfo const &info, + int32_t device, std::shared_ptr *p_cache) { + auto& cache = *p_cache; + if (!cache) { + cache.reset(new DeviceAUCCache); + } + cache->Init(predts, false, device); + + dh::caching_device_vector group_ptr(info.group_ptr_); + dh::XGBCachingDeviceAllocator alloc; + + auto d_group_ptr = dh::ToSpan(group_ptr); + /** + * Validate the dataset + */ + auto check_it = dh::MakeTransformIterator( + thrust::make_counting_iterator(0), + [=] __device__(size_t i) { return d_group_ptr[i + 1] - d_group_ptr[i]; }); + size_t n_valid = thrust::count_if( + thrust::cuda::par(alloc), check_it, check_it + group_ptr.size() - 1, + [=] __device__(size_t len) { return len >= 3; }); + if (n_valid < info.group_ptr_.size() - 1) { + InvalidGroupAUC(); + } + if (n_valid == 0) { + return std::make_pair(0.0f, 0); + } + + /** + * Sort the labels + */ + auto d_sorted_idx = dh::ToSpan(cache->sorted_idx); + auto d_labels = info.labels_.ConstDeviceSpan(); + + dh::Iota(d_sorted_idx, device); + dh::SegmentedArgSort(d_labels, d_group_ptr, d_sorted_idx); + + auto d_weights = info.weights_.ConstDeviceSpan(); + + dh::caching_device_vector threads_group_ptr(group_ptr.size(), 0); + auto d_threads_group_ptr = dh::ToSpan(threads_group_ptr); + // Use max to represent triangle + auto n_threads = common::SegmentedTrapezoidThreads( + d_group_ptr, d_threads_group_ptr, std::numeric_limits::max()); + // get the coordinate in nested summation + auto get_i_j = [=]__device__(size_t idx, size_t query_group_idx) { + auto data_group_begin = d_group_ptr[query_group_idx]; + size_t n_samples = d_group_ptr[query_group_idx + 1] - data_group_begin; + auto thread_group_begin = d_threads_group_ptr[query_group_idx]; + auto idx_in_thread_group = idx - thread_group_begin; + + size_t i, j; + common::UnravelTrapeziodIdx(idx_in_thread_group, n_samples, &i, &j); + // we use global index among all groups for sorted idx, so i, j should also be global + // index. + i += data_group_begin; + j += data_group_begin; + return thrust::make_pair(i, j); + }; // NOLINT + auto in = dh::MakeTransformIterator( + thrust::make_counting_iterator(0), [=] __device__(size_t idx) { + bst_group_t query_group_idx = dh::SegmentId(d_threads_group_ptr, idx); + auto data_group_begin = d_group_ptr[query_group_idx]; + size_t n_samples = d_group_ptr[query_group_idx + 1] - data_group_begin; + if (n_samples < 3) { + // at least 3 documents are required. + return RankScanItem{idx, 0, 0, query_group_idx}; + } + + size_t i, j; + thrust::tie(i, j) = get_i_j(idx, query_group_idx); + + float predt = predts[d_sorted_idx[i]] - predts[d_sorted_idx[j]]; + float w = common::Sqr(d_weights.empty() ? 1.0f : d_weights[query_group_idx]); + if (predt > 0) { + predt = 1.0; + } else if (predt == 0) { + predt = 0.5; + } else { + predt = 0; + } + predt *= w; + return RankScanItem{idx, predt, w, query_group_idx}; + }); + + dh::TemporaryArray d_auc(group_ptr.size() - 1); + auto s_d_auc = dh::ToSpan(d_auc); + auto out = thrust::make_transform_output_iterator( + Discard(), [=] __device__(RankScanItem const &item) -> RankScanItem { + auto group_id = item.group_id; + assert(group_id < d_group_ptr.size()); + auto data_group_begin = d_group_ptr[group_id]; + size_t n_samples = d_group_ptr[group_id + 1] - data_group_begin; + // last item of current group + if (item.idx == LastOf(group_id, d_threads_group_ptr)) { + if (item.w > 0) { + s_d_auc[group_id] = item.predt / item.w; + } else { + s_d_auc[group_id] = 0; + } + } + return {}; // discard + }); + dh::InclusiveScan( + in, out, + [] __device__(RankScanItem const &l, RankScanItem const &r) { + if (l.group_id != r.group_id) { + return r; + } + return RankScanItem{r.idx, l.predt + r.predt, l.w + r.w, l.group_id}; + }, + n_threads); + + /** + * Scale the AUC with number of items in each group. + */ + float auc = thrust::reduce(thrust::cuda::par(alloc), dh::tbegin(s_d_auc), + dh::tend(s_d_auc), 0.0f); + return std::make_pair(auc, n_valid); +} +} // namespace metric +} // namespace xgboost diff --git a/src/metric/auc.h b/src/metric/auc.h new file mode 100644 index 000000000000..cb443f2385a8 --- /dev/null +++ b/src/metric/auc.h @@ -0,0 +1,42 @@ +/*! + * Copyright 2021 by XGBoost Contributors + */ +#ifndef XGBOOST_METRIC_AUC_H_ +#define XGBOOST_METRIC_AUC_H_ +#include +#include +#include +#include + +#include "rabit/rabit.h" +#include "xgboost/base.h" +#include "xgboost/span.h" +#include "xgboost/data.h" + +namespace xgboost { +namespace metric { +XGBOOST_DEVICE inline float TrapesoidArea(float x0, float x1, float y0, float y1) { + return std::abs(x0 - x1) * (y0 + y1) * 0.5f; +} + +struct DeviceAUCCache; + +std::tuple +GPUBinaryAUC(common::Span predts, MetaInfo const &info, + int32_t device, std::shared_ptr *p_cache); + +float GPUMultiClassAUCOVR(common::Span predts, MetaInfo const &info, + int32_t device, std::shared_ptr* cache); + +std::pair +GPURankingAUC(common::Span predts, MetaInfo const &info, + int32_t device, std::shared_ptr *cache); + +inline void InvalidGroupAUC() { + LOG(INFO) << "Invalid group with less than 3 samples is found on worker " + << rabit::GetRank() << ". Calculating AUC value requires at " + << "least 2 pairs of samples."; +} +} // namespace metric +} // namespace xgboost +#endif // XGBOOST_METRIC_AUC_H_ diff --git a/src/metric/rank_metric.cc b/src/metric/rank_metric.cc index f0ea9b8a671b..bc690dea9baa 100644 --- a/src/metric/rank_metric.cc +++ b/src/metric/rank_metric.cc @@ -156,134 +156,6 @@ struct EvalAMS : public Metric { float ratio_; }; -/*! \brief Area Under Curve, for both classification and rank computed on CPU */ -struct EvalAuc : public Metric { - private: - // This is used to compute the AUC metrics on the GPU - for ranking tasks and - // for training jobs that run on the GPU. - std::unique_ptr auc_gpu_; - - template - bst_float Eval(const HostDeviceVector &preds, - const MetaInfo &info, - bool distributed, - const std::vector &gptr) { - const auto ngroups = static_cast(gptr.size() - 1); - // sum of all AUC's across all query groups - double sum_auc = 0.0; - int auc_error = 0; - const auto& labels = info.labels_.ConstHostVector(); - const auto &h_preds = preds.ConstHostVector(); - - dmlc::OMPException exc; - #pragma omp parallel reduction(+:sum_auc, auc_error) if (ngroups > 1) - { - exc.Run([&]() { - // Each thread works on a distinct group and sorts the predictions in that group - PredIndPairContainer rec; - #pragma omp for schedule(static) - for (bst_omp_uint group_id = 0; group_id < ngroups; ++group_id) { - exc.Run([&]() { - // Same thread can work on multiple groups one after another; hence, resize - // the predictions array based on the current group - rec.resize(gptr[group_id + 1] - gptr[group_id]); - #pragma omp parallel for schedule(static) if (!omp_in_parallel()) - for (bst_omp_uint j = gptr[group_id]; j < gptr[group_id + 1]; ++j) { - exc.Run([&]() { - rec[j - gptr[group_id]] = {h_preds[j], j}; - }); - } - - XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst); - // calculate AUC - double sum_pospair = 0.0; - double sum_npos = 0.0, sum_nneg = 0.0, buf_pos = 0.0, buf_neg = 0.0; - for (size_t j = 0; j < rec.size(); ++j) { - const bst_float wt = WeightPolicy::GetWeightOfSortedRecord(info, rec, j, group_id); - const bst_float ctr = labels[rec[j].second]; - // keep bucketing predictions in same bucket - if (j != 0 && rec[j].first != rec[j - 1].first) { - sum_pospair += buf_neg * (sum_npos + buf_pos * 0.5); - sum_npos += buf_pos; - sum_nneg += buf_neg; - buf_neg = buf_pos = 0.0f; - } - buf_pos += ctr * wt; - buf_neg += (1.0f - ctr) * wt; - } - sum_pospair += buf_neg * (sum_npos + buf_pos * 0.5); - sum_npos += buf_pos; - sum_nneg += buf_neg; - // check weird conditions - if (sum_npos <= 0.0 || sum_nneg <= 0.0) { - auc_error += 1; - } else { - // this is the AUC - sum_auc += sum_pospair / (sum_npos * sum_nneg); - } - }); - } - }); - } - exc.Rethrow(); - - // Report average AUC across all groups - // In distributed mode, workers which only contains pos or neg samples - // will be ignored when aggregate AUC. - bst_float dat[2] = {0.0f, 0.0f}; - if (auc_error < static_cast(ngroups)) { - dat[0] = static_cast(sum_auc); - dat[1] = static_cast(static_cast(ngroups) - auc_error); - } - if (distributed) { - rabit::Allreduce(dat, 2); - } - CHECK_GT(dat[1], 0.0f) - << "AUC: the dataset only contains pos or neg samples"; - return dat[0] / dat[1]; - } - - public: - bst_float Eval(const HostDeviceVector &preds, - const MetaInfo &info, - bool distributed) override { - CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; - CHECK_EQ(preds.Size(), info.labels_.Size()) - << "label size predict size not match"; - std::vector tgptr(2, 0); - tgptr[1] = static_cast(info.labels_.Size()); - - const auto &gptr = info.group_ptr_.empty() ? tgptr : info.group_ptr_; - CHECK_EQ(gptr.back(), info.labels_.Size()) - << "EvalAuc: group structure must match number of prediction"; - - // For ranking task, weights are per-group - // For binary classification task, weights are per-instance - const bool is_ranking_task = - !info.group_ptr_.empty() && info.weights_.Size() != info.num_row_; - - // Check if we have a GPU assignment; else, revert back to CPU - if (tparam_->gpu_id >= 0) { - if (!auc_gpu_) { - // Check and see if we have the GPU metric registered in the internal registry - auc_gpu_.reset(GPUMetric::CreateGPUMetric(this->Name(), tparam_)); - } - - if (auc_gpu_) { - return auc_gpu_->Eval(preds, info, distributed); - } - } - - if (is_ranking_task) { - return Eval(preds, info, distributed, gptr); - } else { - return Eval(preds, info, distributed, gptr); - } - } - - const char *Name() const override { return "auc"; } -}; - /*! \brief Evaluate rank list */ struct EvalRank : public Metric, public EvalRankConfig { private: @@ -672,10 +544,6 @@ XGBOOST_REGISTER_METRIC(AMS, "ams") .describe("AMS metric for higgs.") .set_body([](const char* param) { return new EvalAMS(param); }); -XGBOOST_REGISTER_METRIC(Auc, "auc") -.describe("Area under curve for both classification and rank.") -.set_body([](const char*) { return new EvalAuc(); }); - XGBOOST_REGISTER_METRIC(AucPR, "aucpr") .describe("Area under PR curve for both classification and rank.") .set_body([](const char*) { return new EvalAucPR(); }); diff --git a/src/metric/rank_metric.cu b/src/metric/rank_metric.cu index fbb8e5f854ba..70e4a808a04e 100644 --- a/src/metric/rank_metric.cu +++ b/src/metric/rank_metric.cu @@ -274,237 +274,6 @@ struct EvalMAPGpu { } }; -/*! \brief Area Under Curve metric computation for ranking datasets */ -struct EvalAucGpu : public Metric { - public: - // This function object computes the positive precision pair for each prediction group - class ComputePosPair : public thrust::unary_function { - public: - XGBOOST_DEVICE ComputePosPair(const double *pred_group_pos_precision, - const double *pred_group_neg_precision, - const double *pred_group_incr_precision) - : pred_group_pos_precision_(pred_group_pos_precision), - pred_group_neg_precision_(pred_group_neg_precision), - pred_group_incr_precision_(pred_group_incr_precision) {} - - // Compute positive precision pair for the prediction group at 'idx' - __device__ __forceinline__ double operator()(uint32_t idx) const { - return pred_group_neg_precision_[idx] * - (pred_group_incr_precision_[idx] + pred_group_pos_precision_[idx] * 0.5); - } - - private: - // Accumulated positive precision for the prediction group - const double *pred_group_pos_precision_{nullptr}; - // Accumulated negative precision for the prediction group - const double *pred_group_neg_precision_{nullptr}; - // Incremental positive precision for the prediction group - const double *pred_group_incr_precision_{nullptr}; - }; - - template - void ReleaseMemory(dh::caching_device_vector &vec) { // NOLINT - dh::caching_device_vector().swap(vec); - } - - bst_float Eval(const HostDeviceVector &preds, - const MetaInfo &info, - bool distributed) override { - // Sanity check is done by the caller - std::vector tgptr(2, 0); - tgptr[1] = static_cast(info.labels_.Size()); - const std::vector &gptr = info.group_ptr_.empty() ? tgptr : info.group_ptr_; - - auto device = tparam_->gpu_id; - dh::safe_cuda(cudaSetDevice(device)); - - info.labels_.SetDevice(device); - preds.SetDevice(device); - info.weights_.SetDevice(device); - - auto dpreds = preds.ConstDevicePointer(); - auto dlabels = info.labels_.ConstDevicePointer(); - auto dweights = info.weights_.ConstDevicePointer(); - - // Sort all the predictions (from one or more groups) - dh::SegmentSorter segment_pred_sorter; - segment_pred_sorter.SortItems(dpreds, preds.Size(), gptr); - - const auto &dsorted_preds = segment_pred_sorter.GetItemsSpan(); - const auto &dpreds_orig_pos = segment_pred_sorter.GetOriginalPositionsSpan(); - - // Group info on device - const auto &dgroups = segment_pred_sorter.GetGroupsSpan(); - uint32_t ngroups = segment_pred_sorter.GetNumGroups(); - - // Final values - double hsum_auc = 0.0; - unsigned hauc_error = 0; - - int device_id = -1; - dh::safe_cuda(cudaGetDevice(&device_id)); - - // Allocator to be used for managing space overhead while performing reductions - dh::XGBCachingDeviceAllocator alloc; - - if (ngroups == 1) { - const auto nitems = segment_pred_sorter.GetNumItems(); - - // First, segment all the predictions in the group. This is required so that we can - // aggregate the positive and negative precisions within that prediction group - dh::caching_device_vector dpred_segs(nitems, 0); - auto *pred_seg_arr = dpred_segs.data().get(); - // This is for getting the next segment number - dh::caching_device_vector seg_idx(1, 0); - auto *seg_idx_ptr = seg_idx.data().get(); - - dh::caching_device_vector dbuf_pos(nitems, 0); - dh::caching_device_vector dbuf_neg(nitems, 0); - auto *buf_pos_arr = dbuf_pos.data().get(); - auto *buf_neg_arr = dbuf_neg.data().get(); - - dh::LaunchN(device_id, nitems, nullptr, [=] __device__(int idx) { - auto ctr = dlabels[dpreds_orig_pos[idx]]; - // For ranking task, weights are per-group - // For binary classification task, weights are per-instance - const auto wt = dweights == nullptr ? 1.0f : dweights[dpreds_orig_pos[idx]]; - buf_pos_arr[idx] = ctr * wt; - buf_neg_arr[idx] = (1.0f - ctr) * wt; - if (idx == nitems - 1 || dsorted_preds[idx] != dsorted_preds[idx + 1]) { - auto new_seg_idx = atomicAdd(seg_idx_ptr, 1); - auto pred_val = dsorted_preds[idx]; - do { - pred_seg_arr[idx] = new_seg_idx; - idx--; - } while (idx >= 0 && dsorted_preds[idx] == pred_val); - } - }); - - std::array h_nunique_preds; - dh::safe_cuda(cudaMemcpyAsync(h_nunique_preds.data(), - seg_idx.data().get() + seg_idx.size() - 1, - sizeof(uint32_t), cudaMemcpyDeviceToHost)); - auto nunique_preds = h_nunique_preds.back(); - ReleaseMemory(seg_idx); - - // Next, accumulate the positive and negative precisions for every prediction group - dh::caching_device_vector sum_dbuf_pos(nunique_preds, 0); - auto itr = thrust::reduce_by_key(thrust::cuda::par(alloc), - dpred_segs.begin(), dpred_segs.end(), // Segmented by this - dbuf_pos.begin(), // Individual precisions - thrust::make_discard_iterator(), // Ignore unique segments - sum_dbuf_pos.begin()); // Write accumulated results here - ReleaseMemory(dbuf_pos); - CHECK(itr.second - sum_dbuf_pos.begin() == nunique_preds); - - dh::caching_device_vector sum_dbuf_neg(nunique_preds, 0); - itr = thrust::reduce_by_key(thrust::cuda::par(alloc), - dpred_segs.begin(), dpred_segs.end(), - dbuf_neg.begin(), - thrust::make_discard_iterator(), - sum_dbuf_neg.begin()); - ReleaseMemory(dbuf_neg); - ReleaseMemory(dpred_segs); - CHECK(itr.second - sum_dbuf_neg.begin() == nunique_preds); - - dh::caching_device_vector sum_nneg(nunique_preds, 0); - thrust::inclusive_scan(thrust::cuda::par(alloc), - sum_dbuf_neg.begin(), sum_dbuf_neg.end(), - sum_nneg.begin()); - double sum_neg_prec_val = sum_nneg.back(); - ReleaseMemory(sum_nneg); - - // Find incremental sum for the positive precisions that is then used to - // compute incremental positive precision pair - dh::caching_device_vector sum_npos(nunique_preds + 1, 0); - thrust::inclusive_scan(thrust::cuda::par(alloc), - sum_dbuf_pos.begin(), sum_dbuf_pos.end(), - sum_npos.begin() + 1); - double sum_pos_prec_val = sum_npos.back(); - - if (sum_pos_prec_val <= 0.0 || sum_neg_prec_val <= 0.0) { - hauc_error = 1; - } else { - dh::caching_device_vector sum_pospair(nunique_preds, 0); - // Finally, compute the positive precision pair - thrust::transform(thrust::make_counting_iterator(static_cast(0)), - thrust::make_counting_iterator(static_cast(nunique_preds)), - sum_pospair.begin(), - ComputePosPair(sum_dbuf_pos.data().get(), - sum_dbuf_neg.data().get(), - sum_npos.data().get())); - ReleaseMemory(sum_dbuf_pos); - ReleaseMemory(sum_dbuf_neg); - ReleaseMemory(sum_npos); - hsum_auc = thrust::reduce(thrust::cuda::par(alloc), - sum_pospair.begin(), sum_pospair.end()) - / (sum_pos_prec_val * sum_neg_prec_val); - } - } else { - // AUC sum for each group - dh::caching_device_vector sum_auc(ngroups, 0); - // AUC error across all groups - dh::caching_device_vector auc_error(1, 0); - auto *dsum_auc = sum_auc.data().get(); - auto *dauc_error = auc_error.data().get(); - - // For each group item compute the aggregated precision - dh::LaunchN<1, 32>(device_id, ngroups, nullptr, [=] __device__(uint32_t gidx) { - double sum_pospair = 0.0, sum_npos = 0.0, sum_nneg = 0.0, buf_pos = 0.0, buf_neg = 0.0; - - for (auto i = dgroups[gidx]; i < dgroups[gidx + 1]; ++i) { - const auto ctr = dlabels[dpreds_orig_pos[i]]; - // Keep bucketing predictions in same bucket - if (i != dgroups[gidx] && dsorted_preds[i] != dsorted_preds[i - 1]) { - sum_pospair += buf_neg * (sum_npos + buf_pos * 0.5); - sum_npos += buf_pos; - sum_nneg += buf_neg; - buf_neg = buf_pos = 0.0f; - } - // For ranking task, weights are per-group - // For binary classification task, weights are per-instance - const auto wt = dweights == nullptr ? 1.0f : dweights[gidx]; - buf_pos += ctr * wt; - buf_neg += (1.0f - ctr) * wt; - } - sum_pospair += buf_neg * (sum_npos + buf_pos * 0.5); - sum_npos += buf_pos; - sum_nneg += buf_neg; - - // Check weird conditions - if (sum_npos <= 0.0 || sum_nneg <= 0.0) { - atomicAdd(dauc_error, 1); - } else { - // This is the AUC - dsum_auc[gidx] = sum_pospair / (sum_npos * sum_nneg); - } - }); - - hsum_auc = thrust::reduce(thrust::cuda::par(alloc), sum_auc.begin(), sum_auc.end()); - hauc_error = auc_error.back(); // Copy it back to host - } - - // Report average AUC across all groups - // In distributed mode, workers which only contains pos or neg samples - // will be ignored when aggregate AUC. - bst_float dat[2] = {0.0f, 0.0f}; - if (hauc_error < ngroups) { - dat[0] = static_cast(hsum_auc); - dat[1] = static_cast(ngroups - hauc_error); - } - if (distributed) { - rabit::Allreduce(dat, 2); - } - CHECK_GT(dat[1], 0.0f) - << "AUC: the dataset only contains pos or neg samples"; - return dat[0] / dat[1]; - } - - const char* Name() const override { - return "auc"; - } -}; - /*! \brief Area Under PR Curve metric computation for ranking datasets */ struct EvalAucPRGpu : public Metric { public: @@ -691,10 +460,6 @@ struct EvalAucPRGpu : public Metric { } }; -XGBOOST_REGISTER_GPU_METRIC(AucGpu, "auc") -.describe("Area under curve for rank computed on GPU.") -.set_body([](const char* param) { return new EvalAucGpu(); }); - XGBOOST_REGISTER_GPU_METRIC(AucPRGpu, "aucpr") .describe("Area under PR curve for rank computed on GPU.") .set_body([](const char* param) { return new EvalAucPRGpu(); }); diff --git a/src/objective/rank_obj.cu b/src/objective/rank_obj.cu index f1b350bb3b3e..1fa584930c07 100644 --- a/src/objective/rank_obj.cu +++ b/src/objective/rank_obj.cu @@ -293,7 +293,7 @@ class NDCGLambdaWeightComputer group_segments)), thrust::make_discard_iterator(), // We don't care for the group indices dgroup_dcg_.begin()); // Sum of the item's DCG values in the group - CHECK(static_cast(end_range.second - dgroup_dcg_.begin()) == dgroup_dcg_.size()); + CHECK_EQ(static_cast(end_range.second - dgroup_dcg_.begin()), dgroup_dcg_.size()); } inline const common::Span GetGroupDcgsSpan() const { diff --git a/src/tree/param.h b/src/tree/param.h index b686c6ee75da..2cae7686e6e1 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -15,6 +15,7 @@ #include "xgboost/parameter.h" #include "xgboost/data.h" +#include "../common/math.h" namespace xgboost { namespace tree { @@ -264,14 +265,11 @@ XGBOOST_DEVICE inline static T1 ThresholdL1(T1 w, T2 alpha) { return 0.0; } -template -XGBOOST_DEVICE inline static T Sqr(T a) { return a * a; } - // calculate the cost of loss function template XGBOOST_DEVICE inline T CalcGainGivenWeight(const TrainingParams &p, T sum_grad, T sum_hess, T w) { - return -(T(2.0) * sum_grad * w + (sum_hess + p.reg_lambda) * Sqr(w)); + return -(T(2.0) * sum_grad * w + (sum_hess + p.reg_lambda) * common::Sqr(w)); } // calculate weight given the statistics @@ -296,9 +294,9 @@ XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess } if (p.max_delta_step == 0.0f) { if (p.reg_alpha == 0.0f) { - return Sqr(sum_grad) / (sum_hess + p.reg_lambda); + return common::Sqr(sum_grad) / (sum_hess + p.reg_lambda); } else { - return Sqr(ThresholdL1(sum_grad, p.reg_alpha)) / + return common::Sqr(ThresholdL1(sum_grad, p.reg_alpha)) / (sum_hess + p.reg_lambda); } } else { diff --git a/src/tree/split_evaluator.h b/src/tree/split_evaluator.h index 067a3e5bf37c..069718a27378 100644 --- a/src/tree/split_evaluator.h +++ b/src/tree/split_evaluator.h @@ -114,7 +114,7 @@ class TreeEvaluator { } // Avoiding tree::CalcGainGivenWeight can significantly reduce avg floating point error. if (p.max_delta_step == 0.0f && has_constraint == false) { - return Sqr(ThresholdL1(stats.sum_grad, p.reg_alpha)) / + return common::Sqr(ThresholdL1(stats.sum_grad, p.reg_alpha)) / (stats.sum_hess + p.reg_lambda); } return tree::CalcGainGivenWeight(p, stats.sum_grad, diff --git a/tests/cpp/common/test_common.cc b/tests/cpp/common/test_common.cc index 006860b11af2..adaf21feadc9 100644 --- a/tests/cpp/common/test_common.cc +++ b/tests/cpp/common/test_common.cc @@ -1,11 +1,12 @@ #include +#include #include "../../../src/common/common.h" namespace xgboost { namespace common { TEST(ArgSort, Basic) { std::vector inputs {3.0, 2.0, 1.0}; - auto ret = ArgSort(inputs); + auto ret = ArgSort(Span{inputs}); std::vector sol{2, 1, 0}; ASSERT_EQ(ret, sol); } diff --git a/tests/cpp/common/test_ranking_utils.cu b/tests/cpp/common/test_ranking_utils.cu new file mode 100644 index 000000000000..7e0f4244cefb --- /dev/null +++ b/tests/cpp/common/test_ranking_utils.cu @@ -0,0 +1,66 @@ +#include +#include "../../../src/common/ranking_utils.cuh" +#include "../../../src/common/device_helpers.cuh" + +namespace xgboost { +namespace common { + +TEST(SegmentedTrapezoidThreads, Basic) { + size_t constexpr kElements = 24, kGroups = 3; + dh::device_vector offset_ptr(kGroups + 1, 0); + offset_ptr[0] = 0; + offset_ptr[1] = 8; + offset_ptr[2] = 16; + offset_ptr[kGroups] = kElements; + + size_t h = 1; + dh::device_vector thread_ptr(kGroups + 1, 0); + size_t total = SegmentedTrapezoidThreads(dh::ToSpan(offset_ptr), dh::ToSpan(thread_ptr), h); + ASSERT_EQ(total, kElements - kGroups); + + h = 2; + SegmentedTrapezoidThreads(dh::ToSpan(offset_ptr), dh::ToSpan(thread_ptr), h); + std::vector h_thread_ptr(thread_ptr.size()); + thrust::copy(thread_ptr.cbegin(), thread_ptr.cend(), h_thread_ptr.begin()); + for (size_t i = 1; i < h_thread_ptr.size(); ++i) { + ASSERT_EQ(h_thread_ptr[i] - h_thread_ptr[i - 1], 13); + } + + h = 7; + SegmentedTrapezoidThreads(dh::ToSpan(offset_ptr), dh::ToSpan(thread_ptr), h); + thrust::copy(thread_ptr.cbegin(), thread_ptr.cend(), h_thread_ptr.begin()); + for (size_t i = 1; i < h_thread_ptr.size(); ++i) { + ASSERT_EQ(h_thread_ptr[i] - h_thread_ptr[i - 1], 28); + } +} + +TEST(SegmentedTrapezoidThreads, Unravel) { + size_t i = 0, j = 0; + size_t constexpr kN = 8; + + UnravelTrapeziodIdx(6, kN, &i, &j); + ASSERT_EQ(i, 0); + ASSERT_EQ(j, 7); + + UnravelTrapeziodIdx(12, kN, &i, &j); + ASSERT_EQ(i, 1); + ASSERT_EQ(j, 7); + + UnravelTrapeziodIdx(15, kN, &i, &j); + ASSERT_EQ(i, 2); + ASSERT_EQ(j, 5); + + UnravelTrapeziodIdx(21, kN, &i, &j); + ASSERT_EQ(i, 3); + ASSERT_EQ(j, 7); + + UnravelTrapeziodIdx(25, kN, &i, &j); + ASSERT_EQ(i, 5); + ASSERT_EQ(j, 6); + + UnravelTrapeziodIdx(27, kN, &i, &j); + ASSERT_EQ(i, 6); + ASSERT_EQ(j, 7); +} +} // namespace common +} // namespace xgboost diff --git a/tests/cpp/metric/test_auc.cc b/tests/cpp/metric/test_auc.cc new file mode 100644 index 000000000000..310efaa4ddea --- /dev/null +++ b/tests/cpp/metric/test_auc.cc @@ -0,0 +1,133 @@ +#include +#include "../helpers.h" + +namespace xgboost { +namespace metric { + +TEST(Metric, DeclareUnifiedTest(BinaryAUC)) { + auto tparam = xgboost::CreateEmptyGenericParam(GPUIDX); + std::unique_ptr uni_ptr {Metric::Create("auc", &tparam)}; + Metric * metric = uni_ptr.get(); + ASSERT_STREQ(metric->Name(), "auc"); + + // Binary + EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1.0f, 1e-10); + EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {1, 0}), 0.0f, 1e-10); + EXPECT_NEAR(GetMetricEval(metric, {0, 0}, {0, 1}), 0.5f, 1e-10); + EXPECT_NEAR(GetMetricEval(metric, {1, 1}, {0, 1}), 0.5f, 1e-10); + EXPECT_NEAR(GetMetricEval(metric, {0, 0}, {1, 0}), 0.5f, 1e-10); + EXPECT_NEAR(GetMetricEval(metric, {1, 1}, {1, 0}), 0.5f, 1e-10); + EXPECT_NEAR(GetMetricEval(metric, {1, 0, 0}, {0, 0, 1}), 0.25f, 1e-10); + + // Invalid dataset + MetaInfo info; + info.labels_ = {0, 0}; + float auc = metric->Eval({1, 1}, info, false); + ASSERT_TRUE(std::isnan(auc)); + info.labels_ = HostDeviceVector{}; + auc = metric->Eval(HostDeviceVector{}, info, false); + ASSERT_TRUE(std::isnan(auc)); + + EXPECT_NEAR(GetMetricEval(metric, {0, 1, 0, 1}, {0, 1, 0, 1}), 1.0f, 1e-10); + + // AUC with instance weights + EXPECT_NEAR(GetMetricEval(metric, + {0.9f, 0.1f, 0.4f, 0.3f}, + {0, 0, 1, 1}, + {1.0f, 3.0f, 2.0f, 4.0f}), + 0.75f, 0.001f); + + // regression test case + ASSERT_NEAR(GetMetricEval( + metric, + {0.79523796, 0.5201713, 0.79523796, 0.24273258, 0.53452194, + 0.53452194, 0.24273258, 0.5201713, 0.79523796, 0.53452194, + 0.24273258, 0.53452194, 0.79523796, 0.5201713, 0.24273258, + 0.5201713, 0.5201713, 0.53452194, 0.5201713, 0.53452194}, + {0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0}), + 0.5, 1e-10); +} + +TEST(Metric, DeclareUnifiedTest(MultiAUC)) { + auto tparam = CreateEmptyGenericParam(GPUIDX); + std::unique_ptr uni_ptr{ + Metric::Create("auc", &tparam)}; + auto metric = uni_ptr.get(); + + // MultiClass + // 3x3 + EXPECT_NEAR(GetMetricEval(metric, + { + 1.0f, 0.0f, 0.0f, // p_0 + 0.0f, 1.0f, 0.0f, // p_1 + 0.0f, 0.0f, 1.0f // p_2 + }, + {0, 1, 2}), + 1.0f, 1e-10); + EXPECT_NEAR(GetMetricEval(metric, + { + 1.0f, 0.0f, 0.0f, // p_0 + 0.0f, 1.0f, 0.0f, // p_1 + 0.0f, 0.0f, 1.0f // p_2 + }, + {2, 1, 0}), + 0.5f, 1e-10); + EXPECT_NEAR(GetMetricEval(metric, + { + 1.0f, 0.0f, 0.0f, // p_0 + 0.0f, 1.0f, 0.0f, // p_1 + 0.0f, 0.0f, 1.0f // p_2 + }, + {2, 0, 1}), + 0.25f, 1e-10); + + // invalid dataset + float auc = GetMetricEval(metric, + { + 1.0f, 0.0f, 0.0f, // p_0 + 0.0f, 1.0f, 0.0f, // p_1 + 0.0f, 0.0f, 1.0f // p_2 + }, + {0, 1, 1}); // no class 2. + EXPECT_TRUE(std::isnan(auc)) << auc; +} + +TEST(Metric, DeclareUnifiedTest(RankingAUC)) { + auto tparam = CreateEmptyGenericParam(GPUIDX); + std::unique_ptr metric{Metric::Create("auc", &tparam)}; + + // single group + EXPECT_NEAR(GetMetricEval(metric.get(), {0.7f, 0.2f, 0.3f, 0.6f}, + {1.0f, 0.8f, 0.4f, 0.2f}, /*weights=*/{}, + {0, 4}), + 0.5f, 1e-10); + + // multi group + EXPECT_NEAR(GetMetricEval(metric.get(), {0, 1, 2, 0, 1, 2}, + {0, 1, 2, 0, 1, 2}, /*weights=*/{}, {0, 3, 6}), + 1.0f, 1e-10); + + EXPECT_NEAR(GetMetricEval(metric.get(), {0, 1, 2, 0, 1, 2}, + {0, 1, 2, 0, 1, 2}, /*weights=*/{1.0f, 2.0f}, + {0, 3, 6}), + 1.0f, 1e-10); + + // AUC metric for grouped datasets - exception scenarios + ASSERT_TRUE(std::isnan( + GetMetricEval(metric.get(), {0, 1, 2}, {0, 0, 0}, {}, {0, 2, 3}))); + + // regression case + HostDeviceVector predt{0.33935383, 0.5149714, 0.32138085, 1.4547751, + 1.2010975, 0.42651367, 0.23104341, 0.83610827, + 0.8494239, 0.07136688, 0.5623144, 0.8086237, + 1.5066161, -4.094787, 0.76887935, -2.4082742}; + std::vector groups{0, 7, 16}; + std::vector labels{1., 0., 0., 1., 2., 1., 0., 0., + 0., 0., 0., 0., 1., 0., 1., 0.}; + + EXPECT_NEAR(GetMetricEval(metric.get(), std::move(predt), labels, + /*weights=*/{}, groups), + 0.769841f, 1e-6); +} +} // namespace metric +} // namespace xgboost diff --git a/tests/cpp/metric/test_auc.cu b/tests/cpp/metric/test_auc.cu new file mode 100644 index 000000000000..430ab1d374c1 --- /dev/null +++ b/tests/cpp/metric/test_auc.cu @@ -0,0 +1,5 @@ +/*! + * Copyright 2021 XGBoost contributors + */ +// Dummy file to keep the CUDA conditional compile trick. +#include "test_auc.cc" \ No newline at end of file diff --git a/tests/cpp/metric/test_rank_metric.cc b/tests/cpp/metric/test_rank_metric.cc index 29043d32f27a..c8c97bef1930 100644 --- a/tests/cpp/metric/test_rank_metric.cc +++ b/tests/cpp/metric/test_rank_metric.cc @@ -24,49 +24,6 @@ TEST(Metric, AMS) { } #endif -TEST(Metric, DeclareUnifiedTest(AUC)) { - auto tparam = xgboost::CreateEmptyGenericParam(GPUIDX); - xgboost::Metric * metric = xgboost::Metric::Create("auc", &tparam); - ASSERT_STREQ(metric->Name(), "auc"); - EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, 1e-10); - EXPECT_NEAR(GetMetricEval(metric, - {0.1f, 0.9f, 0.1f, 0.9f}, - { 0, 0, 1, 1}), - 0.5f, 0.001f); - EXPECT_ANY_THROW(GetMetricEval(metric, {0, 1}, {})); - EXPECT_ANY_THROW(GetMetricEval(metric, {0, 0}, {0, 0})); - EXPECT_ANY_THROW(GetMetricEval(metric, {0, 1}, {1, 1})); - - // AUC with instance weights - EXPECT_NEAR(GetMetricEval(metric, - {0.9f, 0.1f, 0.4f, 0.3f}, - {0, 0, 1, 1}, - {1.0f, 3.0f, 2.0f, 4.0f}), - 0.75f, 0.001f); - - // AUC for a ranking task without weights - EXPECT_NEAR(GetMetricEval(metric, - {0.9f, 0.1f, 0.4f, 0.3f, 0.7f}, - {0, 1, 0, 1, 1}, - {}, - {0, 2, 5}), - 0.25f, 0.001f); - - // AUC for a ranking task with weights/group - EXPECT_NEAR(GetMetricEval(metric, - {0.9f, 0.1f, 0.4f, 0.3f, 0.7f}, - {1, 0, 1, 0, 0}, - {1, 2}, - {0, 2, 5}), - 0.75f, 0.001f); - - // AUC metric for grouped datasets - exception scenarios - EXPECT_ANY_THROW(GetMetricEval(metric, {0, 1, 2}, {0, 0, 0}, {}, {0, 2, 3})); - EXPECT_ANY_THROW(GetMetricEval(metric, {0, 1, 2}, {1, 1, 1}, {}, {0, 2, 3})); - - delete metric; -} - TEST(Metric, DeclareUnifiedTest(AUCPR)) { auto tparam = xgboost::CreateEmptyGenericParam(GPUIDX); xgboost::Metric *metric = xgboost::Metric::Create("aucpr", &tparam); diff --git a/tests/python-gpu/conftest.py b/tests/python-gpu/conftest.py index 1e5e96df7c6e..6b7eb531a3e4 100644 --- a/tests/python-gpu/conftest.py +++ b/tests/python-gpu/conftest.py @@ -42,6 +42,7 @@ def local_cuda_cluster(request, pytestconfig): def pytest_addoption(parser): parser.addoption('--use-rmm-pool', action='store_true', default=False, help='Use RMM pool') + def pytest_collection_modifyitems(config, items): if config.getoption('--use-rmm-pool'): blocklist = [ @@ -53,3 +54,9 @@ def pytest_collection_modifyitems(config, items): for item in items: if any(item.nodeid.startswith(x) for x in blocklist): item.add_marker(skip_mark) + + # mark dask tests as `mgpu`. + mgpu_mark = pytest.mark.mgpu + for item in items: + if item.nodeid.startswith("python-gpu/test_gpu_with_dask.py"): + item.add_marker(mgpu_mark) diff --git a/tests/python-gpu/test_gpu_eval_metrics.py b/tests/python-gpu/test_gpu_eval_metrics.py new file mode 100644 index 000000000000..36b6a70868b9 --- /dev/null +++ b/tests/python-gpu/test_gpu_eval_metrics.py @@ -0,0 +1,47 @@ +import sys +import xgboost +import pytest + +sys.path.append("tests/python") +import test_eval_metrics as test_em # noqa + + +class TestGPUEvalMetrics: + cpu_test = test_em.TestEvalMetrics() + + @pytest.mark.parametrize("n_samples", [4, 100, 1000]) + def test_roc_auc_binary(self, n_samples): + self.cpu_test.run_roc_auc_binary("gpu_hist", n_samples) + + @pytest.mark.parametrize("n_samples", [4, 100, 1000]) + def test_roc_auc_multi(self, n_samples): + self.cpu_test.run_roc_auc_multi("gpu_hist", n_samples) + + @pytest.mark.parametrize("n_samples", [4, 100, 1000]) + def test_roc_auc_ltr(self, n_samples): + import numpy as np + + rng = np.random.RandomState(1994) + n_samples = n_samples + n_features = 10 + X = rng.randn(n_samples, n_features) + y = rng.randint(0, 16, size=n_samples) + group = np.array([n_samples // 2, n_samples // 2]) + + Xy = xgboost.DMatrix(X, y, group=group) + + cpu = xgboost.train( + {"tree_method": "hist", "eval_metric": "auc", "objective": "rank:ndcg"}, + Xy, + num_boost_round=10, + ) + cpu_auc = float(cpu.eval(Xy).split(":")[1]) + + gpu = xgboost.train( + {"tree_method": "gpu_hist", "eval_metric": "auc", "objective": "rank:ndcg"}, + Xy, + num_boost_round=10, + ) + gpu_auc = float(gpu.eval(Xy).split(":")[1]) + + np.testing.assert_allclose(cpu_auc, gpu_auc) diff --git a/tests/python-gpu/test_gpu_ranking.py b/tests/python-gpu/test_gpu_ranking.py index 556db051f6c1..e95fb78b1bc5 100644 --- a/tests/python-gpu/test_gpu_ranking.py +++ b/tests/python-gpu/test_gpu_ranking.py @@ -5,6 +5,10 @@ import shutil import urllib.request import zipfile +import sys +sys.path.append("tests/python") + +import testing as tm # noqa class TestRanking: @@ -15,9 +19,9 @@ def setup_class(cls): """ from sklearn.datasets import load_svmlight_files # download the test data - cls.dpath = 'demo/rank/' + cls.dpath = os.path.join(tm.PROJECT_ROOT, "demo/rank/") src = 'https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip' - target = cls.dpath + '/MQ2008.zip' + target = os.path.join(cls.dpath, "MQ2008.zip") if os.path.exists(cls.dpath) and os.path.exists(target): print("Skipping dataset download...") @@ -79,8 +83,8 @@ def teardown_class(cls): Cleanup test artifacts from download and unpacking :return: """ - os.remove(cls.dpath + "MQ2008.zip") - shutil.rmtree(cls.dpath + "MQ2008") + os.remove(os.path.join(cls.dpath, "MQ2008.zip")) + shutil.rmtree(os.path.join(cls.dpath, "MQ2008")) @classmethod def __test_training_with_rank_objective(cls, rank_objective, metric_name, tolerance=1e-02): diff --git a/tests/python-gpu/test_gpu_with_dask.py b/tests/python-gpu/test_gpu_with_dask.py index 3633efd19774..cd48ccb4dcdb 100644 --- a/tests/python-gpu/test_gpu_with_dask.py +++ b/tests/python-gpu/test_gpu_with_dask.py @@ -17,6 +17,8 @@ sys.path.append("tests/python") from test_with_dask import run_empty_dmatrix_reg # noqa +from test_with_dask import run_empty_dmatrix_auc # noqa +from test_with_dask import run_auc # noqa from test_with_dask import run_boost_from_prediction # noqa from test_with_dask import run_dask_classifier # noqa from test_with_dask import run_empty_dmatrix_cls # noqa @@ -286,6 +288,15 @@ def test_empty_dmatrix(self, local_cuda_cluster: LocalCUDACluster) -> None: run_empty_dmatrix_reg(client, parameters) run_empty_dmatrix_cls(client, parameters) + def test_empty_dmatrix_auc(self, local_cuda_cluster: LocalCUDACluster) -> None: + with Client(local_cuda_cluster) as client: + n_workers = len(_get_client_workers(client)) + run_empty_dmatrix_auc(client, "gpu_hist", n_workers) + + def test_auc(self, local_cuda_cluster: LocalCUDACluster) -> None: + with Client(local_cuda_cluster) as client: + run_auc(client, "gpu_hist") + def test_data_initialization(self, local_cuda_cluster: LocalCUDACluster) -> None: with Client(local_cuda_cluster) as client: X, y, _ = generate_array() diff --git a/tests/python/test_eval_metrics.py b/tests/python/test_eval_metrics.py index 4117491246e0..e54cd71b670b 100644 --- a/tests/python/test_eval_metrics.py +++ b/tests/python/test_eval_metrics.py @@ -123,3 +123,90 @@ def test_gamma_deviance(self): gamma_dev = float(booster.eval(xgb.DMatrix(X, y)).split(":")[1].split(":")[0]) skl_gamma_dev = mean_gamma_deviance(y, score) np.testing.assert_allclose(gamma_dev, skl_gamma_dev, rtol=1e-6) + + def run_roc_auc_binary(self, tree_method, n_samples): + import numpy as np + from sklearn.datasets import make_classification + from sklearn.metrics import roc_auc_score + + rng = np.random.RandomState(1994) + n_samples = n_samples + n_features = 10 + + X, y = make_classification( + n_samples, + n_features, + n_informative=n_features, + n_redundant=0, + random_state=rng + ) + Xy = xgb.DMatrix(X, y) + booster = xgb.train( + { + "tree_method": tree_method, + "eval_metric": "auc", + "objective": "binary:logistic", + }, + Xy, + num_boost_round=8, + ) + score = booster.predict(Xy) + skl_auc = roc_auc_score(y, score) + auc = float(booster.eval(Xy).split(":")[1]) + np.testing.assert_allclose(skl_auc, auc, rtol=1e-6) + + X = rng.randn(*X.shape) + score = booster.predict(xgb.DMatrix(X)) + skl_auc = roc_auc_score(y, score) + auc = float(booster.eval(xgb.DMatrix(X, y)).split(":")[1]) + np.testing.assert_allclose(skl_auc, auc, rtol=1e-6) + + @pytest.mark.skipif(**tm.no_sklearn()) + @pytest.mark.parametrize("n_samples", [4, 100, 1000]) + def test_roc_auc(self, n_samples): + self.run_roc_auc_binary("hist", n_samples) + + def run_roc_auc_multi(self, tree_method, n_samples): + import numpy as np + from sklearn.datasets import make_classification + from sklearn.metrics import roc_auc_score + + rng = np.random.RandomState(1994) + n_samples = n_samples + n_features = 10 + n_classes = 4 + + X, y = make_classification( + n_samples, + n_features, + n_informative=n_features, + n_redundant=0, + n_classes=n_classes, + random_state=rng + ) + + Xy = xgb.DMatrix(X, y) + booster = xgb.train( + { + "tree_method": tree_method, + "eval_metric": "auc", + "objective": "multi:softprob", + "num_class": n_classes, + }, + Xy, + num_boost_round=8, + ) + score = booster.predict(Xy) + skl_auc = roc_auc_score(y, score, average="weighted", multi_class="ovr") + auc = float(booster.eval(Xy).split(":")[1]) + np.testing.assert_allclose(skl_auc, auc, rtol=1e-6) + + X = rng.randn(*X.shape) + score = booster.predict(xgb.DMatrix(X)) + skl_auc = roc_auc_score(y, score, average="weighted", multi_class="ovr") + auc = float(booster.eval(xgb.DMatrix(X, y)).split(":")[1]) + np.testing.assert_allclose(skl_auc, auc, rtol=1e-6) + + @pytest.mark.parametrize("n_samples", [4, 100, 1000]) + def test_roc_auc_multi(self, n_samples): + self.run_roc_auc_multi("hist", n_samples) diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index 5c6418a4c03a..c08584af2613 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -9,6 +9,7 @@ import json from typing import List, Tuple, Dict, Optional, Type, Any import asyncio +from functools import partial from concurrent.futures import ThreadPoolExecutor import tempfile from sklearn.datasets import make_classification @@ -528,9 +529,106 @@ def _check_outputs(out: xgb.dask.TrainReturnT, predictions: np.ndarray) -> None: _check_outputs(out, predictions) +def run_empty_dmatrix_auc(client: "Client", tree_method: str, n_workers: int) -> None: + from sklearn import datasets + n_samples = 100 + n_features = 97 + rng = np.random.RandomState(1994) + + make_classification = partial( + datasets.make_classification, + n_features=n_features, + random_state=rng + ) + + # binary + X_, y_ = make_classification(n_samples=n_samples, random_state=rng) + X = dd.from_array(X_, chunksize=10) + y = dd.from_array(y_, chunksize=10) + + n_samples = n_workers - 1 + valid_X_, valid_y_ = make_classification(n_samples=n_samples, random_state=rng) + valid_X = dd.from_array(valid_X_, chunksize=n_samples) + valid_y = dd.from_array(valid_y_, chunksize=n_samples) + + cls = xgb.dask.DaskXGBClassifier( + tree_method=tree_method, n_estimators=2, use_label_encoder=False + ) + cls.fit(X, y, eval_metric="auc", eval_set=[(valid_X, valid_y)]) + + # multiclass + X_, y_ = make_classification( + n_samples=n_samples, + n_classes=10, + n_informative=n_features, + n_redundant=0, + n_repeated=0 + ) + X = dd.from_array(X_, chunksize=10) + y = dd.from_array(y_, chunksize=10) + + n_samples = n_workers - 1 + valid_X_, valid_y_ = make_classification( + n_samples=n_samples, + n_classes=10, + n_informative=n_features, + n_redundant=0, + n_repeated=0 + ) + valid_X = dd.from_array(valid_X_, chunksize=n_samples) + valid_y = dd.from_array(valid_y_, chunksize=n_samples) + + cls = xgb.dask.DaskXGBClassifier( + tree_method=tree_method, n_estimators=2, use_label_encoder=False + ) + cls.fit(X, y, eval_metric="auc", eval_set=[(valid_X, valid_y)]) + + +def test_empty_dmatrix_auc() -> None: + with LocalCluster(n_workers=2) as cluster: + with Client(cluster) as client: + run_empty_dmatrix_auc(client, "hist", 2) + + +def run_auc(client: "Client", tree_method: str) -> None: + from sklearn import datasets + n_samples = 100 + n_features = 97 + rng = np.random.RandomState(1994) + X_, y_ = datasets.make_classification( + n_samples=n_samples, n_features=n_features, random_state=rng + ) + X = dd.from_array(X_, chunksize=10) + y = dd.from_array(y_, chunksize=10) + + valid_X_, valid_y_ = datasets.make_classification( + n_samples=n_samples, n_features=n_features, random_state=rng + ) + valid_X = dd.from_array(valid_X_, chunksize=10) + valid_y = dd.from_array(valid_y_, chunksize=10) + + cls = xgb.XGBClassifier( + tree_method=tree_method, n_estimators=2, use_label_encoder=False + ) + cls.fit(X_, y_, eval_metric="auc", eval_set=[(valid_X_, valid_y_)]) + + dcls = xgb.dask.DaskXGBClassifier( + tree_method=tree_method, n_estimators=2, use_label_encoder=False + ) + dcls.fit(X, y, eval_metric="auc", eval_set=[(valid_X, valid_y)]) + + approx = dcls.evals_result()["validation_0"]["auc"] + exact = cls.evals_result()["validation_0"]["auc"] + for i in range(2): + # approximated test. + assert np.abs(approx[i] - exact[i]) <= 0.06 + + +def test_auc(client: "Client") -> None: + run_auc(client, "hist") + # No test for Exact, as empty DMatrix handling are mostly for distributed # environment and Exact doesn't support it. - def test_empty_dmatrix_hist() -> None: with LocalCluster(n_workers=kWorkers) as cluster: with Client(cluster) as client: