Skip to content

Commit

Permalink
Revert C API change.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Feb 11, 2020
1 parent 817ff54 commit 2bce357
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
#include <string>
#include <memory>


#include "xgboost/data.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/learner.h"
#include "xgboost/c_api.h"
#include "xgboost/logging.h"
Expand Down Expand Up @@ -146,7 +146,7 @@ struct XGBAPIThreadLocalEntry {
/*! \brief result holder for returning string pointers */
std::vector<const char *> ret_vec_charp;
/*! \brief returning float vector. */
HostDeviceVector<bst_float> ret_vec_float;
std::vector<bst_float> ret_vec_float;
/*! \brief temp variable of gradient pairs. */
std::vector<GradientPair> tmp_gpair;
};
Expand Down Expand Up @@ -553,22 +553,24 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
int32_t training,
xgboost::bst_ulong *len,
const bst_float **out_result) {
std::vector<bst_float>& preds =
XGBAPIThreadLocalStore::Get()->ret_vec_float;
API_BEGIN();
CHECK_HANDLE();
HostDeviceVector<bst_float>& preds =
XGBAPIThreadLocalStore::Get()->ret_vec_float;
auto *bst = static_cast<Learner*>(handle);
HostDeviceVector<bst_float> tmp_preds;
bst->Predict(
*static_cast<std::shared_ptr<DMatrix>*>(dmat),
(option_mask & 1) != 0,
&preds, ntree_limit,
&tmp_preds, ntree_limit,
static_cast<bool>(training),
(option_mask & 2) != 0,
(option_mask & 4) != 0,
(option_mask & 8) != 0,
(option_mask & 16) != 0);
*out_result = dmlc::BeginPtr(preds.HostVector());
*len = static_cast<xgboost::bst_ulong>(preds.Size());
preds = tmp_preds.HostVector();
*out_result = dmlc::BeginPtr(preds);
*len = static_cast<xgboost::bst_ulong>(preds.size());
API_END();
}

Expand Down

0 comments on commit 2bce357

Please sign in to comment.