Skip to content

Commit

Permalink
Remove caching vector.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Mar 16, 2021
1 parent 988ae3e commit 4fff3aa
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions src/metric/auc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ float GPUMultiClassAUC(common::Span<float const> predts, MetaInfo const &info,
auto d_predts_t = dh::ToSpan(cache->predts_t);
Transpose(predts, d_predts_t, n_samples, n_classes, device);

dh::caching_device_vector<uint32_t> class_ptr(n_classes + 1, 0);
dh::TemporaryArray<uint32_t> class_ptr(n_classes + 1, 0);
auto d_class_ptr = dh::ToSpan(class_ptr);
dh::LaunchN(device, n_classes + 1, [=]__device__(size_t i) {
d_class_ptr[i] = i * n_samples;
Expand Down Expand Up @@ -274,8 +274,8 @@ float GPUMultiClassAUC(common::Span<float const> predts, MetaInfo const &info,
auto d_unique_class_ptr = dh::ToSpan(unique_class_ptr);
auto n_uniques = dh::SegmentedUniqueByKey(
thrust::cuda::par(alloc),
class_ptr.begin(),
class_ptr.end(),
dh::tbegin(d_class_ptr),
dh::tend(d_class_ptr),
uni_key,
uni_key + d_sorted_idx.size(),
dh::tbegin(d_unique_idx),
Expand Down Expand Up @@ -364,7 +364,7 @@ float GPUMultiClassAUC(common::Span<float const> predts, MetaInfo const &info,
/**
* Scale the classes with number of samples for each class.
*/
dh::caching_device_vector<float> resutls(n_classes * 4);
dh::TemporaryArray<float> resutls(n_classes * 4);
auto d_results = dh::ToSpan(resutls);
auto local_area = d_results.subspan(0, n_classes);
auto fp = d_results.subspan(n_classes, n_classes);
Expand Down Expand Up @@ -501,7 +501,7 @@ GPURankingAUC(common::Span<float const> predts, MetaInfo const &info,
return RankScanItem{idx, predt, w, query_group_idx};
});

dh::caching_device_vector<float> d_auc(group_ptr.size() - 1);
dh::TemporaryArray<float> d_auc(group_ptr.size() - 1);
auto s_d_auc = dh::ToSpan(d_auc);
auto out = thrust::make_transform_output_iterator(
Discard<RankScanItem>(), [=] __device__(RankScanItem const &item) -> RankScanItem {
Expand Down Expand Up @@ -532,8 +532,8 @@ GPURankingAUC(common::Span<float const> predts, MetaInfo const &info,
/**
* Scale the AUC with number of items in each group.
*/
float auc = thrust::reduce(thrust::cuda::par(alloc), d_auc.begin(),
d_auc.end(), 0.0f);
float auc = thrust::reduce(thrust::cuda::par(alloc), dh::tbegin(s_d_auc),
dh::tend(s_d_auc), 0.0f);
return std::make_pair(auc, n_valid);
}
} // namespace metric
Expand Down

0 comments on commit 4fff3aa

Please sign in to comment.