From c90991a6f5cbfab1097651991315fc6a1324b9d5 Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 14 Oct 2021 13:51:42 +0800 Subject: [PATCH] Extract CPUExpandEntry and HistParam. * Remove kRootNid. * Check for empty hessian. --- src/tree/hist/expand_entry.h | 64 +++++++++++++++++++++++++++ src/tree/hist/param.h | 23 ++++++++++ src/tree/updater_quantile_hist.cc | 4 +- src/tree/updater_quantile_hist.h | 48 ++------------------ tests/cpp/tree/hist/test_histogram.cc | 2 +- 5 files changed, 93 insertions(+), 48 deletions(-) create mode 100644 src/tree/hist/expand_entry.h create mode 100644 src/tree/hist/param.h diff --git a/src/tree/hist/expand_entry.h b/src/tree/hist/expand_entry.h new file mode 100644 index 000000000000..d0edfbd379a6 --- /dev/null +++ b/src/tree/hist/expand_entry.h @@ -0,0 +1,64 @@ +/*! + * Copyright 2021 XGBoost contributors + */ +#ifndef XGBOOST_TREE_HIST_EXPAND_ENTRY_H_ +#define XGBOOST_TREE_HIST_EXPAND_ENTRY_H_ + +#include +#include "../param.h" + +namespace xgboost { +namespace tree { + +struct CPUExpandEntry { + int nid; + int depth; + SplitEntry split; + CPUExpandEntry() = default; + XGBOOST_DEVICE + CPUExpandEntry(int nid, int depth, SplitEntry split) + : nid(nid), depth(depth), split(std::move(split)) {} + CPUExpandEntry(int nid, int depth, float loss_chg) + : nid(nid), depth(depth) { + split.loss_chg = loss_chg; + } + + bool IsValid(const TrainParam& param, int num_leaves) const { + if (split.loss_chg <= kRtEps) return false; + if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) { + return false; + } + if (split.loss_chg < param.min_split_loss) { + return false; + } + if (param.max_depth > 0 && depth == param.max_depth) { + return false; + } + if (param.max_leaves > 0 && num_leaves == param.max_leaves) { + return false; + } + return true; + } + + float GetLossChange() const { return split.loss_chg; } + bst_node_t GetNodeId() const { return nid; } + + static bool ChildIsValid(const TrainParam& param, int depth, int num_leaves) { + if (param.max_depth > 0 && depth >= param.max_depth) return false; + if (param.max_leaves > 0 && num_leaves >= param.max_leaves) return false; + return true; + } + + friend std::ostream& operator<<(std::ostream& os, const CPUExpandEntry& e) { + os << "ExpandEntry: \n"; + os << "nidx: " << e.nid << "\n"; + os << "depth: " << e.depth << "\n"; + os << "loss: " << e.split.loss_chg << "\n"; + os << "left_sum: " << e.split.left_sum << "\n"; + os << "right_sum: " << e.split.right_sum << "\n"; + return os; + } +}; +} // namespace tree +} // namespace xgboost +#endif // XGBOOST_TREE_HIST_EXPAND_ENTRY_H_ diff --git a/src/tree/hist/param.h b/src/tree/hist/param.h new file mode 100644 index 000000000000..2fbee28c423b --- /dev/null +++ b/src/tree/hist/param.h @@ -0,0 +1,23 @@ +/*! + * Copyright 2021 XGBoost contributors + */ +#ifndef XGBOOST_TREE_HIST_PARAM_H_ +#define XGBOOST_TREE_HIST_PARAM_H_ +#include "xgboost/parameter.h" + +namespace xgboost { +namespace tree { +// training parameters specific to this algorithm +struct CPUHistMakerTrainParam + : public XGBoostParameter { + bool single_precision_histogram; + // declare parameters + DMLC_DECLARE_PARAMETER(CPUHistMakerTrainParam) { + DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe( + "Use single precision to build histograms."); + } +}; +} // namespace tree +} // namespace xgboost + +#endif // XGBOOST_TREE_HIST_PARAM_H_ diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index bc894b4646b6..6d426b2f744a 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -124,7 +124,7 @@ template void QuantileHistMaker::Builder::InitRoot( DMatrix *p_fmat, RegTree *p_tree, const std::vector &gpair_h, int *num_leaves, std::vector *expand) { - CPUExpandEntry node(CPUExpandEntry::kRootNid, p_tree->GetDepth(0), 0.0f); + CPUExpandEntry node(RegTree::kRoot, p_tree->GetDepth(0), 0.0f); nodes_for_explicit_hist_build_.clear(); nodes_for_subtraction_trick_.clear(); @@ -135,7 +135,7 @@ void QuantileHistMaker::Builder::InitRoot( nodes_for_subtraction_trick_, gpair_h); { - auto nid = CPUExpandEntry::kRootNid; + auto nid = RegTree::kRoot; GHistRowT hist = this->histogram_builder_->Histogram()[nid]; GradientPairT grad_stat; if (data_layout_ == DataLayout::kDenseDataZeroBased || diff --git a/src/tree/updater_quantile_hist.h b/src/tree/updater_quantile_hist.h index 69e42b90db44..9654ab00a7c0 100644 --- a/src/tree/updater_quantile_hist.h +++ b/src/tree/updater_quantile_hist.h @@ -23,6 +23,9 @@ #include "hist/evaluate_splits.h" #include "hist/histogram.h" +#include "hist/expand_entry.h" +#include "hist/param.h" + #include "constraints.h" #include "./param.h" #include "./driver.h" @@ -89,51 +92,6 @@ using xgboost::common::GHistBuilder; using xgboost::common::ColumnMatrix; using xgboost::common::Column; -// training parameters specific to this algorithm -struct CPUHistMakerTrainParam - : public XGBoostParameter { - bool single_precision_histogram = false; - // declare parameters - DMLC_DECLARE_PARAMETER(CPUHistMakerTrainParam) { - DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe( - "Use single precision to build histograms."); - } -}; - -/* tree growing policies */ -struct CPUExpandEntry { - static const int kRootNid = 0; - static const int kEmptyNid = -1; - int nid; - int depth; - SplitEntry split; - - CPUExpandEntry() = default; - CPUExpandEntry(int nid, int depth, bst_float loss_chg) - : nid(nid), depth(depth) { - split.loss_chg = loss_chg; - } - - bool IsValid(TrainParam const ¶m, int32_t num_leaves) const { - bool invalid = split.loss_chg <= kRtEps || - (param.max_depth > 0 && this->depth == param.max_depth) || - (param.max_leaves > 0 && num_leaves == param.max_leaves); - return !invalid; - } - - bst_float GetLossChange() const { - return split.loss_chg; - } - - int GetNodeId() const { - return nid; - } - - int GetDepth() const { - return depth; - } -}; - /*! \brief construct a tree using quantized feature values */ class QuantileHistMaker: public TreeUpdater { public: diff --git a/tests/cpp/tree/hist/test_histogram.cc b/tests/cpp/tree/hist/test_histogram.cc index a75ce70d4843..f257a683405e 100644 --- a/tests/cpp/tree/hist/test_histogram.cc +++ b/tests/cpp/tree/hist/test_histogram.cc @@ -258,7 +258,7 @@ void TestBuildHistogram(bool is_distributed) { std::iota(row_indices.begin(), row_indices.end(), 0); row_set_collection_.Init(); - CPUExpandEntry node(CPUExpandEntry::kRootNid, tree.GetDepth(0), 0.0f); + CPUExpandEntry node(RegTree::kRoot, tree.GetDepth(0), 0.0f); std::vector nodes_for_explicit_hist_build_; nodes_for_explicit_hist_build_.push_back(node); histogram.BuildHist(p_fmat.get(), &tree, row_set_collection_,