Skip to content

Commit

Permalink
Fix trapezoid index.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Mar 12, 2021
1 parent 3121b4c commit 310059c
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 18 deletions.
6 changes: 3 additions & 3 deletions src/metric/auc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ float GroupRankingAUC(common::Span<float const> predts,
// on ranking, we just count all pairs.
float auc{0};
auto const sorted_idx = common::ArgSort<size_t>(labels, std::greater<>{});
w = common::Sqr(w);

float sum_w = 0.0f;
for (size_t i = 0; i < labels.size(); ++i) {
Expand All @@ -159,11 +160,10 @@ float GroupRankingAUC(common::Span<float const> predts,
} else {
predt = 0;
}
auc += predt * common::Sqr(w);
sum_w += common::Sqr(w);
auc += predt * w;
sum_w += w;
}
}

if (sum_w != 0) {
auc /= sum_w;
}
Expand Down
28 changes: 13 additions & 15 deletions src/metric/auc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,10 @@ GPURankingAUC(common::Span<float const> predts, MetaInfo const &info,

size_t i, j;
common::UnravelTrapeziodIdx(idx_in_thread_group, n_samples, &i, &j);
// we use global index among all groups for sorted idx, so i, j should also be global
// index.
i += data_group_begin;
j += data_group_begin;
return thrust::make_pair(i, j);
}; // NOLINT
auto in = dh::MakeTransformIterator<RankScanItem>(
Expand All @@ -490,6 +494,7 @@ GPURankingAUC(common::Span<float const> predts, MetaInfo const &info,

size_t i, j;
thrust::tie(i, j) = get_i_j(idx, query_group_idx);

float predt = predts[d_sorted_idx[i]] - predts[d_sorted_idx[j]];
float w = common::Sqr(d_weights.empty() ? 1.0f : d_weights[query_group_idx]);
if (predt > 0) {
Expand All @@ -503,19 +508,21 @@ GPURankingAUC(common::Span<float const> predts, MetaInfo const &info,
return RankScanItem{idx, predt, w, query_group_idx};
});

dh::caching_device_vector<RankScanItem> d_auc(group_ptr.size() - 1);
dh::caching_device_vector<float> d_auc(group_ptr.size() - 1);
auto s_d_auc = dh::ToSpan(d_auc);
auto out = thrust::make_transform_output_iterator(
Discard<RankScanItem>(), [=] __device__(RankScanItem const &item) -> RankScanItem {
auto group_id = item.group_id;
assert(group_id < d_group_ptr.size());
size_t i, j;
thrust::tie(i, j) = get_i_j(item.idx, group_id);
auto data_group_begin = d_group_ptr[group_id];
size_t n_samples = d_group_ptr[group_id + 1] - data_group_begin;
// last item of current group
if (item.idx == LastOf(group_id, d_threads_group_ptr)) {
s_d_auc[group_id] = item;
if (item.w > 0) {
s_d_auc[group_id] = item.predt / item.w;
} else {
s_d_auc[group_id] = 0;
}
}
return {}; // discard
});
Expand All @@ -532,17 +539,8 @@ GPURankingAUC(common::Span<float const> predts, MetaInfo const &info,
/**
* Scale the AUC with number of items in each group.
*/
auto key = dh::MakeTransformIterator<float>(
thrust::make_counting_iterator(0), [=] __device__(size_t i) {
if (s_d_auc[i].w > 0) {
return s_d_auc[i].predt / s_d_auc[i].w;
}
return 0.0f;
});

float auc =
thrust::reduce(thrust::cuda::par(alloc), key, key + s_d_auc.size(), 0.0f);

float auc = thrust::reduce(thrust::cuda::par(alloc), d_auc.begin(),
d_auc.end(), 0.0f);
return std::make_pair(auc, n_valid);
}
} // namespace metric
Expand Down

0 comments on commit 310059c

Please sign in to comment.