Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract CPUExpandEntry and HistParam. #7321

Merged
merged 1 commit into from
Oct 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions src/tree/hist/expand_entry.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*!
* Copyright 2021 XGBoost contributors
*/
#ifndef XGBOOST_TREE_HIST_EXPAND_ENTRY_H_
#define XGBOOST_TREE_HIST_EXPAND_ENTRY_H_

#include <utility>
#include "../param.h"

namespace xgboost {
namespace tree {

struct CPUExpandEntry {
int nid;
int depth;
SplitEntry split;
CPUExpandEntry() = default;
XGBOOST_DEVICE
CPUExpandEntry(int nid, int depth, SplitEntry split)
: nid(nid), depth(depth), split(std::move(split)) {}
CPUExpandEntry(int nid, int depth, float loss_chg)
: nid(nid), depth(depth) {
split.loss_chg = loss_chg;
}

bool IsValid(const TrainParam& param, int num_leaves) const {
if (split.loss_chg <= kRtEps) return false;
if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) {
return false;
}
if (split.loss_chg < param.min_split_loss) {
return false;
}
if (param.max_depth > 0 && depth == param.max_depth) {
return false;
}
if (param.max_leaves > 0 && num_leaves == param.max_leaves) {
return false;
}
return true;
}

float GetLossChange() const { return split.loss_chg; }
bst_node_t GetNodeId() const { return nid; }

static bool ChildIsValid(const TrainParam& param, int depth, int num_leaves) {
if (param.max_depth > 0 && depth >= param.max_depth) return false;
if (param.max_leaves > 0 && num_leaves >= param.max_leaves) return false;
return true;
}

friend std::ostream& operator<<(std::ostream& os, const CPUExpandEntry& e) {
os << "ExpandEntry: \n";
os << "nidx: " << e.nid << "\n";
os << "depth: " << e.depth << "\n";
os << "loss: " << e.split.loss_chg << "\n";
os << "left_sum: " << e.split.left_sum << "\n";
os << "right_sum: " << e.split.right_sum << "\n";
return os;
}
};
} // namespace tree
} // namespace xgboost
#endif // XGBOOST_TREE_HIST_EXPAND_ENTRY_H_
23 changes: 23 additions & 0 deletions src/tree/hist/param.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*!
* Copyright 2021 XGBoost contributors
*/
#ifndef XGBOOST_TREE_HIST_PARAM_H_
#define XGBOOST_TREE_HIST_PARAM_H_
#include "xgboost/parameter.h"

namespace xgboost {
namespace tree {
// training parameters specific to this algorithm
struct CPUHistMakerTrainParam
: public XGBoostParameter<CPUHistMakerTrainParam> {
bool single_precision_histogram;
// declare parameters
DMLC_DECLARE_PARAMETER(CPUHistMakerTrainParam) {
DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe(
"Use single precision to build histograms.");
}
};
} // namespace tree
} // namespace xgboost

#endif // XGBOOST_TREE_HIST_PARAM_H_
4 changes: 2 additions & 2 deletions src/tree/updater_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ template <bool any_missing>
void QuantileHistMaker::Builder<GradientSumT>::InitRoot(
DMatrix *p_fmat, RegTree *p_tree, const std::vector<GradientPair> &gpair_h,
int *num_leaves, std::vector<CPUExpandEntry> *expand) {
CPUExpandEntry node(CPUExpandEntry::kRootNid, p_tree->GetDepth(0), 0.0f);
CPUExpandEntry node(RegTree::kRoot, p_tree->GetDepth(0), 0.0f);

nodes_for_explicit_hist_build_.clear();
nodes_for_subtraction_trick_.clear();
Expand All @@ -135,7 +135,7 @@ void QuantileHistMaker::Builder<GradientSumT>::InitRoot(
nodes_for_subtraction_trick_, gpair_h);

{
auto nid = CPUExpandEntry::kRootNid;
auto nid = RegTree::kRoot;
GHistRowT hist = this->histogram_builder_->Histogram()[nid];
GradientPairT grad_stat;
if (data_layout_ == DataLayout::kDenseDataZeroBased ||
Expand Down
48 changes: 3 additions & 45 deletions src/tree/updater_quantile_hist.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@

#include "hist/evaluate_splits.h"
#include "hist/histogram.h"
#include "hist/expand_entry.h"
#include "hist/param.h"

#include "constraints.h"
#include "./param.h"
#include "./driver.h"
Expand Down Expand Up @@ -89,51 +92,6 @@ using xgboost::common::GHistBuilder;
using xgboost::common::ColumnMatrix;
using xgboost::common::Column;

// training parameters specific to this algorithm
struct CPUHistMakerTrainParam
: public XGBoostParameter<CPUHistMakerTrainParam> {
bool single_precision_histogram = false;
// declare parameters
DMLC_DECLARE_PARAMETER(CPUHistMakerTrainParam) {
DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe(
"Use single precision to build histograms.");
}
};

/* tree growing policies */
struct CPUExpandEntry {
static const int kRootNid = 0;
static const int kEmptyNid = -1;
int nid;
int depth;
SplitEntry split;

CPUExpandEntry() = default;
CPUExpandEntry(int nid, int depth, bst_float loss_chg)
: nid(nid), depth(depth) {
split.loss_chg = loss_chg;
}

bool IsValid(TrainParam const &param, int32_t num_leaves) const {
bool invalid = split.loss_chg <= kRtEps ||
(param.max_depth > 0 && this->depth == param.max_depth) ||
(param.max_leaves > 0 && num_leaves == param.max_leaves);
return !invalid;
}

bst_float GetLossChange() const {
return split.loss_chg;
}

int GetNodeId() const {
return nid;
}

int GetDepth() const {
return depth;
}
};

/*! \brief construct a tree using quantized feature values */
class QuantileHistMaker: public TreeUpdater {
public:
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/tree/hist/test_histogram.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ void TestBuildHistogram(bool is_distributed) {
std::iota(row_indices.begin(), row_indices.end(), 0);
row_set_collection_.Init();

CPUExpandEntry node(CPUExpandEntry::kRootNid, tree.GetDepth(0), 0.0f);
CPUExpandEntry node(RegTree::kRoot, tree.GetDepth(0), 0.0f);
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build_;
nodes_for_explicit_hist_build_.push_back(node);
histogram.BuildHist(p_fmat.get(), &tree, row_set_collection_,
Expand Down