Skip to content

Commit

Permalink
- ndcg ltr implementation on gpu (#5004)
Browse files Browse the repository at this point in the history
* - ndcg ltr implementation on gpu
  - this is a follow-up to the pairwise ltr implementation
  • Loading branch information
sriramch authored and RAMitchell committed Nov 12, 2019
1 parent f4e7b70 commit 2abe69d
Show file tree
Hide file tree
Showing 5 changed files with 778 additions and 200 deletions.
10 changes: 5 additions & 5 deletions src/metric/rank_metric.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ struct EvalAMS : public Metric {
for (bst_omp_uint i = 0; i < ndata; ++i) {
rec[i] = std::make_pair(h_preds[i], i);
}
std::sort(rec.begin(), rec.end(), common::CmpFirst);
std::stable_sort(rec.begin(), rec.end(), common::CmpFirst);
auto ntop = static_cast<unsigned>(ratio_ * ndata);
if (ntop == 0) ntop = ndata;
const double br = 10.0;
Expand Down Expand Up @@ -168,7 +168,7 @@ struct EvalAuc : public Metric {
for (unsigned j = gptr[group_id]; j < gptr[group_id + 1]; ++j) {
rec.emplace_back(h_preds[j], j);
}
XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst);
XGBOOST_PARALLEL_STABLE_SORT(rec.begin(), rec.end(), common::CmpFirst);
// calculate AUC
double sum_pospair = 0.0;
double sum_npos = 0.0, sum_nneg = 0.0, buf_pos = 0.0, buf_neg = 0.0;
Expand Down Expand Up @@ -321,7 +321,7 @@ struct EvalPrecision : public EvalRankList{
protected:
bst_float EvalMetric(std::vector< std::pair<bst_float, unsigned> > &rec) const override {
// calculate Precision
std::sort(rec.begin(), rec.end(), common::CmpFirst);
std::stable_sort(rec.begin(), rec.end(), common::CmpFirst);
unsigned nhit = 0;
for (size_t j = 0; j < rec.size() && j < this->topn_; ++j) {
nhit += (rec[j].second != 0);
Expand Down Expand Up @@ -369,7 +369,7 @@ struct EvalMAP : public EvalRankList {

protected:
bst_float EvalMetric(std::vector< std::pair<bst_float, unsigned> > &rec) const override {
std::sort(rec.begin(), rec.end(), common::CmpFirst);
std::stable_sort(rec.begin(), rec.end(), common::CmpFirst);
unsigned nhits = 0;
double sumap = 0.0;
for (size_t i = 0; i < rec.size(); ++i) {
Expand Down Expand Up @@ -481,7 +481,7 @@ struct EvalAucPR : public Metric {
total_neg += wt * (1.0f - h_labels[j]);
rec.emplace_back(h_preds[j], j);
}
XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst);
XGBOOST_PARALLEL_STABLE_SORT(rec.begin(), rec.end(), common::CmpFirst);
// we need pos > 0 && neg > 0
if (0.0 == total_pos || 0.0 == total_neg) {
auc_error += 1;
Expand Down
Loading

0 comments on commit 2abe69d

Please sign in to comment.