diff --git a/include/xgboost/gbm.h b/include/xgboost/gbm.h index 666e65652a5b..37fa2d42e33b 100644 --- a/include/xgboost/gbm.h +++ b/include/xgboost/gbm.h @@ -9,6 +9,7 @@ #define XGBOOST_GBM_H_ #include +#include #include #include #include @@ -92,6 +93,22 @@ class GradientBooster : public Model, public Configurable { PredictionCacheEntry* out_preds, bool training, unsigned ntree_limit = 0) = 0; + + /*! + * \brief Inplace prediction. + * + * \param x A type erased data adapter. + * \param missing Missing value in the data. + * \param [in,out] out_preds The output preds. + * \param layer_begin (Optional) Begining of boosted tree layer used for prediction. + * \param layer_end (Optional) End of booster layer. 0 means do not limit trees. + */ + virtual void InplacePredict(dmlc::any const &x, float missing, + PredictionCacheEntry *out_preds, + uint32_t layer_begin = 0, + uint32_t layer_end = 0) const { + LOG(FATAL) << "Inplace predict is not supported by current booster."; + } /*! * \brief online prediction function, predict score for one instance at a time * NOTE: use the batch prediction interface if possible, batch prediction is usually diff --git a/include/xgboost/learner.h b/include/xgboost/learner.h index ad034ee70c98..933c8adede48 100644 --- a/include/xgboost/learner.h +++ b/include/xgboost/learner.h @@ -8,6 +8,7 @@ #ifndef XGBOOST_LEARNER_H_ #define XGBOOST_LEARNER_H_ +#include #include #include #include @@ -120,6 +121,21 @@ class Learner : public Model, public Configurable, public rabit::Serializable { bool approx_contribs = false, bool pred_interactions = false) = 0; + /*! + * \brief Inplace prediction. + * + * \param x A type erased data adapter. + * \param type Prediction type. + * \param missing Missing value in the data. + * \param [in,out] out_preds Pointer to output prediction vector. + * \param layer_begin (Optional) Begining of boosted tree layer used for prediction. + * \param layer_end (Optional) End of booster layer. 0 means do not limit trees. + */ + virtual void InplacePredict(dmlc::any const& x, std::string const& type, + float missing, + HostDeviceVector **out_preds, + uint32_t layer_begin = 0, uint32_t layer_end = 0) = 0; + void LoadModel(Json const& in) override = 0; void SaveModel(Json* out) const override = 0; diff --git a/include/xgboost/predictor.h b/include/xgboost/predictor.h index a872f3437b35..76883c30ae11 100644 --- a/include/xgboost/predictor.h +++ b/include/xgboost/predictor.h @@ -16,6 +16,7 @@ #include #include #include +#include // Forward declarations namespace xgboost { @@ -54,6 +55,7 @@ struct PredictionCacheEntry { class PredictionContainer { std::unordered_map container_; void ClearExpiredEntries(); + std::mutex cache_lock_; public: PredictionContainer() = default; @@ -133,6 +135,18 @@ class Predictor { const gbm::GBTreeModel& model, int tree_begin, uint32_t const ntree_limit = 0) = 0; + /** + * \brief Inplace prediction. + * \param x Type erased data adapter. + * \param model The model to predict from. + * \param missing Missing value in the data. + * \param [in,out] out_preds The output preds. + * \param tree_begin (Optional) Begining of boosted trees used for prediction. + * \param tree_end (Optional) End of booster trees. 0 means do not limit trees. + */ + virtual void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model, + float missing, PredictionCacheEntry *out_preds, + uint32_t tree_begin = 0, uint32_t tree_end = 0) const = 0; /** * \brief online prediction function, predict score for one instance at a time * NOTE: use the batch prediction interface if possible, batch prediction is diff --git a/python-package/setup.py b/python-package/setup.py index c2e6fd9941c6..8c09f4b3e508 100644 --- a/python-package/setup.py +++ b/python-package/setup.py @@ -86,7 +86,7 @@ def __init__(self, name): super().__init__(name=name, sources=[]) -class BuildExt(build_ext.build_ext): +class BuildExt(build_ext.build_ext): # pylint: disable=too-many-ancestors '''Custom build_ext command using CMake.''' logger = logging.getLogger('XGBoost build_ext') diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 5221bed3c25b..733b0cfba8d4 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -207,6 +207,19 @@ def ctypes2numpy(cptr, length, dtype): return res +def ctypes2cupy(cptr, length, dtype): + """Convert a ctypes pointer array to a cupy array.""" + import cupy # pylint: disable=import-error + mem = cupy.zeros(length.value, dtype=dtype, order='C') + addr = ctypes.cast(cptr, ctypes.c_void_p).value + # pylint: disable=c-extension-no-member,no-member + cupy.cuda.runtime.memcpy( + mem.__cuda_array_interface__['data'][0], addr, + length.value * ctypes.sizeof(ctypes.c_float), + cupy.cuda.runtime.memcpyDeviceToDevice) + return mem + + def ctypes2buffer(cptr, length): """Convert ctypes pointer to buffer type.""" if not isinstance(cptr, ctypes.POINTER(ctypes.c_char)): @@ -474,6 +487,7 @@ def __init__(self, data, label=None, weight=None, base_margin=None, data, feature_names, feature_types = _convert_dataframes( data, feature_names, feature_types ) + missing = np.nan if missing is None else missing if isinstance(data, (STRING_TYPES, os_PathLike)): handle = ctypes.c_void_p() @@ -1428,12 +1442,17 @@ def predict(self, training=False): """Predict with data. - .. note:: This function is not thread safe. + .. note:: This function is not thread safe except for ``gbtree`` + booster. + + For ``gbtree`` booster, the thread safety is guaranteed by locks. + For lock free prediction use ``inplace_predict`` instead. Also, the + safety does not hold when used in conjunction with other methods. - For each booster object, predict can only be called from one thread. - If you want to run prediction using multiple thread, call - ``bst.copy()`` to make copies of model object and then call - ``predict()``. + When using booster other than ``gbtree``, predict can only be called + from one thread. If you want to run prediction using multiple + thread, call ``bst.copy()`` to make copies of model object and then + call ``predict()``. Parameters ---------- @@ -1547,6 +1566,146 @@ def predict(self, preds = preds.reshape(nrow, chunk_size) return preds + def inplace_predict(self, data, iteration_range=(0, 0), + predict_type='value', missing=np.nan): + '''Run prediction in-place, Unlike ``predict`` method, inplace prediction does + not cache the prediction result. + + Calling only ``inplace_predict`` in multiple threads is safe and lock + free. But the safety does not hold when used in conjunction with other + methods. E.g. you can't train the booster in one thread and perform + prediction in the other. + + .. code-block:: python + + booster.set_param({'predictor': 'gpu_predictor'}) + booster.inplace_predict(cupy_array) + + booster.set_param({'predictor': 'cpu_predictor}) + booster.inplace_predict(numpy_array) + + Parameters + ---------- + data : numpy.ndarray/scipy.sparse.csr_matrix/cupy.ndarray/ + cudf.DataFrame/pd.DataFrame + The input data, must not be a view for numpy array. Set + ``predictor`` to ``gpu_predictor`` for running prediction on CuPy + array or CuDF DataFrame. + iteration_range : tuple + Specifies which layer of trees are used in prediction. For + example, if a random forest is trained with 100 rounds. Specifying + `iteration_range=(10, 20)`, then only the forests built during [10, + 20) (open set) rounds are used in this prediction. + predict_type : str + * `value` Output model prediction values. + * `margin` Output the raw untransformed margin value. + missing : float + Value in the input data which needs to be present as a missing + value. + + Returns + ------- + prediction : numpy.ndarray/cupy.ndarray + The prediction result. When input data is on GPU, prediction + result is stored in a cupy array. + + ''' + + def reshape_output(predt, rows): + '''Reshape for multi-output prediction.''' + if predt.size != rows and predt.size % rows == 0: + cols = int(predt.size / rows) + predt = predt.reshape(rows, cols) + return predt + return predt + + length = c_bst_ulong() + preds = ctypes.POINTER(ctypes.c_float)() + iteration_range = (ctypes.c_uint(iteration_range[0]), + ctypes.c_uint(iteration_range[1])) + + # once caching is supported, we can pass id(data) as cache id. + if isinstance(data, DataFrame): + data = data.values + if isinstance(data, np.ndarray): + assert data.flags.c_contiguous + arr = np.array(data.reshape(data.size), copy=False, + dtype=np.float32) + _check_call(_LIB.XGBoosterPredictFromDense( + self.handle, + arr.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), + c_bst_ulong(data.shape[0]), + c_bst_ulong(data.shape[1]), + ctypes.c_float(missing), + iteration_range[0], + iteration_range[1], + c_str(predict_type), + c_bst_ulong(0), + ctypes.byref(length), + ctypes.byref(preds) + )) + preds = ctypes2numpy(preds, length.value, np.float32) + rows = data.shape[0] + return reshape_output(preds, rows) + if isinstance(data, scipy.sparse.csr_matrix): + csr = data + _check_call(_LIB.XGBoosterPredictFromCSR( + self.handle, + c_array(ctypes.c_size_t, csr.indptr), + c_array(ctypes.c_uint, csr.indices), + c_array(ctypes.c_float, csr.data), + ctypes.c_size_t(len(csr.indptr)), + ctypes.c_size_t(len(csr.data)), + ctypes.c_size_t(csr.shape[1]), + ctypes.c_float(missing), + iteration_range[0], + iteration_range[1], + c_str(predict_type), + c_bst_ulong(0), + ctypes.byref(length), + ctypes.byref(preds))) + preds = ctypes2numpy(preds, length.value, np.float32) + rows = data.shape[0] + return reshape_output(preds, rows) + if lazy_isinstance(data, 'cupy.core.core', 'ndarray'): + assert data.flags.c_contiguous + interface = data.__cuda_array_interface__ + if 'mask' in interface: + interface['mask'] = interface['mask'].__cuda_array_interface__ + interface_str = bytes(json.dumps(interface, indent=2), 'utf-8') + _check_call(_LIB.XGBoosterPredictFromArrayInterface( + self.handle, + interface_str, + ctypes.c_float(missing), + iteration_range[0], + iteration_range[1], + c_str(predict_type), + c_bst_ulong(0), + ctypes.byref(length), + ctypes.byref(preds))) + mem = ctypes2cupy(preds, length, np.float32) + rows = data.shape[0] + return reshape_output(mem, rows) + if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'): + interfaces_str = _cudf_array_interfaces(data) + _check_call(_LIB.XGBoosterPredictFromArrayInterfaceColumns( + self.handle, + interfaces_str, + ctypes.c_float(missing), + iteration_range[0], + iteration_range[1], + c_str(predict_type), + c_bst_ulong(0), + ctypes.byref(length), + ctypes.byref(preds))) + mem = ctypes2cupy(preds, length, np.float32) + rows = data.shape[0] + predt = reshape_output(mem, rows) + return predt + + raise TypeError('Data type:' + str(type(data)) + + ' not supported by inplace prediction.') + def save_model(self, fname): """Save the model to a file. diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 44002c34061e..fa20605c3a18 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -26,6 +26,7 @@ from .compat import sparse, scipy_sparse from .compat import PANDAS_INSTALLED, DataFrame, Series, pandas_concat from .compat import CUDF_INSTALLED, CUDF_DataFrame, CUDF_Series, CUDF_concat +from .compat import lazy_isinstance from .core import DMatrix, Booster, _expect from .training import train as worker_train @@ -86,7 +87,7 @@ def __exit__(self, *args): LOGGER.debug('--------------- rabit say bye ------------------') -def concat(value): +def concat(value): # pylint: disable=too-many-return-statements '''To be replaced with dask builtin.''' if isinstance(value[0], numpy.ndarray): return numpy.concatenate(value, axis=0) @@ -98,6 +99,9 @@ def concat(value): return pandas_concat(value, axis=0) if CUDF_INSTALLED and isinstance(value[0], (CUDF_DataFrame, CUDF_Series)): return CUDF_concat(value, axis=0) + if lazy_isinstance(value[0], 'cupy.core.core', 'ndarray'): + import cupy # pylint: disable=import-error + return cupy.concatenate(value, axis=0) return dd.multi.concat(list(value), axis=0) @@ -370,8 +374,9 @@ def train(client, params, dtrain, *args, evals=(), **kwargs): Specify the dask client used for training. Use default client returned from dask if it's set to None. \\*\\*kwargs: - Other parameters are the same as `xgboost.train` except for `evals_result`, - which is returned as part of function return value instead of argument. + Other parameters are the same as `xgboost.train` except for + `evals_result`, which is returned as part of function return value + instead of argument. Returns ------- @@ -500,11 +505,10 @@ def mapped_predict(partition, is_df): ).result() return predictions if isinstance(data, dd.DataFrame): - import dask predictions = client.submit( dd.map_partitions, mapped_predict, data, True, - meta=dask.dataframe.utils.make_meta({'prediction': 'f4'}) + meta=dd.utils.make_meta({'prediction': 'f4'}) ).result() return predictions.iloc[:, 0] @@ -572,6 +576,79 @@ def map_function(func): return predictions +def inplace_predict(client, model, data, + iteration_range=(0, 0), + predict_type='value', + missing=numpy.nan): + '''Inplace prediction. + + Parameters + ---------- + client: dask.distributed.Client + Specify the dask client used for training. Use default client + returned from dask if it's set to None. + model: Booster/dict + The trained model. + iteration_range: tuple + Specify the range of trees used for prediction. + predict_type: str + * 'value': Normal prediction result. + * 'margin': Output the raw untransformed margin value. + missing: float + Value in the input data which needs to be present as a missing + value. If None, defaults to np.nan. + Returns + ------- + prediction: dask.array.Array + ''' + _assert_dask_support() + client = _xgb_get_client(client) + if isinstance(model, Booster): + booster = model + elif isinstance(model, dict): + booster = model['booster'] + else: + raise TypeError(_expect([Booster, dict], type(model))) + if not isinstance(data, (da.Array, dd.DataFrame)): + raise TypeError(_expect([da.Array, dd.DataFrame], type(data))) + + def mapped_predict(data, is_df): + worker = distributed_get_worker() + booster.set_param({'nthread': worker.nthreads}) + prediction = booster.inplace_predict( + data, + iteration_range=iteration_range, + predict_type=predict_type, + missing=missing) + if is_df: + if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'): + import cudf # pylint: disable=import-error + # There's an error with cudf saying `concat_cudf` got an + # expected argument `ignore_index`. So this is not yet working. + prediction = cudf.DataFrame({'prediction': prediction}, + dtype=numpy.float32) + else: + # If it's from pandas, the partition is a numpy array + prediction = DataFrame(prediction, columns=['prediction'], + dtype=numpy.float32) + return prediction + + if isinstance(data, da.Array): + predictions = client.submit( + da.map_blocks, + mapped_predict, data, False, drop_axis=1, + dtype=numpy.float32 + ).result() + return predictions + if isinstance(data, dd.DataFrame): + predictions = client.submit( + dd.map_partitions, + mapped_predict, data, True, + meta=dd.utils.make_meta({'prediction': 'f4'}) + ).result() + return predictions.iloc[:, 0] + + def _evaluation_matrices(client, validation_set, sample_weights, missing): ''' Parameters diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index a93ba897078c..d31deb83d24e 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -12,6 +12,7 @@ #include "xgboost/base.h" #include "xgboost/data.h" +#include "xgboost/host_device_vector.h" #include "xgboost/learner.h" #include "xgboost/c_api.h" #include "xgboost/logging.h" @@ -450,6 +451,95 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle, API_END(); } +// A hidden API as cache id is not being supported yet. +XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, float *values, + xgboost::bst_ulong n_rows, + xgboost::bst_ulong n_cols, + float missing, + unsigned iteration_begin, + unsigned iteration_end, + char const* c_type, + xgboost::bst_ulong cache_id, + xgboost::bst_ulong *out_len, + const float **out_result) { + API_BEGIN(); + CHECK_HANDLE(); + CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet"; + auto *learner = static_cast(handle); + + auto x = xgboost::data::DenseAdapter(values, n_rows, n_cols); + HostDeviceVector* p_predt { nullptr }; + std::string type { c_type }; + learner->InplacePredict(x, type, missing, &p_predt); + CHECK(p_predt); + + *out_result = dmlc::BeginPtr(p_predt->HostVector()); + *out_len = static_cast(p_predt->Size()); + API_END(); +} + +// A hidden API as cache id is not being supported yet. +XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle, + const size_t* indptr, + const unsigned* indices, + const bst_float* data, + size_t nindptr, + size_t nelem, + size_t num_col, + float missing, + unsigned iteration_begin, + unsigned iteration_end, + char const *c_type, + xgboost::bst_ulong cache_id, + xgboost::bst_ulong *out_len, + const float **out_result) { + API_BEGIN(); + CHECK_HANDLE(); + CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet"; + auto *learner = static_cast(handle); + + auto x = data::CSRAdapter(indptr, indices, data, nindptr - 1, nelem, num_col); + HostDeviceVector* p_predt { nullptr }; + std::string type { c_type }; + learner->InplacePredict(x, type, missing, &p_predt); + CHECK(p_predt); + + *out_result = dmlc::BeginPtr(p_predt->HostVector()); + *out_len = static_cast(p_predt->Size()); + API_END(); +} + +#if !defined(XGBOOST_USE_CUDA) +XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(BoosterHandle handle, + char const* c_json_strs, + float missing, + unsigned iteration_begin, + unsigned iteration_end, + char const* c_type, + xgboost::bst_ulong cache_id, + xgboost::bst_ulong *out_len, + float const** out_result) { + API_BEGIN(); + CHECK_HANDLE(); + LOG(FATAL) << "XGBoost not compiled with CUDA."; + API_END(); +} +XGB_DLL int XGBoosterPredictFromArrayInterface(BoosterHandle handle, + char const* c_json_strs, + float missing, + unsigned iteration_begin, + unsigned iteration_end, + char const* c_type, + xgboost::bst_ulong cache_id, + xgboost::bst_ulong *out_len, + const float **out_result) { + API_BEGIN(); + CHECK_HANDLE(); + LOG(FATAL) << "XGBoost not compiled with CUDA."; + API_END(); +} +#endif // !defined(XGBOOST_USE_CUDA) + XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) { API_BEGIN(); CHECK_HANDLE(); diff --git a/src/c_api/c_api.cu b/src/c_api/c_api.cu index f7b9253387b8..7fc49b43f74d 100644 --- a/src/c_api/c_api.cu +++ b/src/c_api/c_api.cu @@ -52,3 +52,60 @@ XGB_DLL int XGDeviceQuantileDMatrixCreateFromArrayInterface(char const* c_json_s new std::shared_ptr(new data::DeviceDMatrix(&adapter, missing, nthread, max_bin)); API_END(); } + +// A hidden API as cache id is not being supported yet. +XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(BoosterHandle handle, + char const* c_json_strs, + float missing, + unsigned iteration_begin, + unsigned iteration_end, + char const* c_type, + xgboost::bst_ulong cache_id, + xgboost::bst_ulong *out_len, + float const** out_result) { + API_BEGIN(); + CHECK_HANDLE(); + CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet"; + auto *learner = static_cast(handle); + + std::string json_str{c_json_strs}; + auto x = data::CudfAdapter(json_str); + HostDeviceVector* p_predt { nullptr }; + std::string type { c_type }; + learner->InplacePredict(x, type, missing, &p_predt); + CHECK(p_predt); + CHECK(p_predt->DeviceCanRead()); + + *out_result = p_predt->ConstDevicePointer(); + *out_len = static_cast(p_predt->Size()); + + API_END(); +} +// A hidden API as cache id is not being supported yet. +XGB_DLL int XGBoosterPredictFromArrayInterface(BoosterHandle handle, + char const* c_json_strs, + float missing, + unsigned iteration_begin, + unsigned iteration_end, + char const* c_type, + xgboost::bst_ulong cache_id, + xgboost::bst_ulong *out_len, + float const** out_result) { + API_BEGIN(); + CHECK_HANDLE(); + CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet"; + auto *learner = static_cast(handle); + + std::string json_str{c_json_strs}; + auto x = data::CupyAdapter(json_str); + HostDeviceVector* p_predt { nullptr }; + std::string type { c_type }; + learner->InplacePredict(x, type, missing, &p_predt); + CHECK(p_predt); + CHECK(p_predt->DeviceCanRead()); + + *out_result = p_predt->ConstDevicePointer(); + *out_len = static_cast(p_predt->Size()); + + API_END(); +} diff --git a/src/data/device_adapter.cuh b/src/data/device_adapter.cuh index 0763281f6152..e43971a351cd 100644 --- a/src/data/device_adapter.cuh +++ b/src/data/device_adapter.cuh @@ -52,6 +52,13 @@ class CudfAdapterBatch : public detail::NoMetaInfo { : std::numeric_limits::quiet_NaN(); return COOTuple(row_idx, column_idx, value); } + __device__ float GetValue(size_t ridx, bst_feature_t fidx) const { + auto const& column = columns_[fidx]; + float value = column.valid.Data() == nullptr || column.valid.Check(ridx) + ? column.GetElement(ridx) + : std::numeric_limits::quiet_NaN(); + return value; + } private: common::Span columns_; @@ -129,6 +136,7 @@ class CudfAdapter : public detail::SingleBatchDataIter { for (auto& json_col : json_columns) { auto column = ArrayInterface(get(json_col)); columns.push_back(column); + CHECK_EQ(column.num_cols, 1); column_ptr.emplace_back(column_ptr.back() + column.num_rows); num_rows_ = std::max(num_rows_, size_t(column.num_rows)); CHECK_EQ(device_idx_, dh::CudaGetPointerDevice(column.data)) diff --git a/src/data/simple_dmatrix.cu b/src/data/simple_dmatrix.cu index 02b95ddf2cc2..63c96bc67099 100644 --- a/src/data/simple_dmatrix.cu +++ b/src/data/simple_dmatrix.cu @@ -122,8 +122,6 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) { CopyDataColumnMajor(adapter, sparse_page_.data.DeviceSpan(), adapter->DeviceIdx(), missing, s_offset); } - // Sync - sparse_page_.data.HostVector(); info.num_col_ = adapter->NumColumns(); info.num_row_ = adapter->NumRows(); diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index 3bee789c62dc..329d45e2e23c 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -1,5 +1,5 @@ /*! - * Copyright 2014-2019 by Contributors + * Copyright 2014-2020 by Contributors * \file gbtree.cc * \brief gradient boosted tree implementation. * \author Tianqi Chen @@ -16,6 +16,7 @@ #include #include +#include "xgboost/base.h" #include "xgboost/data.h" #include "xgboost/logging.h" #include "xgboost/gbm.h" @@ -203,6 +204,22 @@ class GBTree : public GradientBooster { bool training, unsigned ntree_limit) override; + void InplacePredict(dmlc::any const &x, float missing, + PredictionCacheEntry *out_preds, + uint32_t layer_begin = 0, + unsigned layer_end = 0) const override { + CHECK(configured_); + // From here on, layer becomes concrete trees. + bst_group_t groups = model_.learner_model_param_->num_output_group; + uint32_t tree_begin = layer_begin * groups * tparam_.num_parallel_tree; + uint32_t tree_end = layer_end * groups * tparam_.num_parallel_tree; + if (tree_end == 0 || tree_end > model_.trees.size()) { + tree_end = static_cast(model_.trees.size()); + } + this->GetPredictor()->InplacePredict(x, model_, missing, out_preds, + tree_begin, tree_end); + } + void PredictInstance(const SparsePage::Inst& inst, std::vector* out_preds, unsigned ntree_limit) override { diff --git a/src/learner.cc b/src/learner.cc index 0432f4b83a55..c11fe1176078 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -8,6 +8,8 @@ #include #include +#include +#include #include #include #include @@ -18,6 +20,7 @@ #include #include +#include "dmlc/any.h" #include "xgboost/base.h" #include "xgboost/data.h" #include "xgboost/model.h" @@ -205,7 +208,7 @@ class LearnerConfiguration : public Learner { PredictionContainer cache_; protected: - bool need_configuration_; + std::atomic need_configuration_; std::map cfg_; // Stores information like best-iteration for early stopping. std::map attributes_; @@ -214,6 +217,7 @@ class LearnerConfiguration : public Learner { LearnerModelParam learner_model_param_; LearnerTrainParam tparam_; std::vector metric_names_; + std::mutex config_lock_; public: explicit LearnerConfiguration(std::vector > cache) @@ -226,6 +230,9 @@ class LearnerConfiguration : public Learner { // Configuration before data is known. void Configure() override { + // Varient of double checked lock + if (!this->need_configuration_) { return; } + std::lock_guard gard(config_lock_); if (!this->need_configuration_) { return; } monitor_.Start("Configure"); @@ -1003,6 +1010,23 @@ class LearnerImpl : public LearnerIO { XGBAPIThreadLocalEntry& GetThreadLocal() const override { return (*XGBAPIThreadLocalStore::Get())[this]; } + + void InplacePredict(dmlc::any const &x, std::string const &type, + float missing, HostDeviceVector **out_preds, + uint32_t layer_begin = 0, uint32_t layer_end = 0) override { + this->Configure(); + auto& out_predictions = this->GetThreadLocal().prediction_entry; + this->gbm_->InplacePredict(x, missing, &out_predictions, layer_begin, + layer_end); + if (type == "value") { + obj_->PredTransform(&out_predictions.predictions); + } else if (type == "margin") { + } else { + LOG(FATAL) << "Unsupported prediction type:" << type; + } + *out_preds = &out_predictions.predictions; + } + const std::map& GetConfigurationArguments() const override { return cfg_; } diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 6a7666ccfecf..2fc15731df20 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -2,13 +2,22 @@ * Copyright by Contributors 2017-2020 */ #include +#include +#include +#include +#include + +#include "xgboost/base.h" +#include "xgboost/data.h" #include "xgboost/predictor.h" #include "xgboost/tree_model.h" #include "xgboost/tree_updater.h" #include "xgboost/logging.h" #include "xgboost/host_device_vector.h" +#include "../data/adapter.h" +#include "../common/math.h" #include "../gbm/gbtree_model.h" namespace xgboost { @@ -16,91 +25,158 @@ namespace predictor { DMLC_REGISTRY_FILE_TAG(cpu_predictor); -class CPUPredictor : public Predictor { - protected: - static bst_float PredValue(const SparsePage::Inst& inst, - const std::vector>& trees, - const std::vector& tree_info, int bst_group, - RegTree::FVec* p_feats, - unsigned tree_begin, unsigned tree_end) { - bst_float psum = 0.0f; - p_feats->Fill(inst); - for (size_t i = tree_begin; i < tree_end; ++i) { - if (tree_info[i] == bst_group) { - int tid = trees[i]->GetLeafIndex(*p_feats); - psum += (*trees[i])[tid].LeafValue(); - } +bst_float PredValue(const SparsePage::Inst &inst, + const std::vector> &trees, + const std::vector &tree_info, int bst_group, + RegTree::FVec *p_feats, unsigned tree_begin, + unsigned tree_end) { + bst_float psum = 0.0f; + p_feats->Fill(inst); + for (size_t i = tree_begin; i < tree_end; ++i) { + if (tree_info[i] == bst_group) { + int tid = trees[i]->GetLeafIndex(*p_feats); + psum += (*trees[i])[tid].LeafValue(); } - p_feats->Drop(inst); - return psum; } + p_feats->Drop(inst); + return psum; +} - // init thread buffers - inline void InitThreadTemp(int nthread, int num_feature) { - int prev_thread_temp_size = thread_temp.size(); - if (prev_thread_temp_size < nthread) { - thread_temp.resize(nthread, RegTree::FVec()); - for (int i = prev_thread_temp_size; i < nthread; ++i) { - thread_temp[i].Init(num_feature); +template +struct SparsePageView { + SparsePage const* page; + bst_row_t base_rowid; + static size_t constexpr kUnroll = kUnrollLen; + + explicit SparsePageView(SparsePage const *p) + : page{p}, base_rowid{page->base_rowid} { + // Pull to host before entering omp block, as this is not thread safe. + page->data.HostVector(); + page->offset.HostVector(); + } + SparsePage::Inst operator[](size_t i) { return (*page)[i]; } + size_t Size() const { return page->Size(); } +}; + +template +class AdapterView { + Adapter* adapter_; + float missing_; + common::Span workspace_; + std::vector current_unroll_; + + public: + static size_t constexpr kUnroll = kUnrollLen; + + public: + explicit AdapterView(Adapter *adapter, float missing, + common::Span workplace) + : adapter_{adapter}, missing_{missing}, workspace_{workplace}, + current_unroll_(omp_get_max_threads() > 0 ? omp_get_max_threads() : 1, 0) {} + SparsePage::Inst operator[](size_t i) { + bst_feature_t columns = adapter_->NumColumns(); + auto const &batch = adapter_->Value(); + auto row = batch.GetLine(i); + auto t = omp_get_thread_num(); + auto const beg = (columns * kUnroll * t) + (current_unroll_[t] * columns); + size_t non_missing {beg}; + for (size_t c = 0; c < row.Size(); ++c) { + auto e = row.GetElement(c); + if (missing_ != e.value && !common::CheckNAN(e.value)) { + workspace_[non_missing] = + Entry{static_cast(e.column_idx), e.value}; + ++non_missing; } } + auto ret = workspace_.subspan(beg, non_missing - beg); + current_unroll_[t]++; + if (current_unroll_[t] == kUnroll) { + current_unroll_[t] = 0; + } + return ret; } - void PredInternal(DMatrix *p_fmat, std::vector *out_preds, - gbm::GBTreeModel const &model, int32_t tree_begin, - int32_t tree_end) { - int32_t const num_group = model.learner_model_param_->num_output_group; - const int nthread = omp_get_max_threads(); - InitThreadTemp(nthread, model.learner_model_param_->num_feature); - std::vector& preds = *out_preds; - CHECK_EQ(model.param.size_leaf_vector, 0) - << "size_leaf_vector is enforced to 0 so far"; - CHECK_EQ(preds.size(), p_fmat->Info().num_row_ * num_group); - // start collecting the prediction - for (const auto &batch : p_fmat->GetBatches()) { - // parallel over local batch - constexpr int kUnroll = 8; - const auto nsize = static_cast(batch.Size()); - const bst_omp_uint rest = nsize % kUnroll; - // Pull to host before entering omp block, as this is not thread safe. - batch.data.HostVector(); - batch.offset.HostVector(); - if (nsize >= kUnroll) { + size_t Size() const { return adapter_->NumRows(); } + + bst_row_t const static base_rowid = 0; // NOLINT +}; + +template +void PredictBatchKernel(DataView batch, std::vector *out_preds, + gbm::GBTreeModel const &model, int32_t tree_begin, + int32_t tree_end, + std::vector *p_thread_temp) { + auto& thread_temp = *p_thread_temp; + int32_t const num_group = model.learner_model_param_->num_output_group; + + std::vector &preds = *out_preds; + CHECK_EQ(model.param.size_leaf_vector, 0) + << "size_leaf_vector is enforced to 0 so far"; + // parallel over local batch + const auto nsize = static_cast(batch.Size()); + auto constexpr kUnroll = DataView::kUnroll; + const bst_omp_uint rest = nsize % kUnroll; + if (nsize >= kUnroll) { #pragma omp parallel for schedule(static) - for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) { - const int tid = omp_get_thread_num(); - RegTree::FVec& feats = thread_temp[tid]; - int64_t ridx[kUnroll]; - SparsePage::Inst inst[kUnroll]; - for (int k = 0; k < kUnroll; ++k) { - ridx[k] = static_cast(batch.base_rowid + i + k); - } - for (int k = 0; k < kUnroll; ++k) { - inst[k] = batch[i + k]; - } - for (int k = 0; k < kUnroll; ++k) { - for (int gid = 0; gid < num_group; ++gid) { - const size_t offset = ridx[k] * num_group + gid; - preds[offset] += this->PredValue( - inst[k], model.trees, model.tree_info, gid, - &feats, tree_begin, tree_end); - } - } - } + for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) { + const int tid = omp_get_thread_num(); + RegTree::FVec &feats = thread_temp[tid]; + int64_t ridx[kUnroll]; + SparsePage::Inst inst[kUnroll]; + for (size_t k = 0; k < kUnroll; ++k) { + ridx[k] = static_cast(batch.base_rowid + i + k); } - for (bst_omp_uint i = nsize - rest; i < nsize; ++i) { - RegTree::FVec& feats = thread_temp[0]; - const auto ridx = static_cast(batch.base_rowid + i); - auto inst = batch[i]; + for (size_t k = 0; k < kUnroll; ++k) { + inst[k] = batch[i + k]; + } + for (size_t k = 0; k < kUnroll; ++k) { for (int gid = 0; gid < num_group; ++gid) { - const size_t offset = ridx * num_group + gid; - preds[offset] += - this->PredValue(inst, model.trees, model.tree_info, gid, - &feats, tree_begin, tree_end); + const size_t offset = ridx[k] * num_group + gid; + preds[offset] += PredValue(inst[k], model.trees, model.tree_info, gid, + &feats, tree_begin, tree_end); } } } } + for (bst_omp_uint i = nsize - rest; i < nsize; ++i) { + RegTree::FVec &feats = thread_temp[0]; + const auto ridx = static_cast(batch.base_rowid + i); + auto inst = batch[i]; + for (int gid = 0; gid < num_group; ++gid) { + const size_t offset = ridx * num_group + gid; + preds[offset] += PredValue(inst, model.trees, model.tree_info, gid, + &feats, tree_begin, tree_end); + } + } +} + +class CPUPredictor : public Predictor { + protected: + // init thread buffers + static void InitThreadTemp(int nthread, int num_feature, std::vector* out) { + int prev_thread_temp_size = out->size(); + if (prev_thread_temp_size < nthread) { + out->resize(nthread, RegTree::FVec()); + for (int i = prev_thread_temp_size; i < nthread; ++i) { + (*out)[i].Init(num_feature); + } + } + } + + void PredictDMatrix(DMatrix *p_fmat, std::vector *out_preds, + gbm::GBTreeModel const &model, int32_t tree_begin, + int32_t tree_end) { + std::lock_guard guard(lock_); + const int threads = omp_get_max_threads(); + InitThreadTemp(threads, model.learner_model_param_->num_feature, &this->thread_temp_); + for (auto const& batch : p_fmat->GetBatches()) { + CHECK_EQ(out_preds->size(), + p_fmat->Info().num_row_ * model.learner_model_param_->num_output_group); + size_t constexpr kUnroll = 8; + PredictBatchKernel(SparsePageView{&batch}, out_preds, model, tree_begin, + tree_end, &thread_temp_); + } + } void InitOutPredictions(const MetaInfo& info, HostDeviceVector* out_preds, @@ -175,9 +251,9 @@ class CPUPredictor : public Predictor { CHECK_LE(beg_version, end_version); if (beg_version < end_version) { - this->PredInternal(dmat, &out_preds->HostVector(), model, - beg_version * output_groups, - end_version * output_groups); + this->PredictDMatrix(dmat, &out_preds->HostVector(), model, + beg_version * output_groups, + end_version * output_groups); } // delta means {size of forest} * {number of newly accumulated layers} @@ -189,12 +265,49 @@ class CPUPredictor : public Predictor { out_preds->Size() == dmat->Info().num_row_); } + template + void DispatchedInplacePredict(dmlc::any const &x, + const gbm::GBTreeModel &model, float missing, + PredictionCacheEntry *out_preds, + uint32_t tree_begin, uint32_t tree_end) const { + auto threads = omp_get_max_threads(); + auto m = dmlc::get(x); + CHECK_EQ(m.NumColumns(), model.learner_model_param_->num_feature) + << "Number of columns in data must equal to trained model."; + MetaInfo info; + info.num_col_ = m.NumColumns(); + info.num_row_ = m.NumRows(); + this->InitOutPredictions(info, &(out_preds->predictions), model); + std::vector workspace(info.num_col_ * 8 * threads); + auto &predictions = out_preds->predictions.HostVector(); + std::vector thread_temp; + InitThreadTemp(threads, model.learner_model_param_->num_feature, &thread_temp); + size_t constexpr kUnroll = 8; + PredictBatchKernel(AdapterView( + &m, missing, common::Span{workspace}), + &predictions, model, tree_begin, tree_end, &thread_temp); + } + + void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model, + float missing, PredictionCacheEntry *out_preds, + uint32_t tree_begin, unsigned tree_end) const override { + if (x.type() == typeid(data::DenseAdapter)) { + this->DispatchedInplacePredict( + x, model, missing, out_preds, tree_begin, tree_end); + } else if (x.type() == typeid(data::CSRAdapter)) { + this->DispatchedInplacePredict( + x, model, missing, out_preds, tree_begin, tree_end); + } else { + LOG(FATAL) << "Data type is not supported by CPU Predictor."; + } + } + void PredictInstance(const SparsePage::Inst& inst, std::vector* out_preds, const gbm::GBTreeModel& model, unsigned ntree_limit) override { - if (thread_temp.size() == 0) { - thread_temp.resize(1, RegTree::FVec()); - thread_temp[0].Init(model.learner_model_param_->num_feature); + if (thread_temp_.size() == 0) { + thread_temp_.resize(1, RegTree::FVec()); + thread_temp_[0].Init(model.learner_model_param_->num_feature); } ntree_limit *= model.learner_model_param_->num_output_group; if (ntree_limit == 0 || ntree_limit > model.trees.size()) { @@ -204,16 +317,16 @@ class CPUPredictor : public Predictor { (model.param.size_leaf_vector + 1)); // loop over output groups for (uint32_t gid = 0; gid < model.learner_model_param_->num_output_group; ++gid) { - (*out_preds)[gid] = - PredValue(inst, model.trees, model.tree_info, gid, - &thread_temp[0], 0, ntree_limit) + - model.learner_model_param_->base_score; + (*out_preds)[gid] = PredValue(inst, model.trees, model.tree_info, gid, + &thread_temp_[0], 0, ntree_limit) + + model.learner_model_param_->base_score; } } + void PredictLeaf(DMatrix* p_fmat, std::vector* out_preds, const gbm::GBTreeModel& model, unsigned ntree_limit) override { const int nthread = omp_get_max_threads(); - InitThreadTemp(nthread, model.learner_model_param_->num_feature); + InitThreadTemp(nthread, model.learner_model_param_->num_feature, &this->thread_temp_); const MetaInfo& info = p_fmat->Info(); // number of valid trees ntree_limit *= model.learner_model_param_->num_output_group; @@ -230,7 +343,7 @@ class CPUPredictor : public Predictor { for (bst_omp_uint i = 0; i < nsize; ++i) { const int tid = omp_get_thread_num(); auto ridx = static_cast(batch.base_rowid + i); - RegTree::FVec& feats = thread_temp[tid]; + RegTree::FVec &feats = thread_temp_[tid]; feats.Fill(batch[i]); for (unsigned j = 0; j < ntree_limit; ++j) { int tid = model.trees[j]->GetLeafIndex(feats); @@ -247,7 +360,7 @@ class CPUPredictor : public Predictor { bool approximate, int condition, unsigned condition_feature) override { const int nthread = omp_get_max_threads(); - InitThreadTemp(nthread, model.learner_model_param_->num_feature); + InitThreadTemp(nthread, model.learner_model_param_->num_feature, &this->thread_temp_); const MetaInfo& info = p_fmat->Info(); // number of valid trees ntree_limit *= model.learner_model_param_->num_output_group; @@ -277,7 +390,7 @@ class CPUPredictor : public Predictor { #pragma omp parallel for schedule(static) for (bst_omp_uint i = 0; i < nsize; ++i) { auto row_idx = static_cast(batch.base_rowid + i); - RegTree::FVec& feats = thread_temp[omp_get_thread_num()]; + RegTree::FVec &feats = thread_temp_[omp_get_thread_num()]; std::vector this_tree_contribs(ncolumns); // loop over all classes for (int gid = 0; gid < ngroup; ++gid) { @@ -359,7 +472,10 @@ class CPUPredictor : public Predictor { } } } - std::vector thread_temp; + + private: + std::mutex lock_; + std::vector thread_temp_; }; XGBOOST_REGISTER_PREDICTOR(CPUPredictor, "cpu_predictor") diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 165e99186fb8..a77e564c9a2c 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -15,6 +15,7 @@ #include "../gbm/gbtree_model.h" #include "../data/ellpack_page.cuh" +#include "../data/device_adapter.cuh" #include "../common/common.h" #include "../common/device_helpers.cuh" @@ -116,6 +117,76 @@ struct EllpackLoader { } }; +struct CuPyAdapterLoader { + data::CupyAdapterBatch batch; + bst_feature_t columns; + float* smem; + bool use_shared; + + DEV_INLINE CuPyAdapterLoader(data::CupyAdapterBatch const batch, bool use_shared, + bst_feature_t num_features, bst_row_t num_rows, size_t entry_start) : + batch{batch}, + columns{num_features}, + use_shared{use_shared} { + extern __shared__ float _smem[]; + smem = _smem; + if (use_shared) { + uint32_t global_idx = blockDim.x * blockIdx.x + threadIdx.x; + size_t shared_elements = blockDim.x * num_features; + dh::BlockFill(smem, shared_elements, nanf("")); + __syncthreads(); + if (global_idx < num_rows) { + auto beg = global_idx * columns; + auto end = (global_idx + 1) * columns; + for (size_t i = beg; i < end; ++i) { + smem[threadIdx.x * num_features + (i - beg)] = batch.GetElement(i).value; + } + } + } + __syncthreads(); + } + + DEV_INLINE float GetFvalue(bst_row_t ridx, bst_feature_t fidx) const { + if (use_shared) { + return smem[threadIdx.x * columns + fidx]; + } + return batch.GetElement(ridx * columns + fidx).value; + } +}; + +struct CuDFAdapterLoader { + data::CudfAdapterBatch batch; + bst_feature_t columns; + float* smem; + bool use_shared; + + DEV_INLINE CuDFAdapterLoader(data::CudfAdapterBatch const batch, bool use_shared, + bst_feature_t num_features, + bst_row_t num_rows, size_t entry_start) + : batch{batch}, columns{num_features}, use_shared{use_shared} { + extern __shared__ float _smem[]; + smem = _smem; + if (use_shared) { + uint32_t global_idx = blockDim.x * blockIdx.x + threadIdx.x; + size_t shared_elements = blockDim.x * num_features; + dh::BlockFill(smem, shared_elements, nanf("")); + __syncthreads(); + if (global_idx < num_rows) { + for (size_t i = 0; i < columns; ++i) { + smem[threadIdx.x * columns + i] = batch.GetValue(global_idx, i); + } + } + } + __syncthreads(); + } + DEV_INLINE float GetFvalue(bst_row_t ridx, bst_feature_t fidx) const { + if (use_shared) { + return smem[threadIdx.x * columns + fidx]; + } + return batch.GetValue(ridx, fidx); + } +}; + template __device__ float GetLeafWeight(bst_uint ridx, const RegTree::Node* tree, Loader* loader) { @@ -169,30 +240,61 @@ __global__ void PredictKernel(Data data, } } -class GPUPredictor : public xgboost::Predictor { - private: - void InitModel(const gbm::GBTreeModel& model, +class DeviceModel { + public: + dh::device_vector nodes; + dh::device_vector tree_segments; + dh::device_vector tree_group; + size_t tree_beg_; // NOLINT + size_t tree_end_; // NOLINT + int num_group; + + void CopyModel(const gbm::GBTreeModel& model, const thrust::host_vector& h_tree_segments, const thrust::host_vector& h_nodes, size_t tree_begin, size_t tree_end) { - dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id)); - nodes_.resize(h_nodes.size()); - dh::safe_cuda(cudaMemcpyAsync(nodes_.data().get(), h_nodes.data(), + nodes.resize(h_nodes.size()); + dh::safe_cuda(cudaMemcpyAsync(nodes.data().get(), h_nodes.data(), sizeof(RegTree::Node) * h_nodes.size(), cudaMemcpyHostToDevice)); - tree_segments_.resize(h_tree_segments.size()); - dh::safe_cuda(cudaMemcpyAsync(tree_segments_.data().get(), h_tree_segments.data(), + tree_segments.resize(h_tree_segments.size()); + dh::safe_cuda(cudaMemcpyAsync(tree_segments.data().get(), h_tree_segments.data(), sizeof(size_t) * h_tree_segments.size(), cudaMemcpyHostToDevice)); - tree_group_.resize(model.tree_info.size()); - dh::safe_cuda(cudaMemcpyAsync(tree_group_.data().get(), model.tree_info.data(), + tree_group.resize(model.tree_info.size()); + dh::safe_cuda(cudaMemcpyAsync(tree_group.data().get(), model.tree_info.data(), sizeof(int) * model.tree_info.size(), cudaMemcpyHostToDevice)); - this->tree_begin_ = tree_begin; + this->tree_beg_ = tree_begin; this->tree_end_ = tree_end; - this->num_group_ = model.learner_model_param_->num_output_group; + this->num_group = model.learner_model_param_->num_output_group; + } + + void Init(const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end, int32_t gpu_id) { + dh::safe_cuda(cudaSetDevice(gpu_id)); + CHECK_EQ(model.param.size_leaf_vector, 0); + // Copy decision trees to device + thrust::host_vector h_tree_segments{}; + h_tree_segments.reserve((tree_end - tree_begin) + 1); + size_t sum = 0; + h_tree_segments.push_back(sum); + for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { + sum += model.trees.at(tree_idx)->GetNodes().size(); + h_tree_segments.push_back(sum); + } + + thrust::host_vector h_nodes(h_tree_segments.back()); + for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { + auto& src_nodes = model.trees.at(tree_idx)->GetNodes(); + std::copy(src_nodes.begin(), src_nodes.end(), + h_nodes.begin() + h_tree_segments[tree_idx - tree_begin]); + } + CopyModel(model, h_tree_segments, h_nodes, tree_begin, tree_end); } +}; +class GPUPredictor : public xgboost::Predictor { + private: void PredictInternal(const SparsePage& batch, size_t num_features, HostDeviceVector* predictions, size_t batch_offset) { @@ -214,10 +316,10 @@ class GPUPredictor : public xgboost::Predictor { dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} ( PredictKernel, data, - dh::ToSpan(nodes_), predictions->DeviceSpan().subspan(batch_offset), - dh::ToSpan(tree_segments_), dh::ToSpan(tree_group_), - this->tree_begin_, this->tree_end_, num_features, num_rows, - entry_start, use_shared, this->num_group_); + dh::ToSpan(model_.nodes), predictions->DeviceSpan().subspan(batch_offset), + dh::ToSpan(model_.tree_segments), dh::ToSpan(model_.tree_group), + model_.tree_beg_, model_.tree_end_, num_features, num_rows, + entry_start, use_shared, model_.num_group); } void PredictInternal(EllpackDeviceAccessor const& batch, HostDeviceVector* out_preds, size_t batch_offset) { @@ -230,31 +332,10 @@ class GPUPredictor : public xgboost::Predictor { dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS} ( PredictKernel, batch, - dh::ToSpan(nodes_), out_preds->DeviceSpan().subspan(batch_offset), - dh::ToSpan(tree_segments_), dh::ToSpan(tree_group_), - this->tree_begin_, this->tree_end_, batch.NumFeatures(), num_rows, - entry_start, use_shared, this->num_group_); - } - - void InitModel(const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end) { - CHECK_EQ(model.param.size_leaf_vector, 0); - // Copy decision trees to device - thrust::host_vector h_tree_segments{}; - h_tree_segments.reserve((tree_end - tree_begin) + 1); - size_t sum = 0; - h_tree_segments.push_back(sum); - for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { - sum += model.trees.at(tree_idx)->GetNodes().size(); - h_tree_segments.push_back(sum); - } - - thrust::host_vector h_nodes(h_tree_segments.back()); - for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { - auto& src_nodes = model.trees.at(tree_idx)->GetNodes(); - std::copy(src_nodes.begin(), src_nodes.end(), - h_nodes.begin() + h_tree_segments[tree_idx - tree_begin]); - } - InitModel(model, h_tree_segments, h_nodes, tree_begin, tree_end); + dh::ToSpan(model_.nodes), out_preds->DeviceSpan().subspan(batch_offset), + dh::ToSpan(model_.tree_segments), dh::ToSpan(model_.tree_group), + model_.tree_beg_, model_.tree_end_, batch.NumFeatures(), num_rows, + entry_start, use_shared, model_.num_group); } void DevicePredictInternal(DMatrix* dmat, HostDeviceVector* out_preds, @@ -264,8 +345,7 @@ class GPUPredictor : public xgboost::Predictor { if (tree_end - tree_begin == 0) { return; } - monitor_.StartCuda("DevicePredictInternal"); - InitModel(model, tree_begin, tree_end); + model_.Init(model, tree_begin, tree_end, generic_param_->gpu_id); out_preds->SetDevice(generic_param_->gpu_id); if (dmat->PageExists()) { @@ -284,7 +364,6 @@ class GPUPredictor : public xgboost::Predictor { batch_offset += batch.Size() * model.learner_model_param_->num_output_group; } } - monitor_.StopCuda("DevicePredictInternal"); } public: @@ -302,6 +381,7 @@ class GPUPredictor : public xgboost::Predictor { unsigned ntree_limit = 0) override { // This function is duplicated with CPU predictor PredictBatch, see comments in there. // FIXME(trivialfis): Remove the duplication. + std::lock_guard const guard(lock_); int device = generic_param_->gpu_id; CHECK_GE(device, 0) << "Set `gpu_id' to positive value for processing GPU data."; ConfigureDevice(device); @@ -348,6 +428,63 @@ class GPUPredictor : public xgboost::Predictor { out_preds->Size() == dmat->Info().num_row_); } + template + void DispatchedInplacePredict(dmlc::any const &x, + const gbm::GBTreeModel &model, float missing, + PredictionCacheEntry *out_preds, + uint32_t tree_begin, uint32_t tree_end) const { + auto max_shared_memory_bytes = dh::MaxSharedMemory(this->generic_param_->gpu_id); + uint32_t const output_groups = model.learner_model_param_->num_output_group; + DeviceModel d_model; + d_model.Init(model, tree_begin, tree_end, this->generic_param_->gpu_id); + + auto m = dmlc::get(x); + CHECK_EQ(m.NumColumns(), model.learner_model_param_->num_feature) + << "Number of columns in data must equal to trained model."; + CHECK_EQ(this->generic_param_->gpu_id, m.DeviceIdx()) + << "XGBoost is running on device: " << this->generic_param_->gpu_id << ", " + << "but data is on: " << m.DeviceIdx(); + MetaInfo info; + info.num_col_ = m.NumColumns(); + info.num_row_ = m.NumRows(); + this->InitOutPredictions(info, &(out_preds->predictions), model); + + const uint32_t BLOCK_THREADS = 128; + auto GRID_SIZE = static_cast(common::DivRoundUp(info.num_row_, BLOCK_THREADS)); + + auto shared_memory_bytes = + static_cast(sizeof(float) * m.NumColumns() * BLOCK_THREADS); + bool use_shared = true; + if (shared_memory_bytes > max_shared_memory_bytes) { + shared_memory_bytes = 0; + use_shared = false; + } + size_t entry_start = 0; + + dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} ( + PredictKernel, + m.Value(), + dh::ToSpan(d_model.nodes), out_preds->predictions.DeviceSpan(), + dh::ToSpan(d_model.tree_segments), dh::ToSpan(d_model.tree_group), + tree_begin, tree_end, m.NumColumns(), info.num_row_, + entry_start, use_shared, output_groups); + } + + void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model, + float missing, PredictionCacheEntry *out_preds, + uint32_t tree_begin, unsigned tree_end) const override { + auto max_shared_memory_bytes = dh::MaxSharedMemory(this->generic_param_->gpu_id); + if (x.type() == typeid(data::CupyAdapter)) { + this->DispatchedInplacePredict( + x, model, missing, out_preds, tree_begin, tree_end); + } else if (x.type() == typeid(data::CudfAdapter)) { + this->DispatchedInplacePredict( + x, model, missing, out_preds, tree_begin, tree_end); + } else { + LOG(FATAL) << "Only CuPy and CuDF are supported by GPU Predictor."; + } + } + protected: void InitOutPredictions(const MetaInfo& info, HostDeviceVector* out_preds, @@ -411,14 +548,9 @@ class GPUPredictor : public xgboost::Predictor { } } - common::Monitor monitor_; - dh::device_vector nodes_; - dh::device_vector tree_segments_; - dh::device_vector tree_group_; + std::mutex lock_; + DeviceModel model_; size_t max_shared_memory_bytes_; - size_t tree_begin_; - size_t tree_end_; - int num_group_; }; XGBOOST_REGISTER_PREDICTOR(GPUPredictor, "gpu_predictor") diff --git a/src/predictor/predictor.cc b/src/predictor/predictor.cc index b20edf880f6c..eb68d382b421 100644 --- a/src/predictor/predictor.cc +++ b/src/predictor/predictor.cc @@ -2,8 +2,9 @@ * Copyright 2017-2020 by Contributors */ #include -#include +#include +#include "xgboost/predictor.h" #include "xgboost/data.h" #include "xgboost/generic_parameters.h" @@ -25,6 +26,7 @@ void PredictionContainer::ClearExpiredEntries() { } PredictionCacheEntry &PredictionContainer::Cache(std::shared_ptr m, int32_t device) { + std::lock_guard guard { cache_lock_ }; this->ClearExpiredEntries(); container_[m.get()].ref = m; if (device != GenericParameter::kCpuId) { diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index 3864a4f6d06f..5621e0c7d13d 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -177,9 +177,8 @@ void RandomDataGenerator::GenerateDense(HostDeviceVector *out) const { } } -void RandomDataGenerator::GenerateArrayInterface( - HostDeviceVector *storage, std::string *out) const { - CHECK(out); +Json RandomDataGenerator::ArrayInterfaceImpl(HostDeviceVector *storage, + size_t rows, size_t cols) const { this->GenerateDense(storage); Json array_interface {Object()}; array_interface["data"] = std::vector(2); @@ -187,13 +186,37 @@ void RandomDataGenerator::GenerateArrayInterface( array_interface["data"][1] = Boolean(false); array_interface["shape"] = std::vector(2); - array_interface["shape"][0] = rows_; - array_interface["shape"][1] = cols_; + array_interface["shape"][0] = rows; + array_interface["shape"][1] = cols; array_interface["typestr"] = String(" *storage) const { + auto array_interface = this->ArrayInterfaceImpl(storage, rows_, cols_); + std::string out; + Json::Dump(array_interface, &out); + return out; +} - Json::Dump(array_interface, out); + + +std::string RandomDataGenerator::GenerateColumnarArrayInterface( + std::vector> *data) const { + CHECK(data); + CHECK_EQ(data->size(), cols_); + auto& storage = *data; + Json arr { Array() }; + for (size_t i = 0; i < cols_; ++i) { + auto column = this->ArrayInterfaceImpl(&storage[i], rows_, 1); + get(arr).emplace_back(column); + } + std::string out; + Json::Dump(arr, &out); + return out; } void RandomDataGenerator::GenerateCSR( diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index 662738f1701f..42bc8dde91ce 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -181,6 +181,9 @@ class RandomDataGenerator { int32_t device_; int32_t seed_; + Json ArrayInterfaceImpl(HostDeviceVector *storage, size_t rows, + size_t cols) const; + public: RandomDataGenerator(bst_row_t rows, size_t cols, float sparsity) : rows_{rows}, cols_{cols}, sparsity_{sparsity}, lower_{0.0f}, upper_{1.0f}, @@ -204,7 +207,9 @@ class RandomDataGenerator { } void GenerateDense(HostDeviceVector* out) const; - void GenerateArrayInterface(HostDeviceVector* storage, std::string* out) const; + std::string GenerateArrayInterface(HostDeviceVector* storage) const; + std::string GenerateColumnarArrayInterface( + std::vector> *data) const; void GenerateCSR(HostDeviceVector* value, HostDeviceVector* row_ptr, HostDeviceVector* columns) const; diff --git a/tests/cpp/predictor/test_cpu_predictor.cc b/tests/cpp/predictor/test_cpu_predictor.cc index 8db7c54fae3e..3470902386a2 100644 --- a/tests/cpp/predictor/test_cpu_predictor.cc +++ b/tests/cpp/predictor/test_cpu_predictor.cc @@ -6,7 +6,9 @@ #include #include "../helpers.h" +#include "test_predictor.h" #include "../../../src/gbm/gbtree_model.h" +#include "../../../src/data/adapter.h" namespace xgboost { TEST(CpuPredictor, Basic) { @@ -138,4 +140,27 @@ TEST(CpuPredictor, ExternalMemory) { } } } + +TEST(CpuPredictor, InplacePredict) { + bst_row_t constexpr kRows{128}; + bst_feature_t constexpr kCols{64}; + auto gen = RandomDataGenerator{kRows, kCols, 0.5}.Device(-1); + { + HostDeviceVector data; + gen.GenerateDense(&data); + ASSERT_EQ(data.Size(), kRows * kCols); + data::DenseAdapter x{data.HostPointer(), kRows, kCols}; + TestInplacePrediction(x, "cpu_predictor", kRows, kCols, -1); + } + + { + HostDeviceVector data; + HostDeviceVector rptrs; + HostDeviceVector columns; + gen.GenerateCSR(&data, &rptrs, &columns); + data::CSRAdapter x(rptrs.HostPointer(), columns.HostPointer(), + data.HostPointer(), kRows, data.Size(), kCols); + TestInplacePrediction(x, "cpu_predictor", kRows, kCols, -1); + } +} } // namespace xgboost diff --git a/tests/cpp/predictor/test_gpu_predictor.cu b/tests/cpp/predictor/test_gpu_predictor.cu index 9e74ae2ab838..7c8dac3736fe 100644 --- a/tests/cpp/predictor/test_gpu_predictor.cu +++ b/tests/cpp/predictor/test_gpu_predictor.cu @@ -1,16 +1,17 @@ /*! * Copyright 2017-2020 XGBoost contributors */ +#include #include #include #include #include #include - #include -#include "gtest/gtest.h" + #include "../helpers.h" #include "../../../src/gbm/gbtree_model.h" +#include "../../../src/data/device_adapter.cuh" #include "test_predictor.h" namespace xgboost { @@ -104,5 +105,43 @@ TEST(GPUPredictor, ExternalMemoryTest) { } } } + +TEST(GPUPredictor, InplacePredictCupy) { + size_t constexpr kRows{128}, kCols{64}; + RandomDataGenerator gen(kRows, kCols, 0.5); + gen.Device(0); + HostDeviceVector data; + std::string interface_str = gen.GenerateArrayInterface(&data); + data::CupyAdapter x{interface_str}; + TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 0); +} + +TEST(GPUPredictor, InplacePredictCuDF) { + size_t constexpr kRows{128}, kCols{64}; + RandomDataGenerator gen(kRows, kCols, 0.5); + gen.Device(0); + std::vector> storage(kCols); + auto interface_str = gen.GenerateColumnarArrayInterface(&storage); + data::CudfAdapter x {interface_str}; + TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 0); +} + +TEST(GPUPredictor, MGPU_InplacePredict) { + int32_t n_gpus = xgboost::common::AllVisibleGPUs(); + if (n_gpus <= 1) { + LOG(WARNING) << "GPUPredictor.MGPU_InplacePredict is skipped."; + return; + } + size_t constexpr kRows{128}, kCols{64}; + RandomDataGenerator gen(kRows, kCols, 0.5); + gen.Device(1); + HostDeviceVector data; + std::string interface_str = gen.GenerateArrayInterface(&data); + data::CupyAdapter x{interface_str}; + TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 1); + EXPECT_THROW(TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 0), + dmlc::Error); +} + } // namespace predictor } // namespace xgboost diff --git a/tests/cpp/predictor/test_predictor.cc b/tests/cpp/predictor/test_predictor.cc index f6c2ea736064..675daea337cc 100644 --- a/tests/cpp/predictor/test_predictor.cc +++ b/tests/cpp/predictor/test_predictor.cc @@ -77,4 +77,59 @@ void TestTrainingPrediction(size_t rows, std::string tree_method) { predictions_0.ConstHostVector()[i], kRtEps); } } + +void TestInplacePrediction(dmlc::any x, std::string predictor, + bst_row_t rows, bst_feature_t cols, + int32_t device) { + size_t constexpr kClasses { 4 }; + auto gen = RandomDataGenerator{rows, cols, 0.5}.Device(device); + std::shared_ptr m = gen.GenerateDMatix(true, false, kClasses); + + std::unique_ptr learner { + Learner::Create({m}) + }; + + learner->SetParam("num_parallel_tree", "4"); + learner->SetParam("num_class", std::to_string(kClasses)); + learner->SetParam("seed", "0"); + learner->SetParam("subsample", "0.5"); + learner->SetParam("gpu_id", std::to_string(device)); + learner->SetParam("predictor", predictor); + for (int32_t it = 0; it < 4; ++it) { + learner->UpdateOneIter(it, m); + } + + HostDeviceVector *p_out_predictions_0{nullptr}; + learner->InplacePredict(x, "margin", std::numeric_limits::quiet_NaN(), + &p_out_predictions_0, 0, 2); + CHECK(p_out_predictions_0); + HostDeviceVector predict_0 (p_out_predictions_0->Size()); + predict_0.Copy(*p_out_predictions_0); + + HostDeviceVector *p_out_predictions_1{nullptr}; + learner->InplacePredict(x, "margin", std::numeric_limits::quiet_NaN(), + &p_out_predictions_1, 2, 4); + CHECK(p_out_predictions_1); + HostDeviceVector predict_1 (p_out_predictions_1->Size()); + predict_1.Copy(*p_out_predictions_1); + + HostDeviceVector* p_out_predictions{nullptr}; + learner->InplacePredict(x, "margin", std::numeric_limits::quiet_NaN(), + &p_out_predictions, 0, 4); + + auto& h_pred = p_out_predictions->HostVector(); + auto& h_pred_0 = predict_0.HostVector(); + auto& h_pred_1 = predict_1.HostVector(); + + ASSERT_EQ(h_pred.size(), rows * kClasses); + ASSERT_EQ(h_pred.size(), h_pred_0.size()); + ASSERT_EQ(h_pred.size(), h_pred_1.size()); + for (size_t i = 0; i < h_pred.size(); ++i) { + // Need to remove the global bias here. + ASSERT_NEAR(h_pred[i], h_pred_0[i] + h_pred_1[i] - 0.5f, kRtEps); + } + + learner->SetParam("gpu_id", "-1"); + learner->Configure(); +} } // namespace xgboost diff --git a/tests/cpp/predictor/test_predictor.h b/tests/cpp/predictor/test_predictor.h index 3234584d1f17..4baa0b4e3ed7 100644 --- a/tests/cpp/predictor/test_predictor.h +++ b/tests/cpp/predictor/test_predictor.h @@ -58,6 +58,9 @@ void TestPredictionFromGradientIndex(std::string name, size_t rows, int32_t bins void TestTrainingPrediction(size_t rows, std::string tree_method); +void TestInplacePrediction(dmlc::any x, std::string predictor, + bst_row_t rows, bst_feature_t cols, + int32_t device = -1); } // namespace xgboost #endif // XGBOOST_TEST_PREDICTOR_H_ diff --git a/tests/python-gpu/test_gpu_prediction.py b/tests/python-gpu/test_gpu_prediction.py index fd092c5e713a..211cb115ddb9 100644 --- a/tests/python-gpu/test_gpu_prediction.py +++ b/tests/python-gpu/test_gpu_prediction.py @@ -1,8 +1,12 @@ -from __future__ import print_function +import sys +import unittest +import pytest import numpy as np -import unittest import xgboost as xgb +sys.path.append("tests/python") +import testing as tm +from test_predict import run_threaded_predict # noqa rng = np.random.RandomState(1994) @@ -111,3 +115,65 @@ def test_sklearn(self): assert np.allclose(cpu_train_score, gpu_train_score) assert np.allclose(cpu_test_score, gpu_test_score) + + @pytest.mark.skipif(**tm.no_cupy()) + def test_inplace_predict_cupy(self): + import cupy as cp + rows = 1000 + cols = 10 + cp_rng = cp.random.RandomState(1994) + cp.random.set_random_state(cp_rng) + X = cp.random.randn(rows, cols) + y = cp.random.randn(rows) + + dtrain = xgb.DMatrix(X, y) + + booster = xgb.train({'tree_method': 'gpu_hist'}, + dtrain, num_boost_round=10) + test = xgb.DMatrix(X[:10, ...]) + predt_from_array = booster.inplace_predict(X[:10, ...]) + predt_from_dmatrix = booster.predict(test) + + cp.testing.assert_allclose(predt_from_array, predt_from_dmatrix) + + def predict_dense(x): + inplace_predt = booster.inplace_predict(x) + d = xgb.DMatrix(x) + copied_predt = cp.array(booster.predict(d)) + return cp.all(copied_predt == inplace_predt) + + for i in range(10): + run_threaded_predict(X, rows, predict_dense) + + @pytest.mark.skipif(**tm.no_cudf()) + def test_inplace_predict_cudf(self): + import cupy as cp + import cudf + import pandas as pd + rows = 1000 + cols = 10 + rng = np.random.RandomState(1994) + X = rng.randn(rows, cols) + X = pd.DataFrame(X) + y = rng.randn(rows) + + X = cudf.from_pandas(X) + + dtrain = xgb.DMatrix(X, y) + + booster = xgb.train({'tree_method': 'gpu_hist'}, + dtrain, num_boost_round=10) + test = xgb.DMatrix(X) + predt_from_array = booster.inplace_predict(X) + predt_from_dmatrix = booster.predict(test) + + cp.testing.assert_allclose(predt_from_array, predt_from_dmatrix) + + def predict_df(x): + inplace_predt = booster.inplace_predict(x) + d = xgb.DMatrix(x) + copied_predt = cp.array(booster.predict(d)) + return cp.all(copied_predt == inplace_predt) + + for i in range(10): + run_threaded_predict(X, rows, predict_df) diff --git a/tests/python-gpu/test_gpu_with_dask.py b/tests/python-gpu/test_gpu_with_dask.py index eb5ce6530b8d..2de976ce0af3 100644 --- a/tests/python-gpu/test_gpu_with_dask.py +++ b/tests/python-gpu/test_gpu_with_dask.py @@ -2,6 +2,7 @@ import pytest import numpy as np import unittest +import xgboost if sys.platform.startswith("win"): pytest.skip("Skipping dask tests on Windows", allow_module_level=True) @@ -29,6 +30,7 @@ class TestDistributedGPU(unittest.TestCase): def test_dask_dataframe(self): with LocalCUDACluster() as cluster: with Client(cluster) as client: + import cupy X, y = generate_array() X = dd.from_dask_array(X) @@ -49,6 +51,42 @@ def test_dask_dataframe(self): predictions = dxgb.predict(client, out, dtrain).compute() assert isinstance(predictions, np.ndarray) + # There's an error with cudf saying `concat_cudf` got an + # expected argument `ignore_index`. So the test here is just + # place holder. + + # series_predictions = dxgb.inplace_predict(client, out, X) + # assert isinstance(series_predictions, dd.Series) + + single_node = out['booster'].predict( + xgboost.DMatrix(X.compute())) + cupy.testing.assert_allclose(single_node, predictions) + + @pytest.mark.skipif(**tm.no_cupy()) + def test_dask_array(self): + with LocalCUDACluster() as cluster: + with Client(cluster) as client: + import cupy + X, y = generate_array() + + X = X.map_blocks(cupy.asarray) + y = y.map_blocks(cupy.asarray) + dtrain = dxgb.DaskDMatrix(client, X, y) + out = dxgb.train(client, {'tree_method': 'gpu_hist'}, + dtrain=dtrain, + evals=[(dtrain, 'X')], + num_boost_round=2) + from_dmatrix = dxgb.predict(client, out, dtrain).compute() + inplace_predictions = dxgb.inplace_predict( + client, out, X).compute() + single_node = out['booster'].predict( + xgboost.DMatrix(X.compute())) + np.testing.assert_allclose(single_node, from_dmatrix) + cupy.testing.assert_allclose( + cupy.array(single_node), + inplace_predictions) + + @pytest.mark.skipif(**tm.no_dask()) @pytest.mark.skipif(**tm.no_dask_cuda()) @pytest.mark.mgpu diff --git a/tests/python/test_predict.py b/tests/python/test_predict.py new file mode 100644 index 000000000000..d3258fdc5685 --- /dev/null +++ b/tests/python/test_predict.py @@ -0,0 +1,63 @@ +'''Tests for running inplace prediction.''' +import unittest +from concurrent.futures import ThreadPoolExecutor +import numpy as np +from scipy import sparse + +import xgboost as xgb + + +def run_threaded_predict(X, rows, predict_func): + results = [] + per_thread = 20 + with ThreadPoolExecutor(max_workers=10) as e: + for i in range(0, rows, int(rows / per_thread)): + try: + predictor = X[i:i+per_thread, ...] + except TypeError: + predictor = X.iloc[i:i+per_thread, ...] + f = e.submit(predict_func, predictor) + results.append(f) + + for f in results: + assert f.result() + + +class TestInplacePredict(unittest.TestCase): + '''Tests for running inplace prediction''' + def test_predict(self): + rows = 1000 + cols = 10 + + np.random.seed(1994) + + X = np.random.randn(rows, cols) + y = np.random.randn(rows) + dtrain = xgb.DMatrix(X, y) + + booster = xgb.train({'tree_method': 'hist'}, + dtrain, num_boost_round=10) + + test = xgb.DMatrix(X[:10, ...]) + predt_from_array = booster.inplace_predict(X[:10, ...]) + predt_from_dmatrix = booster.predict(test) + + np.testing.assert_allclose(predt_from_dmatrix, predt_from_array) + + def predict_dense(x): + inplace_predt = booster.inplace_predict(x) + d = xgb.DMatrix(x) + copied_predt = booster.predict(d) + return np.all(copied_predt == inplace_predt) + + for i in range(10): + run_threaded_predict(X, rows, predict_dense) + + def predict_csr(x): + inplace_predt = booster.inplace_predict(sparse.csr_matrix(x)) + d = xgb.DMatrix(x) + copied_predt = booster.predict(d) + return np.all(copied_predt == inplace_predt) + + for i in range(10): + run_threaded_predict(X, rows, predict_csr) diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index b8d7c9a336db..744da439348c 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -63,8 +63,14 @@ def test_from_dask_dataframe(): from_df = prediction.compute() assert isinstance(prediction, dd.Series) + assert np.all(prediction.compute().values == from_dmatrix) assert np.all(from_dmatrix == from_df.to_numpy()) + series_predictions = xgb.dask.inplace_predict(client, booster, X) + assert isinstance(series_predictions, dd.Series) + np.testing.assert_allclose(series_predictions.compute().values, + from_dmatrix) + def test_from_dask_array(): with LocalCluster(n_workers=5, threads_per_worker=5) as cluster: