Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove public access to tree model param. #8902

Merged
merged 1 commit into from
Mar 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 75 additions & 71 deletions include/xgboost/tree_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,51 +178,33 @@ class RegTree : public Model {
}

/*! \brief index of left child */
XGBOOST_DEVICE [[nodiscard]] int LeftChild() const {
return this->cleft_;
}
[[nodiscard]] XGBOOST_DEVICE int LeftChild() const { return this->cleft_; }
/*! \brief index of right child */
XGBOOST_DEVICE [[nodiscard]] int RightChild() const {
return this->cright_;
}
[[nodiscard]] XGBOOST_DEVICE int RightChild() const { return this->cright_; }
/*! \brief index of default child when feature is missing */
XGBOOST_DEVICE [[nodiscard]] int DefaultChild() const {
[[nodiscard]] XGBOOST_DEVICE int DefaultChild() const {
return this->DefaultLeft() ? this->LeftChild() : this->RightChild();
}
/*! \brief feature index of split condition */
XGBOOST_DEVICE [[nodiscard]] unsigned SplitIndex() const {
[[nodiscard]] XGBOOST_DEVICE unsigned SplitIndex() const {
return sindex_ & ((1U << 31) - 1U);
}
/*! \brief when feature is unknown, whether goes to left child */
XGBOOST_DEVICE [[nodiscard]] bool DefaultLeft() const {
return (sindex_ >> 31) != 0;
}
[[nodiscard]] XGBOOST_DEVICE bool DefaultLeft() const { return (sindex_ >> 31) != 0; }
/*! \brief whether current node is leaf node */
XGBOOST_DEVICE [[nodiscard]] bool IsLeaf() const {
return cleft_ == kInvalidNodeId;
}
[[nodiscard]] XGBOOST_DEVICE bool IsLeaf() const { return cleft_ == kInvalidNodeId; }
/*! \return get leaf value of leaf node */
XGBOOST_DEVICE [[nodiscard]] float LeafValue() const {
return (this->info_).leaf_value;
}
[[nodiscard]] XGBOOST_DEVICE float LeafValue() const { return (this->info_).leaf_value; }
/*! \return get split condition of the node */
XGBOOST_DEVICE [[nodiscard]] SplitCondT SplitCond() const {
return (this->info_).split_cond;
}
[[nodiscard]] XGBOOST_DEVICE SplitCondT SplitCond() const { return (this->info_).split_cond; }
/*! \brief get parent of the node */
XGBOOST_DEVICE [[nodiscard]] int Parent() const {
return parent_ & ((1U << 31) - 1);
}
[[nodiscard]] XGBOOST_DEVICE int Parent() const { return parent_ & ((1U << 31) - 1); }
/*! \brief whether current node is left child */
XGBOOST_DEVICE [[nodiscard]] bool IsLeftChild() const {
return (parent_ & (1U << 31)) != 0;
}
[[nodiscard]] XGBOOST_DEVICE bool IsLeftChild() const { return (parent_ & (1U << 31)) != 0; }
/*! \brief whether this node is deleted */
XGBOOST_DEVICE [[nodiscard]] bool IsDeleted() const {
return sindex_ == kDeletedNodeMarker;
}
[[nodiscard]] XGBOOST_DEVICE bool IsDeleted() const { return sindex_ == kDeletedNodeMarker; }
/*! \brief whether current node is root */
XGBOOST_DEVICE [[nodiscard]] bool IsRoot() const { return parent_ == kInvalidNodeId; }
[[nodiscard]] XGBOOST_DEVICE bool IsRoot() const { return parent_ == kInvalidNodeId; }
/*!
* \brief set the left child
* \param nid node id to right child
Expand Down Expand Up @@ -337,15 +319,13 @@ class RegTree : public Model {
this->ChangeToLeaf(rid, value);
}

/*! \brief model parameter */
TreeParam param;
RegTree() {
param.Init(Args{});
nodes_.resize(param.num_nodes);
stats_.resize(param.num_nodes);
split_types_.resize(param.num_nodes, FeatureType::kNumerical);
split_categories_segments_.resize(param.num_nodes);
for (int i = 0; i < param.num_nodes; i++) {
param_.Init(Args{});
nodes_.resize(param_.num_nodes);
stats_.resize(param_.num_nodes);
split_types_.resize(param_.num_nodes, FeatureType::kNumerical);
split_categories_segments_.resize(param_.num_nodes);
for (int i = 0; i < param_.num_nodes; i++) {
nodes_[i].SetLeaf(0.0f);
nodes_[i].SetParent(kInvalidNodeId);
}
Expand All @@ -354,10 +334,10 @@ class RegTree : public Model {
* \brief Constructor that initializes the tree model with shape.
*/
explicit RegTree(bst_target_t n_targets, bst_feature_t n_features) : RegTree{} {
param.num_feature = n_features;
param.size_leaf_vector = n_targets;
param_.num_feature = n_features;
param_.size_leaf_vector = n_targets;
if (n_targets > 1) {
this->p_mt_tree_.reset(new MultiTargetTree{&param});
this->p_mt_tree_.reset(new MultiTargetTree{&param_});
}
}

Expand Down Expand Up @@ -401,7 +381,7 @@ class RegTree : public Model {

bool operator==(const RegTree& b) const {
return nodes_ == b.nodes_ && stats_ == b.stats_ &&
deleted_nodes_ == b.deleted_nodes_ && param == b.param;
deleted_nodes_ == b.deleted_nodes_ && param_ == b.param_;
}
/* \brief Iterate through all nodes in this tree.
*
Expand Down Expand Up @@ -459,7 +439,9 @@ class RegTree : public Model {
bst_float loss_change, float sum_hess, float left_sum,
float right_sum,
bst_node_t leaf_right_child = kInvalidNodeId);

/**
* \brief Expands a leaf node into two additional leaf nodes for a multi-target tree.
*/
void ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond, bool default_left,
linalg::VectorView<float const> base_weight,
linalg::VectorView<float const> left_weight,
Expand All @@ -485,19 +467,48 @@ class RegTree : public Model {
bst_float base_weight, bst_float left_leaf_weight,
bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
float left_sum, float right_sum);

[[nodiscard]] bool HasCategoricalSplit() const {
return !split_categories_.empty();
}
/**
* \brief Whether this tree has categorical split.
*/
[[nodiscard]] bool HasCategoricalSplit() const { return !split_categories_.empty(); }
/**
* \brief Whether this is a multi-target tree.
*/
[[nodiscard]] bool IsMultiTarget() const { return static_cast<bool>(p_mt_tree_); }
[[nodiscard]] bst_target_t NumTargets() const { return param.size_leaf_vector; }
/**
* \brief The size of leaf weight.
*/
[[nodiscard]] bst_target_t NumTargets() const { return param_.size_leaf_vector; }
/**
* \brief Get the underlying implementaiton of multi-target tree.
*/
[[nodiscard]] auto GetMultiTargetTree() const {
CHECK(IsMultiTarget());
return p_mt_tree_.get();
}
/**
* \brief Get the number of features.
*/
[[nodiscard]] bst_feature_t NumFeatures() const noexcept { return param_.num_feature; }
/**
* \brief Get the total number of nodes including deleted ones in this tree.
*/
[[nodiscard]] bst_node_t NumNodes() const noexcept { return param_.num_nodes; }
/**
* \brief Get the total number of valid nodes in this tree.
*/
[[nodiscard]] bst_node_t NumValidNodes() const noexcept {
return param_.num_nodes - param_.num_deleted;
}
/**
* \brief number of extra nodes besides the root
*/
[[nodiscard]] bst_node_t NumExtraNodes() const noexcept {
return param_.num_nodes - 1 - param_.num_deleted;
}
/* \brief Count number of leaves in tree. */
[[nodiscard]] bst_node_t GetNumLeaves() const;
[[nodiscard]] bst_node_t GetNumSplitNodes() const;

/*!
* \brief get current depth
Expand All @@ -514,6 +525,9 @@ class RegTree : public Model {
}
return depth;
}
/**
* \brief Set the leaf weight for a multi-target tree.
*/
void SetLeaf(bst_node_t nidx, linalg::VectorView<float const> weight) {
CHECK(IsMultiTarget());
return this->p_mt_tree_->SetLeaf(nidx, weight);
Expand All @@ -525,25 +539,13 @@ class RegTree : public Model {
*/
[[nodiscard]] int MaxDepth(int nid) const {
if (nodes_[nid].IsLeaf()) return 0;
return std::max(MaxDepth(nodes_[nid].LeftChild())+1,
MaxDepth(nodes_[nid].RightChild())+1);
return std::max(MaxDepth(nodes_[nid].LeftChild()) + 1, MaxDepth(nodes_[nid].RightChild()) + 1);
}

/*!
* \brief get maximum depth
*/
int MaxDepth() {
return MaxDepth(0);
}

/*! \brief number of extra nodes besides the root */
[[nodiscard]] int NumExtraNodes() const {
return param.num_nodes - 1 - param.num_deleted;
}

/* \brief Count number of leaves in tree. */
[[nodiscard]] bst_node_t GetNumLeaves() const;
[[nodiscard]] bst_node_t GetNumSplitNodes() const;
int MaxDepth() { return MaxDepth(0); }

/*!
* \brief dense feature vector that can be taken by RegTree
Expand Down Expand Up @@ -735,6 +737,8 @@ class RegTree : public Model {
template <bool typed>
void LoadCategoricalSplit(Json const& in);
void SaveCategoricalSplit(Json* p_out) const;
/*! \brief model parameter */
TreeParam param_;
// vector of nodes
std::vector<Node> nodes_;
// free node space, used during training process
Expand All @@ -752,20 +756,20 @@ class RegTree : public Model {
// allocate a new node,
// !!!!!! NOTE: may cause BUG here, nodes.resize
bst_node_t AllocNode() {
if (param.num_deleted != 0) {
if (param_.num_deleted != 0) {
int nid = deleted_nodes_.back();
deleted_nodes_.pop_back();
nodes_[nid].Reuse();
--param.num_deleted;
--param_.num_deleted;
return nid;
}
int nd = param.num_nodes++;
CHECK_LT(param.num_nodes, std::numeric_limits<int>::max())
int nd = param_.num_nodes++;
CHECK_LT(param_.num_nodes, std::numeric_limits<int>::max())
<< "number of nodes in the tree exceed 2^31";
nodes_.resize(param.num_nodes);
stats_.resize(param.num_nodes);
split_types_.resize(param.num_nodes, FeatureType::kNumerical);
split_categories_segments_.resize(param.num_nodes);
nodes_.resize(param_.num_nodes);
stats_.resize(param_.num_nodes);
split_types_.resize(param_.num_nodes, FeatureType::kNumerical);
split_categories_segments_.resize(param_.num_nodes);
return nd;
}
// delete a tree node, keep the parent field to allow trace back
Expand All @@ -780,7 +784,7 @@ class RegTree : public Model {

deleted_nodes_.push_back(nid);
nodes_[nid].MarkDelete();
++param.num_deleted;
++param_.num_deleted;
}
};

Expand Down
4 changes: 2 additions & 2 deletions src/gbm/gbtree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,8 @@ void GBTree::BoostNewTrees(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fma
<< "Set `process_type` to `update` if you want to update existing "
"trees.";
// create new tree
std::unique_ptr<RegTree> ptr(new RegTree());
ptr->param.UpdateAllowUnknown(this->cfg_);
std::unique_ptr<RegTree> ptr(new RegTree{this->model_.learner_model_param->LeafLength(),
this->model_.learner_model_param->num_feature});
new_trees.push_back(ptr.get());
ret->push_back(std::move(ptr));
} else if (tparam_.process_type == TreeProcessType::kUpdate) {
Expand Down
2 changes: 0 additions & 2 deletions src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -775,8 +775,6 @@ class LearnerConfiguration : public Learner {
}
CHECK_NE(mparam_.num_feature, 0)
<< "0 feature is supplied. Are you using raw Booster interface?";
// Remove these once binary IO is gone.
cfg_["num_feature"] = common::ToString(mparam_.num_feature);
}

void ConfigureGBM(LearnerTrainParam const& old, Args const& args) {
Expand Down
2 changes: 1 addition & 1 deletion src/predictor/cpu_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ float FillNodeMeanValues(RegTree const *tree, bst_node_t nidx, std::vector<float
}

void FillNodeMeanValues(RegTree const* tree, std::vector<float>* mean_values) {
size_t num_nodes = tree->param.num_nodes;
size_t num_nodes = tree->NumNodes();
if (mean_values->size() == num_nodes) {
return;
}
Expand Down
Loading