Skip to content

Commit

Permalink
tests were added and found bug was fixed(min->max)
Browse files Browse the repository at this point in the history
  • Loading branch information
SHVETS, KIRILL committed May 11, 2020
1 parent a4d6e96 commit f36e5cf
Show file tree
Hide file tree
Showing 3 changed files with 220 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/tree/updater_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ void DistributedHistRowsAdder::AddHistRows(int *starting_index, int *sync_count,
builder_->hist_local_worker_.AddHistRow(nid);
}
}
(*sync_count) = std::min(1, n_left);
(*sync_count) = std::max(1, n_left);
builder_->builder_monitor_.Stop("AddHistRows");
}

Expand Down
18 changes: 12 additions & 6 deletions src/tree/updater_quantile_hist.h
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,9 @@ class HistSynchronizer {
virtual void SyncHistograms(int starting_index,
int sync_count,
RegTree *p_tree) = 0;
virtual ~HistSynchronizer() {
builder_ = nullptr;
}

protected:
QuantileHistMaker::Builder* builder_;
Expand All @@ -409,19 +412,19 @@ class BatchHistSynchronizer: public HistSynchronizer {
public:
explicit BatchHistSynchronizer(QuantileHistMaker::Builder* builder): HistSynchronizer(builder) {}

virtual void SyncHistograms(int starting_index,
void SyncHistograms(int starting_index,
int sync_count,
RegTree *p_tree);
RegTree *p_tree) override;
};

class DistributedHistSynchronizer: public HistSynchronizer {
public:
explicit DistributedHistSynchronizer(QuantileHistMaker::Builder* builder):
HistSynchronizer(builder) {}

virtual void SyncHistograms(int starting_index,
void SyncHistograms(int starting_index,
int sync_count,
RegTree *p_tree);
RegTree *p_tree) override;
void ParallelSubtractionHist(const common::BlockedSpace2d& space,
const std::vector<QuantileHistMaker::Builder::ExpandEntry>& nodes,
const RegTree * p_tree);
Expand All @@ -431,6 +434,9 @@ class HistRowsAdder {
public:
explicit HistRowsAdder(QuantileHistMaker::Builder* builder) : builder_(builder) {}
virtual void AddHistRows(int *starting_index, int *sync_count, RegTree *p_tree) = 0;
virtual ~HistRowsAdder() {
builder_ = nullptr;
}

protected:
QuantileHistMaker::Builder* builder_;
Expand All @@ -440,14 +446,14 @@ class BatchHistRowsAdder: public HistRowsAdder {
public:
explicit BatchHistRowsAdder(QuantileHistMaker::Builder* builder) : HistRowsAdder(builder) {}

void AddHistRows(int *starting_index, int *sync_count, RegTree *p_tree);
void AddHistRows(int *starting_index, int *sync_count, RegTree *p_tree) override;
};

class DistributedHistRowsAdder: public HistRowsAdder {
public:
explicit DistributedHistRowsAdder(QuantileHistMaker::Builder* builder) : HistRowsAdder(builder) {}

void AddHistRows(int *starting_index, int *sync_count, RegTree *p_tree);
void AddHistRows(int *starting_index, int *sync_count, RegTree *p_tree) override;
};


Expand Down
210 changes: 207 additions & 3 deletions tests/cpp/tree/test_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ class QuantileHistMock : public QuantileHistMaker {
std::unique_ptr<SplitEvaluator> spliteval,
FeatureInteractionConstraintHost int_constraint,
DMatrix const* fmat)
: RealImpl(param, std::move(pruner), std::move(spliteval), std::move(int_constraint), fmat) {}
: RealImpl(param, std::move(pruner), std::move(spliteval),
std::move(int_constraint), fmat) {}

public:
void TestInitData(const GHistIndexMatrix& gmat,
Expand Down Expand Up @@ -120,6 +121,147 @@ class QuantileHistMock : public QuantileHistMaker {
omp_set_num_threads(nthreads);
}

void TestAddHistRows(const GHistIndexMatrix& gmat,
const std::vector<GradientPair>& gpair,
DMatrix* p_fmat,
RegTree* tree) {
RealImpl::InitData(gmat, gpair, *p_fmat, *tree);

int starting_index = std::numeric_limits<int>::max();
int sync_count = 0;
nodes_for_explicit_hist_build_.clear();
nodes_for_subtraction_trick_.clear();

tree->ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
tree->ExpandNode((*tree)[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
tree->ExpandNode((*tree)[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
nodes_for_explicit_hist_build_.emplace_back(3, 4, tree->GetDepth(3), 0.0f, 0);
nodes_for_explicit_hist_build_.emplace_back(4, 3, tree->GetDepth(4), 0.0f, 0);
nodes_for_subtraction_trick_.emplace_back(5, 6, tree->GetDepth(5), 0.0f, 0);
nodes_for_subtraction_trick_.emplace_back(6, 5, tree->GetDepth(6), 0.0f, 0);

hist_rows_adder_->AddHistRows(&starting_index, &sync_count, tree);
ASSERT_EQ(sync_count, 2);
ASSERT_EQ(starting_index, 3);

for (const ExpandEntry& node : nodes_for_explicit_hist_build_) {
ASSERT_EQ(hist_.RowExists(node.nid), true);
}
for (const ExpandEntry& node : nodes_for_subtraction_trick_) {
ASSERT_EQ(hist_.RowExists(node.nid), true);
}
}


void TestSyncHistograms(const GHistIndexMatrix& gmat,
const std::vector<GradientPair>& gpair,
DMatrix* p_fmat,
RegTree* tree) {
// init
RealImpl::InitData(gmat, gpair, *p_fmat, *tree);

int starting_index = std::numeric_limits<int>::max();
int sync_count = 0;
nodes_for_explicit_hist_build_.clear();
nodes_for_subtraction_trick_.clear();
// level 0
nodes_for_explicit_hist_build_.emplace_back(0, -1, tree->GetDepth(0), 0.0f, 0);
hist_rows_adder_->AddHistRows(&starting_index, &sync_count, tree);
tree->ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);

nodes_for_explicit_hist_build_.clear();
nodes_for_subtraction_trick_.clear();
// level 1
nodes_for_explicit_hist_build_.emplace_back((*tree)[0].LeftChild(), (*tree)[0].RightChild(),
tree->GetDepth(1), 0.0f, 0);
nodes_for_subtraction_trick_.emplace_back((*tree)[0].RightChild(), (*tree)[0].LeftChild(),
tree->GetDepth(2), 0.0f, 0);
hist_rows_adder_->AddHistRows(&starting_index, &sync_count, tree);
tree->ExpandNode((*tree)[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
tree->ExpandNode((*tree)[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);

nodes_for_explicit_hist_build_.clear();
nodes_for_subtraction_trick_.clear();
// level 2
nodes_for_explicit_hist_build_.emplace_back(3, 4, tree->GetDepth(3), 0.0f, 0);
nodes_for_subtraction_trick_.emplace_back(4, 3, tree->GetDepth(4), 0.0f, 0);
nodes_for_explicit_hist_build_.emplace_back(5, 6, tree->GetDepth(5), 0.0f, 0);
nodes_for_subtraction_trick_.emplace_back(6, 5, tree->GetDepth(6), 0.0f, 0);
hist_rows_adder_->AddHistRows(&starting_index, &sync_count, tree);

const size_t n_nodes = nodes_for_explicit_hist_build_.size();
ASSERT_EQ(n_nodes, 2);
row_set_collection_.AddSplit(0, (*tree)[0].LeftChild(),
(*tree)[0].RightChild(), 4, 4);
row_set_collection_.AddSplit(1, (*tree)[1].LeftChild(),
(*tree)[1].RightChild(), 2, 2);
row_set_collection_.AddSplit(2, (*tree)[2].LeftChild(),
(*tree)[2].RightChild(), 2, 2);

common::BlockedSpace2d space(n_nodes, [&](size_t node) {
const int32_t nid = nodes_for_explicit_hist_build_[node].nid;
return row_set_collection_[nid].Size();
}, 256);

std::vector<GHistRow> target_hists(n_nodes);
for (size_t i = 0; i < nodes_for_explicit_hist_build_.size(); ++i) {
const int32_t nid = nodes_for_explicit_hist_build_[i].nid;
target_hists[i] = hist_[nid];
}

const size_t nbins = hist_builder_.GetNumBins();
// set values to specific nodes hist
std::vector<size_t> n_ids = {1, 2};
for (size_t i : n_ids) {
auto this_hist = hist_[i];
using FPType = decltype(tree::GradStats::sum_grad);
FPType* p_hist = reinterpret_cast<FPType*>(this_hist.data());
for (size_t bin_id = 0; bin_id < 2*nbins; ++bin_id) {
p_hist[bin_id] = 2*bin_id;
}
}
n_ids[0] = 3;
n_ids[1] = 5;
for (size_t i : n_ids) {
auto this_hist = hist_[i];
using FPType = decltype(tree::GradStats::sum_grad);
FPType* p_hist = reinterpret_cast<FPType*>(this_hist.data());
for (size_t bin_id = 0; bin_id < 2*nbins; ++bin_id) {
p_hist[bin_id] = bin_id;
}
}

hist_buffer_.Reset(1, n_nodes, space, target_hists);
// sync hist
hist_synchronizer_->SyncHistograms(starting_index, sync_count, tree);

auto check_hist = [] (const GHistRow parent, const GHistRow left,
const GHistRow right, size_t begin, size_t end) {
using FPType = decltype(tree::GradStats::sum_grad);
const FPType* p_parent = reinterpret_cast<const FPType*>(parent.data());
const FPType* p_left = reinterpret_cast<const FPType*>(left.data());
const FPType* p_right = reinterpret_cast<const FPType*>(right.data());
for (size_t i = 2 * begin; i < 2 * end; ++i) {
ASSERT_EQ(p_parent[i], p_left[i] + p_right[i]);
}
};
for (const ExpandEntry& node : nodes_for_explicit_hist_build_) {
auto this_hist = hist_[node.nid];
const size_t parent_id = (*tree)[node.nid].Parent();
auto parent_hist = hist_[parent_id];
auto sibling_hist = hist_[node.sibling_nid];

check_hist(parent_hist, this_hist, sibling_hist, 0, nbins);
}
for (const ExpandEntry& node : nodes_for_subtraction_trick_) {
auto this_hist = hist_[node.nid];
const size_t parent_id = (*tree)[node.nid].Parent();
auto parent_hist = hist_[parent_id];
auto sibling_hist = hist_[node.sibling_nid];

check_hist(parent_hist, this_hist, sibling_hist, 0, nbins);
}
}

void TestBuildHist(int nid,
const GHistIndexMatrix& gmat,
Expand Down Expand Up @@ -249,7 +391,6 @@ class QuantileHistMock : public QuantileHistMaker {
TestEvaluateSplit(quantile_index_block, tree);
omp_set_num_threads(1);
}

};

int static constexpr kNRows = 8, kNCols = 16;
Expand All @@ -259,7 +400,7 @@ class QuantileHistMock : public QuantileHistMaker {

public:
explicit QuantileHistMock(
const std::vector<std::pair<std::string, std::string> >& args) :
const std::vector<std::pair<std::string, std::string> >& args, bool batch = true) :
cfg_{args} {
QuantileHistMaker::Configure(args);
spliteval_->Init(&param_);
Expand All @@ -271,6 +412,13 @@ class QuantileHistMock : public QuantileHistMaker {
std::unique_ptr<SplitEvaluator>(spliteval_->GetHostClone()),
int_constraint_,
dmat_.get()));
if (batch) {
builder_->SetHistSynchronizer(new BatchHistSynchronizer(builder_.get()));
builder_->SetHistRowsAdder(new BatchHistRowsAdder(builder_.get()));
} else {
builder_->SetHistSynchronizer(new DistributedHistSynchronizer(builder_.get()));
builder_->SetHistRowsAdder(new DistributedHistRowsAdder(builder_.get()));
}
}
~QuantileHistMock() override = default;

Expand Down Expand Up @@ -305,6 +453,34 @@ class QuantileHistMock : public QuantileHistMaker {

builder_->TestInitDataSampling(gmat, gpair, dmat_.get(), tree);
}

void TestAddHistRows() {
size_t constexpr kMaxBins = 4;
common::GHistIndexMatrix gmat;
gmat.Init(dmat_.get(), kMaxBins);

RegTree tree = RegTree();
tree.param.UpdateAllowUnknown(cfg_);
std::vector<GradientPair> gpair =
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
{0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} };
builder_->TestAddHistRows(gmat, gpair, dmat_.get(), &tree);
}

void TestSyncHistograms() {
size_t constexpr kMaxBins = 4;
common::GHistIndexMatrix gmat;
gmat.Init(dmat_.get(), kMaxBins);

RegTree tree = RegTree();
tree.param.UpdateAllowUnknown(cfg_);
std::vector<GradientPair> gpair =
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
{0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} };
builder_->TestSyncHistograms(gmat, gpair, dmat_.get(), &tree);
}


void TestBuildHist() {
RegTree tree = RegTree();
tree.param.UpdateAllowUnknown(cfg_);
Expand Down Expand Up @@ -340,6 +516,34 @@ TEST(QuantileHist, InitDataSampling) {
maker.TestInitDataSampling();
}

TEST(QuantileHist, AddHistRows) {
std::vector<std::pair<std::string, std::string>> cfg
{{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}};
QuantileHistMock maker(cfg);
maker.TestAddHistRows();
}

TEST(QuantileHist, SyncHistograms) {
std::vector<std::pair<std::string, std::string>> cfg
{{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}};
QuantileHistMock maker(cfg);
maker.TestSyncHistograms();
}

TEST(QuantileHist, DistributedAddHistRows) {
std::vector<std::pair<std::string, std::string>> cfg
{{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}};
QuantileHistMock maker(cfg, false);
maker.TestAddHistRows();
}

TEST(QuantileHist, DistributedSyncHistograms) {
std::vector<std::pair<std::string, std::string>> cfg
{{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}};
QuantileHistMock maker(cfg, false);
maker.TestSyncHistograms();
}

TEST(QuantileHist, BuildHist) {
// Don't enable feature grouping
std::vector<std::pair<std::string, std::string>> cfg
Expand Down

0 comments on commit f36e5cf

Please sign in to comment.