Skip to content

Commit

Permalink
Serialise booster after training to reset state (#5484)
Browse files Browse the repository at this point in the history
* Serialise booster after training to reset state

* Prevent process_type being set on load

* Check for correct updater sequence
  • Loading branch information
RAMitchell authored Apr 11, 2020
1 parent 4a0c8ef commit 093e222
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 7 deletions.
4 changes: 3 additions & 1 deletion python-package/xgboost/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def _train_internal(params, dtrain,
else:
bst.best_iteration = nboost - 1
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree
return bst

# Copy to serialise and unserialise booster to reset state and free training memory
return bst.copy()


def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
Expand Down
16 changes: 16 additions & 0 deletions src/gbm/gbtree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,11 @@ void GBTree::BoostNewTrees(HostDeviceVector<GradientPair>* gpair,
// create the trees
for (int i = 0; i < tparam_.num_parallel_tree; ++i) {
if (tparam_.process_type == TreeProcessType::kDefault) {
CHECK(!updaters_.front()->CanModifyTree())
<< "Updater: `" << updaters_.front()->Name() << "` "
<< "can not be used to create new trees. "
<< "Set `process_type` to `update` if you want to update existing "
"trees.";
// create new tree
std::unique_ptr<RegTree> ptr(new RegTree());
ptr->param.UpdateAllowUnknown(this->cfg_);
Expand Down Expand Up @@ -319,6 +324,10 @@ void GBTree::CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& ne
void GBTree::LoadConfig(Json const& in) {
CHECK_EQ(get<String>(in["name"]), "gbtree");
FromJson(in["gbtree_train_param"], &tparam_);
// Process type cannot be kUpdate from loaded model
// This would cause all trees to be pushed to trees_to_update
// e.g. updating a model, then saving and loading it would result in an empty model
tparam_.process_type = TreeProcessType::kDefault;
int32_t const n_gpus = xgboost::common::AllVisibleGPUs();
if (n_gpus == 0 && tparam_.predictor == PredictorType::kGPUPredictor) {
LOG(WARNING)
Expand Down Expand Up @@ -348,6 +357,13 @@ void GBTree::SaveConfig(Json* p_out) const {
auto& out = *p_out;
out["name"] = String("gbtree");
out["gbtree_train_param"] = ToJson(tparam_);

// Process type cannot be kUpdate from loaded model
// This would cause all trees to be pushed to trees_to_update
// e.g. updating a model, then saving and loading it would result in an empty
// model
out["gbtree_train_param"]["process_type"] = String("default");

out["updater"] = Object();

auto& j_updaters = out["updater"];
Expand Down
4 changes: 4 additions & 0 deletions tests/cpp/gbm/test_gbtree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ TEST(GBTree, WrongUpdater) {
// Hist can not be used for updating tree.
learner->SetParams(Args{{"tree_method", "hist"}, {"process_type", "update"}});
ASSERT_THROW(learner->UpdateOneIter(0, p_dmat), dmlc::Error);
// Prune can not be used for learning new tree.
learner->SetParams(
Args{{"tree_method", "prune"}, {"process_type", "default"}});
ASSERT_THROW(learner->UpdateOneIter(0, p_dmat), dmlc::Error);
}

#ifdef XGBOOST_USE_CUDA
Expand Down
5 changes: 2 additions & 3 deletions tests/distributed/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@
bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2)

# Save the model, only ask process 0 to save the model.
if xgb.rabit.get_rank() == 0:
bst.save_model("test.model")
xgb.rabit.tracker_print("Finished training\n")
bst.save_model("test.model{}".format(xgb.rabit.get_rank()))
xgb.rabit.tracker_print("Finished training\n")

# Notify the tracker all training has been successful
# This is only needed in distributed training.
Expand Down
5 changes: 2 additions & 3 deletions tests/distributed/test_issue3402.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,8 @@
num_round = 2
bst = xgb.train(param, dtrain, num_round, watchlist)

if xgb.rabit.get_rank() == 0:
bst.save_model("test_issue3402.model")
xgb.rabit.tracker_print("Finished training\n")
bst.save_model("test_issue3402.model{}".format(xgb.rabit.get_rank()))
xgb.rabit.tracker_print("Finished training\n")

# Notify the tracker all training has been successful
# This is only needed in distributed training.
Expand Down

0 comments on commit 093e222

Please sign in to comment.