Skip to content

Commit

Permalink
More consistent weight vector for ranking objectives.
Browse files Browse the repository at this point in the history
  • Loading branch information
xydrolase committed Mar 5, 2019
1 parent 9c4ff50 commit 51815a5
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
7 changes: 4 additions & 3 deletions src/objective/rank_obj.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,13 @@ class LambdaRankObj : public ObjFunction {
std::vector< std::pair<bst_float, unsigned> > rec;
bst_float sum_weights = 0;
for (bst_omp_uint k = 0; k < ngroup; ++k) {
sum_weights += info.GetWeight(k);
sum_weights += info.GetWeight(gptr[k]);
}
bst_float weight_normalization_factor = ngroup/sum_weights;
const auto& labels = info.labels_.HostVector();
#pragma omp for schedule(static)
for (bst_omp_uint k = 0; k < ngroup; ++k) {
bst_float group_weight = info.GetWeight(gptr[k]);
lst.clear(); pairs.clear();
for (unsigned j = gptr[k]; j < gptr[k+1]; ++j) {
lst.emplace_back(preds_h[j], labels[j], j);
Expand All @@ -94,10 +95,10 @@ class LambdaRankObj : public ObjFunction {
unsigned ridx = std::uniform_int_distribution<unsigned>(0, nleft + nright - 1)(rnd);
if (ridx < nleft) {
pairs.emplace_back(rec[ridx].second, rec[pid].second,
info.GetWeight(k) * weight_normalization_factor);
group_weight * weight_normalization_factor);
} else {
pairs.emplace_back(rec[pid].second, rec[ridx+j-i].second,
info.GetWeight(k) * weight_normalization_factor);
group_weight * weight_normalization_factor);
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions tests/cpp/objective/test_ranking_obj.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ TEST(Objective, PairwiseRankingGPair) {
CheckRankingObjFunction(obj,
{0, 0.1f, 0, 0.1f},
{0, 1, 0, 1},
{2.0f, 0.0f},
{2.0f, 2.0f, 0.0f, 0.0f},
{0, 2, 4},
{1.9f, -1.9f, 0.0f, 0.0f},
{1.995f, 1.995f, 0.0f, 0.0f});

CheckRankingObjFunction(obj,
{0, 0.1f, 0, 0.1f},
{0, 1, 0, 1},
{1.0f, 1.0f},
{1.0f, 1.0f, 1.0f, 1.0f},
{0, 2, 4},
{0.95f, -0.95f, 0.95f, -0.95f},
{0.9975f, 0.9975f, 0.9975f, 0.9975f});
Expand Down

0 comments on commit 51815a5

Please sign in to comment.