Skip to content

Commit

Permalink
Just guard against USE_CUDA.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Dec 5, 2019
1 parent 0087d8f commit eb0c2d7
Showing 1 changed file with 7 additions and 12 deletions.
19 changes: 7 additions & 12 deletions src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ class LearnerImpl : public Learner {
kv.second = "cpu_predictor";
}
#endif // XGBOOST_USE_CUDA
#if defined(XGBOOST_USE_CUDA)
// NO visible GPU in current environment
if (is_gpu_predictor && common::AllVisibleGPUs() == 0) {
cfg_["predictor"] = "cpu_predictor";
Expand All @@ -322,20 +323,13 @@ class LearnerImpl : public Learner {
} else if (is_gpu_predictor) {
cfg_["predictor"] = "gpu_predictor";
}
#endif // defined(XGBOOST_USE_CUDA)
if (saved_configs_.find(saved_param) != saved_configs_.end()) {
cfg_[saved_param] = kv.second;
}
}
}
Args filtered(attr.size());
size_t n {0};
std::copy_if(attr.begin(), attr.end(), filtered.begin(),
[&n](std::pair<std::string const&, std::string const&> const& kv) {
bool r = kv.first != "SAVED_PARAM_gpu_id";
n++;
return r;
});
attributes_ = std::map<std::string, std::string>(attr.begin(), attr.begin() + n);
attributes_ = std::map<std::string, std::string>(attr.begin(), attr.end());
}
if (tparam_.objective == "count:poisson") {
std::string max_delta_step;
Expand Down Expand Up @@ -417,7 +411,9 @@ class LearnerImpl : public Learner {
}
}
}
#if defined(XGBOOST_USE_CUDA)
{
// Force save gpu_id.
if (std::none_of(extra_attr.cbegin(), extra_attr.cend(),
[](std::pair<std::string, std::string> const& it) {
return it.first == "SAVED_PARAM_gpu_id";
Expand All @@ -426,16 +422,15 @@ class LearnerImpl : public Learner {
extra_attr.emplace_back("SAVED_PARAM_gpu_id", std::to_string(generic_param_.gpu_id));
}
}
#endif // defined(XGBOOST_USE_CUDA)
fo->Write(&mparam, sizeof(LearnerModelParam));
fo->Write(tparam_.objective);
fo->Write(tparam_.booster);
gbm_->Save(fo);
if (mparam.contain_extra_attrs != 0) {
std::map<std::string, std::string> attr(attributes_);
for (const auto& kv : extra_attr) {
if (kv.first != "SAVED_PARAM_gpu_id") {
attr[kv.first] = kv.second;
}
attr[kv.first] = kv.second;
}
fo->Write(std::vector<std::pair<std::string, std::string>>(
attr.begin(), attr.end()));
Expand Down

0 comments on commit eb0c2d7

Please sign in to comment.