Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small refactor for hist builder. #8698

Merged
merged 1 commit into from
Jan 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 19 additions & 22 deletions src/common/hist_util.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright 2017-2020 by Contributors
/**
* Copyright 2017-2023 by XGBoost Contributors
* \file hist_util.cc
*/
#include <dmlc/timer.h>
Expand Down Expand Up @@ -193,9 +193,9 @@ class GHistBuildingManager {
};

template <bool do_prefetch, class BuildingManager>
void RowsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
GHistRow hist) {
void RowsWiseBuildHistKernel(Span<GradientPair const> gpair,
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
GHistRow hist) {
constexpr bool kAnyMissing = BuildingManager::kAnyMissing;
constexpr bool kFirstPage = BuildingManager::kFirstPage;
using BinIdxType = typename BuildingManager::BinIdxType;
Expand Down Expand Up @@ -262,9 +262,9 @@ void RowsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
}

template <class BuildingManager>
void ColsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
GHistRow hist) {
void ColsWiseBuildHistKernel(Span<GradientPair const> gpair,
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
GHistRow hist) {
constexpr bool kAnyMissing = BuildingManager::kAnyMissing;
constexpr bool kFirstPage = BuildingManager::kFirstPage;
using BinIdxType = typename BuildingManager::BinIdxType;
Expand Down Expand Up @@ -315,9 +315,8 @@ void ColsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
}

template <class BuildingManager>
void BuildHistDispatch(const std::vector<GradientPair> &gpair,
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
GHistRow hist) {
void BuildHistDispatch(Span<GradientPair const> gpair, const RowSetCollection::Elem row_indices,
const GHistIndexMatrix &gmat, GHistRow hist) {
if (BuildingManager::kReadByColumn) {
ColsWiseBuildHistKernel<BuildingManager>(gpair, row_indices, gmat, hist);
} else {
Expand All @@ -344,33 +343,31 @@ void BuildHistDispatch(const std::vector<GradientPair> &gpair,
}

template <bool any_missing>
void GHistBuilder::BuildHist(const std::vector<GradientPair> &gpair,
const RowSetCollection::Elem row_indices,
const GHistIndexMatrix &gmat,
void GHistBuilder::BuildHist(Span<GradientPair const> gpair,
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
GHistRow hist, bool force_read_by_column) const {
/* force_read_by_column is used for testing the columnwise building of histograms.
* default force_read_by_column = false
*/
constexpr double kAdhocL2Size = 1024 * 1024 * 0.8;
const bool hist_fit_to_l2 = kAdhocL2Size > 2*sizeof(float)*gmat.cut.Ptrs().back();
const bool hist_fit_to_l2 = kAdhocL2Size > 2 * sizeof(float) * gmat.cut.Ptrs().back();
bool first_page = gmat.base_rowid == 0;
bool read_by_column = !hist_fit_to_l2 && !any_missing;
auto bin_type_size = gmat.index.GetBinTypeSize();

GHistBuildingManager<any_missing>::DispatchAndExecute(
{first_page, read_by_column || force_read_by_column, bin_type_size},
[&](auto t) {
using BuildingManager = decltype(t);
BuildHistDispatch<BuildingManager>(gpair, row_indices, gmat, hist);
});
{first_page, read_by_column || force_read_by_column, bin_type_size}, [&](auto t) {
using BuildingManager = decltype(t);
BuildHistDispatch<BuildingManager>(gpair, row_indices, gmat, hist);
});
}

template void GHistBuilder::BuildHist<true>(const std::vector<GradientPair> &gpair,
template void GHistBuilder::BuildHist<true>(Span<GradientPair const> gpair,
const RowSetCollection::Elem row_indices,
const GHistIndexMatrix &gmat, GHistRow hist,
bool force_read_by_column) const;

template void GHistBuilder::BuildHist<false>(const std::vector<GradientPair> &gpair,
template void GHistBuilder::BuildHist<false>(Span<GradientPair const> gpair,
const RowSetCollection::Elem row_indices,
const GHistIndexMatrix &gmat, GHistRow hist,
bool force_read_by_column) const;
Expand Down
13 changes: 7 additions & 6 deletions src/common/hist_util.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright 2017-2022 by XGBoost Contributors
/**
* Copyright 2017-2023 by XGBoost Contributors
* \file hist_util.h
* \brief Utility for fast histogram aggregation
* \author Philip Cho, Tianqi Chen
Expand All @@ -23,6 +23,7 @@
#include "row_set.h"
#include "threading_utils.h"
#include "timer.h"
#include "xgboost/base.h" // bst_feature_t, bst_bin_t

namespace xgboost {
class GHistIndexMatrix;
Expand Down Expand Up @@ -320,10 +321,10 @@ struct Index {
};

template <typename GradientIndex>
bst_bin_t XGBOOST_HOST_DEV_INLINE BinarySearchBin(size_t begin, size_t end,
bst_bin_t XGBOOST_HOST_DEV_INLINE BinarySearchBin(std::size_t begin, std::size_t end,
GradientIndex const& data,
uint32_t const fidx_begin,
uint32_t const fidx_end) {
bst_feature_t const fidx_begin,
bst_feature_t const fidx_end) {
size_t previous_middle = std::numeric_limits<size_t>::max();
while (end != begin) {
size_t middle = begin + (end - begin) / 2;
Expand Down Expand Up @@ -635,7 +636,7 @@ class GHistBuilder {

// construct a histogram via histogram aggregation
template <bool any_missing>
void BuildHist(const std::vector<GradientPair>& gpair, const RowSetCollection::Elem row_indices,
void BuildHist(Span<GradientPair const> gpair, const RowSetCollection::Elem row_indices,
const GHistIndexMatrix& gmat, GHistRow hist,
bool force_read_by_column = false) const;
uint32_t GetNumBins() const {
Expand Down
161 changes: 69 additions & 92 deletions src/tree/hist/histogram.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright 2021-2022 by XGBoost Contributors
/**
* Copyright 2021-2023 by XGBoost Contributors
*/
#ifndef XGBOOST_TREE_HIST_HISTOGRAM_H_
#define XGBOOST_TREE_HIST_HISTOGRAM_H_
Expand Down Expand Up @@ -59,15 +59,14 @@ class HistogramBuilder {
GHistIndexMatrix const &gidx,
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
common::RowSetCollection const &row_set_collection,
const std::vector<GradientPair> &gpair_h,
bool force_read_by_column) {
common::Span<GradientPair const> gpair_h, bool force_read_by_column) {
const size_t n_nodes = nodes_for_explicit_hist_build.size();
CHECK_GT(n_nodes, 0);

std::vector<common::GHistRow> target_hists(n_nodes);
for (size_t i = 0; i < n_nodes; ++i) {
const int32_t nid = nodes_for_explicit_hist_build[i].nid;
target_hists[i] = hist_[nid];
auto const nidx = nodes_for_explicit_hist_build[i].nid;
target_hists[i] = hist_[nidx];
}
if (page_idx == 0) {
// FIXME(jiamingy): Handle different size of space. Right now we use the maximum
Expand All @@ -93,46 +92,37 @@ class HistogramBuilder {
});
}

void
AddHistRows(int *starting_index, int *sync_count,
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
RegTree *p_tree) {
void AddHistRows(int *starting_index, int *sync_count,
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
RegTree const *p_tree) {
if (is_distributed_) {
this->AddHistRowsDistributed(starting_index, sync_count,
nodes_for_explicit_hist_build,
this->AddHistRowsDistributed(starting_index, sync_count, nodes_for_explicit_hist_build,
nodes_for_subtraction_trick, p_tree);
} else {
this->AddHistRowsLocal(starting_index, sync_count,
nodes_for_explicit_hist_build,
this->AddHistRowsLocal(starting_index, sync_count, nodes_for_explicit_hist_build,
nodes_for_subtraction_trick);
}
}

/** Main entry point of this class, build histogram for tree nodes. */
void BuildHist(size_t page_id, common::BlockedSpace2d space, GHistIndexMatrix const &gidx,
RegTree *p_tree, common::RowSetCollection const &row_set_collection,
RegTree const *p_tree, common::RowSetCollection const &row_set_collection,
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
std::vector<GradientPair> const &gpair,
bool force_read_by_column = false) {
common::Span<GradientPair const> gpair, bool force_read_by_column = false) {
int starting_index = std::numeric_limits<int>::max();
int sync_count = 0;
if (page_id == 0) {
this->AddHistRows(&starting_index, &sync_count,
nodes_for_explicit_hist_build,
this->AddHistRows(&starting_index, &sync_count, nodes_for_explicit_hist_build,
nodes_for_subtraction_trick, p_tree);
}
if (gidx.IsDense()) {
this->BuildLocalHistograms<false>(page_id, space, gidx,
nodes_for_explicit_hist_build,
row_set_collection, gpair,
force_read_by_column);
this->BuildLocalHistograms<false>(page_id, space, gidx, nodes_for_explicit_hist_build,
row_set_collection, gpair, force_read_by_column);
} else {
this->BuildLocalHistograms<true>(page_id, space, gidx,
nodes_for_explicit_hist_build,
row_set_collection, gpair,
force_read_by_column);
this->BuildLocalHistograms<true>(page_id, space, gidx, nodes_for_explicit_hist_build,
row_set_collection, gpair, force_read_by_column);
}

CHECK_GE(n_batches_, 1);
Expand All @@ -153,8 +143,7 @@ class HistogramBuilder {
common::RowSetCollection const &row_set_collection,
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
std::vector<GradientPair> const &gpair,
bool force_read_by_column = false) {
common::Span<GradientPair const> gpair, bool force_read_by_column = false) {
const size_t n_nodes = nodes_for_explicit_hist_build.size();
// create space of size (# rows in each node)
common::BlockedSpace2d space(
Expand All @@ -164,83 +153,72 @@ class HistogramBuilder {
return row_set_collection[nidx].Size();
},
256);
this->BuildHist(page_id, space, gidx, p_tree, row_set_collection,
nodes_for_explicit_hist_build, nodes_for_subtraction_trick,
gpair, force_read_by_column);
this->BuildHist(page_id, space, gidx, p_tree, row_set_collection, nodes_for_explicit_hist_build,
nodes_for_subtraction_trick, gpair, force_read_by_column);
}

void SyncHistogramDistributed(
RegTree *p_tree,
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
int starting_index, int sync_count) {
void SyncHistogramDistributed(RegTree const *p_tree,
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
int starting_index, int sync_count) {
const size_t nbins = builder_.GetNumBins();
common::BlockedSpace2d space(
nodes_for_explicit_hist_build.size(), [&](size_t) { return nbins; },
1024);
common::ParallelFor2d(
space, n_threads_, [&](size_t node, common::Range1d r) {
const auto &entry = nodes_for_explicit_hist_build[node];
auto this_hist = this->hist_[entry.nid];
// Merging histograms from each thread into once
buffer_.ReduceHist(node, r.begin(), r.end());
// Store posible parent node
auto this_local = hist_local_worker_[entry.nid];
common::CopyHist(this_local, this_hist, r.begin(), r.end());
nodes_for_explicit_hist_build.size(), [&](size_t) { return nbins; }, 1024);
common::ParallelFor2d(space, n_threads_, [&](size_t node, common::Range1d r) {
const auto &entry = nodes_for_explicit_hist_build[node];
auto this_hist = this->hist_[entry.nid];
// Merging histograms from each thread into once
buffer_.ReduceHist(node, r.begin(), r.end());
// Store posible parent node
auto this_local = hist_local_worker_[entry.nid];
common::CopyHist(this_local, this_hist, r.begin(), r.end());

if (!(*p_tree)[entry.nid].IsRoot()) {
const size_t parent_id = (*p_tree)[entry.nid].Parent();
const int subtraction_node_id =
nodes_for_subtraction_trick[node].nid;
auto parent_hist = this->hist_local_worker_[parent_id];
auto sibling_hist = this->hist_[subtraction_node_id];
common::SubtractionHist(sibling_hist, parent_hist, this_hist,
r.begin(), r.end());
// Store posible parent node
auto sibling_local = hist_local_worker_[subtraction_node_id];
common::CopyHist(sibling_local, sibling_hist, r.begin(), r.end());
}
});
if (!(*p_tree)[entry.nid].IsRoot()) {
const size_t parent_id = (*p_tree)[entry.nid].Parent();
const int subtraction_node_id = nodes_for_subtraction_trick[node].nid;
auto parent_hist = this->hist_local_worker_[parent_id];
auto sibling_hist = this->hist_[subtraction_node_id];
common::SubtractionHist(sibling_hist, parent_hist, this_hist, r.begin(), r.end());
// Store posible parent node
auto sibling_local = hist_local_worker_[subtraction_node_id];
common::CopyHist(sibling_local, sibling_hist, r.begin(), r.end());
}
});

collective::Allreduce<collective::Operation::kSum>(
reinterpret_cast<double *>(this->hist_[starting_index].data()),
builder_.GetNumBins() * sync_count * 2);

ParallelSubtractionHist(space, nodes_for_explicit_hist_build,
nodes_for_subtraction_trick, p_tree);
ParallelSubtractionHist(space, nodes_for_explicit_hist_build, nodes_for_subtraction_trick,
p_tree);

common::BlockedSpace2d space2(
nodes_for_subtraction_trick.size(), [&](size_t) { return nbins; },
1024);
ParallelSubtractionHist(space2, nodes_for_subtraction_trick,
nodes_for_explicit_hist_build, p_tree);
nodes_for_subtraction_trick.size(), [&](size_t) { return nbins; }, 1024);
ParallelSubtractionHist(space2, nodes_for_subtraction_trick, nodes_for_explicit_hist_build,
p_tree);
}

void SyncHistogramLocal(RegTree *p_tree,
void SyncHistogramLocal(RegTree const *p_tree,
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
std::vector<ExpandEntry> const &nodes_for_subtraction_trick) {
const size_t nbins = this->builder_.GetNumBins();
common::BlockedSpace2d space(
nodes_for_explicit_hist_build.size(), [&](size_t) { return nbins; },
1024);
nodes_for_explicit_hist_build.size(), [&](size_t) { return nbins; }, 1024);

common::ParallelFor2d(
space, this->n_threads_, [&](size_t node, common::Range1d r) {
const auto &entry = nodes_for_explicit_hist_build[node];
auto this_hist = this->hist_[entry.nid];
// Merging histograms from each thread into once
this->buffer_.ReduceHist(node, r.begin(), r.end());
common::ParallelFor2d(space, this->n_threads_, [&](size_t node, common::Range1d r) {
const auto &entry = nodes_for_explicit_hist_build[node];
auto this_hist = this->hist_[entry.nid];
// Merging histograms from each thread into once
this->buffer_.ReduceHist(node, r.begin(), r.end());

if (!(*p_tree)[entry.nid].IsRoot()) {
const size_t parent_id = (*p_tree)[entry.nid].Parent();
const int subtraction_node_id =
nodes_for_subtraction_trick[node].nid;
auto parent_hist = this->hist_[parent_id];
auto sibling_hist = this->hist_[subtraction_node_id];
common::SubtractionHist(sibling_hist, parent_hist, this_hist,
r.begin(), r.end());
}
});
if (!(*p_tree)[entry.nid].IsRoot()) {
auto const parent_id = (*p_tree)[entry.nid].Parent();
auto const subtraction_node_id = nodes_for_subtraction_trick[node].nid;
auto parent_hist = this->hist_[parent_id];
auto sibling_hist = this->hist_[subtraction_node_id];
common::SubtractionHist(sibling_hist, parent_hist, this_hist, r.begin(), r.end());
}
});
}

public:
Expand Down Expand Up @@ -289,11 +267,10 @@ class HistogramBuilder {
this->hist_.AllocateAllData();
}

void AddHistRowsDistributed(
int *starting_index, int *sync_count,
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
RegTree *p_tree) {
void AddHistRowsDistributed(int *starting_index, int *sync_count,
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
RegTree const *p_tree) {
const size_t explicit_size = nodes_for_explicit_hist_build.size();
const size_t subtaction_size = nodes_for_subtraction_trick.size();
std::vector<int> merged_node_ids(explicit_size + subtaction_size);
Expand Down