Skip to content

Commit

Permalink
Unify max nodes. (#5497)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Apr 10, 2020
1 parent bd653fa commit 866a477
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 22 deletions.
11 changes: 1 addition & 10 deletions src/tree/constraints.cu
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,7 @@ void FeatureInteractionConstraint::Configure(
}

// --- Initialize allowed features attached to nodes.
if (param.max_depth == 0 && param.max_leaves == 0) {
LOG(FATAL) << "Max leaves and max depth cannot both be unconstrained for gpu_hist.";
}
int32_t n_nodes {0};
if (param.max_depth != 0) {
n_nodes = std::pow(2, param.max_depth + 1);
} else {
n_nodes = param.max_leaves * 2 - 1;
}
CHECK_NE(n_nodes, 0);
int32_t n_nodes { param.MaxNodes() };
node_constraints_.resize(n_nodes);
node_constraints_storage_.resize(n_nodes);
for (auto& n : node_constraints_storage_) {
Expand Down
14 changes: 14 additions & 0 deletions src/tree/param.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,20 @@ struct TrainParam : public XGBoostParameter<TrainParam> {
CHECK_GT(ret, 0U);
return ret;
}

bst_node_t MaxNodes() const {
if (this->max_depth == 0 && this->max_leaves == 0) {
LOG(FATAL) << "Max leaves and max depth cannot both be unconstrained.";
}
bst_node_t n_nodes{0};
if (this->max_leaves > 0) {
n_nodes = this->max_leaves * 2 - 1;
} else {
n_nodes = (1 << (this->max_depth + 1)) - 1;
}
CHECK_NE(n_nodes, 0);
return n_nodes;
}
};

/*! \brief Loss functions */
Expand Down
5 changes: 0 additions & 5 deletions src/tree/updater_gpu_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,5 @@ struct SumCallbackOp {
return old_prefix;
}
};

// Total number of nodes in tree, given depth
XGBOOST_DEVICE inline int MaxNodesDepth(int depth) {
return (1 << (depth + 1)) - 1;
}
} // namespace tree
} // namespace xgboost
8 changes: 1 addition & 7 deletions src/tree/updater_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -927,13 +927,7 @@ struct GPUHistMakerDevice {

template <typename GradientSumT>
inline void GPUHistMakerDevice<GradientSumT>::InitHistogram() {
CHECK(!(param.max_leaves == 0 && param.max_depth == 0))
<< "Max leaves and max depth cannot both be unconstrained for "
"gpu_hist.";

int max_nodes =
param.max_leaves > 0 ? param.max_leaves * 2 : MaxNodesDepth(param.max_depth);

bst_node_t max_nodes { param.MaxNodes() };
ba.Allocate(device_id,
&prediction_cache, n_rows,
&node_sum_gradients_d, max_nodes,
Expand Down

0 comments on commit 866a477

Please sign in to comment.