diff --git a/amalgamation/xgboost-all0.cc b/amalgamation/xgboost-all0.cc index 1241ced409cd..83d003052149 100644 --- a/amalgamation/xgboost-all0.cc +++ b/amalgamation/xgboost-all0.cc @@ -57,6 +57,7 @@ #include "../src/tree/updater_refresh.cc" #include "../src/tree/updater_sync.cc" #include "../src/tree/updater_histmaker.cc" +#include "../src/tree/updater_approx.cc" #include "../src/tree/constraints.cc" // linear diff --git a/demo/guide-python/categorical.py b/demo/guide-python/categorical.py index 7f358fcbb212..eed823ae8bb3 100644 --- a/demo/guide-python/categorical.py +++ b/demo/guide-python/categorical.py @@ -3,7 +3,8 @@ ===================================== Experimental support for categorical data. After 1.5 XGBoost `gpu_hist` tree method has -experimental support for one-hot encoding based tree split. +experimental support for one-hot encoding based tree split, and in 1.6 `approx` supported +was added. In before, users need to run an encoder themselves before passing the data into XGBoost, which creates a sparse matrix and potentially increase memory usage. This demo showcases diff --git a/doc/parameter.rst b/doc/parameter.rst index bb2666737c5b..0ec58026a15c 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -154,7 +154,7 @@ Parameters for Tree Booster * ``sketch_eps`` [default=0.03] - - Only used for ``tree_method=approx``. + - Only used for ``updater=grow_local_histmaker``. - This roughly translates into ``O(1 / sketch_eps)`` number of bins. Compared to directly select number of bins, this comes with theoretical guarantee with sketch accuracy. - Usually user does not have to tune this. @@ -238,13 +238,27 @@ Parameters for Tree Booster list is a group of indices of features that are allowed to interact with each other. See :doc:`/tutorials/feature_interaction_constraint` for more information. -Additional parameters for ``hist`` and ``gpu_hist`` tree method -================================================================ +Additional parameters for ``hist``, ``gpu_hist`` and ``approx`` tree method +=========================================================================== * ``single_precision_histogram``, [default= ``false``] - Use single precision to build histograms instead of double precision. +Additional parameters for ``approx`` tree method +================================================ + +* ``max_cat_to_onehot`` + + .. versionadded:: 1.6 + + .. note:: The support for this parameter is experimental. + + - A threshold for deciding whether XGBoost should use one-hot encoding based split for + categorical data. When number of categories is lesser than the threshold then one-hot + encoding is chosen, otherwise the categories will be partitioned into children nodes. + Only relevant for regression and binary classification with `approx` tree method. + Additional parameters for Dart Booster (``booster=dart``) ========================================================= diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala index 1522cfbb3af8..b06ffc9399a5 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala @@ -53,7 +53,7 @@ class XGBoostRegressorSuite extends FunSuite with PerTest { "max_depth" -> "6", "silent" -> "1", "objective" -> "reg:squarederror", - "max_bin" -> 16, + "max_bin" -> 64, "tree_method" -> treeMethod) val model1 = ScalaXGBoost.train(trainingDM, paramMap, round) diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index a6d486cd615f..54970af6d07a 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -267,6 +267,16 @@ def inner(y_score: np.ndarray, dmatrix: DMatrix) -> Tuple[str, float]: callbacks = [xgb.callback.EarlyStopping(rounds=early_stopping_rounds, save_best=True)] + max_cat_to_onehot : bool + + .. versionadded:: 1.6.0 + + A threshold for deciding whether XGBoost should use one-hot encoding based split + for categorical data. When number of categories is lesser than the threshold then + one-hot encoding is chosen, otherwise the categories will be partitioned into + children nodes. Only relevant for regression and binary classification and + `approx` tree method. + kwargs : dict, optional Keyword arguments for XGBoost Booster object. Full documentation of parameters can be found :doc:`here `. @@ -483,6 +493,7 @@ def __init__( eval_metric: Optional[Union[str, List[str], Callable]] = None, early_stopping_rounds: Optional[int] = None, callbacks: Optional[List[TrainingCallback]] = None, + max_cat_to_onehot: Optional[int] = None, **kwargs: Any ) -> None: if not SKLEARN_INSTALLED: @@ -522,6 +533,7 @@ def __init__( self.eval_metric = eval_metric self.early_stopping_rounds = early_stopping_rounds self.callbacks = callbacks + self.max_cat_to_onehot = max_cat_to_onehot if kwargs: self.kwargs = kwargs @@ -800,8 +812,8 @@ def _duplicated(parameter: str) -> None: _duplicated("callbacks") callbacks = self.callbacks if self.callbacks is not None else callbacks - # lastly check categorical data support. - if self.enable_categorical and params.get("tree_method", None) != "gpu_hist": + tree_method = params.get("tree_method", None) + if self.enable_categorical and tree_method not in ("gpu_hist", "approx"): raise ValueError( "Experimental support for categorical data is not implemented for" " current tree method yet." @@ -876,8 +888,7 @@ def fit( feature_weights : Weight for each feature, defines the probability of each feature being selected when colsample is being used. All values must be greater than 0, - otherwise a `ValueError` is thrown. Only available for `hist`, `gpu_hist` and - `exact` tree methods. + otherwise a `ValueError` is thrown. callbacks : .. deprecated: 1.6.0 @@ -1750,8 +1761,7 @@ def fit( feature_weights : Weight for each feature, defines the probability of each feature being selected when colsample is being used. All values must be greater than 0, - otherwise a `ValueError` is thrown. Only available for `hist`, `gpu_hist` and - `exact` tree methods. + otherwise a `ValueError` is thrown. callbacks : .. deprecated: 1.6.0 diff --git a/src/common/threading_utils.h b/src/common/threading_utils.h index feb936e333f8..da8ddf3c2c6d 100644 --- a/src/common/threading_utils.h +++ b/src/common/threading_utils.h @@ -130,8 +130,7 @@ class BlockedSpace2d { template void ParallelFor2d(const BlockedSpace2d& space, int nthreads, Func func) { const size_t num_blocks_in_space = space.Size(); - nthreads = std::min(nthreads, omp_get_max_threads()); - nthreads = std::max(nthreads, 1); + CHECK_GE(nthreads, 1); dmlc::OMPException exc; #pragma omp parallel num_threads(nthreads) @@ -277,9 +276,10 @@ inline int32_t OmpSetNumThreadsWithoutHT(int32_t* p_threads) { inline int32_t OmpGetNumThreads(int32_t n_threads) { if (n_threads <= 0) { - n_threads = omp_get_num_procs(); + n_threads = std::min(omp_get_num_procs(), omp_get_max_threads()); } n_threads = std::min(n_threads, OmpGetThreadLimit()); + n_threads = std::max(n_threads, 1); return n_threads; } } // namespace common diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index ff736b8ba260..e127e3e4869e 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -168,7 +168,7 @@ void GBTree::ConfigureUpdaters() { // calling this function. break; case TreeMethod::kApprox: - tparam_.updater_seq = "grow_histmaker,prune"; + tparam_.updater_seq = "grow_histmaker"; break; case TreeMethod::kExact: tparam_.updater_seq = "grow_colmaker,prune"; diff --git a/src/tree/param.h b/src/tree/param.h index c660142ebc9f..7ed796a1ef1c 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -38,7 +38,7 @@ struct TrainParam : public XGBoostParameter { enum TreeGrowPolicy { kDepthWise = 0, kLossGuide = 1 }; int grow_policy; - uint32_t max_cat_to_onehot{1}; + uint32_t max_cat_to_onehot{4}; //----- the rest parameters are less important ---- // minimum amount of hessian(weight) allowed in a child diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index a0720a34b5d4..96efb4d68d11 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -973,6 +973,7 @@ void RegTree::SaveCategoricalSplit(Json* p_out) const { } size_t size = categories.size() - begin; categories_sizes.emplace_back(static_cast(size)); + CHECK_NE(size, 0); } } diff --git a/src/tree/tree_updater.cc b/src/tree/tree_updater.cc index 293dfb53a15f..6d0aed009f6a 100644 --- a/src/tree/tree_updater.cc +++ b/src/tree/tree_updater.cc @@ -35,6 +35,7 @@ DMLC_REGISTRY_LINK_TAG(updater_refresh); DMLC_REGISTRY_LINK_TAG(updater_prune); DMLC_REGISTRY_LINK_TAG(updater_quantile_hist); DMLC_REGISTRY_LINK_TAG(updater_histmaker); +DMLC_REGISTRY_LINK_TAG(updater_approx); DMLC_REGISTRY_LINK_TAG(updater_sync); #ifdef XGBOOST_USE_CUDA DMLC_REGISTRY_LINK_TAG(updater_gpu_hist); diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc new file mode 100644 index 000000000000..1f3c7342b120 --- /dev/null +++ b/src/tree/updater_approx.cc @@ -0,0 +1,369 @@ +/*! + * Copyright 2021 XGBoost contributors + * + * \brief Implementation for the approx tree method. + */ +#include "updater_approx.h" + +#include +#include +#include + +#include "../common/random.h" +#include "../data/gradient_index.h" +#include "constraints.h" +#include "driver.h" +#include "hist/evaluate_splits.h" +#include "hist/histogram.h" +#include "hist/param.h" +#include "param.h" +#include "xgboost/base.h" +#include "xgboost/json.h" +#include "xgboost/tree_updater.h" + +namespace xgboost { +namespace tree { + +DMLC_REGISTRY_FILE_TAG(updater_approx); + +template +class GloablApproxBuilder { + protected: + TrainParam param_; + std::shared_ptr col_sampler_; + HistEvaluator evaluator_; + HistogramBuilder histogram_builder_; + GenericParameter const *ctx_; + + std::vector partitioner_; + // Pointer to last updated tree, used for update prediction cache. + RegTree *p_last_tree_{nullptr}; + common::Monitor *monitor_; + size_t n_batches_{0}; + // Cache for histogram cuts. + common::HistogramCuts feature_values_; + + public: + void InitData(DMatrix *p_fmat, common::Span hess) { + monitor_->Start(__func__); + n_batches_ = 0; + int32_t n_total_bins = 0; + partitioner_.clear(); + // Generating the GHistIndexMatrix is quite slow, is there a way to speed it up? + for (auto const &page : p_fmat->GetBatches( + {GenericParameter::kCpuId, param_.max_bin, hess, true})) { + if (n_total_bins == 0) { + n_total_bins = page.cut.TotalBins(); + feature_values_ = page.cut; + } else { + CHECK_EQ(n_total_bins, page.cut.TotalBins()); + } + partitioner_.emplace_back(page.Size(), page.base_rowid); + n_batches_++; + } + + histogram_builder_.Reset(n_total_bins, + BatchParam{GenericParameter::kCpuId, param_.max_bin, hess}, + ctx_->Threads(), n_batches_, rabit::IsDistributed()); + monitor_->Stop(__func__); + } + + CPUExpandEntry InitRoot(DMatrix *p_fmat, std::vector const &gpair, + common::Span hess, RegTree *p_tree) { + monitor_->Start(__func__); + CPUExpandEntry best; + best.nid = RegTree::kRoot; + best.depth = 0; + GradStats root_sum; + for (auto const &g : gpair) { + root_sum.Add(g); + } + rabit::Allreduce(reinterpret_cast(&root_sum), 2); + std::vector nodes{best}; + size_t i = 0; + auto space = this->ConstructHistSpace(nodes); + for (auto const &page : + p_fmat->GetBatches({GenericParameter::kCpuId, param_.max_bin, hess})) { + histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(), nodes, + {}, gpair); + i++; + } + + auto weight = evaluator_.InitRoot(root_sum); + p_tree->Stat(RegTree::kRoot).sum_hess = root_sum.GetHess(); + p_tree->Stat(RegTree::kRoot).base_weight = weight; + (*p_tree)[RegTree::kRoot].SetLeaf(param_.learning_rate * weight); + + auto const &histograms = histogram_builder_.Histogram(); + auto ft = p_fmat->Info().feature_types.ConstHostSpan(); + evaluator_.EvaluateSplits(histograms, feature_values_, ft, *p_tree, &nodes); + monitor_->Stop(__func__); + + return nodes.front(); + } + + void UpdatePredictionCache(const DMatrix *data, linalg::VectorView out_preds) { + monitor_->Start(__func__); + // Caching prediction seems redundant for approx tree method, as sketching takes up + // majority of training time. + CHECK_EQ(out_preds.Size(), data->Info().num_row_); + CHECK(p_last_tree_); + + size_t n_nodes = p_last_tree_->GetNodes().size(); + + auto evaluator = evaluator_.Evaluator(); + auto const &tree = *p_last_tree_; + auto const &snode = evaluator_.Stats(); + for (auto &part : partitioner_) { + CHECK_EQ(part.Size(), n_nodes); + common::BlockedSpace2d space( + part.Size(), [&](size_t node) { return part[node].Size(); }, 1024); + common::ParallelFor2d(space, ctx_->Threads(), [&](size_t nidx, common::Range1d r) { + if (tree[nidx].IsLeaf()) { + const auto rowset = part[nidx]; + auto const &stats = snode.at(nidx); + auto leaf_value = + evaluator.CalcWeight(nidx, param_, GradStats{stats.stats}) * param_.learning_rate; + for (const size_t *it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) { + out_preds(*it) += leaf_value; + } + } + }); + } + monitor_->Stop(__func__); + } + + // Construct a work space for building histogram. Eventually we should move this + // function into histogram builder once hist tree method supports external memory. + common::BlockedSpace2d ConstructHistSpace( + std::vector const &nodes_to_build) const { + std::vector partition_size(nodes_to_build.size(), 0); + for (auto const &partition : partitioner_) { + size_t k = 0; + for (auto node : nodes_to_build) { + auto n_rows_in_node = partition.Partitions()[node.nid].Size(); + partition_size[k] = std::max(partition_size[k], n_rows_in_node); + k++; + } + } + common::BlockedSpace2d space{nodes_to_build.size(), + [&](size_t nidx_in_set) { return partition_size[nidx_in_set]; }, + 256}; + return space; + } + + void BuildHistogram(DMatrix *p_fmat, RegTree *p_tree, + std::vector const &valid_candidates, + std::vector const &gpair, common::Span hess) { + monitor_->Start(__func__); + std::vector nodes_to_build; + std::vector nodes_to_sub; + + for (auto const &c : valid_candidates) { + auto left_nidx = (*p_tree)[c.nid].LeftChild(); + auto right_nidx = (*p_tree)[c.nid].RightChild(); + auto fewer_right = c.split.right_sum.GetHess() < c.split.left_sum.GetHess(); + + auto build_nidx = left_nidx; + auto subtract_nidx = right_nidx; + if (fewer_right) { + std::swap(build_nidx, subtract_nidx); + } + nodes_to_build.push_back(CPUExpandEntry{build_nidx, p_tree->GetDepth(build_nidx), {}}); + nodes_to_sub.push_back(CPUExpandEntry{subtract_nidx, p_tree->GetDepth(subtract_nidx), {}}); + } + + size_t i = 0; + auto space = this->ConstructHistSpace(nodes_to_build); + for (auto const &page : + p_fmat->GetBatches({GenericParameter::kCpuId, param_.max_bin, hess})) { + histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(), + nodes_to_build, nodes_to_sub, gpair); + i++; + } + monitor_->Stop(__func__); + } + + public: + explicit GloablApproxBuilder(TrainParam param, MetaInfo const &info, GenericParameter const *ctx, + std::shared_ptr column_sampler, ObjInfo task, + common::Monitor *monitor) + : param_{std::move(param)}, + col_sampler_{std::move(column_sampler)}, + evaluator_{param_, info, ctx->Threads(), col_sampler_, task}, + ctx_{ctx}, + monitor_{monitor} {} + + void UpdateTree(RegTree *p_tree, std::vector const &gpair, common::Span hess, + DMatrix *p_fmat) { + p_last_tree_ = p_tree; + this->InitData(p_fmat, hess); + + Driver driver(static_cast(param_.grow_policy)); + auto &tree = *p_tree; + driver.Push({this->InitRoot(p_fmat, gpair, hess, p_tree)}); + bst_node_t num_leaves = 1; + auto expand_set = driver.Pop(); + + while (!expand_set.empty()) { + // candidates that can be further splited. + std::vector valid_candidates; + // candidates that can be applied. + std::vector applied; + for (auto const &candidate : expand_set) { + if (!candidate.IsValid(param_, num_leaves)) { + continue; + } + evaluator_.ApplyTreeSplit(candidate, p_tree); + applied.push_back(candidate); + num_leaves++; + int left_child_nidx = tree[candidate.nid].LeftChild(); + if (CPUExpandEntry::ChildIsValid(param_, p_tree->GetDepth(left_child_nidx), num_leaves)) { + valid_candidates.emplace_back(candidate); + } + } + + monitor_->Start("UpdatePosition"); + size_t i = 0; + for (auto const &page : + p_fmat->GetBatches({GenericParameter::kCpuId, param_.max_bin, hess})) { + partitioner_.at(i).UpdatePosition(ctx_, page, applied, p_tree); + i++; + } + monitor_->Stop("UpdatePosition"); + + std::vector best_splits; + if (!valid_candidates.empty()) { + this->BuildHistogram(p_fmat, p_tree, valid_candidates, gpair, hess); + for (auto const &candidate : valid_candidates) { + int left_child_nidx = tree[candidate.nid].LeftChild(); + int right_child_nidx = tree[candidate.nid].RightChild(); + CPUExpandEntry l_best{left_child_nidx, tree.GetDepth(left_child_nidx), {}}; + CPUExpandEntry r_best{right_child_nidx, tree.GetDepth(right_child_nidx), {}}; + best_splits.push_back(l_best); + best_splits.push_back(r_best); + } + auto const &histograms = histogram_builder_.Histogram(); + auto ft = p_fmat->Info().feature_types.ConstHostSpan(); + monitor_->Start("EvaluateSplits"); + evaluator_.EvaluateSplits(histograms, feature_values_, ft, *p_tree, &best_splits); + monitor_->Stop("EvaluateSplits"); + } + driver.Push(best_splits.begin(), best_splits.end()); + expand_set = driver.Pop(); + } + } +}; + +/** + * \brief Implementation for the approx tree method. It constructs quantile for every + * iteration. + */ +class GlobalApproxUpdater : public TreeUpdater { + TrainParam param_; + common::Monitor monitor_; + CPUHistMakerTrainParam hist_param_; + // specializations for different histogram precision. + std::unique_ptr> f32_impl_; + std::unique_ptr> f64_impl_; + // pointer to the last DMatrix, used for update prediction cache. + DMatrix *cached_{nullptr}; + std::shared_ptr column_sampler_ = + std::make_shared(); + ObjInfo task_; + + public: + explicit GlobalApproxUpdater(ObjInfo task) : task_{task} { monitor_.Init(__func__); } + + void Configure(const Args &args) override { + param_.UpdateAllowUnknown(args); + hist_param_.UpdateAllowUnknown(args); + } + void LoadConfig(Json const &in) override { + auto const &config = get(in); + FromJson(config.at("train_param"), &this->param_); + FromJson(config.at("hist_param"), &this->hist_param_); + } + void SaveConfig(Json *p_out) const override { + auto &out = *p_out; + out["train_param"] = ToJson(param_); + out["hist_param"] = ToJson(hist_param_); + } + + void InitData(TrainParam const ¶m, HostDeviceVector *gpair, + std::vector *sampled) { + auto const &h_gpair = gpair->HostVector(); + sampled->resize(h_gpair.size()); + std::copy(h_gpair.cbegin(), h_gpair.cend(), sampled->begin()); + auto &rnd = common::GlobalRandom(); + if (param.subsample != 1.0) { + CHECK(param.sampling_method != TrainParam::kGradientBased) + << "Gradient based sampling is not supported for approx tree method."; + std::bernoulli_distribution coin_flip(param.subsample); + std::transform(sampled->begin(), sampled->end(), sampled->begin(), [&](GradientPair &g) { + if (coin_flip(rnd)) { + return g; + } else { + return GradientPair{}; + } + }); + } + } + + char const *Name() const override { return "grow_histmaker"; } + + void Update(HostDeviceVector *gpair, DMatrix *m, + const std::vector &trees) override { + float lr = param_.learning_rate; + param_.learning_rate = lr / trees.size(); + + if (hist_param_.single_precision_histogram) { + f32_impl_ = std::make_unique>(param_, m->Info(), tparam_, + column_sampler_, task_, &monitor_); + } else { + f64_impl_ = std::make_unique>(param_, m->Info(), tparam_, + column_sampler_, task_, &monitor_); + } + + std::vector h_gpair; + InitData(param_, gpair, &h_gpair); + // Obtain the hessian values for weighted sketching + std::vector hess(h_gpair.size()); + std::transform(h_gpair.begin(), h_gpair.end(), hess.begin(), + [](auto g) { return g.GetHess(); }); + + cached_ = m; + + for (auto p_tree : trees) { + if (hist_param_.single_precision_histogram) { + this->f32_impl_->UpdateTree(p_tree, h_gpair, hess, m); + } else { + this->f64_impl_->UpdateTree(p_tree, h_gpair, hess, m); + } + } + param_.learning_rate = lr; + } + + bool UpdatePredictionCache(const DMatrix *data, linalg::VectorView out_preds) override { + if (data != cached_ || (!this->f32_impl_ && !this->f64_impl_)) { + return false; + } + + if (hist_param_.single_precision_histogram) { + this->f32_impl_->UpdatePredictionCache(data, out_preds); + } else { + this->f64_impl_->UpdatePredictionCache(data, out_preds); + } + return true; + } +}; + +DMLC_REGISTRY_FILE_TAG(grow_histmaker); + +XGBOOST_REGISTER_TREE_UPDATER(GlobalHistMaker, "grow_histmaker") + .describe( + "Tree constructor that uses approximate histogram construction " + "for each node.") + .set_body([](ObjInfo task) { return new GlobalApproxUpdater(task); }); +} // namespace tree +} // namespace xgboost diff --git a/src/tree/updater_histmaker.cc b/src/tree/updater_histmaker.cc index 7552e03034ed..ee4a618bb765 100644 --- a/src/tree/updater_histmaker.cc +++ b/src/tree/updater_histmaker.cc @@ -641,126 +641,10 @@ class CQHistMaker: public HistMaker { std::vector > sketchs_; }; -// global proposal -class GlobalProposalHistMaker: public CQHistMaker { - public: - char const* Name() const override { - return "grow_histmaker"; - } - - protected: - void ResetPosAndPropose(const std::vector &gpair, - DMatrix *p_fmat, - const std::vector &fset, - const RegTree &tree) override { - if (this->qexpand_.size() == 1) { - cached_rptr_.clear(); - cached_cut_.clear(); - } - if (cached_rptr_.size() == 0) { - CHECK_EQ(this->qexpand_.size(), 1U); - CQHistMaker::ResetPosAndPropose(gpair, p_fmat, fset, tree); - cached_rptr_ = this->wspace_.rptr; - cached_cut_ = this->wspace_.cut; - } else { - this->wspace_.cut.clear(); - this->wspace_.rptr.clear(); - this->wspace_.rptr.push_back(0); - for (size_t i = 0; i < this->qexpand_.size(); ++i) { - for (size_t j = 0; j < cached_rptr_.size() - 1; ++j) { - this->wspace_.rptr.push_back( - this->wspace_.rptr.back() + cached_rptr_[j + 1] - cached_rptr_[j]); - } - this->wspace_.cut.insert(this->wspace_.cut.end(), cached_cut_.begin(), cached_cut_.end()); - } - CHECK_EQ(this->wspace_.rptr.size(), - (fset.size() + 1) * this->qexpand_.size() + 1); - CHECK_EQ(this->wspace_.rptr.back(), this->wspace_.cut.size()); - } - } - - // code to create histogram - void CreateHist(const std::vector &gpair, - DMatrix *p_fmat, - const std::vector &fset, - const RegTree &tree) override { - const MetaInfo &info = p_fmat->Info(); - // fill in reverse map - this->feat2workindex_.resize(tree.param.num_feature); - this->work_set_ = fset; - std::fill(this->feat2workindex_.begin(), this->feat2workindex_.end(), -1); - for (size_t i = 0; i < fset.size(); ++i) { - this->feat2workindex_[fset[i]] = static_cast(i); - } - // start to work - this->wspace_.Configure(1); - // to gain speedup in recovery - { - this->thread_hist_.resize(omp_get_max_threads()); - - // TWOPASS: use the real set + split set in the column iteration. - this->SetDefaultPostion(p_fmat, tree); - this->work_set_.insert(this->work_set_.end(), this->fsplit_set_.begin(), - this->fsplit_set_.end()); - XGBOOST_PARALLEL_SORT(this->work_set_.begin(), this->work_set_.end(), - std::less<>{}); - this->work_set_.resize( - std::unique(this->work_set_.begin(), this->work_set_.end()) - this->work_set_.begin()); - - // start accumulating statistics - for (const auto &batch : p_fmat->GetBatches()) { - // TWOPASS: use the real set + split set in the column iteration. - this->CorrectNonDefaultPositionByBatch(batch, this->fsplit_set_, tree); - auto page = batch.GetView(); - - // start enumeration - const auto nsize = static_cast(this->work_set_.size()); - dmlc::OMPException exc; -#pragma omp parallel for schedule(dynamic, 1) - for (bst_omp_uint i = 0; i < nsize; ++i) { - exc.Run([&]() { - int fid = this->work_set_[i]; - int offset = this->feat2workindex_[fid]; - if (offset >= 0) { - this->UpdateHistCol(gpair, page[fid], info, tree, - fset, offset, - &this->thread_hist_[omp_get_thread_num()]); - } - }); - } - exc.Rethrow(); - } - - // update node statistics. - this->GetNodeStats(gpair, *p_fmat, tree, - &(this->thread_stats_), &(this->node_stats_)); - for (const int nid : this->qexpand_) { - const int wid = this->node2workindex_[nid]; - this->wspace_.hset[0][fset.size() + wid * (fset.size()+1)] - .data[0] = this->node_stats_[nid]; - } - } - this->histred_.Allreduce(dmlc::BeginPtr(this->wspace_.hset[0].data), - this->wspace_.hset[0].data.size()); - } - - // cached unit pointer - std::vector cached_rptr_; - // cached cut value. - std::vector cached_cut_; -}; - XGBOOST_REGISTER_TREE_UPDATER(LocalHistMaker, "grow_local_histmaker") .describe("Tree constructor that uses approximate histogram construction.") .set_body([](ObjInfo) { return new CQHistMaker(); }); - -// The updater for approx tree method. -XGBOOST_REGISTER_TREE_UPDATER(HistMaker, "grow_histmaker") -.describe("Tree constructor that uses approximate global of histogram construction.") -.set_body([](ObjInfo) { - return new GlobalProposalHistMaker(); - }); } // namespace tree } // namespace xgboost diff --git a/tests/cpp/gbm/test_gbtree.cc b/tests/cpp/gbm/test_gbtree.cc index 34c1a52d9dd2..c337312a1154 100644 --- a/tests/cpp/gbm/test_gbtree.cc +++ b/tests/cpp/gbm/test_gbtree.cc @@ -35,7 +35,7 @@ TEST(GBTree, SelectTreeMethod) { gbtree.Configure(args); auto const& tparam = gbtree.GetTrainParam(); gbtree.Configure({{"tree_method", "approx"}}); - ASSERT_EQ(tparam.updater_seq, "grow_histmaker,prune"); + ASSERT_EQ(tparam.updater_seq, "grow_histmaker"); gbtree.Configure({{"tree_method", "exact"}}); ASSERT_EQ(tparam.updater_seq, "grow_colmaker,prune"); gbtree.Configure({{"tree_method", "hist"}}); diff --git a/tests/cpp/tree/test_approx.cc b/tests/cpp/tree/test_approx.cc index 680ed9d4b6cd..57a0cd3545bb 100644 --- a/tests/cpp/tree/test_approx.cc +++ b/tests/cpp/tree/test_approx.cc @@ -72,5 +72,58 @@ TEST(Approx, Partitioner) { } } } + +TEST(Approx, PredictionCache) { + size_t n_samples = 2048, n_features = 13; + auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true); + + { + omp_set_num_threads(1); + GenericParameter ctx; + ctx.InitAllowUnknown(Args{{"nthread", "8"}}); + std::unique_ptr approx{ + TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})}; + RegTree tree; + std::vector trees{&tree}; + auto gpair = GenerateRandomGradients(n_samples); + approx->Configure(Args{{"max_bin", "64"}}); + approx->Update(&gpair, Xy.get(), trees); + HostDeviceVector out_prediction_cached; + out_prediction_cached.Resize(n_samples); + auto cache = linalg::VectorView{ + out_prediction_cached.HostSpan(), {out_prediction_cached.Size()}, GenericParameter::kCpuId}; + ASSERT_TRUE(approx->UpdatePredictionCache(Xy.get(), cache)); + } + + std::unique_ptr learner{Learner::Create({Xy})}; + learner->SetParam("tree_method", "approx"); + learner->SetParam("nthread", "0"); + learner->Configure(); + + for (size_t i = 0; i < 8; ++i) { + learner->UpdateOneIter(i, Xy); + } + + HostDeviceVector out_prediction_cached; + learner->Predict(Xy, false, &out_prediction_cached, 0, 0); + + Json model{Object()}; + learner->SaveModel(&model); + + HostDeviceVector out_prediction; + { + std::unique_ptr learner{Learner::Create({Xy})}; + learner->LoadModel(model); + learner->Predict(Xy, false, &out_prediction, 0, 0); + } + + auto const h_predt_cached = out_prediction_cached.ConstHostSpan(); + auto const h_predt = out_prediction.ConstHostSpan(); + + ASSERT_EQ(h_predt.size(), h_predt_cached.size()); + for (size_t i = 0; i < h_predt.size(); ++i) { + ASSERT_NEAR(h_predt[i], h_predt_cached[i], kRtEps); + } +} } // namespace tree } // namespace xgboost diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index c103b23a78b2..5639c2f003bc 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -315,57 +315,6 @@ TEST(GpuHist, TestHistogramIndex) { TestHistogramIndexImpl(); } -// gamma is an alias of min_split_loss -int32_t TestMinSplitLoss(DMatrix* dmat, float gamma, HostDeviceVector* gpair) { - Args args { - {"max_depth", "1"}, - {"max_leaves", "0"}, - - // Disable all other parameters. - {"colsample_bynode", "1"}, - {"colsample_bylevel", "1"}, - {"colsample_bytree", "1"}, - {"min_child_weight", "0.01"}, - {"reg_alpha", "0"}, - {"reg_lambda", "0"}, - {"max_delta_step", "0"}, - - // test gamma - {"gamma", std::to_string(gamma)} - }; - - tree::GPUHistMakerSpecialised hist_maker{ObjInfo{ObjInfo::kRegression}}; - GenericParameter generic_param(CreateEmptyGenericParam(0)); - hist_maker.Configure(args, &generic_param); - - RegTree tree; - hist_maker.Update(gpair, dmat, {&tree}); - - auto n_nodes = tree.NumExtraNodes(); - return n_nodes; -} - -TEST(GpuHist, MinSplitLoss) { - constexpr size_t kRows = 32; - constexpr size_t kCols = 16; - constexpr float kSparsity = 0.6; - auto dmat = RandomDataGenerator(kRows, kCols, kSparsity).Seed(3).GenerateDMatrix(); - auto gpair = GenerateRandomGradients(kRows); - - { - int32_t n_nodes = TestMinSplitLoss(dmat.get(), 0.01, &gpair); - // This is not strictly verified, meaning the numeber `2` is whatever GPU_Hist retured - // when writing this test, and only used for testing larger gamma (below) does prevent - // building tree. - ASSERT_EQ(n_nodes, 2); - } - { - int32_t n_nodes = TestMinSplitLoss(dmat.get(), 100.0, &gpair); - // No new nodes with gamma == 100. - ASSERT_EQ(n_nodes, static_cast(0)); - } -} - void UpdateTree(HostDeviceVector* gpair, DMatrix* dmat, size_t gpu_page_size, RegTree* tree, HostDeviceVector* preds, float subsample = 1.0f, diff --git a/tests/cpp/tree/test_tree_policy.cc b/tests/cpp/tree/test_tree_policy.cc index 68a720a8fba6..65dc975f2319 100644 --- a/tests/cpp/tree/test_tree_policy.cc +++ b/tests/cpp/tree/test_tree_policy.cc @@ -61,7 +61,7 @@ class TestGrowPolicy : public ::testing::Test { } }; -TEST_F(TestGrowPolicy, DISABLED_Approx) { +TEST_F(TestGrowPolicy, Approx) { this->TestTreeGrowPolicy("approx", "depthwise"); this->TestTreeGrowPolicy("approx", "lossguide"); } diff --git a/tests/cpp/tree/test_tree_stat.cc b/tests/cpp/tree/test_tree_stat.cc index 9e2e8c04fc50..772420ce0f23 100644 --- a/tests/cpp/tree/test_tree_stat.cc +++ b/tests/cpp/tree/test_tree_stat.cc @@ -114,4 +114,70 @@ TEST_F(UpdaterEtaTest, Approx) { this->RunTest("grow_histmaker"); } #if defined(XGBOOST_USE_CUDA) TEST_F(UpdaterEtaTest, GpuHist) { this->RunTest("grow_gpu_hist"); } #endif // defined(XGBOOST_USE_CUDA) + +class TestMinSplitLoss : public ::testing::Test { + std::shared_ptr dmat_; + HostDeviceVector gpair_; + + void SetUp() override { + constexpr size_t kRows = 32; + constexpr size_t kCols = 16; + constexpr float kSparsity = 0.6; + dmat_ = RandomDataGenerator(kRows, kCols, kSparsity).Seed(3).GenerateDMatrix(); + gpair_ = GenerateRandomGradients(kRows); + } + + int32_t Update(std::string updater, float gamma) { + Args args{{"max_depth", "1"}, + {"max_leaves", "0"}, + + // Disable all other parameters. + {"colsample_bynode", "1"}, + {"colsample_bylevel", "1"}, + {"colsample_bytree", "1"}, + {"min_child_weight", "0.01"}, + {"reg_alpha", "0"}, + {"reg_lambda", "0"}, + {"max_delta_step", "0"}, + + // test gamma + {"gamma", std::to_string(gamma)}}; + + GenericParameter generic_param(CreateEmptyGenericParam(0)); + auto up = std::unique_ptr{ + TreeUpdater::Create(updater, &generic_param, ObjInfo{ObjInfo::kRegression})}; + up->Configure(args); + + RegTree tree; + up->Update(&gpair_, dmat_.get(), {&tree}); + + auto n_nodes = tree.NumExtraNodes(); + return n_nodes; + } + + public: + void RunTest(std::string updater) { + { + int32_t n_nodes = Update(updater, 0.01); + // This is not strictly verified, meaning the numeber `2` is whatever GPU_Hist retured + // when writing this test, and only used for testing larger gamma (below) does prevent + // building tree. + ASSERT_EQ(n_nodes, 2); + } + { + int32_t n_nodes = Update(updater, 100.0); + // No new nodes with gamma == 100. + ASSERT_EQ(n_nodes, static_cast(0)); + } + } +}; + +/* Exact tree method requires a pruner as an additional updater, so not tested here. */ + +TEST_F(TestMinSplitLoss, Approx) { this->RunTest("grow_histmaker"); } + +TEST_F(TestMinSplitLoss, Hist) { this->RunTest("grow_quantile_histmaker"); } +#if defined(XGBOOST_USE_CUDA) +TEST_F(TestMinSplitLoss, GpuHist) { this->RunTest("grow_gpu_hist"); } +#endif // defined(XGBOOST_USE_CUDA) } // namespace xgboost diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index 505d38778c8b..22799e533b0d 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -7,6 +7,8 @@ sys.path.append("tests/python") import testing as tm +import test_updaters as test_up + parameter_strategy = strategies.fixed_dictionaries({ 'max_depth': strategies.integers(0, 11), @@ -32,6 +34,8 @@ def train_result(param, dmat, num_rounds): class TestGPUUpdaters: + cputest = test_up.TestTreeMethod() + @given(parameter_strategy, strategies.integers(1, 20), tm.dataset_strategy) @settings(deadline=None) def test_gpu_hist(self, param, num_rounds, dataset): @@ -41,51 +45,12 @@ def test_gpu_hist(self, param, num_rounds, dataset): note(result) assert tm.non_increasing(result["train"][dataset.metric]) - def run_categorical_basic(self, rows, cols, rounds, cats): - onehot, label = tm.make_categorical(rows, cols, cats, True) - cat, _ = tm.make_categorical(rows, cols, cats, False) - - by_etl_results = {} - by_builtin_results = {} - - parameters = {"tree_method": "gpu_hist", "predictor": "gpu_predictor"} - - m = xgb.DMatrix(onehot, label, enable_categorical=False) - xgb.train( - parameters, - m, - num_boost_round=rounds, - evals=[(m, "Train")], - evals_result=by_etl_results, - ) - - m = xgb.DMatrix(cat, label, enable_categorical=True) - xgb.train( - parameters, - m, - num_boost_round=rounds, - evals=[(m, "Train")], - evals_result=by_builtin_results, - ) - - # There are guidelines on how to specify tolerance based on considering output as - # random variables. But in here the tree construction is extremely sensitive to - # floating point errors. An 1e-5 error in a histogram bin can lead to an entirely - # different tree. So even though the test is quite lenient, hypothesis can still - # pick up falsifying examples from time to time. - np.testing.assert_allclose( - np.array(by_etl_results["Train"]["rmse"]), - np.array(by_builtin_results["Train"]["rmse"]), - rtol=1e-3, - ) - assert tm.non_increasing(by_builtin_results["Train"]["rmse"]) - @given(strategies.integers(10, 400), strategies.integers(3, 8), strategies.integers(1, 2), strategies.integers(4, 7)) @settings(deadline=None) @pytest.mark.skipif(**tm.no_pandas()) def test_categorical(self, rows, cols, rounds, cats): - self.run_categorical_basic(rows, cols, rounds, cats) + self.cputest.run_categorical_basic(rows, cols, rounds, cats, "gpu_hist") def test_categorical_32_cat(self): '''32 hits the bound of integer bitset, so special test''' @@ -93,7 +58,7 @@ def test_categorical_32_cat(self): cols = 10 cats = 32 rounds = 4 - self.run_categorical_basic(rows, cols, rounds, cats) + self.cputest.run_categorical_basic(rows, cols, rounds, cats, "gpu_hist") def test_invalid_categorical(self): import cupy as cp diff --git a/tests/python/test_monotone_constraints.py b/tests/python/test_monotone_constraints.py index 21fba734888b..2c538bff989b 100644 --- a/tests/python/test_monotone_constraints.py +++ b/tests/python/test_monotone_constraints.py @@ -63,7 +63,6 @@ def is_correctly_constrained(learner, feature_names=None): class TestMonotoneConstraints: - def test_monotone_constraints_for_exact_tree_method(self): # first check monotonicity for the 'exact' tree method @@ -76,32 +75,23 @@ def test_monotone_constraints_for_exact_tree_method(self): ) assert is_correctly_constrained(constrained_exact_method) - def test_monotone_constraints_for_depthwise_hist_tree_method(self): - - # next check monotonicity for the 'hist' tree method - params_for_constrained_hist_method = { - 'tree_method': 'hist', 'verbosity': 1, - 'monotone_constraints': '(1, -1)' - } - constrained_hist_method = xgb.train( - params_for_constrained_hist_method, training_dset - ) - - assert is_correctly_constrained(constrained_hist_method) - - def test_monotone_constraints_for_lossguide_hist_tree_method(self): - - # next check monotonicity for the 'hist' tree method - params_for_constrained_hist_method = { - 'tree_method': 'hist', 'verbosity': 1, - 'grow_policy': 'lossguide', - 'monotone_constraints': '(1, -1)' + @pytest.mark.parametrize( + "tree_method,policy", + [ + ("hist", "depthwise"), + ("approx", "depthwise"), + ("hist", "lossguide"), + ("approx", "lossguide"), + ], + ) + def test_monotone_constraints(self, tree_method: str, policy: str) -> None: + params_for_constrained = { + "tree_method": tree_method, + "grow_policy": policy, + "monotone_constraints": "(1, -1)", } - constrained_hist_method = xgb.train( - params_for_constrained_hist_method, training_dset - ) - - assert is_correctly_constrained(constrained_hist_method) + constrained = xgb.train(params_for_constrained, training_dset) + assert is_correctly_constrained(constrained) @pytest.mark.parametrize('format', [dict, list]) def test_monotone_constraints_feature_names(self, format): diff --git a/tests/python/test_updaters.py b/tests/python/test_updaters.py index 07e6d44c6f2a..2af485676016 100644 --- a/tests/python/test_updaters.py +++ b/tests/python/test_updaters.py @@ -45,14 +45,20 @@ def test_exact(self, param, num_rounds, dataset): result = train_result(param, dataset.get_dmat(), num_rounds) assert tm.non_increasing(result['train'][dataset.metric]) - @given(exact_parameter_strategy, strategies.integers(1, 20), - tm.dataset_strategy) + @given( + exact_parameter_strategy, + hist_parameter_strategy, + strategies.integers(1, 20), + tm.dataset_strategy, + ) @settings(deadline=None) - def test_approx(self, param, num_rounds, dataset): - param['tree_method'] = 'approx' + def test_approx(self, param, hist_param, num_rounds, dataset): + param["tree_method"] = "approx" param = dataset.set_params(param) + param.update(hist_param) result = train_result(param, dataset.get_dmat(), num_rounds) - assert tm.non_increasing(result['train'][dataset.metric], 1e-3) + note(result) + assert tm.non_increasing(result["train"][dataset.metric]) @pytest.mark.skipif(**tm.no_sklearn()) def test_pruner(self): @@ -126,3 +132,53 @@ def test_hist_degenerate_case(self): y = [1000000., 0., 0., 500000.] w = [0, 0, 1, 0] model.fit(X, y, sample_weight=w) + + def run_categorical_basic(self, rows, cols, rounds, cats, tree_method): + onehot, label = tm.make_categorical(rows, cols, cats, True) + cat, _ = tm.make_categorical(rows, cols, cats, False) + + by_etl_results = {} + by_builtin_results = {} + + predictor = "gpu_predictor" if tree_method == "gpu_hist" else None + # Use one-hot exclusively + parameters = { + "tree_method": tree_method, "predictor": predictor, "max_cat_to_onehot": 9999 + } + + m = xgb.DMatrix(onehot, label, enable_categorical=False) + xgb.train( + parameters, + m, + num_boost_round=rounds, + evals=[(m, "Train")], + evals_result=by_etl_results, + ) + + m = xgb.DMatrix(cat, label, enable_categorical=True) + xgb.train( + parameters, + m, + num_boost_round=rounds, + evals=[(m, "Train")], + evals_result=by_builtin_results, + ) + + # There are guidelines on how to specify tolerance based on considering output as + # random variables. But in here the tree construction is extremely sensitive to + # floating point errors. An 1e-5 error in a histogram bin can lead to an entirely + # different tree. So even though the test is quite lenient, hypothesis can still + # pick up falsifying examples from time to time. + np.testing.assert_allclose( + np.array(by_etl_results["Train"]["rmse"]), + np.array(by_builtin_results["Train"]["rmse"]), + rtol=1e-3, + ) + assert tm.non_increasing(by_builtin_results["Train"]["rmse"]) + + @given(strategies.integers(10, 400), strategies.integers(3, 8), + strategies.integers(1, 2), strategies.integers(4, 7)) + @settings(deadline=None) + @pytest.mark.skipif(**tm.no_pandas()) + def test_categorical(self, rows, cols, rounds, cats): + self.run_categorical_basic(rows, cols, rounds, cats, "approx") diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index c03687b16db8..31e006244517 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -1184,9 +1184,13 @@ def runit( for arg in rabit_args: if arg.decode('utf-8').startswith('DMLC_TRACKER_PORT'): port_env = arg.decode('utf-8') + if arg.decode("utf-8").startswith("DMLC_TRACKER_URI"): + uri_env = arg.decode("utf-8") port = port_env.split('=') env = os.environ.copy() env[port[0]] = port[1] + uri = uri_env.split("=") + env["DMLC_TRACKER_URI"] = uri[1] return subprocess.run([str(exe), test], env=env, capture_output=True) with LocalCluster(n_workers=4) as cluster: @@ -1210,11 +1214,13 @@ def runit( @pytest.mark.gtest def test_quantile_basic(self) -> None: self.run_quantile('DistributedBasic') + self.run_quantile('SortedDistributedBasic') @pytest.mark.skipif(**tm.no_dask()) @pytest.mark.gtest def test_quantile(self) -> None: self.run_quantile('Distributed') + self.run_quantile('SortedDistributed') @pytest.mark.skipif(**tm.no_dask()) @pytest.mark.gtest @@ -1252,13 +1258,17 @@ def test_feature_weights(self, client: "Client") -> None: for i in range(kCols): fw[i] *= float(i) fw = da.from_array(fw) - poly_increasing = run_feature_weights(X, y, fw, model=xgb.dask.DaskXGBRegressor) + poly_increasing = run_feature_weights( + X, y, fw, "approx", model=xgb.dask.DaskXGBRegressor + ) fw = np.ones(shape=(kCols,)) for i in range(kCols): fw[i] *= float(kCols - i) fw = da.from_array(fw) - poly_decreasing = run_feature_weights(X, y, fw, model=xgb.dask.DaskXGBRegressor) + poly_decreasing = run_feature_weights( + X, y, fw, "approx", model=xgb.dask.DaskXGBRegressor + ) # Approxmated test, this is dependent on the implementation of random # number generator in std library. diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index aef4657ea1f7..149e77ed72df 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -1031,10 +1031,10 @@ def test_pandas_input(): np.array([0, 1])) -def run_feature_weights(X, y, fw, model=xgb.XGBRegressor): +def run_feature_weights(X, y, fw, tree_method, model=xgb.XGBRegressor): with tempfile.TemporaryDirectory() as tmpdir: colsample_bynode = 0.5 - reg = model(tree_method='hist', colsample_bynode=colsample_bynode) + reg = model(tree_method=tree_method, colsample_bynode=colsample_bynode) reg.fit(X, y, feature_weights=fw) model_path = os.path.join(tmpdir, 'model.json') @@ -1069,7 +1069,8 @@ def run_feature_weights(X, y, fw, model=xgb.XGBRegressor): return w -def test_feature_weights(): +@pytest.mark.parametrize("tree_method", ["approx", "hist"]) +def test_feature_weights(tree_method): kRows = 512 kCols = 64 X = rng.randn(kRows, kCols) @@ -1078,12 +1079,12 @@ def test_feature_weights(): fw = np.ones(shape=(kCols,)) for i in range(kCols): fw[i] *= float(i) - poly_increasing = run_feature_weights(X, y, fw, xgb.XGBRegressor) + poly_increasing = run_feature_weights(X, y, fw, tree_method, xgb.XGBRegressor) fw = np.ones(shape=(kCols,)) for i in range(kCols): fw[i] *= float(kCols - i) - poly_decreasing = run_feature_weights(X, y, fw, xgb.XGBRegressor) + poly_decreasing = run_feature_weights(X, y, fw, tree_method, xgb.XGBRegressor) # Approxmated test, this is dependent on the implementation of random # number generator in std library.