Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement weighted sketching for adapter. #5760

Merged
merged 8 commits into from
Jun 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/common/hist_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ void HistogramCuts::Build(DMatrix* dmat, uint32_t const max_num_bins) {

bool CutsBuilder::UseGroup(DMatrix* dmat) {
auto& info = dmat->Info();
return CutsBuilder::UseGroup(info);
}

bool CutsBuilder::UseGroup(MetaInfo const& info) {
size_t const num_groups = info.group_ptr_.size() == 0 ?
0 : info.group_ptr_.size() - 1;
// Use group index for weights?
Expand Down
67 changes: 27 additions & 40 deletions src/common/hist_util.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright 2018 XGBoost contributors
* Copyright 2018~2020 XGBoost contributors
*/

#include <xgboost/logging.h>
Expand Down Expand Up @@ -28,24 +28,10 @@

namespace xgboost {
namespace common {
// Count the entries in each column and exclusive scan
void GetColumnSizesScan(int device,
dh::caching_device_vector<size_t>* column_sizes_scan,
Span<const Entry> entries, size_t num_columns) {
column_sizes_scan->resize(num_columns + 1, 0);
auto d_column_sizes_scan = column_sizes_scan->data().get();
auto d_entries = entries.data();
dh::LaunchN(device, entries.size(), [=] __device__(size_t idx) {
auto& e = d_entries[idx];
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
&d_column_sizes_scan[e.index]),
static_cast<unsigned long long>(1)); // NOLINT
});
dh::XGBCachingDeviceAllocator<char> alloc;
thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan->begin(),
column_sizes_scan->end(), column_sizes_scan->begin());
}

constexpr float SketchContainer::kFactor;

// Count the entries in each column and exclusive scan
void ExtractCuts(int device,
size_t num_cuts_per_feature,
Span<Entry const> sorted_data,
Expand Down Expand Up @@ -158,6 +144,23 @@ void ProcessBatch(int device, const SparsePage& page, size_t begin, size_t end,
sketch_container->Push(num_cuts, host_cuts, host_column_sizes_scan);
}

void SortByWeight(dh::XGBCachingDeviceAllocator<char>* alloc,
dh::caching_device_vector<float>* weights,
dh::caching_device_vector<Entry>* sorted_entries) {
// Sort both entries and wegihts.
thrust::sort_by_key(thrust::cuda::par(*alloc), sorted_entries->begin(),
sorted_entries->end(), weights->begin(),
EntryCompareOp());

// Scan weights
thrust::inclusive_scan_by_key(thrust::cuda::par(*alloc),
sorted_entries->begin(), sorted_entries->end(),
weights->begin(), weights->begin(),
[=] __device__(const Entry& a, const Entry& b) {
return a.index == b.index;
});
}

void ProcessWeightedBatch(int device, const SparsePage& page,
Span<const float> weights, size_t begin, size_t end,
SketchContainer* sketch_container, int num_cuts_per_feature,
Expand Down Expand Up @@ -201,19 +204,7 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
d_temp_weights[idx] = weights[ridx + base_rowid];
});
}

// Sort both entries and wegihts.
thrust::sort_by_key(thrust::cuda::par(alloc), sorted_entries.begin(),
sorted_entries.end(), temp_weights.begin(),
EntryCompareOp());

// Scan weights
thrust::inclusive_scan_by_key(thrust::cuda::par(alloc),
sorted_entries.begin(), sorted_entries.end(),
temp_weights.begin(), temp_weights.begin(),
[=] __device__(const Entry& a, const Entry& b) {
return a.index == b.index;
});
SortByWeight(&alloc, &temp_weights, &sorted_entries);

dh::caching_device_vector<size_t> column_sizes_scan;
GetColumnSizesScan(device, &column_sizes_scan,
Expand All @@ -239,13 +230,9 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
// Configure batch size based on available memory
bool has_weights = dmat->Info().weights_.Size() > 0;
size_t num_cuts_per_feature = RequiredSampleCuts(max_bins, dmat->Info().num_row_);
if (sketch_batch_num_elements == 0) {
int bytes_per_element = has_weights ? 24 : 16;
size_t bytes_cuts = num_cuts_per_feature * dmat->Info().num_col_ * sizeof(SketchEntry);
// use up to 80% of available space
sketch_batch_num_elements =
(dh::AvailableMemory(device) - bytes_cuts) * 0.8 / bytes_per_element;
}
sketch_batch_num_elements = SketchBatchNumElements(
sketch_batch_num_elements,
dmat->Info().num_col_, device, num_cuts_per_feature, has_weights);

HistogramCuts cuts;
DenseCuts dense_cuts(&cuts);
Expand All @@ -256,12 +243,12 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
size_t batch_nnz = batch.data.Size();
auto const& info = dmat->Info();
dh::caching_device_vector<uint32_t> groups(info.group_ptr_.cbegin(),
info.group_ptr_.cend());
for (auto begin = 0ull; begin < batch_nnz; begin += sketch_batch_num_elements) {
size_t end = std::min(batch_nnz, size_t(begin + sketch_batch_num_elements));
if (has_weights) {
bool is_ranking = CutsBuilder::UseGroup(dmat);
dh::caching_device_vector<uint32_t> groups(info.group_ptr_.cbegin(),
info.group_ptr_.cend());
ProcessWeightedBatch(
device, batch, dmat->Info().weights_.ConstDeviceSpan(), begin, end,
&sketch_container,
Expand Down
Loading