Skip to content

Commit

Permalink
comments were applied
Browse files Browse the repository at this point in the history
  • Loading branch information
SHVETS, KIRILL committed May 6, 2020
1 parent 77734fe commit 4d4b11e
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 71 deletions.
6 changes: 5 additions & 1 deletion src/tree/updater_prune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include "xgboost/json.h"
#include "./param.h"
#include "../common/io.h"

#include "../common/timer.h"
namespace xgboost {
namespace tree {

Expand All @@ -25,6 +25,7 @@ class TreePruner: public TreeUpdater {
public:
TreePruner() {
syncher_.reset(TreeUpdater::Create("sync", tparam_));
pruner_monitor_.Init("TreePruner");
}
char const* Name() const override {
return "prune";
Expand Down Expand Up @@ -52,6 +53,7 @@ class TreePruner: public TreeUpdater {
void Update(HostDeviceVector<GradientPair> *gpair,
DMatrix *p_fmat,
const std::vector<RegTree*> &trees) override {
pruner_monitor_.Start("PrunerUpdate");
// rescale learning rate according to size of trees
float lr = param_.learning_rate;
param_.learning_rate = lr / trees.size();
Expand All @@ -60,6 +62,7 @@ class TreePruner: public TreeUpdater {
}
param_.learning_rate = lr;
syncher_->Update(gpair, p_fmat, trees);
pruner_monitor_.Stop("PrunerUpdate");
}

private:
Expand Down Expand Up @@ -105,6 +108,7 @@ class TreePruner: public TreeUpdater {
std::unique_ptr<TreeUpdater> syncher_;
// training parameter
TrainParam param_;
common::Monitor pruner_monitor_;
};

XGBOOST_REGISTER_TREE_UPDATER(TreePruner, "prune")
Expand Down
115 changes: 47 additions & 68 deletions src/tree/updater_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,23 @@ bool QuantileHistMaker::UpdatePredictionCache(
}
}

void QuantileHistMaker::Builder::ParallelSubtractionHist(const common::BlockedSpace2d& space,
const std::vector<ExpandEntry>& nodes,
const RegTree * p_tree) {
common::ParallelFor2d(space, this->nthread_, [&](size_t node, common::Range1d r) {
const auto entry = nodes[node];
if (!((*p_tree)[entry.nid].IsLeftChild())) {
auto this_hist = hist_[entry.nid];

if (!(*p_tree)[entry.nid].IsRoot() && entry.sibling_nid > -1) {
auto parent_hist = hist_[(*p_tree)[entry.nid].Parent()];
auto sibling_hist = hist_[entry.sibling_nid];
SubtractionHist(this_hist, parent_hist, sibling_hist, r.begin(), r.end());
}
}
});
}

void QuantileHistMaker::Builder::SyncHistograms(
int starting_index,
int sync_count,
Expand All @@ -105,81 +122,44 @@ void QuantileHistMaker::Builder::SyncHistograms(

const bool isDistributed = rabit::IsDistributed();
const size_t nbins = hist_builder_.GetNumBins();
if (!isDistributed) {
common::BlockedSpace2d space(nodes_for_explicit_hist_build_.size(), [&](size_t node) {
return nbins;
}, 1024);

common::ParallelFor2d(space, this->nthread_, [&](size_t node, common::Range1d r) {
const auto entry = nodes_for_explicit_hist_build_[node];
auto this_hist = hist_[entry.nid];
// Merging histograms from each thread into once
hist_buffer_.ReduceHist(node, r.begin(), r.end());

if (!(*p_tree)[entry.nid].IsRoot() && entry.sibling_nid > -1 /*&& !isDistributed*/) {
auto parent_hist = hist_[(*p_tree)[entry.nid].Parent()];
auto sibling_hist = hist_[entry.sibling_nid];
SubtractionHist(sibling_hist, parent_hist, this_hist, r.begin(), r.end());
}
});
} else {
common::BlockedSpace2d space(nodes_for_explicit_hist_build_.size(), [&](size_t node) {
return nbins;
}, 1024);

common::ParallelFor2d(space, this->nthread_, [&](size_t node, common::Range1d r) {
const auto entry = nodes_for_explicit_hist_build_[node];
auto this_hist = hist_[entry.nid];
// Merging histograms from each thread into once
hist_buffer_.ReduceHist(node, r.begin(), r.end());

common::BlockedSpace2d space(nodes_for_explicit_hist_build_.size(), [&](size_t node) {
return nbins;
}, 1024);
common::ParallelFor2d(space, this->nthread_, [&](size_t node, common::Range1d r) {
const auto entry = nodes_for_explicit_hist_build_[node];
auto this_hist = hist_[entry.nid];
// Merging histograms from each thread into once
hist_buffer_.ReduceHist(node, r.begin(), r.end());
if (isDistributed) {
// Store posible parent node
auto this_local = phist_local_[entry.nid];
auto this_local = hist_local_worker_[entry.nid];
CopyHist(this_local, this_hist, r.begin(), r.end());
}

if (!(*p_tree)[entry.nid].IsRoot() && entry.sibling_nid > -1) {
auto sibling_hist = hist_[entry.sibling_nid];
auto parent_hist = phist_local_[(*p_tree)[entry.nid].Parent()];
SubtractionHist(sibling_hist, parent_hist, this_hist, r.begin(), r.end());
if (!(*p_tree)[entry.nid].IsRoot() && entry.sibling_nid > -1) {
const size_t parent_id = (*p_tree)[entry.nid].Parent();
auto parent_hist = isDistributed ? hist_local_worker_[parent_id] : hist_[parent_id];
auto sibling_hist = hist_[entry.sibling_nid];
SubtractionHist(sibling_hist, parent_hist, this_hist, r.begin(), r.end());
if (isDistributed) {
// Store posible parent node
auto sibling_local = phist_local_[entry.sibling_nid];
auto sibling_local = hist_local_worker_[entry.sibling_nid];
CopyHist(sibling_local, sibling_hist, r.begin(), r.end());
}
});
}
});

if (isDistributed) {
builder_monitor_.Start("SyncHistogramsAllreduce");
this->histred_.Allreduce(hist_[starting_index].data(), hist_builder_.GetNumBins() * sync_count);
builder_monitor_.Stop("SyncHistogramsAllreduce");

common::ParallelFor2d(space, this->nthread_, [&](size_t node, common::Range1d r) {
const auto entry = nodes_for_explicit_hist_build_[node];
if (!((*p_tree)[entry.nid].IsLeftChild())) {
auto this_hist = hist_[entry.nid];

if (!(*p_tree)[entry.nid].IsRoot() && entry.sibling_nid > -1) {
auto parent_hist = hist_[(*p_tree)[entry.nid].Parent()];
auto sibling_hist = hist_[entry.sibling_nid];
SubtractionHist(this_hist, parent_hist, sibling_hist, r.begin(), r.end());
}
}
});
ParallelSubtractionHist(space, nodes_for_explicit_hist_build_, p_tree);

common::BlockedSpace2d space2(nodes_for_subtraction_trick_.size(), [&](size_t node) {
return nbins;
}, 1024);

common::ParallelFor2d(space2, this->nthread_, [&](size_t node, common::Range1d r) {
const auto entry = nodes_for_subtraction_trick_[node];
if (!((*p_tree)[entry.nid].IsLeftChild())) {
auto this_hist = hist_[entry.nid];

if (!(*p_tree)[entry.nid].IsRoot() && entry.sibling_nid > -1) {
auto parent_hist = hist_[(*p_tree)[entry.nid].Parent()];
auto sibling_hist = hist_[entry.sibling_nid];
SubtractionHist(this_hist, parent_hist, sibling_hist, r.begin(), r.end());
}
}
});
ParallelSubtractionHist(space2, nodes_for_subtraction_trick_, p_tree);
}

builder_monitor_.Stop("SyncHistograms");
Expand Down Expand Up @@ -230,24 +210,25 @@ void QuantileHistMaker::Builder::AddHistRows(int *starting_index, int *sync_coun
(*starting_index) = std::min(nid, (*starting_index));
n_left++;
if (rabit::IsDistributed()) {
phist_local_.AddHistRow(nid);
hist_local_worker_.AddHistRow(nid);
}
}
}
for (auto const& nid : merged_hist) {
if (!((*p_tree)[nid].IsLeftChild())) {
hist_.AddHistRow(nid);
if (rabit::IsDistributed()) {
phist_local_.AddHistRow(nid);
hist_local_worker_.AddHistRow(nid);
}
}
}

(*sync_count) = merged_hist.size() / 2;
if (*sync_count == 0) {
if (n_left == 0) {
(*sync_count) = 1;
} else {
(*sync_count) = n_left;
}
CHECK_EQ(n_left, (*sync_count));

builder_monitor_.Stop("AddHistRows");
}

Expand Down Expand Up @@ -549,9 +530,7 @@ void QuantileHistMaker::Builder::Update(const GHistIndexMatrix& gmat,
p_tree->Stat(nid).base_weight = snode_[nid].weight;
p_tree->Stat(nid).sum_hess = static_cast<float>(snode_[nid].stats.sum_hess);
}
builder_monitor_.Start("PrunerUpdate");
pruner_->Update(gpair, p_fmat, std::vector<RegTree*>{p_tree});
builder_monitor_.Stop("PrunerUpdate");

builder_monitor_.Stop("Update");
}
Expand Down Expand Up @@ -685,7 +664,7 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat,
// initialize histogram collection
uint32_t nbins = gmat.cut.Ptrs().back();
hist_.Init(nbins);
phist_local_.Init(nbins);
hist_local_worker_.Init(nbins);
hist_buffer_.Init(nbins);

// initialize histogram builder
Expand Down
7 changes: 5 additions & 2 deletions src/tree/updater_quantile_hist.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ using xgboost::common::Column;
class QuantileHistMaker: public TreeUpdater {
public:
QuantileHistMaker() {
updater_monitor_.Init("Quantile");
updater_monitor_.Init("QuantileHistMaker");
}
void Configure(const Args& args) override;

Expand Down Expand Up @@ -281,6 +281,9 @@ class QuantileHistMaker: public TreeUpdater {
void SyncHistograms(int starting_index,
int sync_count,
RegTree *p_tree);
void ParallelSubtractionHist(const common::BlockedSpace2d& space,
const std::vector<ExpandEntry>& nodes,
const RegTree * p_tree);

void BuildNodeStats(const GHistIndexMatrix &gmat,
DMatrix *p_fmat,
Expand Down Expand Up @@ -334,7 +337,7 @@ class QuantileHistMaker: public TreeUpdater {
/*! \brief culmulative histogram of gradients. */
HistCollection hist_;
/*! \brief culmulative local parent histogram of gradients. */
HistCollection phist_local_;
HistCollection hist_local_worker_;

/*! \brief feature with least # of bins. to be used for dense specialization
of InitNewNode() */
Expand Down

0 comments on commit 4d4b11e

Please sign in to comment.