From c56cbab1504fdf37db63eed7f851418124c03ce3 Mon Sep 17 00:00:00 2001 From: fis Date: Tue, 24 Aug 2021 14:27:29 +0800 Subject: [PATCH] Rewrite approx. Save cuts. Prototype on fetching. Copy the code. Simple test. Add gpair to batch parameter. Add hessian to batch parameter. Move. Pass hessian into sketching. Extract a push page function. Make private. Lint. Revert debug. Simple DMatrix. Regenerate the index. ama. Clang tidy. Retain page. Fix. Lint. Tidy. Integer backed enum. Convert to uint32_t. Prototype for saving gidx. Save cuts. Prototype on fetching. Copy the code. Simple test. Add gpair to batch parameter. Add hessian to batch parameter. Move. Pass hessian into sketching. Extract a push page function. Make private. Lint. Revert debug. Simple DMatrix. Initial port. Pass in hessian. Init column sampler. Unused code. Use ctx. Merge sampling. Use ctx in partition. Fix init root. Force regenerate the sketch. Create a ctx. Get it compile. Don't use const method. Use page id. Pass in base row id. Pass the cut instead. Small fixes. Debug. Fix bin size. Debug. Fixes. Debug. Fix empty partition. Remove comment. Lint. Fix tests compilation. Remove check. Merge some fixes. fix. Fix fetching. lint. Extract expand entry. Lint. Fix unittests. Fix windows build. Fix comparison. Make const. Note. const. Fix reduce hist. Fix sparse data. Avoid implicit conversion. private. mem leak. Remove skip initialization. Use maximum space. demo. lint. File link tags. ama. Fix redefinition. Fix ranking. use npy. Comment. Tune it down. Specify the tree method. Get rid of the duplicated partitioner. Allocate task. Tests. make batches. Log. Remove span. Revert "make batches." This reverts commit 33f7072bcd8d32f6b842c71e96843df09d19da9b. small cleanup. Lint. Revert demo. Better make batches. Demo. Test for grow policy. Test feature weights. small cleanup. Remove iterator in evaluation. Fix dask test. Pass n_threads. Start implementation for categorical data. Fix. Add apply split. Enumerate splits. Enable sklearn. Works. d_step. update. Pass feature types into index. Search cut. Add test. As cat. Fix cut. Extract some tests. Fix. Interesting case. Add Python tests. Cleanup. Revert "Interesting case." This reverts commit 6bbaac20f9b9a34809775ebfaff0d20ca509bb94. Bin. Fix. Dispatch. Remove subtraction trick. Lint Use multiple buffers. Revert "Use multiple buffers." This reverts commit 2849f57cfce3b46c69e77015d187aebf4c7c10cd. Test for external memory. Format. Partition based categorical split. Remove debug code. Fix. Lint. Fix test. Fix demo. Fix. Add test. Remove use of omp func. name. Fix. test. Make LCG impl compliant to std. Fix test. Constexpr. Use unsigned type. osx More test. --- amalgamation/xgboost-all0.cc | 1 + demo/guide-python/categorical.py | 4 +- doc/parameter.rst | 16 +- include/xgboost/learner.h | 9 +- include/xgboost/objective.h | 6 + include/xgboost/task.h | 17 + include/xgboost/tree_updater.h | 15 +- .../scala/spark/XGBoostRegressorSuite.scala | 6 +- plugin/example/custom_obj.cc | 3 + python-package/xgboost/sklearn.py | 15 +- src/common/categorical.h | 14 +- src/common/common.h | 8 + src/common/hist_util.cc | 172 +++--- src/common/hist_util.h | 18 +- src/common/partition_builder.h | 494 ++++++++++-------- src/common/quantile.cc | 4 +- src/common/threading_utils.h | 1 + src/data/gradient_index.cc | 19 +- src/data/gradient_index.h | 32 +- src/data/gradient_index_page_source.cc | 3 +- src/data/gradient_index_page_source.h | 6 +- src/data/sparse_page_dmatrix.cc | 3 +- src/gbm/gbtree.cc | 9 +- src/learner.cc | 19 +- src/metric/auc.cu | 38 +- src/objective/aft_obj.cu | 2 + src/objective/hinge.cu | 2 + src/objective/multiclass_obj.cu | 3 + src/objective/rank_obj.cu | 2 + src/objective/regression_obj.cu | 12 + src/tree/hist/evaluate_splits.h | 253 ++++++--- src/tree/hist/histogram.h | 123 +++-- src/tree/param.h | 53 +- src/tree/split_evaluator.h | 8 +- src/tree/tree_model.cc | 1 + src/tree/tree_updater.cc | 8 +- src/tree/updater_approx.cc | 355 +++++++++++++ src/tree/updater_approx.h | 146 ++++++ src/tree/updater_colmaker.cc | 16 +- src/tree/updater_gpu_hist.cu | 17 +- src/tree/updater_histmaker.cc | 10 +- src/tree/updater_prune.cc | 8 +- src/tree/updater_quantile_hist.cc | 52 +- src/tree/updater_quantile_hist.h | 19 +- src/tree/updater_refresh.cc | 2 +- src/tree/updater_sync.cc | 2 +- tests/cpp/categorical_helpers.h | 44 ++ tests/cpp/common/test_quantile.cu | 37 +- tests/cpp/common/test_quantile.h | 4 +- tests/cpp/common/test_span.cc | 3 + tests/cpp/data/test_gradient_index.cc | 34 ++ tests/cpp/gbm/test_gbtree.cc | 2 +- tests/cpp/helpers.cc | 21 +- tests/cpp/helpers.h | 43 +- tests/cpp/tree/gpu_hist/test_histogram.cu | 41 +- tests/cpp/tree/hist/test_evaluate_splits.cc | 152 +++++- tests/cpp/tree/hist/test_histogram.cc | 187 ++++++- tests/cpp/tree/test_approx.cc | 129 +++++ tests/cpp/tree/test_gpu_hist.cu | 10 +- tests/cpp/tree/test_histmaker.cc | 6 +- tests/cpp/tree/test_param.cc | 6 +- tests/cpp/tree/test_prune.cc | 2 +- tests/cpp/tree/test_quantile_hist.cc | 4 +- tests/cpp/tree/test_refresh.cc | 3 +- tests/cpp/tree/test_tree_policy.cc | 2 +- tests/cpp/tree/test_tree_stat.cc | 4 +- tests/python-gpu/test_gpu_updaters.py | 47 +- tests/python/test_updaters.py | 50 ++ tests/python/test_with_dask.py | 8 +- tests/python/test_with_sklearn.py | 11 +- 70 files changed, 2158 insertions(+), 718 deletions(-) create mode 100644 include/xgboost/task.h create mode 100644 src/tree/updater_approx.cc create mode 100644 src/tree/updater_approx.h create mode 100644 tests/cpp/categorical_helpers.h create mode 100644 tests/cpp/tree/test_approx.cc diff --git a/amalgamation/xgboost-all0.cc b/amalgamation/xgboost-all0.cc index 1241ced409cd..83d003052149 100644 --- a/amalgamation/xgboost-all0.cc +++ b/amalgamation/xgboost-all0.cc @@ -57,6 +57,7 @@ #include "../src/tree/updater_refresh.cc" #include "../src/tree/updater_sync.cc" #include "../src/tree/updater_histmaker.cc" +#include "../src/tree/updater_approx.cc" #include "../src/tree/constraints.cc" // linear diff --git a/demo/guide-python/categorical.py b/demo/guide-python/categorical.py index 9476c1ed6232..86377025346f 100644 --- a/demo/guide-python/categorical.py +++ b/demo/guide-python/categorical.py @@ -1,5 +1,5 @@ -"""Experimental support for categorical data. After 1.5 XGBoost `gpu_hist` tree method -has experimental support for one-hot encoding based tree split. +"""Experimental support for categorical data. After 1.6 XGBoost `gpu_hist` and `approx` +tree method have experimental support for one-hot encoding based tree split. In before, users need to run an encoder themselves before passing the data into XGBoost, which creates a sparse matrix and potentially increase memory usage. This demo showcases diff --git a/doc/parameter.rst b/doc/parameter.rst index f4b56949ff42..fcad21f94536 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -232,15 +232,25 @@ Parameters for Tree Booster - Constraints for interaction representing permitted interactions. The constraints must be specified in the form of a nest list, e.g. ``[[0, 1], [2, 3, 4]]``, where each inner list is a group of indices of features that are allowed to interact with each other. - See tutorial for more information + See tutorial for more information. -Additional parameters for ``hist`` and ``gpu_hist`` tree method -================================================================ +Additional parameters for ``hist`` and ``gpu_hist`` and ``approx`` tree method +============================================================================== * ``single_precision_histogram``, [default=``false``] - Use single precision to build histograms instead of double precision. +Additional parameters for ``approx`` tree method +================================================ + +* ``max_cat_to_onehot`` + + - A threshold for deciding whether XGBoost should use one-hot encoding based split for + categorical data. When number of categories is lesser than the threshold then one-hot + encoding is chosen, otherwise the categories will be partitioned into children nodes. + Only relevant for regression and binary classification and `approx` tree method. + Additional parameters for Dart Booster (``booster=dart``) ========================================================= diff --git a/include/xgboost/learner.h b/include/xgboost/learner.h index 09c16eff6cfa..9060a2e1ccaa 100644 --- a/include/xgboost/learner.h +++ b/include/xgboost/learner.h @@ -11,15 +11,16 @@ #include #include #include -#include #include #include #include +#include +#include -#include #include #include #include +#include #include namespace xgboost { @@ -307,11 +308,13 @@ struct LearnerModelParam { uint32_t num_feature { 0 }; /* \brief number of classes, if it is multi-class classification */ uint32_t num_output_group { 0 }; + /* \brief Current task, determined by objective. */ + Task task{Task::kRegression}; LearnerModelParam() = default; // As the old `LearnerModelParamLegacy` is still used by binary IO, we keep // this one as an immutable copy. - LearnerModelParam(LearnerModelParamLegacy const& user_param, float base_margin); + LearnerModelParam(LearnerModelParamLegacy const& user_param, float base_margin, Task t); /* \brief Whether this parameter is initialized with LearnerModelParamLegacy. */ bool Initialized() const { return num_feature != 0; } }; diff --git a/include/xgboost/objective.h b/include/xgboost/objective.h index 3e722a18f37a..86e137dc7641 100644 --- a/include/xgboost/objective.h +++ b/include/xgboost/objective.h @@ -13,6 +13,7 @@ #include #include #include +#include #include #include @@ -72,6 +73,11 @@ class ObjFunction : public Configurable { virtual bst_float ProbToMargin(bst_float base_score) const { return base_score; } + /*! + * \brief Return task of this objective. + */ + virtual enum Task Task() const = 0; + /*! * \brief Create an objective function according to name. * \param tparam Generic parameters. diff --git a/include/xgboost/task.h b/include/xgboost/task.h new file mode 100644 index 000000000000..433efcbd45bb --- /dev/null +++ b/include/xgboost/task.h @@ -0,0 +1,17 @@ +/*! + * Copyright 2015-2021 by XGBoost Contributors + */ +#include +#ifndef XGBOOST_TASK_H_ +#define XGBOOST_TASK_H_ +namespace xgboost { +enum class Task : uint8_t { + kRegression = 0, + kBinary = 1, + kClassification = 2, + kSurvival = 3, + kRanking = 4, + kOther = 5, +}; +} +#endif // XGBOOST_TASK_H_ diff --git a/include/xgboost/tree_updater.h b/include/xgboost/tree_updater.h index f36005a9a69e..540f9500cea0 100644 --- a/include/xgboost/tree_updater.h +++ b/include/xgboost/tree_updater.h @@ -11,16 +11,17 @@ #include #include #include -#include #include #include -#include #include +#include +#include +#include #include -#include -#include #include +#include +#include namespace xgboost { @@ -83,16 +84,14 @@ class TreeUpdater : public Configurable { * \param name Name of the tree updater. * \param tparam A global runtime parameter */ - static TreeUpdater* Create(const std::string& name, GenericParameter const* tparam); + static TreeUpdater* Create(const std::string& name, GenericParameter const* tparam, Task task); }; /*! * \brief Registry entry for tree updater. */ struct TreeUpdaterReg - : public dmlc::FunctionRegEntryBase > { -}; + : public dmlc::FunctionRegEntryBase > {}; /*! * \brief Macro to register tree updater. diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala index 7a8bf6fa4d90..7017615c2b0d 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala @@ -219,8 +219,12 @@ abstract class XGBoostRegressorSuiteBase extends FunSuite with PerTest { } } -class XGBoostCpuRegressorSuite extends XGBoostRegressorSuiteBase { +class XGBoostCpuRegressorSuiteApprox extends XGBoostRegressorSuiteBase { + override protected val treeMethod: String = "approx" +} +class XGBoostCpuRegressorSuiteHist extends XGBoostRegressorSuiteBase { + override protected val treeMethod: String = "hist" } @GpuTestSuite diff --git a/plugin/example/custom_obj.cc b/plugin/example/custom_obj.cc index a18d8aecc490..11d18faf2baa 100644 --- a/plugin/example/custom_obj.cc +++ b/plugin/example/custom_obj.cc @@ -34,6 +34,9 @@ class MyLogistic : public ObjFunction { void Configure(const std::vector >& args) override { param_.UpdateAllowUnknown(args); } + + enum Task Task() const override { return Task::kRegression; } + void GetGradient(const HostDeviceVector &preds, const MetaInfo &info, int iter, diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 7313351fd62b..507090b7781b 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -257,6 +257,16 @@ def inner(y_score: np.ndarray, dmatrix: DMatrix) -> Tuple[str, float]: This parameter replaces `early_stopping_rounds` in :py:meth:`fit` method. + max_cat_to_onehot : bool + + .. versionadded:: 1.6.0 + + A threshold for deciding whether XGBoost should use one-hot encoding based split + for categorical data. When number of categories is lesser than the threshold then + one-hot encoding is chosen, otherwise the categories will be partitioned into + children nodes. Only relevant for regression and binary classification and + `approx` tree method. + kwargs : dict, optional Keyword arguments for XGBoost Booster object. Full documentation of parameters can be found here: @@ -473,6 +483,7 @@ def __init__( enable_categorical: bool = False, eval_metric: Optional[Union[str, List[str], Callable]] = None, early_stopping_rounds: Optional[int] = None, + max_cat_to_onehot: Optional[int] = None, **kwargs: Any ) -> None: if not SKLEARN_INSTALLED: @@ -511,6 +522,7 @@ def __init__( self.enable_categorical = enable_categorical self.eval_metric = eval_metric self.early_stopping_rounds = early_stopping_rounds + self.max_cat_to_onehot = max_cat_to_onehot if kwargs: self.kwargs = kwargs @@ -779,7 +791,8 @@ def _duplicated(parameter: str) -> None: else early_stopping_rounds ) - if self.enable_categorical and params.get("tree_method", None) != "gpu_hist": + tree_method = params.get("tree_method", None) + if self.enable_categorical and tree_method not in ("gpu_hist", "approx"): raise ValueError( "Experimental support for categorical data is not implemented for" " current tree method yet." diff --git a/src/common/categorical.h b/src/common/categorical.h index 3706c4f2370d..9b66d8e429da 100644 --- a/src/common/categorical.h +++ b/src/common/categorical.h @@ -5,11 +5,12 @@ #ifndef XGBOOST_COMMON_CATEGORICAL_H_ #define XGBOOST_COMMON_CATEGORICAL_H_ +#include "bitfield.h" #include "xgboost/base.h" #include "xgboost/data.h" -#include "xgboost/span.h" #include "xgboost/parameter.h" -#include "bitfield.h" +#include "xgboost/span.h" +#include "xgboost/task.h" namespace xgboost { namespace common { @@ -47,6 +48,15 @@ inline void CheckCat(bst_cat_t cat) { "should be non-negative."; } +/*! + * \brief Whether should we use onehot encoding for categorical data. + */ +inline bool UseOneHot(uint32_t n_cats, uint32_t max_cat_to_onehot, Task task) { + bool use_one_hot = + n_cats < max_cat_to_onehot || (task != Task::kRegression && task != Task::kBinary); + return use_one_hot; +} + struct IsCatOp { XGBOOST_DEVICE bool operator()(FeatureType ft) { return ft == FeatureType::kCategorical; diff --git a/src/common/common.h b/src/common/common.h index 8230e532ff69..562941c222a0 100644 --- a/src/common/common.h +++ b/src/common/common.h @@ -188,6 +188,14 @@ std::vector ArgSort(Container const &array, Comp comp = std::less{}) { XGBOOST_PARALLEL_STABLE_SORT(result.begin(), result.end(), op); return result; } + +/** + * Last index of a group in a CSR style of index pointer. + */ +template +XGBOOST_DEVICE size_t LastOf(Idx group, common::Span indptr) { + return indptr[group + 1] - 1; +} } // namespace common } // namespace xgboost #endif // XGBOOST_COMMON_COMMON_H_ diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index 40392a2bcfaa..c14da59a7f60 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -15,7 +15,6 @@ #include "random.h" #include "column_matrix.h" #include "quantile.h" -#include "./../tree/updater_quantile_hist.h" #include "../data/gradient_index.h" #if defined(XGBOOST_MM_PREFETCH_PRESENT) @@ -133,74 +132,108 @@ struct Prefetch { constexpr size_t Prefetch::kNoPrefetchSize; - -template -void BuildHistKernel(const std::vector& gpair, +template +void BuildHistKernel(const std::vector &gpair, const RowSetCollection::Elem row_indices, - const GHistIndexMatrix& gmat, - GHistRow hist) { + const GHistIndexMatrix &gmat, GHistRow hist) { const size_t size = row_indices.Size(); - const size_t* rid = row_indices.begin; - const float* pgh = reinterpret_cast(gpair.data()); - const BinIdxType* gradient_index = gmat.index.data(); - const size_t* row_ptr = gmat.row_ptr.data(); - const uint32_t* offsets = gmat.index.Offset(); - const size_t n_features = row_ptr[row_indices.begin[0]+1] - row_ptr[row_indices.begin[0]]; - FPType* hist_data = reinterpret_cast(hist.data()); - const uint32_t two {2}; // Each element from 'gpair' and 'hist' contains - // 2 FP values: gradient and hessian. - // So we need to multiply each row-index/bin-index by 2 - // to work with gradient pairs as a singe row FP array + const size_t *rid = row_indices.begin; + auto const *pgh = reinterpret_cast(gpair.data()); + const BinIdxType *gradient_index = gmat.index.data(); + + auto const &row_ptr = gmat.row_ptr.data(); + auto base_rowid = gmat.base_rowid; + const uint32_t *offsets = gmat.index.Offset(); + auto get_row_ptr = [&](size_t ridx) { + return first_page ? row_ptr[ridx] : row_ptr[ridx - base_rowid]; + }; + auto get_rid = [&](size_t ridx) { + return first_page ? ridx : (ridx - base_rowid); + }; + + const size_t n_features = + get_row_ptr(row_indices.begin[0] + 1) - get_row_ptr(row_indices.begin[0]); + auto hist_data = reinterpret_cast(hist.data()); + const uint32_t two{2}; // Each element from 'gpair' and 'hist' contains + // 2 FP values: gradient and hessian. + // So we need to multiply each row-index/bin-index by 2 + // to work with gradient pairs as a singe row FP array for (size_t i = 0; i < size; ++i) { - const size_t icol_start = any_missing ? row_ptr[rid[i]] : rid[i] * n_features; - const size_t icol_end = any_missing ? row_ptr[rid[i]+1] : icol_start + n_features; + const size_t icol_start = + any_missing ? get_row_ptr(rid[i]) : get_rid(rid[i]) * n_features; + const size_t icol_end = + any_missing ? get_row_ptr(rid[i] + 1) : icol_start + n_features; + const size_t row_size = icol_end - icol_start; const size_t idx_gh = two * rid[i]; if (do_prefetch) { - const size_t icol_start_prefetch = any_missing ? row_ptr[rid[i+Prefetch::kPrefetchOffset]] : - rid[i + Prefetch::kPrefetchOffset] * n_features; - const size_t icol_end_prefetch = any_missing ? row_ptr[rid[i+Prefetch::kPrefetchOffset]+1] : - icol_start_prefetch + n_features; + const size_t icol_start_prefetch = + any_missing + ? get_row_ptr(rid[i + Prefetch::kPrefetchOffset]) + : get_rid(rid[i + Prefetch::kPrefetchOffset]) * n_features; + const size_t icol_end_prefetch = + any_missing ? get_row_ptr(rid[i + Prefetch::kPrefetchOffset] + 1) + : icol_start_prefetch + n_features; PREFETCH_READ_T0(pgh + two * rid[i + Prefetch::kPrefetchOffset]); for (size_t j = icol_start_prefetch; j < icol_end_prefetch; - j+=Prefetch::GetPrefetchStep()) { + j += Prefetch::GetPrefetchStep()) { PREFETCH_READ_T0(gradient_index + j); } } - const BinIdxType* gr_index_local = gradient_index + icol_start; + const BinIdxType *gr_index_local = gradient_index + icol_start; for (size_t j = 0; j < row_size; ++j) { - const uint32_t idx_bin = two * (static_cast(gr_index_local[j]) + ( - any_missing ? 0 : offsets[j])); - - hist_data[idx_bin] += pgh[idx_gh]; - hist_data[idx_bin+1] += pgh[idx_gh+1]; + const uint32_t idx_bin = two * (static_cast(gr_index_local[j]) + + (any_missing ? 0 : offsets[j])); + hist_data[idx_bin] += pgh[idx_gh]; + hist_data[idx_bin + 1] += pgh[idx_gh + 1]; } } } -template -void BuildHistDispatch(const std::vector& gpair, +template +void BuildHistDispatch(const std::vector &gpair, const RowSetCollection::Elem row_indices, - const GHistIndexMatrix& gmat, GHistRow hist) { - switch (gmat.index.GetBinTypeSize()) { + const GHistIndexMatrix &gmat, GHistRow hist) { + auto first_page = gmat.base_rowid == 0; + if (first_page) { + switch (gmat.index.GetBinTypeSize()) { case kUint8BinsTypeSize: - BuildHistKernel(gpair, row_indices, - gmat, hist); + BuildHistKernel( + gpair, row_indices, gmat, hist); break; case kUint16BinsTypeSize: - BuildHistKernel(gpair, row_indices, - gmat, hist); + BuildHistKernel( + gpair, row_indices, gmat, hist); break; case kUint32BinsTypeSize: - BuildHistKernel(gpair, row_indices, - gmat, hist); + BuildHistKernel( + gpair, row_indices, gmat, hist); break; default: CHECK(false); // no default behavior + } + } else { + switch (gmat.index.GetBinTypeSize()) { + case kUint8BinsTypeSize: + BuildHistKernel( + gpair, row_indices, gmat, hist); + break; + case kUint16BinsTypeSize: + BuildHistKernel( + gpair, row_indices, gmat, hist); + break; + case kUint32BinsTypeSize: + BuildHistKernel( + gpair, row_indices, gmat, hist); + break; + default: + CHECK(false); // no default behavior + } } } @@ -208,73 +241,52 @@ template template void GHistBuilder::BuildHist( const std::vector &gpair, - const RowSetCollection::Elem row_indices, - const GHistIndexMatrix &gmat, - GHistRowT hist) { + const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat, + GHistRowT hist) const { const size_t nrows = row_indices.Size(); const size_t no_prefetch_size = Prefetch::NoPrefetchSize(nrows); // if need to work with all rows from bin-matrix (e.g. root node) - const bool contiguousBlock = (row_indices.begin[nrows - 1] - row_indices.begin[0]) == (nrows - 1); + const bool contiguousBlock = + (row_indices.begin[nrows - 1] - row_indices.begin[0]) == (nrows - 1); if (contiguousBlock) { // contiguous memory access, built-in HW prefetching is enough - BuildHistDispatch(gpair, row_indices, gmat, hist); + BuildHistDispatch(gpair, row_indices, + gmat, hist); } else { - const RowSetCollection::Elem span1(row_indices.begin, row_indices.end - no_prefetch_size); - const RowSetCollection::Elem span2(row_indices.end - no_prefetch_size, row_indices.end); + const RowSetCollection::Elem span1(row_indices.begin, + row_indices.end - no_prefetch_size); + const RowSetCollection::Elem span2(row_indices.end - no_prefetch_size, + row_indices.end); - BuildHistDispatch(gpair, span1, gmat, hist); + BuildHistDispatch(gpair, span1, gmat, + hist); // no prefetching to avoid loading extra memory - BuildHistDispatch(gpair, span2, gmat, hist); + BuildHistDispatch(gpair, span2, gmat, + hist); } } + template void GHistBuilder::BuildHist(const std::vector &gpair, const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat, - GHistRow hist); + GHistRow hist) const; template void GHistBuilder::BuildHist(const std::vector &gpair, const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat, - GHistRow hist); + GHistRow hist) const; template void GHistBuilder::BuildHist(const std::vector &gpair, const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat, - GHistRow hist); + GHistRow hist) const; template void GHistBuilder::BuildHist(const std::vector &gpair, const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat, - GHistRow hist); - -template -void GHistBuilder::SubtractionTrick(GHistRowT self, - GHistRowT sibling, - GHistRowT parent) { - const size_t size = self.size(); - CHECK_EQ(sibling.size(), size); - CHECK_EQ(parent.size(), size); - - const size_t block_size = 1024; // aproximatly 1024 values per block - size_t n_blocks = size/block_size + !!(size%block_size); - - ParallelFor(omp_ulong(n_blocks), [&](omp_ulong iblock) { - const size_t ibegin = iblock*block_size; - const size_t iend = (((iblock+1)*block_size > size) ? size : ibegin + block_size); - SubtractionHist(self, parent, sibling, ibegin, iend); - }); -} -template -void GHistBuilder::SubtractionTrick(GHistRow self, - GHistRow sibling, - GHistRow parent); -template -void GHistBuilder::SubtractionTrick(GHistRow self, - GHistRow sibling, - GHistRow parent); - + GHistRow hist) const; } // namespace common } // namespace xgboost diff --git a/src/common/hist_util.h b/src/common/hist_util.h index 9dc0bd1c5fd1..fa309105a905 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -441,7 +441,7 @@ class ParallelGHistBuilder { } // Reduce following bins (begin, end] for nid-node in dst across threads - void ReduceHist(size_t nid, size_t begin, size_t end) { + void ReduceHist(size_t nid, size_t begin, size_t end) const { CHECK_GT(end, begin); CHECK_LT(nid, nodes_); @@ -467,7 +467,6 @@ class ParallelGHistBuilder { } } - protected: void MatchThreadsToNodes(const BlockedSpace2d& space) { const size_t space_size = space.Size(); const size_t chunck_size = space_size / nthreads_ + !!(space_size % nthreads_); @@ -514,6 +513,7 @@ class ParallelGHistBuilder { } } + private: void MatchNodeNidPairToHist() { size_t hist_allocated_additionally = 0; @@ -567,26 +567,18 @@ class GHistBuilder { using GHistRowT = GHistRow; GHistBuilder() = default; - GHistBuilder(size_t nthread, uint32_t nbins) : nthread_{nthread}, nbins_{nbins} {} + explicit GHistBuilder(uint32_t nbins): nbins_{nbins} {} // construct a histogram via histogram aggregation template - void BuildHist(const std::vector& gpair, + void BuildHist(const std::vector &gpair, const RowSetCollection::Elem row_indices, - const GHistIndexMatrix& gmat, - GHistRowT hist); - // construct a histogram via subtraction trick - void SubtractionTrick(GHistRowT self, - GHistRowT sibling, - GHistRowT parent); - + const GHistIndexMatrix &gmat, GHistRowT hist) const; uint32_t GetNumBins() const { return nbins_; } private: - /*! \brief number of threads for parallel computation */ - size_t nthread_ { 0 }; /*! \brief number of all bins over all features */ uint32_t nbins_ { 0 }; }; diff --git a/src/common/partition_builder.h b/src/common/partition_builder.h index 98612359ec73..0a59b6522428 100644 --- a/src/common/partition_builder.h +++ b/src/common/partition_builder.h @@ -1,228 +1,266 @@ - -/*! - * Copyright 2021 by Contributors - * \file row_set.h - * \brief Quick Utility to compute subset of rows - * \author Philip Cho, Tianqi Chen - */ -#ifndef XGBOOST_COMMON_PARTITION_BUILDER_H_ -#define XGBOOST_COMMON_PARTITION_BUILDER_H_ - -#include -#include -#include -#include -#include -#include "xgboost/tree_model.h" -#include "../common/column_matrix.h" - -namespace xgboost { -namespace common { - -// The builder is required for samples partition to left and rights children for set of nodes -// Responsible for: -// 1) Effective memory allocation for intermediate results for multi-thread work -// 2) Merging partial results produced by threads into original row set (row_set_collection_) -// BlockSize is template to enable memory alignment easily with C++11 'alignas()' feature -template -class PartitionBuilder { - public: - template - void Init(const size_t n_tasks, size_t n_nodes, Func funcNTask) { - left_right_nodes_sizes_.resize(n_nodes); - blocks_offsets_.resize(n_nodes+1); - - blocks_offsets_[0] = 0; - for (size_t i = 1; i < n_nodes+1; ++i) { - blocks_offsets_[i] = blocks_offsets_[i-1] + funcNTask(i-1); - } - - if (n_tasks > max_n_tasks_) { - mem_blocks_.resize(n_tasks); - max_n_tasks_ = n_tasks; - } - } - - // split row indexes (rid_span) to 2 parts (left_part, right_part) depending - // on comparison of indexes values (idx_span) and split point (split_cond) - // Handle dense columns - // Analog of std::stable_partition, but in no-inplace manner - template - inline std::pair PartitionKernel(const ColumnType& column, - common::Span rid_span, const int32_t split_cond, - common::Span left_part, common::Span right_part) { - size_t* p_left_part = left_part.data(); - size_t* p_right_part = right_part.data(); - size_t nleft_elems = 0; - size_t nright_elems = 0; - auto state = column.GetInitialState(rid_span.front()); - - for (auto rid : rid_span) { - const int32_t bin_id = column.GetBinIdx(rid, &state); - if (any_missing && bin_id == ColumnType::kMissingId) { - if (default_left) { - p_left_part[nleft_elems++] = rid; - } else { - p_right_part[nright_elems++] = rid; - } - } else { - if (bin_id <= split_cond) { - p_left_part[nleft_elems++] = rid; - } else { - p_right_part[nright_elems++] = rid; - } - } - } - - return {nleft_elems, nright_elems}; - } - - - template - void Partition(const size_t node_in_set, const size_t nid, const common::Range1d range, - const int32_t split_cond, - const ColumnMatrix& column_matrix, const RegTree& tree, const size_t* rid) { - common::Span rid_span(rid + range.begin(), rid + range.end()); - common::Span left = GetLeftBuffer(node_in_set, - range.begin(), range.end()); - common::Span right = GetRightBuffer(node_in_set, - range.begin(), range.end()); - const bst_uint fid = tree[nid].SplitIndex(); - const bool default_left = tree[nid].DefaultLeft(); - const auto column_ptr = column_matrix.GetColumn(fid); - - std::pair child_nodes_sizes; - - if (column_ptr->GetType() == xgboost::common::kDenseColumn) { - const common::DenseColumn& column = - static_cast& >(*(column_ptr.get())); - if (default_left) { - child_nodes_sizes = PartitionKernel(column, rid_span, - split_cond, left, right); - } else { - child_nodes_sizes = PartitionKernel(column, rid_span, - split_cond, left, right); - } - } else { - CHECK_EQ(any_missing, true); - const common::SparseColumn& column - = static_cast& >(*(column_ptr.get())); - if (default_left) { - child_nodes_sizes = PartitionKernel(column, rid_span, - split_cond, left, right); - } else { - child_nodes_sizes = PartitionKernel(column, rid_span, - split_cond, left, right); - } - } - - const size_t n_left = child_nodes_sizes.first; - const size_t n_right = child_nodes_sizes.second; - - SetNLeftElems(node_in_set, range.begin(), range.end(), n_left); - SetNRightElems(node_in_set, range.begin(), range.end(), n_right); - } - - - // allocate thread local memory, should be called for each specific task - void AllocateForTask(size_t id) { - if (mem_blocks_[id].get() == nullptr) { - BlockInfo* local_block_ptr = new BlockInfo; - CHECK_NE(local_block_ptr, (BlockInfo*)nullptr); - mem_blocks_[id].reset(local_block_ptr); - } - } - - common::Span GetLeftBuffer(int nid, size_t begin, size_t end) { - const size_t task_idx = GetTaskIdx(nid, begin); - return { mem_blocks_.at(task_idx)->Left(), end - begin }; - } - - common::Span GetRightBuffer(int nid, size_t begin, size_t end) { - const size_t task_idx = GetTaskIdx(nid, begin); - return { mem_blocks_.at(task_idx)->Right(), end - begin }; - } - - void SetNLeftElems(int nid, size_t begin, size_t end, size_t n_left) { - size_t task_idx = GetTaskIdx(nid, begin); - mem_blocks_.at(task_idx)->n_left = n_left; - } - - void SetNRightElems(int nid, size_t begin, size_t end, size_t n_right) { - size_t task_idx = GetTaskIdx(nid, begin); - mem_blocks_.at(task_idx)->n_right = n_right; - } - - - size_t GetNLeftElems(int nid) const { - return left_right_nodes_sizes_[nid].first; - } - - size_t GetNRightElems(int nid) const { - return left_right_nodes_sizes_[nid].second; - } - - // Each thread has partial results for some set of tree-nodes - // The function decides order of merging partial results into final row set - void CalculateRowOffsets() { - for (size_t i = 0; i < blocks_offsets_.size()-1; ++i) { - size_t n_left = 0; - for (size_t j = blocks_offsets_[i]; j < blocks_offsets_[i+1]; ++j) { - mem_blocks_[j]->n_offset_left = n_left; - n_left += mem_blocks_[j]->n_left; - } - size_t n_right = 0; - for (size_t j = blocks_offsets_[i]; j < blocks_offsets_[i+1]; ++j) { - mem_blocks_[j]->n_offset_right = n_left + n_right; - n_right += mem_blocks_[j]->n_right; - } - left_right_nodes_sizes_[i] = {n_left, n_right}; - } - } - - void MergeToArray(int nid, size_t begin, size_t* rows_indexes) { - size_t task_idx = GetTaskIdx(nid, begin); - - size_t* left_result = rows_indexes + mem_blocks_[task_idx]->n_offset_left; - size_t* right_result = rows_indexes + mem_blocks_[task_idx]->n_offset_right; - - const size_t* left = mem_blocks_[task_idx]->Left(); - const size_t* right = mem_blocks_[task_idx]->Right(); - - std::copy_n(left, mem_blocks_[task_idx]->n_left, left_result); - std::copy_n(right, mem_blocks_[task_idx]->n_right, right_result); - } - - size_t GetTaskIdx(int nid, size_t begin) { - return blocks_offsets_[nid] + begin / BlockSize; - } - - protected: - struct BlockInfo{ - size_t n_left; - size_t n_right; - - size_t n_offset_left; - size_t n_offset_right; - - size_t* Left() { - return &left_data_[0]; - } - - size_t* Right() { - return &right_data_[0]; - } - private: - size_t left_data_[BlockSize]; - size_t right_data_[BlockSize]; - }; - std::vector> left_right_nodes_sizes_; - std::vector blocks_offsets_; - std::vector> mem_blocks_; - size_t max_n_tasks_ = 0; -}; - -} // namespace common -} // namespace xgboost - -#endif // XGBOOST_COMMON_PARTITION_BUILDER_H_ +/*! + * Copyright 2021 by Contributors + * \file row_set.h + * \brief Quick Utility to compute subset of rows + * \author Philip Cho, Tianqi Chen + */ +#ifndef XGBOOST_COMMON_PARTITION_BUILDER_H_ +#define XGBOOST_COMMON_PARTITION_BUILDER_H_ + +#include +#include +#include +#include +#include +#include "xgboost/tree_model.h" +#include "../common/column_matrix.h" + +namespace xgboost { +namespace common { + +// The builder is required for samples partition to left and rights children for set of nodes +// Responsible for: +// 1) Effective memory allocation for intermediate results for multi-thread work +// 2) Merging partial results produced by threads into original row set (row_set_collection_) +// BlockSize is template to enable memory alignment easily with C++11 'alignas()' feature +template +class PartitionBuilder { + public: + template + void Init(const size_t n_tasks, size_t n_nodes, Func funcNTask) { + left_right_nodes_sizes_.resize(n_nodes); + blocks_offsets_.resize(n_nodes+1); + + blocks_offsets_[0] = 0; + for (size_t i = 1; i < n_nodes+1; ++i) { + blocks_offsets_[i] = blocks_offsets_[i-1] + funcNTask(i-1); + } + + if (n_tasks > max_n_tasks_) { + mem_blocks_.resize(n_tasks); + max_n_tasks_ = n_tasks; + } + } + + // split row indexes (rid_span) to 2 parts (left_part, right_part) depending + // on comparison of indexes values (idx_span) and split point (split_cond) + // Handle dense columns + // Analog of std::stable_partition, but in no-inplace manner + template + inline std::pair PartitionKernel(const ColumnType& column, + common::Span rid_span, const int32_t split_cond, + common::Span left_part, common::Span right_part) { + size_t* p_left_part = left_part.data(); + size_t* p_right_part = right_part.data(); + size_t nleft_elems = 0; + size_t nright_elems = 0; + auto state = column.GetInitialState(rid_span.front()); + + for (auto rid : rid_span) { + const int32_t bin_id = column.GetBinIdx(rid, &state); + if (any_missing && bin_id == ColumnType::kMissingId) { + if (default_left) { + p_left_part[nleft_elems++] = rid; + } else { + p_right_part[nright_elems++] = rid; + } + } else { + if (bin_id <= split_cond) { + p_left_part[nleft_elems++] = rid; + } else { + p_right_part[nright_elems++] = rid; + } + } + } + + return {nleft_elems, nright_elems}; + } + + template + inline std::pair + PartitionRangeKernel(common::Span ridx, + common::Span left_part, + common::Span right_part, Pred pred) { + size_t *p_left_part = left_part.data(); + size_t *p_right_part = right_part.data(); + size_t nleft_elems = 0; + size_t nright_elems = 0; + for (auto row_id : ridx) { + if (pred(row_id)) { + p_left_part[nleft_elems++] = row_id; + } else { + p_right_part[nright_elems++] = row_id; + } + } + return {nleft_elems, nright_elems}; + } + + template + void Partition(const size_t node_in_set, const size_t nid, const common::Range1d range, + const int32_t split_cond, + const ColumnMatrix& column_matrix, const RegTree& tree, const size_t* rid) { + common::Span rid_span(rid + range.begin(), rid + range.end()); + common::Span left = GetLeftBuffer(node_in_set, + range.begin(), range.end()); + common::Span right = GetRightBuffer(node_in_set, + range.begin(), range.end()); + const bst_uint fid = tree[nid].SplitIndex(); + const bool default_left = tree[nid].DefaultLeft(); + const auto column_ptr = column_matrix.GetColumn(fid); + + std::pair child_nodes_sizes; + + if (column_ptr->GetType() == xgboost::common::kDenseColumn) { + const common::DenseColumn& column = + static_cast& >(*(column_ptr.get())); + if (default_left) { + child_nodes_sizes = PartitionKernel(column, rid_span, + split_cond, left, right); + } else { + child_nodes_sizes = PartitionKernel(column, rid_span, + split_cond, left, right); + } + } else { + CHECK_EQ(any_missing, true); + const common::SparseColumn& column + = static_cast& >(*(column_ptr.get())); + if (default_left) { + child_nodes_sizes = PartitionKernel(column, rid_span, + split_cond, left, right); + } else { + child_nodes_sizes = PartitionKernel(column, rid_span, + split_cond, left, right); + } + } + + const size_t n_left = child_nodes_sizes.first; + const size_t n_right = child_nodes_sizes.second; + + SetNLeftElems(node_in_set, range.begin(), range.end(), n_left); + SetNRightElems(node_in_set, range.begin(), range.end(), n_right); + } + + template + void PartitionRange(const size_t node_in_set, const size_t nid, + common::Range1d range, bst_feature_t fidx, + common::RowSetCollection *p_row_set_collection, + Pred pred) { + auto &row_set_collection = *p_row_set_collection; + const size_t *p_ridx = row_set_collection[nid].begin; + common::Span ridx(p_ridx + range.begin(), p_ridx + range.end()); + common::Span left = + this->GetLeftBuffer(node_in_set, range.begin(), range.end()); + common::Span right = + this->GetRightBuffer(node_in_set, range.begin(), range.end()); + std::pair child_nodes_sizes = + PartitionRangeKernel(ridx, left, right, pred); + + const size_t n_left = child_nodes_sizes.first; + const size_t n_right = child_nodes_sizes.second; + + this->SetNLeftElems(node_in_set, range.begin(), range.end(), n_left); + this->SetNRightElems(node_in_set, range.begin(), range.end(), n_right); + } + + // allocate thread local memory, should be called for each specific task + void AllocateForTask(size_t id) { + if (mem_blocks_[id].get() == nullptr) { + BlockInfo* local_block_ptr = new BlockInfo; + CHECK_NE(local_block_ptr, (BlockInfo*)nullptr); + mem_blocks_[id].reset(local_block_ptr); + } + } + + common::Span GetLeftBuffer(int nid, size_t begin, size_t end) { + const size_t task_idx = GetTaskIdx(nid, begin); + return { mem_blocks_.at(task_idx)->Left(), end - begin }; + } + + common::Span GetRightBuffer(int nid, size_t begin, size_t end) { + const size_t task_idx = GetTaskIdx(nid, begin); + return { mem_blocks_.at(task_idx)->Right(), end - begin }; + } + + void SetNLeftElems(int nid, size_t begin, size_t end, size_t n_left) { + size_t task_idx = GetTaskIdx(nid, begin); + mem_blocks_.at(task_idx)->n_left = n_left; + } + + void SetNRightElems(int nid, size_t begin, size_t end, size_t n_right) { + size_t task_idx = GetTaskIdx(nid, begin); + mem_blocks_.at(task_idx)->n_right = n_right; + } + + + size_t GetNLeftElems(int nid) const { + return left_right_nodes_sizes_[nid].first; + } + + size_t GetNRightElems(int nid) const { + return left_right_nodes_sizes_[nid].second; + } + + // Each thread has partial results for some set of tree-nodes + // The function decides order of merging partial results into final row set + void CalculateRowOffsets() { + for (size_t i = 0; i < blocks_offsets_.size()-1; ++i) { + size_t n_left = 0; + for (size_t j = blocks_offsets_[i]; j < blocks_offsets_[i+1]; ++j) { + mem_blocks_[j]->n_offset_left = n_left; + n_left += mem_blocks_[j]->n_left; + } + size_t n_right = 0; + for (size_t j = blocks_offsets_[i]; j < blocks_offsets_[i+1]; ++j) { + mem_blocks_[j]->n_offset_right = n_left + n_right; + n_right += mem_blocks_[j]->n_right; + } + left_right_nodes_sizes_[i] = {n_left, n_right}; + } + } + + void MergeToArray(int nid, size_t begin, size_t* rows_indexes) { + size_t task_idx = GetTaskIdx(nid, begin); + + size_t* left_result = rows_indexes + mem_blocks_[task_idx]->n_offset_left; + size_t* right_result = rows_indexes + mem_blocks_[task_idx]->n_offset_right; + + const size_t* left = mem_blocks_[task_idx]->Left(); + const size_t* right = mem_blocks_[task_idx]->Right(); + + std::copy_n(left, mem_blocks_[task_idx]->n_left, left_result); + std::copy_n(right, mem_blocks_[task_idx]->n_right, right_result); + } + + size_t GetTaskIdx(int nid, size_t begin) { + return blocks_offsets_[nid] + begin / BlockSize; + } + + private: + struct BlockInfo{ + size_t n_left; + size_t n_right; + + size_t n_offset_left; + size_t n_offset_right; + + size_t* Left() { + return &left_data_[0]; + } + + size_t* Right() { + return &right_data_[0]; + } + private: + size_t left_data_[BlockSize]; + size_t right_data_[BlockSize]; + }; + std::vector> left_right_nodes_sizes_; + std::vector blocks_offsets_; + std::vector> mem_blocks_; + size_t max_n_tasks_ = 0; +}; + +} // namespace common +} // namespace xgboost + +#endif // XGBOOST_COMMON_PARTITION_BUILDER_H_ diff --git a/src/common/quantile.cc b/src/common/quantile.cc index 4e84719c0537..2bbe6b937f79 100644 --- a/src/common/quantile.cc +++ b/src/common/quantile.cc @@ -189,7 +189,7 @@ void HostSketchContainer::PushRowPage( if (is_dense) { for (size_t ii = begin; ii < end; ii++) { if (IsCat(feature_types_, ii)) { - categories_[ii].emplace(p_inst[ii].fvalue); + categories_[ii].emplace(AsCat(p_inst[ii].fvalue)); } else { sketches_[ii].Push(p_inst[ii].fvalue, w); } @@ -199,7 +199,7 @@ void HostSketchContainer::PushRowPage( auto const& entry = p_inst[i]; if (entry.index >= begin && entry.index < end) { if (IsCat(feature_types_, entry.index)) { - categories_[entry.index].emplace(entry.fvalue); + categories_[entry.index].emplace(AsCat(entry.fvalue)); } else { sketches_[entry.index].Push(entry.fvalue, w); } diff --git a/src/common/threading_utils.h b/src/common/threading_utils.h index ab3765f501fe..d6cd44b71466 100644 --- a/src/common/threading_utils.h +++ b/src/common/threading_utils.h @@ -251,6 +251,7 @@ inline int32_t OmpSetNumThreadsWithoutHT(int32_t* p_threads) { inline int32_t OmpGetNumThreads(int32_t n_threads) { if (n_threads <= 0) { n_threads = omp_get_num_procs(); + n_threads = std::min(n_threads, omp_get_max_threads()); } return n_threads; } diff --git a/src/data/gradient_index.cc b/src/data/gradient_index.cc index f2e14882e80b..b0eea203ea56 100644 --- a/src/data/gradient_index.cc +++ b/src/data/gradient_index.cc @@ -9,8 +9,9 @@ namespace xgboost { -void GHistIndexMatrix::PushBatch(SparsePage const &batch, size_t rbegin, - size_t prev_sum, uint32_t nbins, +void GHistIndexMatrix::PushBatch(SparsePage const &batch, + common::Span ft, + size_t rbegin, size_t prev_sum, uint32_t nbins, int32_t n_threads) { // The number of threads is pegged to the batch size. If the OMP // block is parallelized on anything other than the batch/block size, @@ -86,7 +87,7 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch, size_t rbegin, common::BinTypeSize curent_bin_size = index.GetBinTypeSize(); if (curent_bin_size == common::kUint8BinsTypeSize) { common::Span index_data_span = {index.data(), n_index}; - SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins, + SetIndexData(index_data_span, ft, batch_threads, batch, rbegin, nbins, [offsets](auto idx, auto j) { return static_cast(idx - offsets[j]); }); @@ -94,7 +95,7 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch, size_t rbegin, } else if (curent_bin_size == common::kUint16BinsTypeSize) { common::Span index_data_span = {index.data(), n_index}; - SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins, + SetIndexData(index_data_span, ft, batch_threads, batch, rbegin, nbins, [offsets](auto idx, auto j) { return static_cast(idx - offsets[j]); }); @@ -102,7 +103,7 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch, size_t rbegin, CHECK_EQ(curent_bin_size, common::kUint32BinsTypeSize); common::Span index_data_span = {index.data(), n_index}; - SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins, + SetIndexData(index_data_span, ft, batch_threads, batch, rbegin, nbins, [offsets](auto idx, auto j) { return static_cast(idx - offsets[j]); }); @@ -113,7 +114,7 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch, size_t rbegin, not reduced */ } else { common::Span index_data_span = {index.data(), n_index}; - SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins, + SetIndexData(index_data_span, ft, batch_threads, batch, rbegin, nbins, [](auto idx, auto) { return idx; }); } @@ -147,15 +148,17 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_bins, common::Span h size_t prev_sum = 0; const bool isDense = p_fmat->IsDense(); this->isDense_ = isDense; + auto ft = p_fmat->Info().feature_types.ConstHostSpan(); for (const auto &batch : p_fmat->GetBatches()) { - this->PushBatch(batch, rbegin, prev_sum, nbins, nthread); + this->PushBatch(batch, ft, rbegin, prev_sum, nbins, nthread); prev_sum = row_ptr[rbegin + batch.Size()]; rbegin += batch.Size(); } } void GHistIndexMatrix::Init(SparsePage const &batch, + common::Span ft, common::HistogramCuts const &cuts, int32_t max_bins_per_feat, bool isDense, int32_t n_threads) { @@ -176,7 +179,7 @@ void GHistIndexMatrix::Init(SparsePage const &batch, size_t rbegin = 0; size_t prev_sum = 0; - this->PushBatch(batch, rbegin, prev_sum, nbins, n_threads); + this->PushBatch(batch, ft, rbegin, prev_sum, nbins, n_threads); } void GHistIndexMatrix::ResizeIndex(const size_t n_index, diff --git a/src/data/gradient_index.h b/src/data/gradient_index.h index 971e82d4f081..353b7022c67d 100644 --- a/src/data/gradient_index.h +++ b/src/data/gradient_index.h @@ -7,6 +7,7 @@ #include #include "xgboost/base.h" #include "xgboost/data.h" +#include "../common/categorical.h" #include "../common/hist_util.h" #include "../common/threading_utils.h" @@ -18,8 +19,9 @@ namespace xgboost { * index for CPU histogram. On GPU ellpack page is used. */ class GHistIndexMatrix { - void PushBatch(SparsePage const &batch, size_t rbegin, size_t prev_sum, - uint32_t nbins, int32_t n_threads); + void PushBatch(SparsePage const &batch, common::Span ft, + size_t rbegin, size_t prev_sum, uint32_t nbins, + int32_t n_threads); public: /*! \brief row pointer to rows by element position */ @@ -40,12 +42,14 @@ class GHistIndexMatrix { } // Create a global histogram matrix, given cut void Init(DMatrix* p_fmat, int max_num_bins, common::Span hess); - void Init(SparsePage const &page, common::HistogramCuts const &cuts, - int32_t max_bins_per_feat, bool is_dense, int32_t n_threads); + void Init(SparsePage const &page, common::Span ft, + common::HistogramCuts const &cuts, int32_t max_bins_per_feat, + bool is_dense, int32_t n_threads); // specific method for sparse data as no possibility to reduce allocated memory template void SetIndexData(common::Span index_data_span, + common::Span ft, size_t batch_threads, const SparsePage &batch, size_t rbegin, size_t nbins, GetOffset get_offset) { const xgboost::Entry *data_ptr = batch.data.HostVector().data(); @@ -61,9 +65,23 @@ class GHistIndexMatrix { SparsePage::Inst inst = {data_ptr + offset_vec[i], size}; CHECK_EQ(ibegin + inst.size(), iend); for (bst_uint j = 0; j < inst.size(); ++j) { - uint32_t idx = cut.SearchBin(inst[j]); - index_data[ibegin + j] = get_offset(idx, j); - ++hit_count_tloc_[tid * nbins + idx]; + auto e = inst[j]; + if (common::IsCat(ft, e.index)) { + auto const& cut_ptrs_ = cut.Ptrs(); + auto const& cut_values_ = cut.Values(); + auto beg = cut_ptrs_.at(e.index) + cut_values_.cbegin(); + auto end = cut_ptrs_.at(e.index + 1) + cut_values_.cbegin(); + auto bin_idx = std::lower_bound(beg, end, e.fvalue) - cut_values_.cbegin(); + if (bin_idx == cut_ptrs_.at(e.index + 1)) { + bin_idx -= 1; + } + index_data[ibegin + j] = get_offset(bin_idx, j); + ++hit_count_tloc_[tid * nbins + bin_idx]; + } else { + uint32_t idx = cut.SearchBin(inst[j]); + index_data[ibegin + j] = get_offset(idx, j); + ++hit_count_tloc_[tid * nbins + idx]; + } } }); } diff --git a/src/data/gradient_index_page_source.cc b/src/data/gradient_index_page_source.cc index e35970bf3e4e..8f592213f58f 100644 --- a/src/data/gradient_index_page_source.cc +++ b/src/data/gradient_index_page_source.cc @@ -10,7 +10,8 @@ void GradientIndexPageSource::Fetch() { auto const& csr = source_->Page(); this->page_.reset(new GHistIndexMatrix()); CHECK_NE(cuts_.Values().size(), 0); - this->page_->Init(*csr, cuts_, max_bin_per_feat_, is_dense_, nthreads_); + this->page_->Init(*csr, feature_types_, cuts_, max_bin_per_feat_, is_dense_, + nthreads_); this->WriteCache(); } } diff --git a/src/data/gradient_index_page_source.h b/src/data/gradient_index_page_source.h index db66a1cda02f..a11057d5492c 100644 --- a/src/data/gradient_index_page_source.h +++ b/src/data/gradient_index_page_source.h @@ -16,16 +16,18 @@ class GradientIndexPageSource : public PageSourceIncMixIn { common::HistogramCuts cuts_; bool is_dense_; int32_t max_bin_per_feat_; + common::Span feature_types_; public: GradientIndexPageSource(float missing, int nthreads, bst_feature_t n_features, size_t n_batches, std::shared_ptr cache, BatchParam param, common::HistogramCuts cuts, bool is_dense, int32_t max_bin_per_feat, + common::Span feature_types, std::shared_ptr source) : PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache), - cuts_{std::move(cuts)}, is_dense_{is_dense}, max_bin_per_feat_{ - max_bin_per_feat} { + cuts_{std::move(cuts)}, is_dense_{is_dense}, + max_bin_per_feat_{max_bin_per_feat}, feature_types_{feature_types} { this->source_ = source; this->Fetch(); } diff --git a/src/data/sparse_page_dmatrix.cc b/src/data/sparse_page_dmatrix.cc index 18c81a654ad3..db2e298df8b4 100644 --- a/src/data/sparse_page_dmatrix.cc +++ b/src/data/sparse_page_dmatrix.cc @@ -184,10 +184,11 @@ BatchSet SparsePageDMatrix::GetGradientIndex(const BatchParam& batch_param_ = param; ghist_index_source_.reset(); CHECK_NE(cuts.Values().size(), 0); + auto ft = this->info_.feature_types.ConstHostSpan(); ghist_index_source_.reset(new GradientIndexPageSource( this->missing_, this->ctx_.Threads(), this->Info().num_col_, this->n_batches_, cache_info_.at(id), param, std::move(cuts), - this->IsDense(), param.max_bin, sparse_page_source_)); + this->IsDense(), param.max_bin, ft, sparse_page_source_)); } else { CHECK(ghist_index_source_); ghist_index_source_->Reset(); diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 859e5ba9d3ad..4ddfd2cd0f13 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -168,7 +168,8 @@ void GBTree::ConfigureUpdaters() { // calling this function. break; case TreeMethod::kApprox: - tparam_.updater_seq = "grow_histmaker,prune"; + // grow_histmaker,prune + tparam_.updater_seq = "grow_global_approx_histmaker"; break; case TreeMethod::kExact: tparam_.updater_seq = "grow_colmaker,prune"; @@ -306,7 +307,8 @@ void GBTree::InitUpdater(Args const& cfg) { // create new updaters for (const std::string& pstr : ups) { - std::unique_ptr up(TreeUpdater::Create(pstr.c_str(), generic_param_)); + std::unique_ptr up( + TreeUpdater::Create(pstr.c_str(), generic_param_, model_.learner_model_param->task)); up->Configure(cfg); updaters_.push_back(std::move(up)); } @@ -391,7 +393,8 @@ void GBTree::LoadConfig(Json const& in) { auto const& j_updaters = get(in["updater"]); updaters_.clear(); for (auto const& kv : j_updaters) { - std::unique_ptr up(TreeUpdater::Create(kv.first, generic_param_)); + std::unique_ptr up( + TreeUpdater::Create(kv.first, generic_param_, model_.learner_model_param->task)); up->LoadConfig(kv.second); updaters_.push_back(std::move(up)); } diff --git a/src/learner.cc b/src/learner.cc index 399d299f5358..f2d128ffd3be 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -159,13 +159,12 @@ struct LearnerModelParamLegacy : public dmlc::Parameter } }; -LearnerModelParam::LearnerModelParam( - LearnerModelParamLegacy const &user_param, float base_margin) - : base_score{base_margin}, num_feature{user_param.num_feature}, - num_output_group{user_param.num_class == 0 - ? 1 - : static_cast(user_param.num_class)} -{} +LearnerModelParam::LearnerModelParam(LearnerModelParamLegacy const& user_param, float base_margin, + Task t) + : base_score{base_margin}, + num_feature{user_param.num_feature}, + num_output_group{user_param.num_class == 0 ? 1 : static_cast(user_param.num_class)}, + task{t} {} struct LearnerTrainParam : public XGBoostParameter { // data split mode, can be row, col, or none. @@ -339,8 +338,8 @@ class LearnerConfiguration : public Learner { // - model is created from scratch. // - model is configured second time due to change of parameter if (!learner_model_param_.Initialized() || mparam_.base_score != mparam_backup.base_score) { - learner_model_param_ = LearnerModelParam(mparam_, - obj_->ProbToMargin(mparam_.base_score)); + learner_model_param_ = + LearnerModelParam(mparam_, obj_->ProbToMargin(mparam_.base_score), obj_->Task()); } this->ConfigureGBM(old_tparam, args); @@ -832,7 +831,7 @@ class LearnerIO : public LearnerConfiguration { } learner_model_param_ = - LearnerModelParam(mparam_, obj_->ProbToMargin(mparam_.base_score)); + LearnerModelParam(mparam_, obj_->ProbToMargin(mparam_.base_score), obj_->Task()); if (attributes_.find("objective") != attributes_.cend()) { auto obj_str = attributes_.at("objective"); auto j_obj = Json::Load({obj_str.c_str(), obj_str.size()}); diff --git a/src/metric/auc.cu b/src/metric/auc.cu index 153a0290afcc..3488d6104c1e 100644 --- a/src/metric/auc.cu +++ b/src/metric/auc.cu @@ -15,6 +15,7 @@ #include "xgboost/span.h" #include "xgboost/data.h" #include "auc.h" +#include "../common/common.h" #include "../common/device_helpers.cuh" #include "../common/ranking_utils.cuh" @@ -201,17 +202,8 @@ void Transpose(common::Span in, common::Span out, size_t m, }); } -/** - * Last index of a group in a CSR style of index pointer. - */ -template -XGBOOST_DEVICE size_t LastOf(size_t group, common::Span indptr) { - return indptr[group + 1] - 1; -} - -double ScaleClasses(common::Span results, - common::Span local_area, common::Span fp, - common::Span tp, common::Span auc, +double ScaleClasses(common::Span results, common::Span local_area, + common::Span fp, common::Span tp, common::Span auc, std::shared_ptr cache, size_t n_classes) { dh::XGBDeviceAllocator alloc; if (rabit::IsDistributed()) { @@ -229,8 +221,8 @@ double ScaleClasses(common::Span results, double tp_sum; double auc_sum; thrust::tie(auc_sum, tp_sum) = - thrust::reduce(thrust::cuda::par(alloc), reduce_in, reduce_in + n_classes, - Pair{0.0, 0.0}, PairPlus{}); + thrust::reduce(thrust::cuda::par(alloc), reduce_in, reduce_in + n_classes, Pair{0.0, 0.0}, + PairPlus{}); if (tp_sum != 0 && !std::isnan(auc_sum)) { auc_sum /= tp_sum; } else { @@ -300,9 +292,9 @@ void SegmentedReduceAUC(common::Span d_unique_idx, double fp, tp, fp_prev, tp_prev; if (i == d_unique_class_ptr[class_id]) { // first item is ignored, we use this thread to calculate the last item - thrust::tie(fp, tp) = d_fptp[LastOf(class_id, d_class_ptr)]; + thrust::tie(fp, tp) = d_fptp[common::LastOf(class_id, d_class_ptr)]; thrust::tie(fp_prev, tp_prev) = - d_neg_pos[d_unique_idx[LastOf(class_id, d_unique_class_ptr)]]; + d_neg_pos[d_unique_idx[common::LastOf(class_id, d_unique_class_ptr)]]; } else { thrust::tie(fp, tp) = d_fptp[d_unique_idx[i] - 1]; thrust::tie(fp_prev, tp_prev) = d_neg_pos[d_unique_idx[i - 1]]; @@ -413,10 +405,10 @@ double GPUMultiClassAUCOVR(common::Span predts, } uint32_t class_id = d_unique_idx[i] / n_samples; d_neg_pos[d_unique_idx[i]] = d_fptp[d_unique_idx[i] - 1]; - if (i == LastOf(class_id, d_unique_class_ptr)) { + if (i == common::LastOf(class_id, d_unique_class_ptr)) { // last one needs to be included. - size_t last = d_unique_idx[LastOf(class_id, d_unique_class_ptr)]; - d_neg_pos[LastOf(class_id, d_class_ptr)] = d_fptp[last - 1]; + size_t last = d_unique_idx[common::LastOf(class_id, d_unique_class_ptr)]; + d_neg_pos[common::LastOf(class_id, d_class_ptr)] = d_fptp[last - 1]; return; } }); @@ -592,7 +584,7 @@ GPURankingAUC(common::Span predts, MetaInfo const &info, auto data_group_begin = d_group_ptr[group_id]; size_t n_samples = d_group_ptr[group_id + 1] - data_group_begin; // last item of current group - if (item.idx == LastOf(group_id, d_threads_group_ptr)) { + if (item.idx == common::LastOf(group_id, d_threads_group_ptr)) { if (item.w > 0) { s_d_auc[group_id] = item.predt / item.w; } else { @@ -797,10 +789,10 @@ GPURankingPRAUCImpl(common::Span predts, MetaInfo const &info, } auto group_idx = dh::SegmentId(d_group_ptr, d_unique_idx[i]); d_neg_pos[d_unique_idx[i]] = d_fptp[d_unique_idx[i] - 1]; - if (i == LastOf(group_idx, d_unique_class_ptr)) { + if (i == common::LastOf(group_idx, d_unique_class_ptr)) { // last one needs to be included. - size_t last = d_unique_idx[LastOf(group_idx, d_unique_class_ptr)]; - d_neg_pos[LastOf(group_idx, d_group_ptr)] = d_fptp[last - 1]; + size_t last = d_unique_idx[common::LastOf(group_idx, d_unique_class_ptr)]; + d_neg_pos[common::LastOf(group_idx, d_group_ptr)] = d_fptp[last - 1]; return; } }); @@ -821,7 +813,7 @@ GPURankingPRAUCImpl(common::Span predts, MetaInfo const &info, auto it = dh::MakeTransformIterator>( thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t g) { double fp, tp; - thrust::tie(fp, tp) = d_fptp[LastOf(g, d_group_ptr)]; + thrust::tie(fp, tp) = d_fptp[common::LastOf(g, d_group_ptr)]; double area = fp * tp; auto n_documents = d_group_ptr[g + 1] - d_group_ptr[g]; if (area > 0 && n_documents >= 2) { diff --git a/src/objective/aft_obj.cu b/src/objective/aft_obj.cu index 95e4f4c5525d..bc046c57a2c4 100644 --- a/src/objective/aft_obj.cu +++ b/src/objective/aft_obj.cu @@ -38,6 +38,8 @@ class AFTObj : public ObjFunction { param_.UpdateAllowUnknown(args); } + enum Task Task() const override { return Task::kSurvival; } + template void GetGradientImpl(const HostDeviceVector &preds, const MetaInfo &info, diff --git a/src/objective/hinge.cu b/src/objective/hinge.cu index 0c8c2f317b1c..532ff65a8b5a 100644 --- a/src/objective/hinge.cu +++ b/src/objective/hinge.cu @@ -27,6 +27,8 @@ class HingeObj : public ObjFunction { void Configure( const std::vector > &args) override {} + enum Task Task() const override { return Task::kRegression; } + void GetGradient(const HostDeviceVector &preds, const MetaInfo &info, int iter, diff --git a/src/objective/multiclass_obj.cu b/src/objective/multiclass_obj.cu index 6ffa6eac2ca8..bd7da3805f65 100644 --- a/src/objective/multiclass_obj.cu +++ b/src/objective/multiclass_obj.cu @@ -45,6 +45,9 @@ class SoftmaxMultiClassObj : public ObjFunction { void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); } + + enum Task Task() const override { return Task::kClassification; } + void GetGradient(const HostDeviceVector& preds, const MetaInfo& info, int iter, diff --git a/src/objective/rank_obj.cu b/src/objective/rank_obj.cu index 164b60611ef3..b89c046365a9 100644 --- a/src/objective/rank_obj.cu +++ b/src/objective/rank_obj.cu @@ -754,6 +754,8 @@ class LambdaRankObj : public ObjFunction { param_.UpdateAllowUnknown(args); } + enum Task Task() const override { return Task::kRanking; } + void GetGradient(const HostDeviceVector& preds, const MetaInfo& info, int iter, diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index ccb3a723d32e..1acddd0f1613 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -52,6 +52,10 @@ class RegLossObj : public ObjFunction { param_.UpdateAllowUnknown(args); } + enum Task Task() const override { + return std::is_same::value ? Task::kBinary : Task::kRegression; + } + void GetGradient(const HostDeviceVector& preds, const MetaInfo &info, int, HostDeviceVector* out_gpair) override { @@ -207,6 +211,8 @@ class PoissonRegression : public ObjFunction { param_.UpdateAllowUnknown(args); } + enum Task Task() const override { return Task::kRegression; } + void GetGradient(const HostDeviceVector& preds, const MetaInfo &info, int, HostDeviceVector *out_gpair) override { @@ -298,6 +304,8 @@ class CoxRegression : public ObjFunction { void Configure( const std::vector >&) override {} + enum Task Task() const override { return Task::kRegression; } + void GetGradient(const HostDeviceVector& preds, const MetaInfo &info, int, HostDeviceVector *out_gpair) override { @@ -395,6 +403,8 @@ class GammaRegression : public ObjFunction { void Configure( const std::vector >&) override {} + enum Task Task() const override { return Task::kRegression; } + void GetGradient(const HostDeviceVector &preds, const MetaInfo &info, int, HostDeviceVector *out_gpair) override { @@ -491,6 +501,8 @@ class TweedieRegression : public ObjFunction { metric_ = os.str(); } + enum Task Task() const override { return Task::kRegression; } + void GetGradient(const HostDeviceVector& preds, const MetaInfo &info, int, HostDeviceVector *out_gpair) override { diff --git a/src/tree/hist/evaluate_splits.h b/src/tree/hist/evaluate_splits.h index 24b99ed4a8c7..dcfcb293abba 100644 --- a/src/tree/hist/evaluate_splits.h +++ b/src/tree/hist/evaluate_splits.h @@ -6,13 +6,16 @@ #include #include +#include #include #include #include +#include "xgboost/task.h" #include "../param.h" #include "../constraints.h" #include "../split_evaluator.h" +#include "../../common/categorical.h" #include "../../common/random.h" #include "../../common/hist_util.h" #include "../../data/gradient_index.h" @@ -36,13 +39,13 @@ template class HistEvaluator { int32_t n_threads_ {0}; FeatureInteractionConstraintHost interaction_constraints_; std::vector snode_; + Task task_; // if sum of statistics for non-missing values in the node // is equal to sum of statistics for all values: // then - there are no missing values // else - there are missing values - bool static SplitContainsMissingValues(const GradStats e, - const NodeEntry &snode) { + bool static SplitContainsMissingValues(const GradStats e, const NodeEntry &snode) { if (e.GetGrad() == snode.stats.GetGrad() && e.GetHess() == snode.stats.GetHess()) { return false; @@ -50,38 +53,40 @@ template class HistEvaluator { return true; } } + enum SplitType { kNum = 0, kOneHot = 1, kPart = 2 }; // Enumerate/Scan the split values of specific feature // Returns the sum of gradients corresponding to the data points that contains // a non-missing value for the particular feature fid. - template - GradStats EnumerateSplit( - common::HistogramCuts const &cut, const common::GHistRow &hist, - const NodeEntry &snode, SplitEntry *p_best, bst_feature_t fidx, - bst_node_t nidx, - TreeEvaluator::SplitEvaluator const &evaluator) const { + template + GradStats EnumerateSplit(common::HistogramCuts const &cut, common::Span sorted_idx, + const common::GHistRow &hist, bst_feature_t fidx, + bst_node_t nidx, + TreeEvaluator::SplitEvaluator const &evaluator, + SplitEntry *p_best) const { static_assert(d_step == +1 || d_step == -1, "Invalid step."); // aliases const std::vector &cut_ptr = cut.Ptrs(); const std::vector &cut_val = cut.Values(); + auto const &parent = snode_[nidx]; + int32_t n_bins{static_cast(cut_ptr.at(fidx + 1) - cut_ptr[fidx])}; + auto f_hist = hist.subspan(cut_ptr[fidx], n_bins); // statistics on both sides of split - GradStats c; - GradStats e; + GradStats left_sum; + GradStats right_sum; // best split so far SplitEntry best; // bin boundaries - CHECK_LE(cut_ptr[fidx], - static_cast(std::numeric_limits::max())); - CHECK_LE(cut_ptr[fidx + 1], - static_cast(std::numeric_limits::max())); - // imin: index (offset) of the minimum value for feature fid - // need this for backward enumeration + CHECK_LE(cut_ptr[fidx], static_cast(std::numeric_limits::max())); + CHECK_LE(cut_ptr[fidx + 1], static_cast(std::numeric_limits::max())); + // imin: index (offset) of the minimum value for feature fid need this for backward + // enumeration const auto imin = static_cast(cut_ptr[fidx]); - // ibegin, iend: smallest/largest cut points for feature fid - // use int to allow for value -1 + // ibegin, iend: smallest/largest cut points for feature fid use int to allow for + // value -1 int32_t ibegin, iend; if (d_step > 0) { ibegin = static_cast(cut_ptr[fidx]); @@ -91,49 +96,118 @@ template class HistEvaluator { iend = static_cast(cut_ptr[fidx]) - 1; } + auto calc_bin_value = [&](auto i) { + switch (split_type) { + case kNum: { + left_sum.Add(hist[i].GetGrad(), hist[i].GetHess()); + right_sum.SetSubstract(parent.stats, left_sum); + break; + } + case kOneHot: { + // not-chosen categories go to left + right_sum = GradStats{hist[i]}; + left_sum.SetSubstract(parent.stats, right_sum); + break; + } + case kPart: { + auto j = d_step == 1 ? (i - ibegin) : (ibegin - i); + right_sum.Add(f_hist[sorted_idx[j]].GetGrad(), f_hist[sorted_idx[j]].GetHess()); + left_sum.SetSubstract(parent.stats, right_sum); + break; + } + default: { + std::terminate(); + } + } + }; + + int32_t best_thresh{-1}; for (int32_t i = ibegin; i != iend; i += d_step) { // start working // try to find a split - e.Add(hist[i].GetGrad(), hist[i].GetHess()); - if (e.GetHess() >= param_.min_child_weight) { - c.SetSubstract(snode.stats, e); - if (c.GetHess() >= param_.min_child_weight) { - bst_float loss_chg; - bst_float split_pt; - if (d_step > 0) { - // forward enumeration: split at right bound of each bin - loss_chg = static_cast( - evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{e}, - GradStats{c}) - - snode.root_gain); - split_pt = cut_val[i]; - best.Update(loss_chg, fidx, split_pt, d_step == -1, e, c); - } else { - // backward enumeration: split at left bound of each bin - loss_chg = static_cast( - evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{c}, - GradStats{e}) - - snode.root_gain); - if (i == imin) { - // for leftmost bin, left bound is the smallest feature value - split_pt = cut.MinValues()[fidx]; - } else { - split_pt = cut_val[i - 1]; + calc_bin_value(i); + bool improved{false}; + if (left_sum.GetHess() >= param_.min_child_weight && + right_sum.GetHess() >= param_.min_child_weight) { + bst_float loss_chg; + bst_float split_pt; + if (d_step > 0) { + // forward enumeration: split at right bound of each bin + loss_chg = + static_cast(evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, + GradStats{right_sum}) - + parent.root_gain); + split_pt = cut_val[i]; + improved = best.Update(loss_chg, fidx, split_pt, d_step == -1, split_type != kNum, + left_sum, right_sum); + } else { + // backward enumeration: split at left bound of each bin + loss_chg = + static_cast(evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{right_sum}, + GradStats{left_sum}) - + parent.root_gain); + switch (split_type) { + case kNum: { + if (i == imin) { + split_pt = cut.MinValues()[fidx]; + } else { + split_pt = cut_val[i - 1]; + } + break; + } + case kOneHot: { + split_pt = cut_val[i]; + break; + } + case kPart: { + split_pt = cut_val[i]; + break; } - best.Update(loss_chg, fidx, split_pt, d_step == -1, c, e); } + improved = best.Update(loss_chg, fidx, split_pt, d_step == -1, split_type != kNum, + right_sum, left_sum); + } + if (improved) { + best_thresh = i; } } } + + if (split_type == kPart && best_thresh != -1) { + auto n = common::CatBitField::ComputeStorageSize(n_bins); + best.cat_bits.resize(n, 0); + common::CatBitField cat_bits{best.cat_bits}; + + if (d_step == 1) { + std::for_each(sorted_idx.begin(), sorted_idx.begin() + (best_thresh - ibegin + 1), + [&cat_bits](size_t c) { cat_bits.Set(c); }); + } else { + std::for_each(sorted_idx.rbegin(), sorted_idx.rbegin() + (ibegin - best_thresh), + [&cat_bits](size_t c) { cat_bits.Set(c); }); + } + } p_best->Update(best); - return e; + switch (split_type) { + case kNum: + // Normal, accumulated to left + return left_sum; + case kOneHot: + // Doesn't matter, not accumulating. + return {}; + case kPart: + // Accumulated to right due to chosen cats go to right. + return right_sum; + } + return left_sum; } public: void EvaluateSplits(const common::HistCollection &hist, - common::HistogramCuts const &cut, const RegTree &tree, - std::vector* p_entries) { + common::HistogramCuts const &cut, + common::Span feature_types, + const RegTree &tree, + std::vector *p_entries) { auto& entries = *p_entries; // All nodes are on the same level, so we can store the shared ptr. std::vector>> features( @@ -150,7 +224,7 @@ template class HistEvaluator { return features[nidx_in_set]->Size(); }, grain_size); - std::vector tloc_candidates(omp_get_max_threads() * entries.size()); + std::vector tloc_candidates(n_threads_ * entries.size()); for (size_t i = 0; i < entries.size(); ++i) { for (decltype(n_threads_) j = 0; j < n_threads_; ++j) { tloc_candidates[i * n_threads_ + j] = entries[i]; @@ -167,12 +241,37 @@ template class HistEvaluator { auto features_set = features[nidx_in_set]->ConstHostSpan(); for (auto fidx_in_set = r.begin(); fidx_in_set < r.end(); fidx_in_set++) { auto fidx = features_set[fidx_in_set]; - if (interaction_constraints_.Query(nidx, fidx)) { - auto grad_stats = EnumerateSplit<+1>(cut, histogram, snode_[nidx], - best, fidx, nidx, evaluator); + bool is_cat = common::IsCat(feature_types, fidx); + if (!interaction_constraints_.Query(nidx, fidx)) { + continue; + } + if (is_cat) { + auto n_bins = cut.Ptrs().at(fidx + 1) - cut.Ptrs()[fidx]; + if (common::UseOneHot(n_bins, param_.max_cat_to_onehot, task_)) { + EnumerateSplit<+1, kOneHot>(cut, {}, histogram, fidx, nidx, evaluator, best); + EnumerateSplit<-1, kOneHot>(cut, {}, histogram, fidx, nidx, evaluator, best); + } else { + auto const &cut_ptr = cut.Ptrs(); + std::vector sorted_idx(n_bins); + std::iota(sorted_idx.begin(), sorted_idx.end(), 0); + auto feat_hist = histogram.subspan(cut_ptr[fidx], n_bins); + std::stable_sort(sorted_idx.begin(), sorted_idx.end(), [&](size_t l, size_t r) { + auto ret = evaluator.CalcWeightCat(param_, feat_hist[l]) < + evaluator.CalcWeightCat(param_, feat_hist[r]); + static_assert(std::is_same::value, ""); + return ret; + }); + auto grad_stats = + EnumerateSplit<+1, kPart>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best); + if (SplitContainsMissingValues(grad_stats, snode_[nidx])) { + EnumerateSplit<-1, kPart>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best); + } + } + } else { + auto grad_stats = + EnumerateSplit<+1, kNum>(cut, {}, histogram, fidx, nidx, evaluator, best); if (SplitContainsMissingValues(grad_stats, snode_[nidx])) { - EnumerateSplit<-1>(cut, histogram, snode_[nidx], best, fidx, nidx, - evaluator); + EnumerateSplit<-1, kNum>(cut, {}, histogram, fidx, nidx, evaluator, best); } } } @@ -187,7 +286,7 @@ template class HistEvaluator { } } // Add splits to tree, handles all statistic - void ApplyTreeSplit(ExpandEntry candidate, RegTree *p_tree) { + void ApplyTreeSplit(ExpandEntry const& candidate, RegTree *p_tree) { auto evaluator = tree_evaluator_.GetEvaluator(); RegTree &tree = *p_tree; @@ -201,13 +300,31 @@ template class HistEvaluator { auto right_weight = evaluator.CalcWeight( candidate.nid, param_, GradStats{candidate.split.right_sum}); - tree.ExpandNode(candidate.nid, candidate.split.SplitIndex(), - candidate.split.split_value, candidate.split.DefaultLeft(), - base_weight, left_weight * param_.learning_rate, - right_weight * param_.learning_rate, - candidate.split.loss_chg, parent_sum.GetHess(), - candidate.split.left_sum.GetHess(), - candidate.split.right_sum.GetHess()); + if (candidate.split.is_cat) { + std::vector split_cats; + if (candidate.split.cat_bits.empty()) { + CHECK_LT(candidate.split.split_value, std::numeric_limits::max()) + << "Categorical feature value too large."; + auto cat = common::AsCat(candidate.split.split_value); + split_cats.resize(LBitField32::ComputeStorageSize(std::max(cat + 1, 1)), 0); + LBitField32 cat_bits; + cat_bits = LBitField32(split_cats); + cat_bits.Set(cat); + } else { + split_cats = candidate.split.cat_bits; + } + + tree.ExpandCategorical( + candidate.nid, candidate.split.SplitIndex(), split_cats, candidate.split.DefaultLeft(), + base_weight, left_weight, right_weight, candidate.split.loss_chg, parent_sum.GetHess(), + candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess()); + } else { + tree.ExpandNode(candidate.nid, candidate.split.SplitIndex(), candidate.split.split_value, + candidate.split.DefaultLeft(), base_weight, + left_weight * param_.learning_rate, right_weight * param_.learning_rate, + candidate.split.loss_chg, parent_sum.GetHess(), + candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess()); + } // Set up child constraints auto left_child = tree[candidate.nid].LeftChild(); @@ -249,14 +366,14 @@ template class HistEvaluator { public: // The column sampler must be constructed by caller since we need to preserve the rng // for the entire training session. - explicit HistEvaluator(TrainParam const ¶m, MetaInfo const &info, - int32_t n_threads, - std::shared_ptr sampler, + explicit HistEvaluator(TrainParam const ¶m, MetaInfo const &info, int32_t n_threads, + std::shared_ptr sampler, Task task, bool skip_0_index = false) - : param_{param}, column_sampler_{std::move(sampler)}, - tree_evaluator_{param, static_cast(info.num_col_), - GenericParameter::kCpuId}, - n_threads_{n_threads} { + : param_{param}, + column_sampler_{std::move(sampler)}, + tree_evaluator_{param, static_cast(info.num_col_), GenericParameter::kCpuId}, + n_threads_{n_threads}, + task_{task} { interaction_constraints_.Configure(param, info.num_col_); column_sampler_->Init(info.num_col_, info.feature_weigths.HostVector(), param_.colsample_bynode, param_.colsample_bylevel, diff --git a/src/tree/hist/histogram.h b/src/tree/hist/histogram.h index 70c756e765e6..0ff16fdc60fe 100644 --- a/src/tree/hist/histogram.h +++ b/src/tree/hist/histogram.h @@ -11,6 +11,8 @@ #include "rabit/rabit.h" #include "xgboost/tree_model.h" #include "../../common/hist_util.h" +#include "../../data/gradient_index.h" +#include "../../common/observer.h" namespace xgboost { namespace tree { @@ -25,8 +27,9 @@ template class HistogramBuilder { common::GHistBuilder builder_; common::ParallelGHistBuilder buffer_; rabit::Reducer reducer_; - int32_t max_bin_ {-1}; + BatchParam param_; int32_t n_threads_ {-1}; + size_t n_batches_ {0}; // Whether XGBoost is running in distributed environment. bool is_distributed_ {false}; @@ -39,59 +42,56 @@ template class HistogramBuilder { * \param is_distributed Mostly used for testing to allow injecting parameters instead * of using global rabit variable. */ - void Reset(uint32_t total_bins, int32_t max_bin_per_feat, int32_t n_threads, - bool is_distributed = rabit::IsDistributed()) { + void Reset(uint32_t total_bins, BatchParam p, int32_t n_threads, + size_t n_batches, bool is_distributed) { CHECK_GE(n_threads, 1); n_threads_ = n_threads; - CHECK_GE(max_bin_per_feat, 2); - max_bin_ = max_bin_per_feat; + n_batches_ = n_batches; + param_ = p; hist_.Init(total_bins); hist_local_worker_.Init(total_bins); buffer_.Init(total_bins); - builder_ = common::GHistBuilder(n_threads, total_bins); + builder_ = common::GHistBuilder(total_bins); is_distributed_ = is_distributed; } template - void - BuildLocalHistograms(DMatrix *p_fmat, - std::vector nodes_for_explicit_hist_build, - common::RowSetCollection const &row_set_collection, - const std::vector &gpair_h) { + void BuildLocalHistograms( + size_t page_idx, + common::BlockedSpace2d space, + GHistIndexMatrix const &gidx, + std::vector const &nodes_for_explicit_hist_build, + common::RowSetCollection const &row_set_collection, + const std::vector &gpair_h) { const size_t n_nodes = nodes_for_explicit_hist_build.size(); - - // create space of size (# rows in each node) - common::BlockedSpace2d space( - n_nodes, - [&](size_t node) { - const int32_t nid = nodes_for_explicit_hist_build[node].nid; - return row_set_collection[nid].Size(); - }, - 256); + CHECK_GT(n_nodes, 0); std::vector target_hists(n_nodes); for (size_t i = 0; i < n_nodes; ++i) { const int32_t nid = nodes_for_explicit_hist_build[i].nid; target_hists[i] = hist_[nid]; } - buffer_.Reset(this->n_threads_, n_nodes, space, target_hists); + if (page_idx == 0) { + // FIXME: Handle different size of space. + buffer_.Reset(this->n_threads_, n_nodes, space, target_hists); + } // Parallel processing by nodes and data in each node - for (auto const &gmat : p_fmat->GetBatches( - BatchParam{GenericParameter::kCpuId, max_bin_})) { - common::ParallelFor2d( - space, this->n_threads_, [&](size_t nid_in_set, common::Range1d r) { - const auto tid = static_cast(omp_get_thread_num()); - const int32_t nid = nodes_for_explicit_hist_build[nid_in_set].nid; - - auto start_of_row_set = row_set_collection[nid].begin; - auto rid_set = common::RowSetCollection::Elem( - start_of_row_set + r.begin(), start_of_row_set + r.end(), nid); - builder_.template BuildHist( - gpair_h, rid_set, gmat, - buffer_.GetInitializedHist(tid, nid_in_set)); - }); - } + common::ParallelFor2d( + space, this->n_threads_, [&](size_t nid_in_set, common::Range1d r) { + const auto tid = static_cast(omp_get_thread_num()); + const int32_t nid = nodes_for_explicit_hist_build[nid_in_set].nid; + auto elem = row_set_collection[nid]; + auto start_of_row_set = std::min(r.begin(), elem.Size()); + auto end_of_row_set = std::min(r.end(), elem.Size()); + auto rid_set = common::RowSetCollection::Elem( + elem.begin + start_of_row_set, elem.begin + end_of_row_set, nid); + auto hist = buffer_.GetInitializedHist(tid, nid_in_set); + if (rid_set.Size() != 0) { + builder_.template BuildHist(gpair_h, rid_set, gidx, + hist); + } + }); } void @@ -110,24 +110,36 @@ template class HistogramBuilder { } } - /* Main entry point of this class, build histogram for tree nodes. */ - void BuildHist(DMatrix *p_fmat, RegTree *p_tree, + /** Main entry point of this class, build histogram for tree nodes. */ + void BuildHist(size_t page_id, + common::BlockedSpace2d space, + GHistIndexMatrix const& gidx, RegTree *p_tree, common::RowSetCollection const &row_set_collection, std::vector const &nodes_for_explicit_hist_build, std::vector const &nodes_for_subtraction_trick, std::vector const &gpair) { int starting_index = std::numeric_limits::max(); int sync_count = 0; - this->AddHistRows(&starting_index, &sync_count, - nodes_for_explicit_hist_build, - nodes_for_subtraction_trick, p_tree); - if (p_fmat->IsDense()) { - BuildLocalHistograms(p_fmat, nodes_for_explicit_hist_build, - row_set_collection, gpair); + if (page_id == 0) { + this->AddHistRows(&starting_index, &sync_count, + nodes_for_explicit_hist_build, + nodes_for_subtraction_trick, p_tree); + } + if (gidx.IsDense()) { + this->BuildLocalHistograms(page_id, space, gidx, + nodes_for_explicit_hist_build, + row_set_collection, gpair); } else { - BuildLocalHistograms(p_fmat, nodes_for_explicit_hist_build, - row_set_collection, gpair); + this->BuildLocalHistograms(page_id, space, gidx, + nodes_for_explicit_hist_build, + row_set_collection, gpair); } + + CHECK_GE(n_batches_, 1); + if (page_id != n_batches_ - 1) { + return; + } + if (is_distributed_) { this->SyncHistogramDistributed(p_tree, nodes_for_explicit_hist_build, nodes_for_subtraction_trick, @@ -138,6 +150,25 @@ template class HistogramBuilder { sync_count); } } + /** same as the other build hist but handles only single batch data (in-core) */ + void BuildHist(size_t page_id, GHistIndexMatrix const &gidx, RegTree *p_tree, + common::RowSetCollection const &row_set_collection, + std::vector const &nodes_for_explicit_hist_build, + std::vector const &nodes_for_subtraction_trick, + std::vector const &gpair) { + const size_t n_nodes = nodes_for_explicit_hist_build.size(); + // create space of size (# rows in each node) + common::BlockedSpace2d space( + n_nodes, + [&](size_t nidx_in_set) { + const int32_t nidx = nodes_for_explicit_hist_build[nidx_in_set].nid; + return row_set_collection[nidx].Size(); + }, + 256); + this->BuildHist(page_id, space, gidx, p_tree, row_set_collection, + nodes_for_explicit_hist_build, nodes_for_subtraction_trick, + gpair); + } void SyncHistogramDistributed( RegTree *p_tree, diff --git a/src/tree/param.h b/src/tree/param.h index bebf5b8d6363..968e05c4a71c 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -1,5 +1,5 @@ /*! - * Copyright 2014-2019 by Contributors + * Copyright 2014-2021 by Contributors * \file param.h * \brief training parameters, statistics used to support tree construction. * \author Tianqi Chen @@ -7,6 +7,7 @@ #ifndef XGBOOST_TREE_PARAM_H_ #define XGBOOST_TREE_PARAM_H_ +#include #include #include #include @@ -15,6 +16,7 @@ #include "xgboost/parameter.h" #include "xgboost/data.h" +#include "../common/categorical.h" #include "../common/math.h" namespace xgboost { @@ -36,6 +38,9 @@ struct TrainParam : public XGBoostParameter { enum TreeGrowPolicy { kDepthWise = 0, kLossGuide = 1 }; int grow_policy; + uint32_t max_cat_to_onehot{1}; + float cat_smooth{10.0}; + //----- the rest parameters are less important ---- // minimum amount of hessian(weight) allowed in a child float min_child_weight; @@ -119,6 +124,13 @@ struct TrainParam : public XGBoostParameter { "Tree growing policy. 0: favor splitting at nodes closest to the node, " "i.e. grow depth-wise. 1: favor splitting at nodes with highest loss " "change. (cf. LightGBM)"); + DMLC_DECLARE_FIELD(max_cat_to_onehot) + .set_default(4) + .set_lower_bound(1) + .describe("Maximum number of categories to use one-hot encoding based split."); + DMLC_DECLARE_FIELD(cat_smooth) + .set_default(10.0) + .describe(""); DMLC_DECLARE_FIELD(min_child_weight) .set_lower_bound(0.0f) .set_default(1.0f) @@ -281,6 +293,11 @@ XGBOOST_DEVICE inline T CalcWeight(const TrainingParams &p, T sum_grad, return dw; } +template +XGBOOST_DEVICE T CalcWeightCat(T grad, T hess, T cat_smooth) { + return grad / (hess + cat_smooth); +} + // calculate the cost of loss function template XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess) { @@ -384,6 +401,8 @@ struct SplitEntryContainer { /*! \brief split index */ bst_feature_t sindex{0}; bst_float split_value{0.0f}; + std::vector cat_bits; + bool is_cat{false}; GradientT left_sum; GradientT right_sum; @@ -433,6 +452,8 @@ struct SplitEntryContainer { this->loss_chg = e.loss_chg; this->sindex = e.sindex; this->split_value = e.split_value; + this->is_cat = e.is_cat; + this->cat_bits = e.cat_bits; this->left_sum = e.left_sum; this->right_sum = e.right_sum; return true; @@ -449,9 +470,8 @@ struct SplitEntryContainer { * \return whether the proposed split is better and can replace current split */ bool Update(bst_float new_loss_chg, unsigned split_index, - bst_float new_split_value, bool default_left, - const GradientT &left_sum, - const GradientT &right_sum) { + bst_float new_split_value, bool default_left, bool is_cat, + const GradientT &left_sum, const GradientT &right_sum) { if (this->NeedReplace(new_loss_chg, split_index)) { this->loss_chg = new_loss_chg; if (default_left) { @@ -459,6 +479,31 @@ struct SplitEntryContainer { } this->sindex = split_index; this->split_value = new_split_value; + this->is_cat = is_cat; + this->left_sum = left_sum; + this->right_sum = right_sum; + return true; + } else { + return false; + } + } + + /*! + * \brief Update with partition based categorical split. + * + * \return Whether the proposed split is better and can replace current split. + */ + bool Update(float new_loss_chg, bst_feature_t split_index, common::KCatBitField cats, + bool default_left, GradientT const &left_sum, GradientT const &right_sum) { + if (this->NeedReplace(new_loss_chg, split_index)) { + this->loss_chg = new_loss_chg; + if (default_left) { + split_index |= (1U << 31); + } + this->sindex = split_index; + cat_bits.resize(cats.Bits().size()); + std::copy(cats.Bits().begin(), cats.Bits().end(), cat_bits.begin()); + this->is_cat = true; this->left_sum = left_sum; this->right_sum = right_sum; return true; diff --git a/src/tree/split_evaluator.h b/src/tree/split_evaluator.h index 069718a27378..4fdf70145a95 100644 --- a/src/tree/split_evaluator.h +++ b/src/tree/split_evaluator.h @@ -92,7 +92,7 @@ class TreeEvaluator { XGBOOST_DEVICE float CalcWeight(bst_node_t nodeid, const ParamT ¶m, tree::GradStats const& stats) const { - float w = xgboost::tree::CalcWeight(param, stats); + float w = ::xgboost::tree::CalcWeight(param, stats); if (!has_constraint) { return w; } @@ -107,6 +107,12 @@ class TreeEvaluator { return w; } } + + template + XGBOOST_DEVICE double CalcWeightCat(ParamT const& param, GradientSumT const& stats) const { + return ::xgboost::tree::CalcWeight(param, stats); + } + XGBOOST_DEVICE float CalcGainGivenWeight(ParamT const &p, tree::GradStats const& stats, float w) const { if (stats.GetHess() <= 0) { diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index b30f0d65330c..cbb41adc9416 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -973,6 +973,7 @@ void RegTree::SaveCategoricalSplit(Json* p_out) const { } size_t size = categories.size() - begin; categories_sizes.emplace_back(static_cast(size)); + CHECK_NE(size, 0); } } diff --git a/src/tree/tree_updater.cc b/src/tree/tree_updater.cc index a619713e043a..79fc49a5c311 100644 --- a/src/tree/tree_updater.cc +++ b/src/tree/tree_updater.cc @@ -14,12 +14,13 @@ DMLC_REGISTRY_ENABLE(::xgboost::TreeUpdaterReg); namespace xgboost { -TreeUpdater* TreeUpdater::Create(const std::string& name, GenericParameter const* tparam) { - auto *e = ::dmlc::Registry< ::xgboost::TreeUpdaterReg>::Get()->Find(name); +TreeUpdater* TreeUpdater::Create(const std::string& name, GenericParameter const* tparam, + Task task) { + auto* e = ::dmlc::Registry< ::xgboost::TreeUpdaterReg>::Get()->Find(name); if (e == nullptr) { LOG(FATAL) << "Unknown tree updater " << name; } - auto p_updater = (e->body)(); + auto p_updater = (e->body)(task); p_updater->tparam_ = tparam; return p_updater; } @@ -34,6 +35,7 @@ DMLC_REGISTRY_LINK_TAG(updater_refresh); DMLC_REGISTRY_LINK_TAG(updater_prune); DMLC_REGISTRY_LINK_TAG(updater_quantile_hist); DMLC_REGISTRY_LINK_TAG(updater_histmaker); +DMLC_REGISTRY_LINK_TAG(updater_approx); DMLC_REGISTRY_LINK_TAG(updater_sync); #ifdef XGBOOST_USE_CUDA DMLC_REGISTRY_LINK_TAG(updater_gpu_hist); diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc new file mode 100644 index 000000000000..8a33ea874fa3 --- /dev/null +++ b/src/tree/updater_approx.cc @@ -0,0 +1,355 @@ +/*! + * Copyright 2021 XGBoost contributors + * + * \brief Implementation for the approx tree method. + */ +#include +#include +#include + +#include "../common/random.h" +#include "../data/gradient_index.h" +#include "constraints.h" +#include "driver.h" +#include "updater_approx.h" +#include "hist/evaluate_splits.h" +#include "hist/histogram.h" +#include "hist/param.h" +#include "param.h" +#include "xgboost/base.h" +#include "xgboost/json.h" +#include "xgboost/tree_updater.h" + +namespace xgboost { +namespace tree { + +DMLC_REGISTRY_FILE_TAG(updater_approx); + +template +class GloablApproxBuilder { + protected: + TrainParam param_; + std::shared_ptr col_sampler_; + HistEvaluator evaluator_; + HistogramBuilder histogram_builder_; + GenericParameter const *ctx_; + + std::vector partitioner_; + RegTree *p_last_tree_{nullptr}; + common::Monitor *monitor_; + size_t n_batches_{0}; + common::HistogramCuts feature_values_; + + public: + void InitData(DMatrix *p_fmat, common::Span hess) { + monitor_->Start(__func__); + n_batches_ = 0; + int32_t n_total_bins = 0; + partitioner_.clear(); + // Generating the GHistIndexMatrix is quite slow, is there a way to speed it up? + for (auto const &page : p_fmat->GetBatches( + {GenericParameter::kCpuId, param_.max_bin, hess, true})) { + if (n_total_bins == 0) { + n_total_bins = page.cut.TotalBins(); + feature_values_ = page.cut; + } else { + CHECK_EQ(n_total_bins, page.cut.TotalBins()); + } + partitioner_.emplace_back(page.Size(), page.base_rowid); + n_batches_++; + } + + histogram_builder_.Reset(n_total_bins, + BatchParam{GenericParameter::kCpuId, param_.max_bin, hess}, + ctx_->Threads(), n_batches_, rabit::IsDistributed()); + monitor_->Stop(__func__); + } + + CPUExpandEntry InitRoot(DMatrix *p_fmat, std::vector const &gpair, + common::Span hess, RegTree *p_tree) { + monitor_->Start(__func__); + CPUExpandEntry best; + best.nid = RegTree::kRoot; + best.depth = 0; + GradStats root_sum; + for (auto const &g : gpair) { + root_sum.Add(g); + } + rabit::Allreduce(reinterpret_cast(&root_sum), 2); + std::vector nodes{best}; + size_t i = 0; + auto space = this->ConstructHistSpace(nodes); + for (auto const &page : + p_fmat->GetBatches({GenericParameter::kCpuId, param_.max_bin, hess})) { + histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(), nodes, + {}, gpair); + i++; + } + + auto weight = evaluator_.InitRoot(root_sum); + p_tree->Stat(RegTree::kRoot).sum_hess = root_sum.GetHess(); + p_tree->Stat(RegTree::kRoot).base_weight = weight; + (*p_tree)[RegTree::kRoot].SetLeaf(param_.learning_rate * weight); + + auto const &histograms = histogram_builder_.Histogram(); + auto ft = p_fmat->Info().feature_types.ConstHostSpan(); + evaluator_.EvaluateSplits(histograms, feature_values_, ft, *p_tree, &nodes); + monitor_->Stop(__func__); + + return nodes.front(); + } + + void UpdatePredictionCache(const DMatrix *data, VectorView out_preds) { + monitor_->Start(__func__); + // Caching prediction seems redundant for approx tree method, as sketching takes up + // majority of training time. + CHECK_EQ(out_preds.Size(), data->Info().num_row_); + CHECK(p_last_tree_); + + size_t n_nodes = p_last_tree_->GetNodes().size(); + + auto evaluator = evaluator_.Evaluator(); + auto const &tree = *p_last_tree_; + auto const &snode = evaluator_.Stats(); + for (auto &part : partitioner_) { + CHECK_EQ(part.Size(), n_nodes); + common::BlockedSpace2d space( + part.Size(), [&](size_t node) { return part[node].Size(); }, 1024); + common::ParallelFor2d(space, ctx_->Threads(), [&](size_t nidx, common::Range1d r) { + if (tree[nidx].IsLeaf()) { + const auto rowset = part[nidx]; + auto const &stats = snode.at(nidx); + auto leaf_value = + evaluator.CalcWeight(nidx, param_, GradStats{stats.stats}) * param_.learning_rate; + for (const size_t *it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) { + out_preds[*it] += leaf_value; + } + } + }); + } + monitor_->Stop(__func__); + } + + // Construct a work space for building histogram. Eventually we should move this + // function into histogram builder once hist tree method supports external memory. + common::BlockedSpace2d ConstructHistSpace( + std::vector const &nodes_to_build) const { + std::vector partition_size(nodes_to_build.size(), 0); + for (auto const &partition : partitioner_) { + size_t k = 0; + for (auto node : nodes_to_build) { + auto n_rows_in_node = partition.Partitions()[node.nid].Size(); + partition_size[k] = std::max(partition_size[k], n_rows_in_node); + k++; + } + } + common::BlockedSpace2d space{nodes_to_build.size(), + [&](size_t nidx_in_set) { return partition_size[nidx_in_set]; }, + 256}; + return space; + } + + void BuildHistogram(DMatrix *p_fmat, RegTree *p_tree, + std::vector const &valid_candidates, + std::vector const &gpair, common::Span hess) { + std::vector nodes_to_build; + std::vector nodes_to_sub; + + for (auto const &c : valid_candidates) { + auto left_nidx = (*p_tree)[c.nid].LeftChild(); + auto right_nidx = (*p_tree)[c.nid].RightChild(); + auto fewer_right = c.split.right_sum.GetHess() < c.split.left_sum.GetHess(); + + auto build_nidx = left_nidx; + auto subtract_nidx = right_nidx; + if (fewer_right) { + std::swap(build_nidx, subtract_nidx); + } + nodes_to_build.push_back(CPUExpandEntry{build_nidx, p_tree->GetDepth(build_nidx), {}}); + nodes_to_sub.push_back(CPUExpandEntry{subtract_nidx, p_tree->GetDepth(subtract_nidx), {}}); + } + + size_t i = 0; + auto space = this->ConstructHistSpace(nodes_to_build); + for (auto const &page : + p_fmat->GetBatches({GenericParameter::kCpuId, param_.max_bin, hess})) { + histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(), + nodes_to_build, nodes_to_sub, gpair); + i++; + } + auto histograms = histogram_builder_.Histogram(); + } + + public: + explicit GloablApproxBuilder(TrainParam param, MetaInfo const &info, GenericParameter const *ctx, + std::shared_ptr column_sampler, Task task, + common::Monitor *monitor) + : param_{std::move(param)}, + col_sampler_{std::move(column_sampler)}, + evaluator_{param_, info, ctx->Threads(), col_sampler_, task}, + ctx_{ctx}, + monitor_{monitor} {} + + void UpdateTree(RegTree *p_tree, MetaInfo const &info, std::vector const &gpair, + common::Span hess, DMatrix *p_fmat) { + p_last_tree_ = p_tree; + this->InitData(p_fmat, hess); + + Driver driver(static_cast(param_.grow_policy)); + auto &tree = *p_tree; + driver.Push({this->InitRoot(p_fmat, gpair, hess, p_tree)}); + bst_node_t num_leaves = 1; + auto expand_set = driver.Pop(); + + while (!expand_set.empty()) { + // candidates that can be further splited. + std::vector valid_candidates; + // candidaates that can be applied. + std::vector applied; + for (auto const &candidate : expand_set) { + if (!candidate.IsValid(param_, num_leaves)) { + continue; + } + evaluator_.ApplyTreeSplit(candidate, p_tree); + applied.push_back(candidate); + num_leaves++; + int left_child_nidx = tree[candidate.nid].LeftChild(); + if (CPUExpandEntry::ChildIsValid(param_, p_tree->GetDepth(left_child_nidx), num_leaves)) { + valid_candidates.emplace_back(candidate); + } + } + size_t i = 0; + for (auto const &page : + p_fmat->GetBatches({GenericParameter::kCpuId, param_.max_bin, hess})) { + partitioner_.at(i).UpdatePosition(ctx_, page, applied, p_tree); + i++; + } + + std::vector best_splits; + if (!valid_candidates.empty()) { + this->BuildHistogram(p_fmat, p_tree, valid_candidates, gpair, hess); + for (auto const &candidate : valid_candidates) { + int left_child_nidx = tree[candidate.nid].LeftChild(); + int right_child_nidx = tree[candidate.nid].RightChild(); + CPUExpandEntry l_best{left_child_nidx, tree.GetDepth(left_child_nidx), {}}; + CPUExpandEntry r_best{right_child_nidx, tree.GetDepth(right_child_nidx), {}}; + best_splits.push_back(l_best); + best_splits.push_back(r_best); + } + auto const &histograms = histogram_builder_.Histogram(); + auto ft = p_fmat->Info().feature_types.ConstHostSpan(); + evaluator_.EvaluateSplits(histograms, feature_values_, ft, *p_tree, &best_splits); + } + driver.Push(best_splits.begin(), best_splits.end()); + expand_set = driver.Pop(); + } + } +}; + +class GlobalApproxUpdater : public TreeUpdater { + TrainParam param_; + common::Monitor monitor_; + CPUHistMakerTrainParam hist_param_; + + std::unique_ptr> f32_impl_; + std::unique_ptr> f64_impl_; + DMatrix *cached_{nullptr}; + std::shared_ptr column_sampler_ = + std::make_shared(); + Task task_; + + public: + explicit GlobalApproxUpdater(Task task) : task_{task} { monitor_.Init(__func__); } + + void Configure(const Args &args) override { + param_.UpdateAllowUnknown(args); + hist_param_.UpdateAllowUnknown(args); + } + void LoadConfig(Json const &in) override { + auto const &config = get(in); + FromJson(config.at("train_param"), &this->param_); + FromJson(config.at("hist_param"), &this->hist_param_); + } + void SaveConfig(Json *p_out) const override { + auto &out = *p_out; + out["train_param"] = ToJson(param_); + out["hist_param"] = ToJson(hist_param_); + } + + void InitData(TrainParam const ¶m, HostDeviceVector *gpair, + std::vector *sampled) { + auto const &h_gpair = gpair->HostVector(); + sampled->resize(h_gpair.size()); + std::copy(h_gpair.cbegin(), h_gpair.cend(), sampled->begin()); + auto &rnd = common::GlobalRandom(); + if (param.subsample != 1.0) { + CHECK(param.sampling_method != TrainParam::kGradientBased) + << "Gradient based sampling is not supported for approx tree method."; + std::bernoulli_distribution coin_flip(param.subsample); + std::transform(sampled->begin(), sampled->end(), sampled->begin(), [&](GradientPair &g) { + if (coin_flip(rnd)) { + return g; + } else { + return GradientPair{}; + } + }); + } + } + + char const *Name() const override { return "grow_global_approx_histmaker"; } + + void Update(HostDeviceVector *gpair, DMatrix *m, + const std::vector &trees) override { + float lr = param_.learning_rate; + param_.learning_rate = lr / trees.size(); + + if (hist_param_.single_precision_histogram) { + f32_impl_ = std::make_unique>(param_, m->Info(), tparam_, + column_sampler_, task_, &monitor_); + } else { + f64_impl_ = std::make_unique>(param_, m->Info(), tparam_, + column_sampler_, task_, &monitor_); + } + + std::vector h_gpair; + InitData(param_, gpair, &h_gpair); + std::vector hess(h_gpair.size()); + std::transform(h_gpair.begin(), h_gpair.end(), hess.begin(), + [](auto g) { return g.GetHess(); }); + + cached_ = m; + auto const &info = m->Info(); + + for (auto p_tree : trees) { + if (hist_param_.single_precision_histogram) { + this->f32_impl_->UpdateTree(p_tree, info, h_gpair, hess, m); + } else { + this->f64_impl_->UpdateTree(p_tree, info, h_gpair, hess, m); + } + } + param_.learning_rate = lr; + } + + bool UpdatePredictionCache(const DMatrix *data, VectorView out_preds) override { + if (data != cached_) { + return false; + } + + if (hist_param_.single_precision_histogram) { + this->f32_impl_->UpdatePredictionCache(data, out_preds); + } else { + this->f64_impl_->UpdatePredictionCache(data, out_preds); + } + return true; + } +}; + +DMLC_REGISTRY_FILE_TAG(grow_global_approx_histmaker); + +XGBOOST_REGISTER_TREE_UPDATER(GlobalHistMaker, "grow_global_approx_histmaker") + .describe( + "Tree constructor that uses approximate histogram construction " + "for each node.") + .set_body([](Task task) { return new GlobalApproxUpdater(task); }); +} // namespace tree +} // namespace xgboost diff --git a/src/tree/updater_approx.h b/src/tree/updater_approx.h new file mode 100644 index 000000000000..1cf54f0dea07 --- /dev/null +++ b/src/tree/updater_approx.h @@ -0,0 +1,146 @@ +/*! + * Copyright 2021 XGBoost contributors + * + * \brief Implementation for the approx tree method. + */ +#ifndef XGBOOST_TREE_UPDATER_APPROX_H_ +#define XGBOOST_TREE_UPDATER_APPROX_H_ + +#include +#include +#include + +#include "../common/partition_builder.h" +#include "../common/random.h" +#include "constraints.h" +#include "driver.h" +#include "hist/evaluate_splits.h" +#include "hist/expand_entry.h" +#include "hist/param.h" +#include "param.h" +#include "xgboost/json.h" +#include "xgboost/tree_updater.h" + +namespace xgboost { +namespace tree { +class ApproxRowPartitioner { + static constexpr size_t kPartitionBlockSize = 2048; + common::PartitionBuilder partition_builder_; + common::RowSetCollection row_set_collection_; + + public: + bst_row_t base_rowid = 0; + + static auto SearchCutValue(bst_row_t ridx, bst_feature_t fidx, GHistIndexMatrix const &index, + std::vector const &cut_ptrs, + std::vector const &cut_values) { + int32_t gidx = -1; + auto const &row_ptr = index.row_ptr; + auto get_row_ptr = [&](size_t ridx) { return row_ptr.at(ridx - index.base_rowid); }; + + if (index.IsDense()) { + gidx = index.index[get_row_ptr(ridx) + fidx]; + } else { + auto begin = get_row_ptr(ridx); + auto end = get_row_ptr(ridx + 1); + auto f_begin = cut_ptrs[fidx]; + auto f_end = cut_ptrs[fidx + 1]; + gidx = common::BinarySearchBin(begin, end, index.index, f_begin, f_end); + } + if (gidx == -1) { + return std::numeric_limits::quiet_NaN(); + } + return cut_values[gidx]; + } + + public: + void UpdatePosition(GenericParameter const *ctx, GHistIndexMatrix const &index, + std::vector const &candidates, RegTree const *p_tree) { + size_t n_nodes = candidates.size(); + + auto const &cut_values = index.cut.Values(); + auto const &cut_ptrs = index.cut.Ptrs(); + + common::BlockedSpace2d space{n_nodes, + [&](size_t node_in_set) { + auto candidate = candidates[node_in_set]; + int32_t nid = candidate.nid; + return row_set_collection_[nid].Size(); + }, + kPartitionBlockSize}; + partition_builder_.Init(space.Size(), n_nodes, [&](size_t node_in_set) { + auto candidate = candidates[node_in_set]; + const int32_t nid = candidate.nid; + const size_t size = row_set_collection_[nid].Size(); + const size_t n_tasks = size / kPartitionBlockSize + !!(size % kPartitionBlockSize); + return n_tasks; + }); + auto node_ptr = p_tree->GetCategoriesMatrix().node_ptr; + auto categories = p_tree->GetCategoriesMatrix().categories; + common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) { + auto candidate = candidates[node_in_set]; + auto is_cat = candidate.split.is_cat; + const int32_t nid = candidate.nid; + auto fidx = candidate.split.SplitIndex(); + const size_t task_id = partition_builder_.GetTaskIdx(node_in_set, r.begin()); + partition_builder_.AllocateForTask(task_id); + partition_builder_.PartitionRange( + node_in_set, nid, r, fidx, &row_set_collection_, [&](size_t row_id) { + auto cut_value = SearchCutValue(row_id, fidx, index, cut_ptrs, cut_values); + if (std::isnan(cut_value)) { + return candidate.split.DefaultLeft(); + } + bst_node_t nidx = candidate.nid; + auto segment = node_ptr[nidx]; + auto node_cats = categories.subspan(segment.beg, segment.size); + bool go_left = true; + if (is_cat) { + go_left = common::Decision(node_cats, common::AsCat(cut_value)); + } else { + go_left = cut_value <= candidate.split.split_value; + } + return go_left; + }); + }); + + partition_builder_.CalculateRowOffsets(); + common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) { + auto candidate = candidates[node_in_set]; + const int32_t nid = candidate.nid; + partition_builder_.MergeToArray(node_in_set, r.begin(), + const_cast(row_set_collection_[nid].begin)); + }); + for (size_t i = 0; i < candidates.size(); ++i) { + auto const &candidate = candidates[i]; + auto nidx = candidate.nid; + auto n_left = partition_builder_.GetNLeftElems(i); + auto n_right = partition_builder_.GetNRightElems(i); + CHECK_EQ(n_left + n_right, row_set_collection_[nidx].Size()); + bst_node_t left_nidx = (*p_tree)[nidx].LeftChild(); + bst_node_t right_nidx = (*p_tree)[nidx].RightChild(); + row_set_collection_.AddSplit(nidx, left_nidx, right_nidx, n_left, n_right); + } + } + + auto const &Partitions() const { return row_set_collection_; } + + auto operator[](bst_node_t nidx) { return row_set_collection_[nidx]; } + auto const &operator[](bst_node_t nidx) const { return row_set_collection_[nidx]; } + + size_t Size() const { + return std::distance(row_set_collection_.begin(), row_set_collection_.end()); + } + + ApproxRowPartitioner() = default; + explicit ApproxRowPartitioner(bst_row_t num_row, bst_row_t _base_rowid) + : base_rowid{_base_rowid} { + row_set_collection_.Clear(); + auto p_positions = row_set_collection_.Data(); + p_positions->resize(num_row); + std::iota(p_positions->begin(), p_positions->end(), base_rowid); + row_set_collection_.Init(); + } +}; +} // namespace tree +} // namespace xgboost +#endif // XGBOOST_TREE_UPDATER_APPROX_H_ diff --git a/src/tree/updater_colmaker.cc b/src/tree/updater_colmaker.cc index 952a60f0fcb1..9d5584261690 100644 --- a/src/tree/updater_colmaker.cc +++ b/src/tree/updater_colmaker.cc @@ -336,10 +336,10 @@ class ColMaker: public TreeUpdater { bst_float proposed_split = (fvalue + e.last_fvalue) * 0.5f; if ( proposed_split == fvalue ) { e.best.Update(loss_chg, fid, e.last_fvalue, - d_step == -1, c, e.stats); + d_step == -1, false, c, e.stats); } else { e.best.Update(loss_chg, fid, proposed_split, - d_step == -1, c, e.stats); + d_step == -1, false, c, e.stats); } } else { loss_chg = static_cast( @@ -348,10 +348,10 @@ class ColMaker: public TreeUpdater { bst_float proposed_split = (fvalue + e.last_fvalue) * 0.5f; if ( proposed_split == fvalue ) { e.best.Update(loss_chg, fid, e.last_fvalue, - d_step == -1, e.stats, c); + d_step == -1, false, e.stats, c); } else { e.best.Update(loss_chg, fid, proposed_split, - d_step == -1, e.stats, c); + d_step == -1, false, e.stats, c); } } } @@ -430,14 +430,14 @@ class ColMaker: public TreeUpdater { loss_chg = static_cast( evaluator.CalcSplitGain(param_, nid, fid, c, e.stats) - snode_[nid].root_gain); - e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1, c, - e.stats); + e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1, + false, c, e.stats); } else { loss_chg = static_cast( evaluator.CalcSplitGain(param_, nid, fid, e.stats, c) - snode_[nid].root_gain); e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1, - e.stats, c); + false, e.stats, c); } } } @@ -628,7 +628,7 @@ class ColMaker: public TreeUpdater { XGBOOST_REGISTER_TREE_UPDATER(ColMaker, "grow_colmaker") .describe("Grow tree with parallelization over columns.") -.set_body([]() { +.set_body([](Task) { return new ColMaker(); }); } // namespace tree diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index cbe63d243da4..19eaffa51ea5 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -696,7 +696,7 @@ struct GPUHistMakerDevice { int right_child_nidx = tree[candidate.nid].RightChild(); // Only create child entries if needed if (GPUExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx), - num_leaves)) { + num_leaves)) { monitor.Start("UpdatePosition"); this->UpdatePosition(candidate.nid, p_tree); monitor.Stop("UpdatePosition"); @@ -730,7 +730,7 @@ struct GPUHistMakerDevice { template class GPUHistMakerSpecialised { public: - GPUHistMakerSpecialised() = default; + explicit GPUHistMakerSpecialised(Task task) : task_{task} {}; void Configure(const Args& args, GenericParameter const* generic_param) { param_.UpdateAllowUnknown(args); generic_param_ = generic_param; @@ -857,12 +857,14 @@ class GPUHistMakerSpecialised { DMatrix* p_last_fmat_ { nullptr }; int device_{-1}; + Task task_; common::Monitor monitor_; }; class GPUHistMaker : public TreeUpdater { public: + explicit GPUHistMaker(Task task) : task_{task} {} void Configure(const Args& args) override { // Used in test to count how many configurations are performed LOG(DEBUG) << "[GPU Hist]: Configure"; @@ -876,11 +878,11 @@ class GPUHistMaker : public TreeUpdater { param = double_maker_->param_; } if (hist_maker_param_.single_precision_histogram) { - float_maker_.reset(new GPUHistMakerSpecialised()); + float_maker_.reset(new GPUHistMakerSpecialised(task_)); float_maker_->param_ = param; float_maker_->Configure(args, tparam_); } else { - double_maker_.reset(new GPUHistMakerSpecialised()); + double_maker_.reset(new GPUHistMakerSpecialised(task_)); double_maker_->param_ = param; double_maker_->Configure(args, tparam_); } @@ -890,10 +892,10 @@ class GPUHistMaker : public TreeUpdater { auto const& config = get(in); FromJson(config.at("gpu_hist_train_param"), &this->hist_maker_param_); if (hist_maker_param_.single_precision_histogram) { - float_maker_.reset(new GPUHistMakerSpecialised()); + float_maker_.reset(new GPUHistMakerSpecialised(task_)); FromJson(config.at("train_param"), &float_maker_->param_); } else { - double_maker_.reset(new GPUHistMakerSpecialised()); + double_maker_.reset(new GPUHistMakerSpecialised(task_)); FromJson(config.at("train_param"), &double_maker_->param_); } } @@ -931,6 +933,7 @@ class GPUHistMaker : public TreeUpdater { private: GPUHistMakerTrainParam hist_maker_param_; + Task task_; std::unique_ptr> float_maker_; std::unique_ptr> double_maker_; }; @@ -938,7 +941,7 @@ class GPUHistMaker : public TreeUpdater { #if !defined(GTEST_TEST) XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist") .describe("Grow tree with GPU.") - .set_body([]() { return new GPUHistMaker(); }); + .set_body([](Task task) { return new GPUHistMaker(task); }); #endif // !defined(GTEST_TEST) } // namespace tree diff --git a/src/tree/updater_histmaker.cc b/src/tree/updater_histmaker.cc index 1c086b69a8e9..3a9534a5ad38 100644 --- a/src/tree/updater_histmaker.cc +++ b/src/tree/updater_histmaker.cc @@ -173,7 +173,8 @@ class HistMaker: public BaseMaker { if (c.sum_hess >= param_.min_child_weight) { double loss_chg = CalcGain(param_, s.GetGrad(), s.GetHess()) + CalcGain(param_, c.GetGrad(), c.GetHess()) - root_gain; - if (best->Update(static_cast(loss_chg), fid, hist.cut[i], false, s, c)) { + if (best->Update(static_cast(loss_chg), fid, hist.cut[i], + false, false, s, c)) { *left_sum = s; } } @@ -187,7 +188,8 @@ class HistMaker: public BaseMaker { if (c.sum_hess >= param_.min_child_weight) { double loss_chg = CalcGain(param_, s.GetGrad(), s.GetHess()) + CalcGain(param_, c.GetGrad(), c.GetHess()) - root_gain; - if (best->Update(static_cast(loss_chg), fid, hist.cut[i-1], true, c, s)) { + if (best->Update(static_cast(loss_chg), fid, + hist.cut[i - 1], true, false, c, s)) { *left_sum = c; } } @@ -750,14 +752,14 @@ class GlobalProposalHistMaker: public CQHistMaker { XGBOOST_REGISTER_TREE_UPDATER(LocalHistMaker, "grow_local_histmaker") .describe("Tree constructor that uses approximate histogram construction.") -.set_body([]() { +.set_body([](Task) { return new CQHistMaker(); }); // The updater for approx tree method. XGBOOST_REGISTER_TREE_UPDATER(HistMaker, "grow_histmaker") .describe("Tree constructor that uses approximate global of histogram construction.") -.set_body([]() { +.set_body([](Task) { return new GlobalProposalHistMaker(); }); } // namespace tree diff --git a/src/tree/updater_prune.cc b/src/tree/updater_prune.cc index 76a8916a0598..884545a0a115 100644 --- a/src/tree/updater_prune.cc +++ b/src/tree/updater_prune.cc @@ -23,8 +23,8 @@ DMLC_REGISTRY_FILE_TAG(updater_prune); /*! \brief pruner that prunes a tree after growing finishes */ class TreePruner: public TreeUpdater { public: - TreePruner() { - syncher_.reset(TreeUpdater::Create("sync", tparam_)); + explicit TreePruner(Task task) { + syncher_.reset(TreeUpdater::Create("sync", tparam_, task)); pruner_monitor_.Init("TreePruner"); } char const* Name() const override { @@ -113,8 +113,8 @@ class TreePruner: public TreeUpdater { XGBOOST_REGISTER_TREE_UPDATER(TreePruner, "prune") .describe("Pruner that prune the tree according to statistics.") -.set_body([]() { - return new TreePruner(); +.set_body([](Task task) { + return new TreePruner(task); }); } // namespace tree } // namespace xgboost diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 19c300b30672..e5d598604087 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -40,7 +40,7 @@ DMLC_REGISTER_PARAMETER(CPUHistMakerTrainParam); void QuantileHistMaker::Configure(const Args& args) { // initialize pruner if (!pruner_) { - pruner_.reset(TreeUpdater::Create("prune", tparam_)); + pruner_.reset(TreeUpdater::Create("prune", tparam_, task_)); } pruner_->Configure(args); param_.UpdateAllowUnknown(args); @@ -52,7 +52,7 @@ void QuantileHistMaker::SetBuilder(const size_t n_trees, std::unique_ptr>* builder, DMatrix *dmat) { builder->reset( - new Builder(n_trees, param_, std::move(pruner_), dmat)); + new Builder(n_trees, param_, std::move(pruner_), dmat, task_)); } template @@ -130,9 +130,14 @@ void QuantileHistMaker::Builder::InitRoot( nodes_for_subtraction_trick_.clear(); nodes_for_explicit_hist_build_.push_back(node); - this->histogram_builder_->BuildHist(p_fmat, p_tree, row_set_collection_, - nodes_for_explicit_hist_build_, - nodes_for_subtraction_trick_, gpair_h); + size_t page_id = 0; + for (auto const &gidx : p_fmat->GetBatches( + {GenericParameter::kCpuId, param_.max_bin})) { + this->histogram_builder_->BuildHist( + page_id, gidx, p_tree, row_set_collection_, + nodes_for_explicit_hist_build_, nodes_for_subtraction_trick_, gpair_h); + ++page_id; + } { auto nid = RegTree::kRoot; @@ -168,9 +173,11 @@ void QuantileHistMaker::Builder::InitRoot( std::vector entries{node}; builder_monitor_.Start("EvaluateSplits"); + auto ft = p_fmat->Info().feature_types.ConstHostSpan(); for (auto const &gmat : p_fmat->GetBatches( BatchParam{GenericParameter::kCpuId, param_.max_bin})) { - evaluator_->EvaluateSplits(histogram_builder_->Histogram(), gmat.cut, *p_tree, &entries); + evaluator_->EvaluateSplits(histogram_builder_->Histogram(), gmat.cut, ft, + *p_tree, &entries); break; } builder_monitor_.Stop("EvaluateSplits"); @@ -260,9 +267,15 @@ void QuantileHistMaker::Builder::ExpandTree( SplitSiblings(nodes_for_apply_split, &nodes_to_evaluate, p_tree); if (depth < param_.max_depth) { - this->histogram_builder_->BuildHist( - p_fmat, p_tree, row_set_collection_, nodes_for_explicit_hist_build_, - nodes_for_subtraction_trick_, gpair_h); + size_t i = 0; + for (auto const &gidx : p_fmat->GetBatches( + {GenericParameter::kCpuId, param_.max_bin})) { + this->histogram_builder_->BuildHist( + i, gidx, p_tree, row_set_collection_, + nodes_for_explicit_hist_build_, nodes_for_subtraction_trick_, + gpair_h); + ++i; + } } else { int starting_index = std::numeric_limits::max(); int sync_count = 0; @@ -272,8 +285,9 @@ void QuantileHistMaker::Builder::ExpandTree( } builder_monitor_.Start("EvaluateSplits"); - evaluator_->EvaluateSplits(this->histogram_builder_->Histogram(), gmat.cut, - *p_tree, &nodes_to_evaluate); + auto ft = p_fmat->Info().feature_types.ConstHostSpan(); + evaluator_->EvaluateSplits(this->histogram_builder_->Histogram(), + gmat.cut, ft, *p_tree, &nodes_to_evaluate); builder_monitor_.Stop("EvaluateSplits"); for (size_t i = 0; i < nodes_for_apply_split.size(); ++i) { @@ -432,7 +446,9 @@ void QuantileHistMaker::Builder::InitData( }); } exc.Rethrow(); - this->histogram_builder_->Reset(nbins, param_.max_bin, this->nthread_); + this->histogram_builder_->Reset( + nbins, BatchParam{GenericParameter::kCpuId, param_.max_bin}, + this->nthread_, 1, rabit::IsDistributed()); std::vector& row_indices = *row_set_collection_.Data(); row_indices.resize(info.num_row_); @@ -530,10 +546,10 @@ void QuantileHistMaker::Builder::InitData( p_last_tree_ = &tree; if (data_layout_ == DataLayout::kDenseDataOneBased) { evaluator_.reset(new HistEvaluator{ - param_, info, this->nthread_, column_sampler_, true}); + param_, info, this->nthread_, column_sampler_, task_, true}); } else { evaluator_.reset(new HistEvaluator{ - param_, info, this->nthread_, column_sampler_, false}); + param_, info, this->nthread_, column_sampler_, task_, false}); } if (data_layout_ == DataLayout::kDenseDataZeroBased @@ -677,17 +693,17 @@ XGBOOST_REGISTER_TREE_UPDATER(FastHistMaker, "grow_fast_histmaker") .describe("(Deprecated, use grow_quantile_histmaker instead.)" " Grow tree using quantized histogram.") .set_body( - []() { + [](Task task) { LOG(WARNING) << "grow_fast_histmaker is deprecated, " << "use grow_quantile_histmaker instead."; - return new QuantileHistMaker(); + return new QuantileHistMaker(task); }); XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker") .describe("Grow tree using quantized histogram.") .set_body( - []() { - return new QuantileHistMaker(); + [](Task task) { + return new QuantileHistMaker(task); }); } // namespace tree } // namespace xgboost diff --git a/src/tree/updater_quantile_hist.h b/src/tree/updater_quantile_hist.h index 9654ab00a7c0..8a68312c68b4 100644 --- a/src/tree/updater_quantile_hist.h +++ b/src/tree/updater_quantile_hist.h @@ -95,7 +95,7 @@ using xgboost::common::Column; /*! \brief construct a tree using quantized feature values */ class QuantileHistMaker: public TreeUpdater { public: - QuantileHistMaker() { + explicit QuantileHistMaker(Task task) { updater_monitor_.Init("QuantileHistMaker"); } void Configure(const Args& args) override; @@ -154,12 +154,15 @@ class QuantileHistMaker: public TreeUpdater { using GHistRowT = GHistRow; using GradientPairT = xgboost::detail::GradientPairInternal; // constructor - explicit Builder(const size_t n_trees, const TrainParam ¶m, - std::unique_ptr pruner, DMatrix const *fmat) - : n_trees_(n_trees), param_(param), pruner_(std::move(pruner)), - p_last_tree_(nullptr), p_last_fmat_(fmat), - histogram_builder_{ - new HistogramBuilder} { + explicit Builder(const size_t n_trees, const TrainParam& param, + std::unique_ptr pruner, DMatrix const* fmat, Task task) + : n_trees_(n_trees), + param_(param), + pruner_(std::move(pruner)), + p_last_tree_(nullptr), + p_last_fmat_(fmat), + histogram_builder_{new HistogramBuilder}, + task_{task} { builder_monitor_.Init("Quantile::Builder"); } ~Builder(); @@ -261,6 +264,7 @@ class QuantileHistMaker: public TreeUpdater { DataLayout data_layout_; std::unique_ptr> histogram_builder_; + Task task_; common::Monitor builder_monitor_; }; @@ -281,6 +285,7 @@ class QuantileHistMaker: public TreeUpdater { std::unique_ptr> double_builder_; std::unique_ptr pruner_; + Task task_; }; } // namespace tree } // namespace xgboost diff --git a/src/tree/updater_refresh.cc b/src/tree/updater_refresh.cc index 1d54ad9e3d9b..520aad70b0d9 100644 --- a/src/tree/updater_refresh.cc +++ b/src/tree/updater_refresh.cc @@ -161,7 +161,7 @@ class TreeRefresher: public TreeUpdater { XGBOOST_REGISTER_TREE_UPDATER(TreeRefresher, "refresh") .describe("Refresher that refreshes the weight and statistics according to data.") -.set_body([]() { +.set_body([](Task) { return new TreeRefresher(); }); } // namespace tree diff --git a/src/tree/updater_sync.cc b/src/tree/updater_sync.cc index 7979d10c21bf..b01d9203a48c 100644 --- a/src/tree/updater_sync.cc +++ b/src/tree/updater_sync.cc @@ -53,7 +53,7 @@ class TreeSyncher: public TreeUpdater { XGBOOST_REGISTER_TREE_UPDATER(TreeSyncher, "sync") .describe("Syncher that synchronize the tree in all distributed nodes.") -.set_body([]() { +.set_body([](Task) { return new TreeSyncher(); }); } // namespace tree diff --git a/tests/cpp/categorical_helpers.h b/tests/cpp/categorical_helpers.h new file mode 100644 index 000000000000..f4470a6c910e --- /dev/null +++ b/tests/cpp/categorical_helpers.h @@ -0,0 +1,44 @@ +/*! + * Copyright 2021 by XGBoost Contributors + * + * \brief Utilities for testing categorical data support. + */ +#include +#include + +#include "xgboost/span.h" +#include "helpers.h" +#include "../../src/common/categorical.h" + +namespace xgboost { +inline std::vector OneHotEncodeFeature(std::vector x, + size_t num_cat) { + std::vector ret(x.size() * num_cat, 0); + size_t n_rows = x.size(); + for (size_t r = 0; r < n_rows; ++r) { + bst_cat_t cat = common::AsCat(x[r]); + ret.at(num_cat * r + cat) = 1; + } + return ret; +} + +template +void ValidateCategoricalHistogram(size_t n_categories, + common::Span onehot, + common::Span cat) { + auto cat_sum = std::accumulate(cat.cbegin(), cat.cend(), GradientPairPrecise{}); + for (size_t c = 0; c < n_categories; ++c) { + auto zero = onehot[c * 2]; + auto one = onehot[c * 2 + 1]; + + auto chosen = cat[c]; + auto not_chosen = cat_sum - chosen; + + ASSERT_LE(RelError(zero.GetGrad(), not_chosen.GetGrad()), kRtEps); + ASSERT_LE(RelError(zero.GetHess(), not_chosen.GetHess()), kRtEps); + + ASSERT_LE(RelError(one.GetGrad(), chosen.GetGrad()), kRtEps); + ASSERT_LE(RelError(one.GetHess(), chosen.GetHess()), kRtEps); + } +} +} // namespace xgboost diff --git a/tests/cpp/common/test_quantile.cu b/tests/cpp/common/test_quantile.cu index 586ff762fc2b..c124ab5055e6 100644 --- a/tests/cpp/common/test_quantile.cu +++ b/tests/cpp/common/test_quantile.cu @@ -5,6 +5,13 @@ #include "../../../src/common/quantile.cuh" namespace xgboost { +namespace { +struct IsSorted { + XGBOOST_DEVICE bool operator()(common::SketchEntry const& a, common::SketchEntry const& b) const { + return a.value < b.value; + } +}; +} namespace common { TEST(GPUQuantile, Basic) { constexpr size_t kRows = 1000, kCols = 100, kBins = 256; @@ -52,9 +59,15 @@ void TestSketchUnique(float sparsity) { ASSERT_EQ(sketch.Data().size(), h_columns_ptr.back()); sketch.Unique(); - ASSERT_TRUE(thrust::is_sorted(thrust::device, sketch.Data().data(), - sketch.Data().data() + sketch.Data().size(), - detail::SketchUnique{})); + + std::vector h_data(sketch.Data().size()); + thrust::copy(dh::tcbegin(sketch.Data()), dh::tcend(sketch.Data()), h_data.begin()); + + for (size_t i = 1; i < h_columns_ptr.size(); ++i) { + auto begin = h_columns_ptr[i - 1]; + auto column = common::Span(h_data).subspan(begin, h_columns_ptr[i] - begin); + ASSERT_TRUE(std::is_sorted(column.begin(), column.end(), IsSorted{})); + } }); } @@ -84,8 +97,7 @@ void TestQuantileElemRank(int32_t device, Span in, if (with_error) { ASSERT_GE(in_column[idx].rmin + in_column[idx].rmin * kRtEps, prev_rmin); - ASSERT_GE(in_column[idx].rmax + in_column[idx].rmin * kRtEps, - prev_rmax); + ASSERT_GE(in_column[idx].rmax + in_column[idx].rmin * kRtEps, prev_rmax); ASSERT_GE(in_column[idx].rmax + in_column[idx].rmin * kRtEps, rmin_next); } else { @@ -169,7 +181,7 @@ TEST(GPUQuantile, MergeEmpty) { TEST(GPUQuantile, MergeBasic) { constexpr size_t kRows = 1000, kCols = 100; - RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) { + RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const &info) { HostDeviceVector ft; SketchContainer sketch_0(ft, n_bins, kCols, kRows, 0); HostDeviceVector storage_0; @@ -265,9 +277,16 @@ void TestMergeDuplicated(int32_t n_bins, size_t cols, size_t rows, float frac) { ASSERT_EQ(h_columns_ptr.back(), sketch_1.Data().size() + size_before_merge); sketch_0.Unique(); - ASSERT_TRUE(thrust::is_sorted(thrust::device, sketch_0.Data().data(), - sketch_0.Data().data() + sketch_0.Data().size(), - detail::SketchUnique{})); + columns_ptr = sketch_0.ColumnsPtr(); + dh::CopyDeviceSpanToVector(&h_columns_ptr, columns_ptr); + + std::vector h_data(sketch_0.Data().size()); + dh::CopyDeviceSpanToVector(&h_data, sketch_0.Data()); + for (size_t i = 1; i < h_columns_ptr.size(); ++i) { + auto begin = h_columns_ptr[i - 1]; + auto column = Span {h_data}.subspan(begin, h_columns_ptr[i] - begin); + ASSERT_TRUE(std::is_sorted(column.begin(), column.end(), IsSorted{})); + } } TEST(GPUQuantile, MergeDuplicated) { diff --git a/tests/cpp/common/test_quantile.h b/tests/cpp/common/test_quantile.h index 083766cfbb5a..8118248dc939 100644 --- a/tests/cpp/common/test_quantile.h +++ b/tests/cpp/common/test_quantile.h @@ -48,7 +48,9 @@ template void RunWithSeedsAndBins(size_t rows, Fn fn) { std::vector infos(2); auto& h_weights = infos.front().weights_.HostVector(); h_weights.resize(rows); - std::generate(h_weights.begin(), h_weights.end(), [&]() { return dist(&lcg); }); + + SimpleRealUniformDistribution weight_dist(0, 10); + std::generate(h_weights.begin(), h_weights.end(), [&]() { return weight_dist(&lcg); }); for (auto seed : seeds) { for (auto n_bin : bins) { diff --git a/tests/cpp/common/test_span.cc b/tests/cpp/common/test_span.cc index 3ee99c0ae465..bd73b839cf9c 100644 --- a/tests/cpp/common/test_span.cc +++ b/tests/cpp/common/test_span.cc @@ -277,6 +277,9 @@ TEST(Span, RBeginREnd) { int status = 1; TestRBeginREnd{&status}(); ASSERT_EQ(status, 1); + + std::vector idx(10); + Span span{idx}; } TEST(Span, ElementAccess) { diff --git a/tests/cpp/data/test_gradient_index.cc b/tests/cpp/data/test_gradient_index.cc index 2c19b9e58c9b..2dcb5ed1a40f 100644 --- a/tests/cpp/data/test_gradient_index.cc +++ b/tests/cpp/data/test_gradient_index.cc @@ -23,5 +23,39 @@ TEST(GradientIndex, ExternalMemory) { ++i; } } + +TEST(GradientIndex, FromCategoricalBasic) { + size_t constexpr kRows = 1000, kCats = 13, kCols = 1; + size_t max_bins = 8; + auto x = GenerateRandomCategoricalSingleColumn(kRows, kCats); + auto m = GetDMatrixFromData(x, kRows, 1); + + auto &h_ft = m->Info().feature_types.HostVector(); + h_ft.resize(kCols, FeatureType::kCategorical); + + BatchParam p(0, max_bins); + GHistIndexMatrix gidx; + + gidx.Init(m.get(), max_bins, {}); + + auto x_copy = x; + std::sort(x_copy.begin(), x_copy.end()); + auto n_uniques = std::unique(x_copy.begin(), x_copy.end()) - x_copy.begin(); + ASSERT_EQ(n_uniques, kCats); + + auto const &h_cut_ptr = gidx.cut.Ptrs(); + auto const &h_cut_values = gidx.cut.Values(); + + ASSERT_EQ(h_cut_ptr.size(), 2); + ASSERT_EQ(h_cut_values.size(), kCats); + + auto const &index = gidx.index; + + for (size_t i = 0; i < x.size(); ++i) { + auto bin = index[i]; + auto bin_value = h_cut_values.at(bin); + ASSERT_EQ(common::AsCat(x[i]), common::AsCat(bin_value)); + } +} } // namespace data } // namespace xgboost diff --git a/tests/cpp/gbm/test_gbtree.cc b/tests/cpp/gbm/test_gbtree.cc index 9255bf2c32dc..00fa56278a0c 100644 --- a/tests/cpp/gbm/test_gbtree.cc +++ b/tests/cpp/gbm/test_gbtree.cc @@ -35,7 +35,7 @@ TEST(GBTree, SelectTreeMethod) { gbtree.Configure(args); auto const& tparam = gbtree.GetTrainParam(); gbtree.Configure({{"tree_method", "approx"}}); - ASSERT_EQ(tparam.updater_seq, "grow_histmaker,prune"); + ASSERT_EQ(tparam.updater_seq, "grow_global_approx_histmaker"); gbtree.Configure({{"tree_method", "exact"}}); ASSERT_EQ(tparam.updater_seq, "grow_colmaker,prune"); gbtree.Configure({{"tree_method", "hist"}}); diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index 1e4731454d9b..a3ca415c3113 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -172,12 +172,10 @@ SimpleLCG::StateType SimpleLCG::operator()() { state_ = (alpha_ * state_) % mod_; return state_; } -SimpleLCG::StateType SimpleLCG::Min() const { - return seed_ * alpha_; -} -SimpleLCG::StateType SimpleLCG::Max() const { - return max_value_; -} +SimpleLCG::StateType SimpleLCG::Min() const { return min(); } +SimpleLCG::StateType SimpleLCG::Max() const { return max(); } +// Make sure it's compile time constant. +static_assert(SimpleLCG::max() - SimpleLCG::min(), ""); void RandomDataGenerator::GenerateDense(HostDeviceVector *out) const { xgboost::SimpleRealUniformDistribution dist(lower_, upper_); @@ -291,6 +289,7 @@ void RandomDataGenerator::GenerateCSR( xgboost::SimpleRealUniformDistribution dist(lower_, upper_); float sparsity = sparsity_ * (upper_ - lower_) + lower_; + SimpleRealUniformDistribution cat(0.0, max_cat_); h_rptr.emplace_back(0); for (size_t i = 0; i < rows_; ++i) { @@ -298,7 +297,11 @@ void RandomDataGenerator::GenerateCSR( for (size_t j = 0; j < cols_; ++j) { auto g = dist(&lcg); if (g >= sparsity) { - g = dist(&lcg); + if (common::IsCat(ft_, j)) { + g = common::AsCat(cat(&lcg)); + } else { + g = dist(&lcg); + } h_value.emplace_back(g); rptr++; h_cols.emplace_back(j); @@ -347,11 +350,15 @@ RandomDataGenerator::GenerateDMatrix(bool with_label, bool float_label, } if (device_ >= 0) { out->Info().labels_.SetDevice(device_); + out->Info().feature_types.SetDevice(device_); for (auto const& page : out->GetBatches()) { page.data.SetDevice(device_); page.offset.SetDevice(device_); } } + if (!ft_.empty()) { + out->Info().feature_types.HostVector() = ft_; + } return out; } diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index fc9b594c0183..d5084d06d0f2 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -106,42 +106,39 @@ bool IsNear(std::vector::const_iterator _beg1, */ class SimpleLCG { private: - using StateType = int64_t; + using StateType = uint64_t; static StateType constexpr kDefaultInit = 3; - static StateType constexpr default_alpha_ = 61; - static StateType constexpr max_value_ = ((StateType)1 << 32) - 1; + static StateType constexpr kDefaultAlpha = 61; + static StateType constexpr kMaxValue = (static_cast(1) << 32) - 1; StateType state_; StateType const alpha_; StateType const mod_; - StateType seed_; + public: + using result_type = StateType; // NOLINT public: - SimpleLCG() : state_{kDefaultInit}, - alpha_{default_alpha_}, mod_{max_value_}, seed_{state_}{} + SimpleLCG() : state_{kDefaultInit}, alpha_{kDefaultAlpha}, mod_{kMaxValue} {} SimpleLCG(SimpleLCG const& that) = default; SimpleLCG(SimpleLCG&& that) = default; - void Seed(StateType seed) { - seed_ = seed; - } + void Seed(StateType seed) { state_ = seed % mod_; } /*! * \brief Initialize SimpleLCG. * * \param state Initial state, can also be considered as seed. If set to * zero, SimpleLCG will use internal default value. - * \param alpha multiplier - * \param mod modulo */ - explicit SimpleLCG(StateType state, - StateType alpha=default_alpha_, StateType mod=max_value_) - : state_{state == 0 ? kDefaultInit : state}, - alpha_{alpha}, mod_{mod} , seed_{state} {} + explicit SimpleLCG(StateType state) + : state_{state == 0 ? kDefaultInit : state}, alpha_{kDefaultAlpha}, mod_{kMaxValue} {} StateType operator()(); StateType Min() const; StateType Max() const; + + constexpr result_type static min() { return 0; }; // NOLINT + constexpr result_type static max() { return kMaxValue; } // NOLINT }; template @@ -217,10 +214,12 @@ class RandomDataGenerator { float upper_; int32_t device_; - int32_t seed_; + uint64_t seed_; SimpleLCG lcg_; size_t bins_; + std::vector ft_; + bst_cat_t max_cat_; Json ArrayInterfaceImpl(HostDeviceVector *storage, size_t rows, size_t cols) const; @@ -242,7 +241,7 @@ class RandomDataGenerator { device_ = d; return *this; } - RandomDataGenerator& Seed(int32_t s) { + RandomDataGenerator& Seed(uint64_t s) { seed_ = s; lcg_.Seed(seed_); return *this; @@ -251,6 +250,16 @@ class RandomDataGenerator { bins_ = b; return *this; } + RandomDataGenerator& Type(common::Span ft) { + CHECK_EQ(ft.size(), cols_); + ft_.resize(ft.size()); + std::copy(ft.cbegin(), ft.cend(), ft_.begin()); + return *this; + } + RandomDataGenerator& MaxCategory(bst_cat_t cat) { + max_cat_ = cat; + return *this; + } void GenerateDense(HostDeviceVector* out) const; diff --git a/tests/cpp/tree/gpu_hist/test_histogram.cu b/tests/cpp/tree/gpu_hist/test_histogram.cu index 5c659965759b..3b543a48d7cc 100644 --- a/tests/cpp/tree/gpu_hist/test_histogram.cu +++ b/tests/cpp/tree/gpu_hist/test_histogram.cu @@ -1,9 +1,11 @@ #include #include -#include "../../helpers.h" + #include "../../../../src/common/categorical.h" -#include "../../../../src/tree/gpu_hist/row_partitioner.cuh" #include "../../../../src/tree/gpu_hist/histogram.cuh" +#include "../../../../src/tree/gpu_hist/row_partitioner.cuh" +#include "../../categorical_helpers.h" +#include "../../helpers.h" namespace xgboost { namespace tree { @@ -99,16 +101,6 @@ TEST(Histogram, GPUDeterministic) { } } -std::vector OneHotEncodeFeature(std::vector x, size_t num_cat) { - std::vector ret(x.size() * num_cat, 0); - size_t n_rows = x.size(); - for (size_t r = 0; r < n_rows; ++r) { - bst_cat_t cat = common::AsCat(x[r]); - ret.at(num_cat * r + cat) = 1; - } - return ret; -} - // Test 1 vs rest categorical histogram is equivalent to one hot encoded data. void TestGPUHistogramCategorical(size_t num_categories) { size_t constexpr kRows = 340; @@ -123,7 +115,9 @@ void TestGPUHistogramCategorical(size_t num_categories) { auto gpair = GenerateRandomGradients(kRows, 0, 2); gpair.SetDevice(0); auto rounding = CreateRoundingFactor(gpair.DeviceSpan()); - // Generate hist with cat data. + /** + * Generate hist with cat data. + */ for (auto const &batch : cat_m->GetBatches(batch_param)) { auto* page = batch.Impl(); FeatureGroups single_group(page->Cuts()); @@ -133,7 +127,9 @@ void TestGPUHistogramCategorical(size_t num_categories) { rounding); } - // Generate hist with one hot encoded data. + /** + * Generate hist with one hot encoded data. + */ auto x_encoded = OneHotEncodeFeature(x, num_categories); auto encode_m = GetDMatrixFromData(x_encoded, kRows, num_categories); dh::device_vector encode_hist(2 * num_categories); @@ -152,20 +148,9 @@ void TestGPUHistogramCategorical(size_t num_categories) { std::vector h_encode_hist(encode_hist.size()); thrust::copy(encode_hist.begin(), encode_hist.end(), h_encode_hist.begin()); - - for (size_t c = 0; c < num_categories; ++c) { - auto zero = h_encode_hist[c * 2]; - auto one = h_encode_hist[c * 2 + 1]; - - auto chosen = h_cat_hist[c]; - auto not_chosen = cat_sum - chosen; - - ASSERT_LE(RelError(zero.GetGrad(), not_chosen.GetGrad()), kRtEps); - ASSERT_LE(RelError(zero.GetHess(), not_chosen.GetHess()), kRtEps); - - ASSERT_LE(RelError(one.GetGrad(), chosen.GetGrad()), kRtEps); - ASSERT_LE(RelError(one.GetHess(), chosen.GetHess()), kRtEps); - } + ValidateCategoricalHistogram(num_categories, + common::Span{h_encode_hist}, + common::Span{h_cat_hist}); } TEST(Histogram, GPUHistCategorical) { diff --git a/tests/cpp/tree/hist/test_evaluate_splits.cc b/tests/cpp/tree/hist/test_evaluate_splits.cc index cb0171269305..a3e71c7856fb 100644 --- a/tests/cpp/tree/hist/test_evaluate_splits.cc +++ b/tests/cpp/tree/hist/test_evaluate_splits.cc @@ -7,7 +7,6 @@ namespace xgboost { namespace tree { - template void TestEvaluateSplits() { int static constexpr kRows = 8, kCols = 16; auto orig = omp_get_max_threads(); @@ -16,14 +15,12 @@ template void TestEvaluateSplits() { auto sampler = std::make_shared(); TrainParam param; - param.UpdateAllowUnknown(Args{{}}); - param.min_child_weight = 0; - param.reg_lambda = 0; + param.UpdateAllowUnknown(Args{{"min_child_weight", "0"}, {"reg_lambda", "0"}}); auto dmat = RandomDataGenerator(kRows, kCols, 0).Seed(3).GenerateDMatrix(); - auto evaluator = - HistEvaluator{param, dmat->Info(), n_threads, sampler}; + auto evaluator = HistEvaluator{param, dmat->Info(), n_threads, + sampler, Task::kRegression}; common::HistCollection hist; std::vector row_gpairs = { {1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f}, @@ -39,7 +36,7 @@ template void TestEvaluateSplits() { std::iota(row_indices.begin(), row_indices.end(), 0); row_set_collection.Init(); - auto hist_builder = GHistBuilder(n_threads, gmat.cut.Ptrs().back()); + auto hist_builder = GHistBuilder(gmat.cut.Ptrs().back()); hist.Init(gmat.cut.Ptrs().back()); hist.AddHistRow(0); hist.AllocateAllData(); @@ -58,7 +55,7 @@ template void TestEvaluateSplits() { entries.front().depth = 0; evaluator.InitRoot(GradStats{total_gpair}); - evaluator.EvaluateSplits(hist, gmat.cut, tree, &entries); + evaluator.EvaluateSplits(hist, gmat.cut, {}, tree, &entries); auto best_loss_chg = evaluator.Evaluator().CalcSplitGain( @@ -97,7 +94,7 @@ TEST(HistEvaluator, Apply) { auto dmat = RandomDataGenerator(kNRows, kNCols, 0).Seed(3).GenerateDMatrix(); auto sampler = std::make_shared(); auto evaluator_ = - HistEvaluator{param, dmat->Info(), 4, sampler}; + HistEvaluator{param, dmat->Info(), 4, sampler, Task::kRegression}; CPUExpandEntry entry{0, 0, 10.0f}; entry.split.left_sum = GradStats{0.4, 0.6f}; @@ -108,5 +105,142 @@ TEST(HistEvaluator, Apply) { ASSERT_EQ(tree.Stat(tree[0].LeftChild()).sum_hess, 0.6f); ASSERT_EQ(tree.Stat(tree[0].RightChild()).sum_hess, 0.7f); } + +TEST(HistEvaluator, CategoricalPartition) { + int static constexpr kRows = 128, kCols = 1; + using GradientSumT = double; + std::vector ft(kCols, FeatureType::kCategorical); + + TrainParam param; + param.UpdateAllowUnknown(Args{{"min_child_weight", "0"}, {"reg_lambda", "0"}}); + + size_t n_cats{8}; + + auto dmat = + RandomDataGenerator(kRows, kCols, 0).Seed(3).Type(ft).MaxCategory(n_cats).GenerateDMatrix(); + + int32_t n_threads = 16; + auto sampler = std::make_shared(); + auto evaluator = HistEvaluator{param, dmat->Info(), n_threads, + sampler, Task::kRegression}; + + for (auto const &gmat : dmat->GetBatches({GenericParameter::kCpuId, 32})) { + common::HistCollection hist; + + std::vector entries(1); + entries.front().nid = 0; + entries.front().depth = 0; + + hist.Init(gmat.cut.TotalBins()); + hist.AddHistRow(0); + hist.AllocateAllData(); + auto node_hist = hist[0]; + ASSERT_EQ(node_hist.size(), n_cats); + ASSERT_EQ(node_hist.size(), gmat.cut.Ptrs().back()); + + GradientPairPrecise total_gpair; + for (size_t i = 0; i < node_hist.size(); ++i) { + node_hist[i] = {static_cast(node_hist.size() - i), 1.0}; + total_gpair += node_hist[i]; + } + SimpleLCG lcg; + std::shuffle(node_hist.begin(), node_hist.end(), lcg); + + RegTree tree; + evaluator.InitRoot(GradStats{total_gpair}); + evaluator.EvaluateSplits(hist, gmat.cut, ft, tree, &entries); + ASSERT_TRUE(entries.front().split.is_cat); + + auto run_eval = [&](auto fn) { + for (size_t i = 1; i < gmat.cut.Ptrs().size(); ++i) { + GradStats left, right; + for (size_t j = gmat.cut.Ptrs()[i - 1]; j < gmat.cut.Ptrs()[i]; ++j) { + auto loss_chg = evaluator.Evaluator().CalcSplitGain(param, 0, i - 1, left, right) - + evaluator.Stats().front().root_gain; + fn(loss_chg); + left.Add(node_hist[j].GetGrad(), node_hist[j].GetHess()); + right.SetSubstract(GradStats{total_gpair}, left); + } + } + }; + // Assert that's the best split + auto best_loss_chg = entries.front().split.loss_chg; + run_eval([&](auto loss_chg) { + // Approximated test that gain returned by optimal partition is greater than + // numerical split. + ASSERT_GT(best_loss_chg, loss_chg); + }); + // node_hist is captured in lambda. + std::sort(node_hist.begin(), node_hist.end(), [&](auto l, auto r) { + return evaluator.Evaluator().CalcWeightCat(param, l) < + evaluator.Evaluator().CalcWeightCat(param, r); + }); + + double reimpl = 0; + run_eval([&](auto loss_chg) { reimpl = std::max(loss_chg, reimpl); }); + CHECK_EQ(reimpl, best_loss_chg); + } +} + +namespace { +auto CompareOneHotAndPartition(bool onehot) { + int static constexpr kRows = 128, kCols = 1; + using GradientSumT = double; + std::vector ft(kCols, FeatureType::kCategorical); + + TrainParam param; + if (onehot) { + // force use one-hot + param.UpdateAllowUnknown( + Args{{"min_child_weight", "0"}, {"reg_lambda", "0"}, {"max_cat_to_onehot", "100"}}); + } else { + param.UpdateAllowUnknown( + Args{{"min_child_weight", "0"}, {"reg_lambda", "0"}, {"max_cat_to_onehot", "1"}}); + } + + size_t n_cats{2}; + + auto dmat = + RandomDataGenerator(kRows, kCols, 0).Seed(3).Type(ft).MaxCategory(n_cats).GenerateDMatrix(); + + int32_t n_threads = 16; + auto sampler = std::make_shared(); + auto evaluator = HistEvaluator{param, dmat->Info(), n_threads, + sampler, Task::kRegression}; + std::vector entries(1); + + for (auto const &gmat : dmat->GetBatches({GenericParameter::kCpuId, 32})) { + common::HistCollection hist; + + entries.front().nid = 0; + entries.front().depth = 0; + + hist.Init(gmat.cut.TotalBins()); + hist.AddHistRow(0); + hist.AllocateAllData(); + auto node_hist = hist[0]; + + CHECK_EQ(node_hist.size(), n_cats); + CHECK_EQ(node_hist.size(), gmat.cut.Ptrs().back()); + + GradientPairPrecise total_gpair; + for (size_t i = 0; i < node_hist.size(); ++i) { + node_hist[i] = {static_cast(node_hist.size() - i), 1.0}; + total_gpair += node_hist[i]; + } + RegTree tree; + evaluator.InitRoot(GradStats{total_gpair}); + evaluator.EvaluateSplits(hist, gmat.cut, ft, tree, &entries); + } + return entries.front(); +} +} // anonymous namespace + +TEST(HistEvaluator, Categorical) { + auto with_onehot = CompareOneHotAndPartition(true); + auto with_part = CompareOneHotAndPartition(false); + + ASSERT_EQ(with_onehot.split.loss_chg, with_part.split.loss_chg); +} } // namespace tree } // namespace xgboost diff --git a/tests/cpp/tree/hist/test_histogram.cc b/tests/cpp/tree/hist/test_histogram.cc index f257a683405e..799d1dafdbb9 100644 --- a/tests/cpp/tree/hist/test_histogram.cc +++ b/tests/cpp/tree/hist/test_histogram.cc @@ -2,7 +2,11 @@ * Copyright 2018-2021 by Contributors */ #include + #include "../../helpers.h" +#include "../../categorical_helpers.h" + +#include "../../../../src/common/categorical.h" #include "../../../../src/tree/hist/histogram.h" #include "../../../../src/tree/updater_quantile_hist.h" @@ -35,8 +39,9 @@ void TestAddHistRows(bool is_distributed) { nodes_for_subtraction_trick_.emplace_back(6, tree.GetDepth(6), 0.0f); HistogramBuilder histogram_builder; - histogram_builder.Reset(gmat.cut.TotalBins(), kMaxBins, omp_get_max_threads(), - is_distributed); + histogram_builder.Reset(gmat.cut.TotalBins(), + {GenericParameter::kCpuId, kMaxBins}, + omp_get_max_threads(), 1, is_distributed); histogram_builder.AddHistRows(&starting_index, &sync_count, nodes_for_explicit_hist_build_, nodes_for_subtraction_trick_, &tree); @@ -81,7 +86,8 @@ void TestSyncHist(bool is_distributed) { HistogramBuilder histogram; uint32_t total_bins = gmat.cut.Ptrs().back(); - histogram.Reset(total_bins, kMaxBins, omp_get_max_threads(), is_distributed); + histogram.Reset(total_bins, {GenericParameter::kCpuId, kMaxBins}, + omp_get_max_threads(), 1, is_distributed); RowSetCollection row_set_collection_; { @@ -247,22 +253,26 @@ void TestBuildHistogram(bool is_distributed) { bst_node_t nid = 0; HistogramBuilder histogram; - histogram.Reset(total_bins, kMaxBins, omp_get_max_threads(), is_distributed); + histogram.Reset(total_bins, {GenericParameter::kCpuId, kMaxBins}, + omp_get_max_threads(), 1, is_distributed); RegTree tree; - RowSetCollection row_set_collection_; - row_set_collection_.Clear(); - std::vector &row_indices = *row_set_collection_.Data(); + RowSetCollection row_set_collection; + row_set_collection.Clear(); + std::vector &row_indices = *row_set_collection.Data(); row_indices.resize(kNRows); std::iota(row_indices.begin(), row_indices.end(), 0); - row_set_collection_.Init(); + row_set_collection.Init(); CPUExpandEntry node(RegTree::kRoot, tree.GetDepth(0), 0.0f); - std::vector nodes_for_explicit_hist_build_; - nodes_for_explicit_hist_build_.push_back(node); - histogram.BuildHist(p_fmat.get(), &tree, row_set_collection_, - nodes_for_explicit_hist_build_, {}, gpair); + std::vector nodes_for_explicit_hist_build; + nodes_for_explicit_hist_build.push_back(node); + for (auto const &gidx : p_fmat->GetBatches( + {GenericParameter::kCpuId, kMaxBins})) { + histogram.BuildHist(0, gidx, &tree, row_set_collection, + nodes_for_explicit_hist_build, {}, gpair); + } // Check if number of histogram bins is correct ASSERT_EQ(histogram.Histogram()[nid].size(), gmat.cut.Ptrs().back()); @@ -294,5 +304,158 @@ TEST(CPUHistogram, BuildHist) { TestBuildHistogram(false); TestBuildHistogram(false); } + +namespace { +void TestHistogramCategorical(size_t n_categories) { + size_t constexpr kRows = 340; + int32_t constexpr kBins = 256; + auto x = GenerateRandomCategoricalSingleColumn(kRows, n_categories); + auto cat_m = GetDMatrixFromData(x, kRows, 1); + cat_m->Info().feature_types.HostVector().push_back(FeatureType::kCategorical); + BatchParam batch_param{0, static_cast(kBins)}; + + RegTree tree; + CPUExpandEntry node(RegTree::kRoot, tree.GetDepth(0), 0.0f); + std::vector nodes_for_explicit_hist_build; + nodes_for_explicit_hist_build.push_back(node); + + auto gpair = GenerateRandomGradients(kRows, 0, 2); + + RowSetCollection row_set_collection; + row_set_collection.Clear(); + std::vector &row_indices = *row_set_collection.Data(); + row_indices.resize(kRows); + std::iota(row_indices.begin(), row_indices.end(), 0); + row_set_collection.Init(); + + /** + * Generate hist with cat data. + */ + HistogramBuilder cat_hist; + for (auto const &gidx : cat_m->GetBatches( + {GenericParameter::kCpuId, kBins})) { + auto total_bins = gidx.cut.TotalBins(); + cat_hist.Reset(total_bins, {GenericParameter::kCpuId, kBins}, + omp_get_max_threads(), 1, false); + cat_hist.BuildHist(0, gidx, &tree, row_set_collection, + nodes_for_explicit_hist_build, {}, gpair.HostVector()); + } + + /** + * Generate hist with one hot encoded data. + */ + auto x_encoded = OneHotEncodeFeature(x, n_categories); + auto encode_m = GetDMatrixFromData(x_encoded, kRows, n_categories); + HistogramBuilder onehot_hist; + for (auto const &gidx : encode_m->GetBatches( + {GenericParameter::kCpuId, kBins})) { + auto total_bins = gidx.cut.TotalBins(); + onehot_hist.Reset(total_bins, {GenericParameter::kCpuId, kBins}, + omp_get_max_threads(), 1, false); + onehot_hist.BuildHist(0, gidx, &tree, row_set_collection, + nodes_for_explicit_hist_build, {}, + gpair.HostVector()); + } + + auto cat = cat_hist.Histogram()[0]; + auto onehot = onehot_hist.Histogram()[0]; + ValidateCategoricalHistogram(n_categories, onehot, cat); +} +} // anonymous namespace + +TEST(CPUHistogram, Categorical) { + for (size_t n_categories = 2; n_categories < 8; ++n_categories) { + TestHistogramCategorical(n_categories); + } +} + +TEST(CPUHistogram, ExternalMemory) { + size_t constexpr kEntries = 1 << 16; + int32_t constexpr kBins = 32; + auto m = CreateSparsePageDMatrix(kEntries, "cache"); + std::vector partition_size(1, 0); + size_t total_bins{0}; + size_t n_samples{0}; + + auto gpair = GenerateRandomGradients(m->Info().num_row_, 0.0, 1.0); + auto const &h_gpair = gpair.HostVector(); + + RegTree tree; + std::vector nodes; + nodes.emplace_back(0, tree.GetDepth(0), 0.0f); + + GHistRow multi_page; + HistogramBuilder multi_build; + { + /** + * Multi page + */ + std::vector rows_set; + std::vector hess(m->Info().num_row_, 1.0); + for (auto const &page : m->GetBatches( + {GenericParameter::kCpuId, kBins, hess})) { + CHECK_LT(page.base_rowid, m->Info().num_row_); + auto n_rows_in_node = page.Size(); + partition_size[0] = std::max(partition_size[0], n_rows_in_node); + total_bins = page.cut.TotalBins(); + n_samples += n_rows_in_node; + + rows_set.emplace_back(); + std::vector &row_indices = *rows_set.back().Data(); + row_indices.resize(n_rows_in_node); + std::iota(row_indices.begin(), row_indices.end(), page.base_rowid); + rows_set.back().Init(); + } + ASSERT_EQ(n_samples, m->Info().num_row_); + + common::BlockedSpace2d space{ + 1, [&](size_t nidx_in_set) { return partition_size.at(nidx_in_set); }, + 256}; + + multi_build.Reset(total_bins, {GenericParameter::kCpuId, kBins}, + omp_get_max_threads(), rows_set.size(), false); + + size_t page_idx{0}; + CHECK_EQ(h_gpair.size(), n_samples); + for (auto const &page : m->GetBatches( + {GenericParameter::kCpuId, kBins, hess})) { + multi_build.BuildHist(page_idx, space, page, &tree, + rows_set.at(page_idx), nodes, {}, h_gpair); + ++page_idx; + } + multi_page = multi_build.Histogram()[0]; + } + + GHistRow single_page; + HistogramBuilder single_build; + { + /** + * Single page + */ + RowSetCollection row_set_collection; + auto &row_indices = *row_set_collection.Data(); + row_indices.resize(n_samples); + std::iota(row_indices.begin(), row_indices.end(), 0); + row_set_collection.Init(); + + single_build.Reset(total_bins, {GenericParameter::kCpuId, kBins}, + omp_get_max_threads(), 1, false); + size_t n_batches{0}; + for (auto const &page : + m->GetBatches({GenericParameter::kCpuId, kBins})) { + single_build.BuildHist(0, page, &tree, row_set_collection, nodes, {}, + h_gpair); + n_batches ++; + } + ASSERT_EQ(n_batches, 1); + single_page = single_build.Histogram()[0]; + } + + for (size_t i = 0; i < single_page.size(); ++i) { + ASSERT_NEAR(single_page[i].GetGrad(), multi_page[i].GetGrad(), kRtEps); + ASSERT_NEAR(single_page[i].GetHess(), multi_page[i].GetHess(), kRtEps); + } + +} } // namespace tree } // namespace xgboost diff --git a/tests/cpp/tree/test_approx.cc b/tests/cpp/tree/test_approx.cc new file mode 100644 index 000000000000..6d3a803e3da5 --- /dev/null +++ b/tests/cpp/tree/test_approx.cc @@ -0,0 +1,129 @@ +/*! + * Copyright 2021 XGBoost contributors + */ +#include +#include "../helpers.h" +#include "../../../src/tree/updater_approx.h" + +namespace xgboost { +namespace tree { +TEST(Approx, Partitioner) { + size_t n_samples = 1024, n_features = 1, base_rowid = 0; + ApproxRowPartitioner partitioner{n_samples, base_rowid}; + ASSERT_EQ(partitioner.base_rowid, base_rowid); + ASSERT_EQ(partitioner.Size(), 1); + ASSERT_EQ(partitioner.Partitions()[0].Size(), n_samples); + + auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true); + GenericParameter ctx; + ctx.InitAllowUnknown(Args{}); + std::vector candidates{{0, 0, 0.4}}; + + for (auto const &page : + Xy->GetBatches({GenericParameter::kCpuId, 64})) { + bst_feature_t split_ind = 0; + { + auto min_value = page.cut.MinValues()[split_ind]; + RegTree tree; + tree.ExpandNode( + /*nid=*/0, /*split_index=*/0, /*split_value=*/min_value, + /*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + /*left_sum=*/0.0f, + /*right_sum=*/0.0f); + ApproxRowPartitioner partitioner{n_samples, base_rowid}; + candidates.front().split.split_value = min_value; + candidates.front().split.sindex = 0; + candidates.front().split.sindex |= (1U << 31); + partitioner.UpdatePosition(&ctx, page, candidates, &tree); + ASSERT_EQ(partitioner.Size(), 3); + ASSERT_EQ(partitioner[1].Size(), 0); + ASSERT_EQ(partitioner[2].Size(), n_samples); + } + { + ApproxRowPartitioner partitioner{n_samples, base_rowid}; + auto ptr = page.cut.Ptrs()[split_ind + 1]; + float split_value = page.cut.Values().at(ptr / 2); + RegTree tree; + tree.ExpandNode( + /*nid=*/RegTree::kRoot, /*split_index=*/split_ind, + /*split_value=*/split_value, + /*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + /*left_sum=*/0.0f, + /*right_sum=*/0.0f); + auto left_nidx = tree[RegTree::kRoot].LeftChild(); + candidates.front().split.split_value = split_value; + candidates.front().split.sindex = 0; + candidates.front().split.sindex |= (1U << 31); + partitioner.UpdatePosition(&ctx, page, candidates, &tree); + + auto elem = partitioner[left_nidx]; + ASSERT_LT(elem.Size(), n_samples); + ASSERT_GT(elem.Size(), 1); + for (auto it = elem.begin; it != elem.end; ++it) { + auto value = page.cut.Values().at(page.index[*it]); + ASSERT_LE(value, split_value); + } + auto right_nidx = tree[RegTree::kRoot].RightChild(); + elem = partitioner[right_nidx]; + for (auto it = elem.begin; it != elem.end; ++it) { + auto value = page.cut.Values().at(page.index[*it]); + ASSERT_GT(value, split_value) << *it; + } + } + } +} + +TEST(Approx, PredictionCache) { + size_t n_samples = 2048, n_features = 13; + auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true); + + { + GenericParameter ctx; + ctx.InitAllowUnknown(Args{}); + std::unique_ptr approx{ + TreeUpdater::Create("grow_global_approx_histmaker", &ctx, Task::kRegression)}; + RegTree tree; + std::vector trees{&tree}; + auto gpair = GenerateRandomGradients(n_samples); + approx->Configure(Args{{"max_bin", "64"}}); + approx->Update(&gpair, Xy.get(), trees); + HostDeviceVector out_prediction_cached; + out_prediction_cached.Resize(n_samples); + MatrixView m(&out_prediction_cached, {n_samples, 1}, + GenericParameter::kCpuId); + VectorView v(m, 0); + ASSERT_TRUE(approx->UpdatePredictionCache(Xy.get(), v)); + } + + std::unique_ptr learner{Learner::Create({Xy})}; + learner->SetParam("tree_method", "approx"); + learner->SetParam("nthread", "0"); + learner->Configure(); + + for (size_t i = 0; i < 8; ++i) { + learner->UpdateOneIter(i, Xy); + } + + HostDeviceVector out_prediction_cached; + learner->Predict(Xy, false, &out_prediction_cached, 0, 0); + + Json model{Object()}; + learner->SaveModel(&model); + + HostDeviceVector out_prediction; + { + std::unique_ptr learner{Learner::Create({Xy})}; + learner->LoadModel(model); + learner->Predict(Xy, false, &out_prediction, 0, 0); + } + + auto const h_predt_cached = out_prediction_cached.ConstHostSpan(); + auto const h_predt = out_prediction.ConstHostSpan(); + + ASSERT_EQ(h_predt.size(), h_predt_cached.size()); + for (size_t i = 0; i < h_predt.size(); ++i) { + ASSERT_NEAR(h_predt[i], h_predt_cached[i], kRtEps); + } +} +} // namespace tree +} // namespace xgboost diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 72c22539679f..42688320c3cb 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -275,7 +275,8 @@ void TestHistogramIndexImpl() { int constexpr kNRows = 1000, kNCols = 10; // Build 2 matrices and build a histogram maker with that - tree::GPUHistMakerSpecialised hist_maker, hist_maker_ext; + tree::GPUHistMakerSpecialised hist_maker{Task::kRegression}, + hist_maker_ext{Task::kRegression}; std::unique_ptr hist_maker_dmat( CreateSparsePageDMatrixWithRC(kNRows, kNCols, 0, true)); @@ -333,7 +334,7 @@ int32_t TestMinSplitLoss(DMatrix* dmat, float gamma, HostDeviceVector hist_maker; + tree::GPUHistMakerSpecialised hist_maker{Task::kRegression}; GenericParameter generic_param(CreateEmptyGenericParam(0)); hist_maker.Configure(args, &generic_param); @@ -394,7 +395,7 @@ void UpdateTree(HostDeviceVector* gpair, DMatrix* dmat, {"sampling_method", sampling_method}, }; - tree::GPUHistMakerSpecialised hist_maker; + tree::GPUHistMakerSpecialised hist_maker{Task::kRegression}; GenericParameter generic_param(CreateEmptyGenericParam(0)); hist_maker.Configure(args, &generic_param); @@ -539,7 +540,8 @@ TEST(GpuHist, ExternalMemoryWithSampling) { TEST(GpuHist, ConfigIO) { GenericParameter generic_param(CreateEmptyGenericParam(0)); - std::unique_ptr updater {TreeUpdater::Create("grow_gpu_hist", &generic_param) }; + std::unique_ptr updater{ + TreeUpdater::Create("grow_gpu_hist", &generic_param, Task::kRegression)}; updater->Configure(Args{}); Json j_updater { Object() }; diff --git a/tests/cpp/tree/test_histmaker.cc b/tests/cpp/tree/test_histmaker.cc index e1cb3568d5ef..60fb835df215 100644 --- a/tests/cpp/tree/test_histmaker.cc +++ b/tests/cpp/tree/test_histmaker.cc @@ -34,7 +34,8 @@ TEST(GrowHistMaker, InteractionConstraint) { RegTree tree; tree.param.num_feature = kCols; - std::unique_ptr updater { TreeUpdater::Create("grow_histmaker", ¶m) }; + std::unique_ptr updater{ + TreeUpdater::Create("grow_histmaker", ¶m, Task::kRegression)}; updater->Configure(Args{ {"interaction_constraints", "[[0, 1]]"}, {"num_feature", std::to_string(kCols)}}); @@ -51,7 +52,8 @@ TEST(GrowHistMaker, InteractionConstraint) { RegTree tree; tree.param.num_feature = kCols; - std::unique_ptr updater { TreeUpdater::Create("grow_histmaker", ¶m) }; + std::unique_ptr updater{ + TreeUpdater::Create("grow_histmaker", ¶m, Task::kRegression)}; updater->Configure(Args{{"num_feature", std::to_string(kCols)}}); updater->Update(&gradients, p_dmat.get(), {&tree}); diff --git a/tests/cpp/tree/test_param.cc b/tests/cpp/tree/test_param.cc index b4cc4005e3ad..d4194bb74c58 100644 --- a/tests/cpp/tree/test_param.cc +++ b/tests/cpp/tree/test_param.cc @@ -88,14 +88,14 @@ TEST(Param, SplitEntry) { xgboost::tree::SplitEntry se2; EXPECT_FALSE(se1.Update(se2)); - EXPECT_FALSE(se2.Update(-1, 100, 0, true, xgboost::tree::GradStats(), + EXPECT_FALSE(se2.Update(-1, 100, 0, true, false, xgboost::tree::GradStats(), xgboost::tree::GradStats())); - ASSERT_TRUE(se2.Update(1, 100, 0, true, xgboost::tree::GradStats(), + ASSERT_TRUE(se2.Update(1, 100, 0, true, false, xgboost::tree::GradStats(), xgboost::tree::GradStats())); ASSERT_TRUE(se1.Update(se2)); xgboost::tree::SplitEntry se3; - se3.Update(2, 101, 0, false, xgboost::tree::GradStats(), + se3.Update(2, 101, 0, false, false, xgboost::tree::GradStats(), xgboost::tree::GradStats()); xgboost::tree::SplitEntry::Reduce(se2, se3); EXPECT_EQ(se2.SplitIndex(), 101); diff --git a/tests/cpp/tree/test_prune.cc b/tests/cpp/tree/test_prune.cc index dbe910a8f183..e970c48bccdc 100644 --- a/tests/cpp/tree/test_prune.cc +++ b/tests/cpp/tree/test_prune.cc @@ -38,7 +38,7 @@ TEST(Updater, Prune) { tree.param.UpdateAllowUnknown(cfg); std::vector trees {&tree}; // prepare pruner - std::unique_ptr pruner(TreeUpdater::Create("prune", &lparam)); + std::unique_ptr pruner(TreeUpdater::Create("prune", &lparam, Task::kRegression)); pruner->Configure(cfg); // loss_chg < min_split_loss; diff --git a/tests/cpp/tree/test_quantile_hist.cc b/tests/cpp/tree/test_quantile_hist.cc index 938205aae024..37c60eb51edf 100644 --- a/tests/cpp/tree/test_quantile_hist.cc +++ b/tests/cpp/tree/test_quantile_hist.cc @@ -28,7 +28,7 @@ class QuantileHistMock : public QuantileHistMaker { BuilderMock(const TrainParam ¶m, std::unique_ptr pruner, DMatrix const *fmat) - : RealImpl(1, param, std::move(pruner), fmat) {} + : RealImpl(1, param, std::move(pruner), fmat, Task::kRegression) {} public: void TestInitData(const GHistIndexMatrix& gmat, @@ -230,7 +230,7 @@ class QuantileHistMock : public QuantileHistMaker { explicit QuantileHistMock( const std::vector >& args, const bool single_precision_histogram = false, bool batch = true) : - cfg_{args} { + QuantileHistMaker{Task::kRegression}, cfg_{args} { QuantileHistMaker::Configure(args); dmat_ = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix(); if (single_precision_histogram) { diff --git a/tests/cpp/tree/test_refresh.cc b/tests/cpp/tree/test_refresh.cc index 3689940fda35..3817abf8902a 100644 --- a/tests/cpp/tree/test_refresh.cc +++ b/tests/cpp/tree/test_refresh.cc @@ -32,7 +32,8 @@ TEST(Updater, Refresh) { auto lparam = CreateEmptyGenericParam(GPUIDX); tree.param.UpdateAllowUnknown(cfg); std::vector trees {&tree}; - std::unique_ptr refresher(TreeUpdater::Create("refresh", &lparam)); + std::unique_ptr refresher( + TreeUpdater::Create("refresh", &lparam, Task::kRegression)); tree.ExpandNode(0, 2, 0.2f, false, 0.0, 0.2f, 0.8f, 0.0f, 0.0f, /*left_sum=*/0.0f, /*right_sum=*/0.0f); diff --git a/tests/cpp/tree/test_tree_policy.cc b/tests/cpp/tree/test_tree_policy.cc index 68a720a8fba6..65dc975f2319 100644 --- a/tests/cpp/tree/test_tree_policy.cc +++ b/tests/cpp/tree/test_tree_policy.cc @@ -61,7 +61,7 @@ class TestGrowPolicy : public ::testing::Test { } }; -TEST_F(TestGrowPolicy, DISABLED_Approx) { +TEST_F(TestGrowPolicy, Approx) { this->TestTreeGrowPolicy("approx", "depthwise"); this->TestTreeGrowPolicy("approx", "lossguide"); } diff --git a/tests/cpp/tree/test_tree_stat.cc b/tests/cpp/tree/test_tree_stat.cc index eb8a7c5d910c..f05a534ed5cc 100644 --- a/tests/cpp/tree/test_tree_stat.cc +++ b/tests/cpp/tree/test_tree_stat.cc @@ -22,8 +22,8 @@ class UpdaterTreeStatTest : public ::testing::Test { void RunTest(std::string updater) { auto tparam = CreateEmptyGenericParam(0); - auto up = std::unique_ptr{ - TreeUpdater::Create(updater, &tparam)}; + auto up = + std::unique_ptr{TreeUpdater::Create(updater, &tparam, Task::kRegression)}; up->Configure(Args{}); RegTree tree; tree.param.num_feature = kCols; diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index a2da32d2f705..c0869221bd33 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -7,6 +7,8 @@ sys.path.append("tests/python") import testing as tm +import test_updaters as test_up + parameter_strategy = strategies.fixed_dictionaries({ 'max_depth': strategies.integers(0, 11), @@ -32,6 +34,8 @@ def train_result(param, dmat, num_rounds): class TestGPUUpdaters: + cputest = test_up.TestTreeMethod() + @given(parameter_strategy, strategies.integers(1, 20), tm.dataset_strategy) @settings(deadline=None) def test_gpu_hist(self, param, num_rounds, dataset): @@ -41,51 +45,12 @@ def test_gpu_hist(self, param, num_rounds, dataset): note(result) assert tm.non_increasing(result["train"][dataset.metric]) - def run_categorical_basic(self, rows, cols, rounds, cats): - onehot, label = tm.make_categorical(rows, cols, cats, True) - cat, _ = tm.make_categorical(rows, cols, cats, False) - - by_etl_results = {} - by_builtin_results = {} - - parameters = {"tree_method": "gpu_hist", "predictor": "gpu_predictor"} - - m = xgb.DMatrix(onehot, label, enable_categorical=False) - xgb.train( - parameters, - m, - num_boost_round=rounds, - evals=[(m, "Train")], - evals_result=by_etl_results, - ) - - m = xgb.DMatrix(cat, label, enable_categorical=True) - xgb.train( - parameters, - m, - num_boost_round=rounds, - evals=[(m, "Train")], - evals_result=by_builtin_results, - ) - - # There are guidelines on how to specify tolerance based on considering output as - # random variables. But in here the tree construction is extremely sensitive to - # floating point errors. An 1e-5 error in a histogram bin can lead to an entirely - # different tree. So even though the test is quite lenient, hypothesis can still - # pick up falsifying examples from time to time. - np.testing.assert_allclose( - np.array(by_etl_results["Train"]["rmse"]), - np.array(by_builtin_results["Train"]["rmse"]), - rtol=1e-3, - ) - assert tm.non_increasing(by_builtin_results["Train"]["rmse"]) - @given(strategies.integers(10, 400), strategies.integers(3, 8), strategies.integers(1, 2), strategies.integers(4, 7)) @settings(deadline=None) @pytest.mark.skipif(**tm.no_pandas()) def test_categorical(self, rows, cols, rounds, cats): - self.run_categorical_basic(rows, cols, rounds, cats) + self.cputest.run_categorical_basic(rows, cols, rounds, cats, "gpu_hist") def test_categorical_32_cat(self): '''32 hits the bound of integer bitset, so special test''' @@ -93,7 +58,7 @@ def test_categorical_32_cat(self): cols = 10 cats = 32 rounds = 4 - self.run_categorical_basic(rows, cols, rounds, cats) + self.cputest.run_categorical_basic(rows, cols, rounds, cats, "gpu_hist") @pytest.mark.skipif(**tm.no_cupy()) @given(parameter_strategy, strategies.integers(1, 20), diff --git a/tests/python/test_updaters.py b/tests/python/test_updaters.py index 07e6d44c6f2a..0bdbb703090a 100644 --- a/tests/python/test_updaters.py +++ b/tests/python/test_updaters.py @@ -126,3 +126,53 @@ def test_hist_degenerate_case(self): y = [1000000., 0., 0., 500000.] w = [0, 0, 1, 0] model.fit(X, y, sample_weight=w) + + def run_categorical_basic(self, rows, cols, rounds, cats, tree_method): + onehot, label = tm.make_categorical(rows, cols, cats, True) + cat, _ = tm.make_categorical(rows, cols, cats, False) + + by_etl_results = {} + by_builtin_results = {} + + predictor = "gpu_predictor" if tree_method == "gpu_hist" else None + # Use one-hot exclusively + parameters = { + "tree_method": tree_method, "predictor": predictor, "max_cat_to_onehot": 9999 + } + + m = xgb.DMatrix(onehot, label, enable_categorical=False) + xgb.train( + parameters, + m, + num_boost_round=rounds, + evals=[(m, "Train")], + evals_result=by_etl_results, + ) + + m = xgb.DMatrix(cat, label, enable_categorical=True) + xgb.train( + parameters, + m, + num_boost_round=rounds, + evals=[(m, "Train")], + evals_result=by_builtin_results, + ) + + # There are guidelines on how to specify tolerance based on considering output as + # random variables. But in here the tree construction is extremely sensitive to + # floating point errors. An 1e-5 error in a histogram bin can lead to an entirely + # different tree. So even though the test is quite lenient, hypothesis can still + # pick up falsifying examples from time to time. + np.testing.assert_allclose( + np.array(by_etl_results["Train"]["rmse"]), + np.array(by_builtin_results["Train"]["rmse"]), + rtol=1e-3, + ) + assert tm.non_increasing(by_builtin_results["Train"]["rmse"]) + + @given(strategies.integers(10, 400), strategies.integers(3, 8), + strategies.integers(1, 2), strategies.integers(4, 7)) + @settings(deadline=None) + @pytest.mark.skipif(**tm.no_pandas()) + def test_categorical(self, rows, cols, rounds, cats): + self.run_categorical_basic(rows, cols, rounds, cats, "approx") diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index 3315d3d8f184..72f6ebaaad68 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -1255,13 +1255,17 @@ def test_feature_weights(self, client: "Client") -> None: for i in range(kCols): fw[i] *= float(i) fw = da.from_array(fw) - poly_increasing = run_feature_weights(X, y, fw, model=xgb.dask.DaskXGBRegressor) + poly_increasing = run_feature_weights( + X, y, fw, "approx", model=xgb.dask.DaskXGBRegressor + ) fw = np.ones(shape=(kCols,)) for i in range(kCols): fw[i] *= float(kCols - i) fw = da.from_array(fw) - poly_decreasing = run_feature_weights(X, y, fw, model=xgb.dask.DaskXGBRegressor) + poly_decreasing = run_feature_weights( + X, y, fw, "approx", model=xgb.dask.DaskXGBRegressor + ) # Approxmated test, this is dependent on the implementation of random # number generator in std library. diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 9d63b160c244..0a37a1cd7c90 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -1088,10 +1088,10 @@ def test_pandas_input(): np.array([0, 1])) -def run_feature_weights(X, y, fw, model=xgb.XGBRegressor): +def run_feature_weights(X, y, fw, tree_method, model=xgb.XGBRegressor): with TemporaryDirectory() as tmpdir: colsample_bynode = 0.5 - reg = model(tree_method='hist', colsample_bynode=colsample_bynode) + reg = model(tree_method=tree_method, colsample_bynode=colsample_bynode) reg.fit(X, y, feature_weights=fw) model_path = os.path.join(tmpdir, 'model.json') @@ -1126,7 +1126,8 @@ def run_feature_weights(X, y, fw, model=xgb.XGBRegressor): return w -def test_feature_weights(): +@pytest.mark.parametrize("tree_method", ["approx", "hist"]) +def test_feature_weights(tree_method): kRows = 512 kCols = 64 X = rng.randn(kRows, kCols) @@ -1135,12 +1136,12 @@ def test_feature_weights(): fw = np.ones(shape=(kCols,)) for i in range(kCols): fw[i] *= float(i) - poly_increasing = run_feature_weights(X, y, fw, xgb.XGBRegressor) + poly_increasing = run_feature_weights(X, y, fw, tree_method, xgb.XGBRegressor) fw = np.ones(shape=(kCols,)) for i in range(kCols): fw[i] *= float(kCols - i) - poly_decreasing = run_feature_weights(X, y, fw, xgb.XGBRegressor) + poly_decreasing = run_feature_weights(X, y, fw, tree_method, xgb.XGBRegressor) # Approxmated test, this is dependent on the implementation of random # number generator in std library.