From 912e341d575f107be1cc2631271fd0737b75dfba Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 31 Jul 2023 15:50:28 +0800 Subject: [PATCH] Initial GPU support for the approx tree method. (#9414) --- doc/parameter.rst | 3 +- doc/treemethod.rst | 34 +-- python-package/xgboost/testing/updater.py | 140 ++++++++++- src/common/error_msg.h | 5 + src/common/ranking_utils.h | 6 +- src/data/ellpack_page.cu | 7 +- src/data/gradient_index.cc | 5 +- src/data/simple_dmatrix.cc | 6 +- src/data/sparse_page_dmatrix.cc | 1 - src/data/sparse_page_dmatrix.cu | 23 +- src/gbm/gbtree.cc | 7 +- src/tree/constraints.h | 8 +- src/tree/gpu_hist/gradient_based_sampler.cu | 28 +-- src/tree/updater_approx.cc | 1 - src/tree/updater_gpu_hist.cu | 231 +++++++++++++----- tests/cpp/tree/test_gpu_hist.cu | 8 +- tests/cpp/tree/test_prediction_cache.cc | 55 +++-- tests/cpp/tree/test_regen.cc | 59 ++++- tests/cpp/tree/test_tree_policy.cc | 46 ++-- tests/cpp/tree/test_tree_stat.cc | 32 ++- tests/python-gpu/test_gpu_updaters.py | 130 +++++++--- tests/python/test_updaters.py | 140 +---------- .../test_gpu_with_dask/test_gpu_with_dask.py | 24 +- 23 files changed, 639 insertions(+), 360 deletions(-) diff --git a/doc/parameter.rst b/doc/parameter.rst index 2072c4b7551f..6f767c80def2 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -162,7 +162,8 @@ Parameters for Tree Booster - ``grow_colmaker``: non-distributed column-based construction of trees. - ``grow_histmaker``: distributed tree construction with row-based data splitting based on global proposal of histogram counting. - ``grow_quantile_histmaker``: Grow tree using quantized histogram. - - ``grow_gpu_hist``: Grow tree with GPU. Enabled when ``tree_method`` is set to ``hist`` along with ``device=cuda``. + - ``grow_gpu_hist``: Enabled when ``tree_method`` is set to ``hist`` along with ``device=cuda``. + - ``grow_gpu_approx``: Enabled when ``tree_method`` is set to ``approx`` along with ``device=cuda``. - ``sync``: synchronizes trees in all distributed nodes. - ``refresh``: refreshes tree's statistics and/or leaf values based on the current data. Note that no random subsampling of data rows is performed. - ``prune``: prunes the splits where loss < min_split_loss (or gamma) and nodes that have depth greater than ``max_depth``. diff --git a/doc/treemethod.rst b/doc/treemethod.rst index 4dfb107a0c30..1f83401fe2a9 100644 --- a/doc/treemethod.rst +++ b/doc/treemethod.rst @@ -123,23 +123,23 @@ Feature Matrix Following table summarizes some differences in supported features between 4 tree methods, `T` means supported while `F` means unsupported. -+------------------+-----------+---------------------+---------------------+------------------------+ -| | Exact | Approx | Hist | Hist (GPU) | -+==================+===========+=====================+=====================+========================+ -| grow_policy | Depthwise | depthwise/lossguide | depthwise/lossguide | depthwise/lossguide | -+------------------+-----------+---------------------+---------------------+------------------------+ -| max_leaves | F | T | T | T | -+------------------+-----------+---------------------+---------------------+------------------------+ -| sampling method | uniform | uniform | uniform | gradient_based/uniform | -+------------------+-----------+---------------------+---------------------+------------------------+ -| categorical data | F | T | T | T | -+------------------+-----------+---------------------+---------------------+------------------------+ -| External memory | F | T | T | P | -+------------------+-----------+---------------------+---------------------+------------------------+ -| Distributed | F | T | T | T | -+------------------+-----------+---------------------+---------------------+------------------------+ - -Features/parameters that are not mentioned here are universally supported for all 4 tree ++------------------+-----------+---------------------+------------------------+---------------------+------------------------+ +| | Exact | Approx | Approx (GPU) | Hist | Hist (GPU) | ++==================+===========+=====================+========================+=====================+========================+ +| grow_policy | Depthwise | depthwise/lossguide | depthwise/lossguide | depthwise/lossguide | depthwise/lossguide | ++------------------+-----------+---------------------+------------------------+---------------------+------------------------+ +| max_leaves | F | T | T | T | T | ++------------------+-----------+---------------------+------------------------+---------------------+------------------------+ +| sampling method | uniform | uniform | gradient_based/uniform | uniform | gradient_based/uniform | ++------------------+-----------+---------------------+------------------------+---------------------+------------------------+ +| categorical data | F | T | T | T | T | ++------------------+-----------+---------------------+------------------------+---------------------+------------------------+ +| External memory | F | T | P | T | P | ++------------------+-----------+---------------------+------------------------+---------------------+------------------------+ +| Distributed | F | T | T | T | T | ++------------------+-----------+---------------------+------------------------+---------------------+------------------------+ + +Features/parameters that are not mentioned here are universally supported for all 3 tree methods (for instance, column sampling and constraints). The `P` in external memory means special handling. Please note that both categorical data and external memory are experimental. diff --git a/python-package/xgboost/testing/updater.py b/python-package/xgboost/testing/updater.py index 62df8ec2ea23..af5acf428758 100644 --- a/python-package/xgboost/testing/updater.py +++ b/python-package/xgboost/testing/updater.py @@ -1,7 +1,7 @@ """Tests for updaters.""" import json from functools import partial, update_wrapper -from typing import Any, Dict +from typing import Any, Dict, List import numpy as np @@ -256,3 +256,141 @@ def check_get_quantile_cut(tree_method: str) -> None: check_get_quantile_cut_device(tree_method, False) if use_cupy: check_get_quantile_cut_device(tree_method, True) + + +USE_ONEHOT = np.iinfo(np.int32).max +USE_PART = 1 + + +def check_categorical_ohe( # pylint: disable=too-many-arguments + rows: int, cols: int, rounds: int, cats: int, device: str, tree_method: str +) -> None: + "Test for one-hot encoding with categorical data." + + onehot, label = tm.make_categorical(rows, cols, cats, True) + cat, _ = tm.make_categorical(rows, cols, cats, False) + + by_etl_results: Dict[str, Dict[str, List[float]]] = {} + by_builtin_results: Dict[str, Dict[str, List[float]]] = {} + + parameters: Dict[str, Any] = { + "tree_method": tree_method, + # Use one-hot exclusively + "max_cat_to_onehot": USE_ONEHOT, + "device": device, + } + + m = xgb.DMatrix(onehot, label, enable_categorical=False) + xgb.train( + parameters, + m, + num_boost_round=rounds, + evals=[(m, "Train")], + evals_result=by_etl_results, + ) + + m = xgb.DMatrix(cat, label, enable_categorical=True) + xgb.train( + parameters, + m, + num_boost_round=rounds, + evals=[(m, "Train")], + evals_result=by_builtin_results, + ) + + # There are guidelines on how to specify tolerance based on considering output + # as random variables. But in here the tree construction is extremely sensitive + # to floating point errors. An 1e-5 error in a histogram bin can lead to an + # entirely different tree. So even though the test is quite lenient, hypothesis + # can still pick up falsifying examples from time to time. + np.testing.assert_allclose( + np.array(by_etl_results["Train"]["rmse"]), + np.array(by_builtin_results["Train"]["rmse"]), + rtol=1e-3, + ) + assert tm.non_increasing(by_builtin_results["Train"]["rmse"]) + + by_grouping: Dict[str, Dict[str, List[float]]] = {} + # switch to partition-based splits + parameters["max_cat_to_onehot"] = USE_PART + parameters["reg_lambda"] = 0 + m = xgb.DMatrix(cat, label, enable_categorical=True) + xgb.train( + parameters, + m, + num_boost_round=rounds, + evals=[(m, "Train")], + evals_result=by_grouping, + ) + rmse_oh = by_builtin_results["Train"]["rmse"] + rmse_group = by_grouping["Train"]["rmse"] + # always better or equal to onehot when there's no regularization. + for a, b in zip(rmse_oh, rmse_group): + assert a >= b + + parameters["reg_lambda"] = 1.0 + by_grouping = {} + xgb.train( + parameters, + m, + num_boost_round=32, + evals=[(m, "Train")], + evals_result=by_grouping, + ) + assert tm.non_increasing(by_grouping["Train"]["rmse"]), by_grouping + + +def check_categorical_missing( + rows: int, cols: int, cats: int, device: str, tree_method: str +) -> None: + """Check categorical data with missing values.""" + parameters: Dict[str, Any] = {"tree_method": tree_method, "device": device} + cat, label = tm.make_categorical( + rows, n_features=cols, n_categories=cats, onehot=False, sparsity=0.5 + ) + Xy = xgb.DMatrix(cat, label, enable_categorical=True) + + def run(max_cat_to_onehot: int) -> None: + # Test with onehot splits + parameters["max_cat_to_onehot"] = max_cat_to_onehot + + evals_result: Dict[str, Dict] = {} + booster = xgb.train( + parameters, + Xy, + num_boost_round=16, + evals=[(Xy, "Train")], + evals_result=evals_result, + ) + assert tm.non_increasing(evals_result["Train"]["rmse"]) + y_predt = booster.predict(Xy) + + rmse = tm.root_mean_square(label, y_predt) + np.testing.assert_allclose(rmse, evals_result["Train"]["rmse"][-1], rtol=2e-5) + + # Test with OHE split + run(USE_ONEHOT) + + # Test with partition-based split + run(USE_PART) + + +def train_result( + param: Dict[str, Any], dmat: xgb.DMatrix, num_rounds: int +) -> Dict[str, Any]: + """Get training result from parameters and data.""" + result: Dict[str, Any] = {} + booster = xgb.train( + param, + dmat, + num_rounds, + evals=[(dmat, "train")], + verbose_eval=False, + evals_result=result, + ) + assert booster.num_features() == dmat.num_col() + assert booster.num_boosted_rounds() == num_rounds + assert booster.feature_names == dmat.feature_names + assert booster.feature_types == dmat.feature_types + + return result diff --git a/src/common/error_msg.h b/src/common/error_msg.h index 8bdc8599921f..1af4b7c88063 100644 --- a/src/common/error_msg.h +++ b/src/common/error_msg.h @@ -89,5 +89,10 @@ void WarnDeprecatedGPUId(); void WarnEmptyDataset(); std::string DeprecatedFunc(StringView old, StringView since, StringView replacement); + +constexpr StringView InvalidCUDAOrdinal() { + return "Invalid device. `device` is required to be CUDA and there must be at least one GPU " + "available for using GPU."; +} } // namespace xgboost::error #endif // XGBOOST_COMMON_ERROR_MSG_H_ diff --git a/src/common/ranking_utils.h b/src/common/ranking_utils.h index 7d11de048d3d..75622bd84b60 100644 --- a/src/common/ranking_utils.h +++ b/src/common/ranking_utils.h @@ -12,7 +12,7 @@ #include // for vector #include "dmlc/parameter.h" // for FieldEntry, DMLC_DECLARE_FIELD -#include "error_msg.h" // for GroupWeight, GroupSize +#include "error_msg.h" // for GroupWeight, GroupSize, InvalidCUDAOrdinal #include "xgboost/base.h" // for XGBOOST_DEVICE, bst_group_t #include "xgboost/context.h" // for Context #include "xgboost/data.h" // for MetaInfo @@ -240,7 +240,7 @@ class RankingCache { // The function simply returns a uninitialized buffer as this is only used by the // objective for creating pairs. common::Span SortedIdxY(Context const* ctx, std::size_t n_samples) { - CHECK(ctx->IsCUDA()); + CHECK(ctx->IsCUDA()) << error::InvalidCUDAOrdinal(); if (y_sorted_idx_cache_.Empty()) { y_sorted_idx_cache_.SetDevice(ctx->gpu_id); y_sorted_idx_cache_.Resize(n_samples); @@ -248,7 +248,7 @@ class RankingCache { return y_sorted_idx_cache_.DeviceSpan(); } common::Span RankedY(Context const* ctx, std::size_t n_samples) { - CHECK(ctx->IsCUDA()); + CHECK(ctx->IsCUDA()) << error::InvalidCUDAOrdinal(); if (y_ranked_by_model_.Empty()) { y_ranked_by_model_.SetDevice(ctx->gpu_id); y_ranked_by_model_.Resize(n_samples); diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index 7097df405f54..3690213765f0 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -11,7 +11,6 @@ #include "../common/categorical.h" #include "../common/cuda_context.cuh" #include "../common/hist_util.cuh" -#include "../common/random.h" #include "../common/transform_iterator.h" // MakeIndexTransformIter #include "./ellpack_page.cuh" #include "device_adapter.cuh" // for HasInfInData @@ -131,7 +130,11 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx, DMatrix* dmat, const BatchP monitor_.Start("Quantiles"); // Create the quantile sketches for the dmatrix and initialize HistogramCuts. row_stride = GetRowStride(dmat); - cuts_ = common::DeviceSketch(ctx, dmat, param.max_bin); + if (!param.hess.empty()) { + cuts_ = common::DeviceSketchWithHessian(ctx, dmat, param.max_bin, param.hess); + } else { + cuts_ = common::DeviceSketch(ctx, dmat, param.max_bin); + } monitor_.Stop("Quantiles"); monitor_.Start("InitCompressedData"); diff --git a/src/data/gradient_index.cc b/src/data/gradient_index.cc index 1ee1bd60ba09..a2b3f3e54fa0 100644 --- a/src/data/gradient_index.cc +++ b/src/data/gradient_index.cc @@ -7,13 +7,12 @@ #include #include #include -#include // std::forward +#include // for forward #include "../common/column_matrix.h" #include "../common/hist_util.h" #include "../common/numeric.h" -#include "../common/threading_utils.h" -#include "../common/transform_iterator.h" // MakeIndexTransformIter +#include "../common/transform_iterator.h" // for MakeIndexTransformIter namespace xgboost { diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index 5a2f6f8df16c..85ede3258fae 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -8,12 +8,12 @@ #include #include +#include // for accumulate #include #include -#include "../common/error_msg.h" // for InconsistentMaxBin -#include "../common/random.h" -#include "../common/threading_utils.h" +#include "../collective/communicator-inl.h" // for GetWorldSize, GetRank, Allgather +#include "../common/error_msg.h" // for InconsistentMaxBin #include "./simple_batch_iterator.h" #include "adapter.h" #include "batch_utils.h" // for CheckEmpty, RegenGHist diff --git a/src/data/sparse_page_dmatrix.cc b/src/data/sparse_page_dmatrix.cc index ec9c90b1041a..042a75c56a85 100644 --- a/src/data/sparse_page_dmatrix.cc +++ b/src/data/sparse_page_dmatrix.cc @@ -8,7 +8,6 @@ #include "./sparse_page_dmatrix.h" #include "../collective/communicator-inl.h" -#include "./simple_batch_iterator.h" #include "batch_utils.h" // for RegenGHist #include "gradient_index.h" diff --git a/src/data/sparse_page_dmatrix.cu b/src/data/sparse_page_dmatrix.cu index 1d9af9f06d25..9d4c633871df 100644 --- a/src/data/sparse_page_dmatrix.cu +++ b/src/data/sparse_page_dmatrix.cu @@ -1,13 +1,15 @@ /** * Copyright 2021-2023 by XGBoost contributors */ -#include +#include // for unique_ptr #include "../common/hist_util.cuh" -#include "batch_utils.h" // for CheckEmpty, RegenGHist +#include "../common/hist_util.h" // for HistogramCuts +#include "batch_utils.h" // for CheckEmpty, RegenGHist #include "ellpack_page.cuh" #include "sparse_page_dmatrix.h" -#include "sparse_page_source.h" +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for BatchParam namespace xgboost::data { BatchSet SparsePageDMatrix::GetEllpackBatches(Context const* ctx, @@ -25,8 +27,13 @@ BatchSet SparsePageDMatrix::GetEllpackBatches(Context const* ctx, cache_info_.erase(id); MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_); std::unique_ptr cuts; - cuts = - std::make_unique(common::DeviceSketch(ctx, this, param.max_bin, 0)); + if (!param.hess.empty()) { + cuts = std::make_unique( + common::DeviceSketchWithHessian(ctx, this, param.max_bin, param.hess)); + } else { + cuts = + std::make_unique(common::DeviceSketch(ctx, this, param.max_bin)); + } this->InitializeSparsePage(ctx); // reset after use. row_stride = GetRowStride(this); @@ -35,10 +42,10 @@ BatchSet SparsePageDMatrix::GetEllpackBatches(Context const* ctx, batch_param_ = param; auto ft = this->info_.feature_types.ConstDeviceSpan(); - ellpack_page_source_.reset(); // release resources. - ellpack_page_source_.reset(new EllpackPageSource( + ellpack_page_source_.reset(); // make sure resource is released before making new ones. + ellpack_page_source_ = std::make_shared( this->missing_, ctx->Threads(), this->Info().num_col_, this->n_batches_, cache_info_.at(id), - param, std::move(cuts), this->IsDense(), row_stride, ft, sparse_page_source_, ctx->gpu_id)); + param, std::move(cuts), this->IsDense(), row_stride, ft, sparse_page_source_, ctx->gpu_id); } else { CHECK(sparse_page_source_); ellpack_page_source_->Reset(); diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 8b456af66931..e3df3862915c 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -47,15 +47,16 @@ std::string MapTreeMethodToUpdaters(Context const* ctx, TreeMethod tree_method) if (ctx->IsCUDA()) { common::AssertGPUSupport(); } + switch (tree_method) { case TreeMethod::kAuto: // Use hist as default in 2.0 case TreeMethod::kHist: { return ctx->DispatchDevice([] { return "grow_quantile_histmaker"; }, [] { return "grow_gpu_hist"; }); } - case TreeMethod::kApprox: - CHECK(ctx->IsCPU()) << "The `approx` tree method is not supported on GPU."; - return "grow_histmaker"; + case TreeMethod::kApprox: { + return ctx->DispatchDevice([] { return "grow_histmaker"; }, [] { return "grow_gpu_approx"; }); + } case TreeMethod::kExact: CHECK(ctx->IsCPU()) << "The `exact` tree method is not supported on GPU."; return "grow_colmaker,prune"; diff --git a/src/tree/constraints.h b/src/tree/constraints.h index 580576a5889d..3789d2a24ddb 100644 --- a/src/tree/constraints.h +++ b/src/tree/constraints.h @@ -1,5 +1,5 @@ -/*! - * Copyright 2018-2019 by Contributors +/** + * Copyright 2018-2023 by Contributors */ #ifndef XGBOOST_TREE_CONSTRAINTS_H_ #define XGBOOST_TREE_CONSTRAINTS_H_ @@ -8,10 +8,8 @@ #include #include -#include "xgboost/span.h" -#include "xgboost/base.h" - #include "param.h" +#include "xgboost/base.h" namespace xgboost { /*! diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index 5f763fb933bf..1082f89550eb 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -8,10 +8,10 @@ #include #include +#include // for size_t #include #include -#include "../../common/compressed_iterator.h" #include "../../common/cuda_context.cuh" // for CUDAContext #include "../../common/random.h" #include "../param.h" @@ -202,27 +202,27 @@ ExternalMemoryUniformSampling::ExternalMemoryUniformSampling(size_t n_rows, GradientBasedSample ExternalMemoryUniformSampling::Sample(Context const* ctx, common::Span gpair, DMatrix* dmat) { + auto cuctx = ctx->CUDACtx(); // Set gradient pair to 0 with p = 1 - subsample - thrust::replace_if(dh::tbegin(gpair), dh::tend(gpair), - thrust::counting_iterator(0), - BernoulliTrial(common::GlobalRandom()(), subsample_), - GradientPair()); + thrust::replace_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), + thrust::counting_iterator(0), + BernoulliTrial(common::GlobalRandom()(), subsample_), GradientPair{}); // Count the sampled rows. - size_t sample_rows = thrust::count_if(dh::tbegin(gpair), dh::tend(gpair), IsNonZero()); + size_t sample_rows = + thrust::count_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), IsNonZero{}); // Compact gradient pairs. gpair_.resize(sample_rows); - thrust::copy_if(dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero()); + thrust::copy_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero{}); // Index the sample rows. - thrust::transform(dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(), IsNonZero()); - thrust::exclusive_scan(sample_row_index_.begin(), sample_row_index_.end(), + thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(), + IsNonZero()); + thrust::exclusive_scan(cuctx->CTP(), sample_row_index_.begin(), sample_row_index_.end(), sample_row_index_.begin()); - thrust::transform(dh::tbegin(gpair), dh::tend(gpair), - sample_row_index_.begin(), - sample_row_index_.begin(), - ClearEmptyRows()); + thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(), + sample_row_index_.begin(), ClearEmptyRows()); auto batch_iterator = dmat->GetBatches(ctx, batch_param_); auto first_page = (*batch_iterator.begin()).Impl(); @@ -232,7 +232,7 @@ GradientBasedSample ExternalMemoryUniformSampling::Sample(Context const* ctx, first_page->row_stride, sample_rows)); // Compact the ELLPACK pages into the single sample page. - thrust::fill(dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0); + thrust::fill(cuctx->CTP(), dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0); for (auto& batch : batch_iterator) { page_->Compact(ctx->gpu_id, batch.Impl(), dh::ToSpan(sample_row_index_)); } diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc index 78506305faa0..7b50206212e1 100644 --- a/src/tree/updater_approx.cc +++ b/src/tree/updater_approx.cc @@ -11,7 +11,6 @@ #include "../common/random.h" #include "../data/gradient_index.h" #include "common_row_partitioner.h" -#include "constraints.h" #include "driver.h" #include "hist/evaluate_splits.h" #include "hist/histogram.h" diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index e2a863e3d2d2..56d7d2a89550 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -31,7 +31,6 @@ #include "gpu_hist/histogram.cuh" #include "gpu_hist/row_partitioner.cuh" #include "param.h" -#include "split_evaluator.h" #include "updater_gpu_common.cuh" #include "xgboost/base.h" #include "xgboost/context.h" @@ -49,13 +48,30 @@ DMLC_REGISTRY_FILE_TAG(updater_gpu_hist); #endif // !defined(GTEST_TEST) // training parameters specific to this algorithm -struct GPUHistMakerTrainParam - : public XGBoostParameter { +struct GPUHistMakerTrainParam : public XGBoostParameter { bool debug_synchronize; // declare parameters DMLC_DECLARE_PARAMETER(GPUHistMakerTrainParam) { - DMLC_DECLARE_FIELD(debug_synchronize).set_default(false).describe( - "Check if all distributed tree are identical after tree construction."); + DMLC_DECLARE_FIELD(debug_synchronize) + .set_default(false) + .describe("Check if all distributed tree are identical after tree construction."); + } + + // Only call this method for testing + void CheckTreesSynchronized(RegTree const* local_tree) const { + if (this->debug_synchronize) { + std::string s_model; + common::MemoryBufferStream fs(&s_model); + int rank = collective::GetRank(); + if (rank == 0) { + local_tree->Save(&fs); + } + fs.Seek(0); + collective::Broadcast(&s_model, 0); + RegTree reference_tree{}; // rank 0 tree + reference_tree.Load(&fs); + CHECK(*local_tree == reference_tree); + } } }; #if !defined(GTEST_TEST) @@ -170,16 +186,15 @@ class DeviceHistogramStorage { }; // Manage memory for a single GPU -template struct GPUHistMakerDevice { private: GPUHistEvaluator evaluator_; Context const* ctx_; + std::shared_ptr column_sampler_; public: EllpackPageImpl const* page{nullptr}; common::Span feature_types; - BatchParam batch_param; std::unique_ptr row_partitioner; DeviceHistogramStorage<> hist{}; @@ -199,7 +214,6 @@ struct GPUHistMakerDevice { dh::PinnedMemory pinned2; common::Monitor monitor; - common::ColumnSampler column_sampler; FeatureInteractionConstraintDevice interaction_constraints; std::unique_ptr sampler; @@ -208,22 +222,22 @@ struct GPUHistMakerDevice { GPUHistMakerDevice(Context const* ctx, bool is_external_memory, common::Span _feature_types, bst_row_t _n_rows, - TrainParam _param, uint32_t column_sampler_seed, uint32_t n_features, - BatchParam _batch_param) + TrainParam _param, std::shared_ptr column_sampler, + uint32_t n_features, BatchParam batch_param) : evaluator_{_param, n_features, ctx->gpu_id}, ctx_(ctx), feature_types{_feature_types}, param(std::move(_param)), - column_sampler(column_sampler_seed), - interaction_constraints(param, n_features), - batch_param(std::move(_batch_param)) { - sampler.reset(new GradientBasedSampler(ctx, _n_rows, batch_param, param.subsample, - param.sampling_method, is_external_memory)); + column_sampler_(std::move(column_sampler)), + interaction_constraints(param, n_features) { + sampler = std::make_unique(ctx, _n_rows, batch_param, param.subsample, + param.sampling_method, is_external_memory); if (!param.monotone_constraints.empty()) { // Copy assigning an empty vector causes an exception in MSVC debug builds monotone_constraints = param.monotone_constraints; } + CHECK(column_sampler_); monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(ctx_->gpu_id)); } @@ -234,16 +248,16 @@ struct GPUHistMakerDevice { CHECK(page); feature_groups.reset(new FeatureGroups(page->Cuts(), page->is_dense, dh::MaxSharedMemoryOptin(ctx_->gpu_id), - sizeof(GradientSumT))); + sizeof(GradientPairPrecise))); } } // Reset values for each update iteration void Reset(HostDeviceVector* dh_gpair, DMatrix* dmat, int64_t num_columns) { auto const& info = dmat->Info(); - this->column_sampler.Init(ctx_, num_columns, info.feature_weights.HostVector(), - param.colsample_bynode, param.colsample_bylevel, - param.colsample_bytree); + this->column_sampler_->Init(ctx_, num_columns, info.feature_weights.HostVector(), + param.colsample_bynode, param.colsample_bylevel, + param.colsample_bytree); dh::safe_cuda(cudaSetDevice(ctx_->gpu_id)); this->interaction_constraints.Reset(); @@ -275,8 +289,8 @@ struct GPUHistMakerDevice { GPUExpandEntry EvaluateRootSplit(GradientPairInt64 root_sum) { int nidx = RegTree::kRoot; GPUTrainingParam gpu_param(param); - auto sampled_features = column_sampler.GetFeatureSet(0); - sampled_features->SetDevice(ctx_->gpu_id); + auto sampled_features = column_sampler_->GetFeatureSet(0); + sampled_features->SetDevice(ctx_->Device()); common::Span feature_set = interaction_constraints.Query(sampled_features->DeviceSpan(), nidx); auto matrix = page->GetDeviceAccessor(ctx_->gpu_id); @@ -316,13 +330,13 @@ struct GPUHistMakerDevice { int right_nidx = tree[candidate.nid].RightChild(); nidx[i * 2] = left_nidx; nidx[i * 2 + 1] = right_nidx; - auto left_sampled_features = column_sampler.GetFeatureSet(tree.GetDepth(left_nidx)); - left_sampled_features->SetDevice(ctx_->gpu_id); + auto left_sampled_features = column_sampler_->GetFeatureSet(tree.GetDepth(left_nidx)); + left_sampled_features->SetDevice(ctx_->Device()); feature_sets.emplace_back(left_sampled_features); common::Span left_feature_set = interaction_constraints.Query(left_sampled_features->DeviceSpan(), left_nidx); - auto right_sampled_features = column_sampler.GetFeatureSet(tree.GetDepth(right_nidx)); - right_sampled_features->SetDevice(ctx_->gpu_id); + auto right_sampled_features = column_sampler_->GetFeatureSet(tree.GetDepth(right_nidx)); + right_sampled_features->SetDevice(ctx_->Device()); feature_sets.emplace_back(right_sampled_features); common::Span right_feature_set = interaction_constraints.Query(right_sampled_features->DeviceSpan(), @@ -657,7 +671,6 @@ struct GPUHistMakerDevice { evaluator_.ApplyTreeSplit(candidate, p_tree); const auto& parent = tree[candidate.nid]; - std::size_t max_nidx = std::max(parent.LeftChild(), parent.RightChild()); interaction_constraints.Split(candidate.nid, parent.SplitIndex(), parent.LeftChild(), parent.RightChild()); } @@ -693,9 +706,8 @@ struct GPUHistMakerDevice { return root_entry; } - void UpdateTree(HostDeviceVector* gpair_all, DMatrix* p_fmat, - ObjInfo const* task, RegTree* p_tree, - HostDeviceVector* p_out_position) { + void UpdateTree(HostDeviceVector* gpair_all, DMatrix* p_fmat, ObjInfo const* task, + RegTree* p_tree, HostDeviceVector* p_out_position) { auto& tree = *p_tree; // Process maximum 32 nodes at a time Driver driver(param, 32); @@ -720,7 +732,6 @@ struct GPUHistMakerDevice { std::copy_if(expand_set.begin(), expand_set.end(), std::back_inserter(filtered_expand_set), [&](const auto& e) { return driver.IsChildValid(e); }); - auto new_candidates = pinned.GetSpan(filtered_expand_set.size() * 2, GPUExpandEntry()); @@ -753,8 +764,7 @@ class GPUHistMaker : public TreeUpdater { using GradientSumT = GradientPairPrecise; public: - explicit GPUHistMaker(Context const* ctx, ObjInfo const* task) - : TreeUpdater(ctx), task_{task} {}; + explicit GPUHistMaker(Context const* ctx, ObjInfo const* task) : TreeUpdater(ctx), task_{task} {}; void Configure(const Args& args) override { // Used in test to count how many configurations are performed LOG(DEBUG) << "[GPU Hist]: Configure"; @@ -786,13 +796,10 @@ class GPUHistMaker : public TreeUpdater { // build tree try { - size_t t_idx{0}; + std::size_t t_idx{0}; for (xgboost::RegTree* tree : trees) { this->UpdateTree(param, gpair, dmat, tree, &out_position[t_idx]); - - if (hist_maker_param_.debug_synchronize) { - this->CheckTreesSynchronized(tree); - } + this->hist_maker_param_.CheckTreesSynchronized(tree); ++t_idx; } dh::safe_cuda(cudaGetLastError()); @@ -809,13 +816,14 @@ class GPUHistMaker : public TreeUpdater { // Synchronise the column sampling seed uint32_t column_sampling_seed = common::GlobalRandom()(); collective::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0); + this->column_sampler_ = std::make_shared(column_sampling_seed); auto batch_param = BatchParam{param->max_bin, TrainParam::DftSparseThreshold()}; dh::safe_cuda(cudaSetDevice(ctx_->gpu_id)); info_->feature_types.SetDevice(ctx_->gpu_id); - maker.reset(new GPUHistMakerDevice( + maker = std::make_unique( ctx_, !dmat->SingleColBlock(), info_->feature_types.ConstDeviceSpan(), info_->num_row_, - *param, column_sampling_seed, info_->num_col_, batch_param)); + *param, column_sampler_, info_->num_col_, batch_param); p_last_fmat_ = dmat; initialised_ = true; @@ -830,21 +838,6 @@ class GPUHistMaker : public TreeUpdater { p_last_tree_ = p_tree; } - // Only call this method for testing - void CheckTreesSynchronized(RegTree* local_tree) const { - std::string s_model; - common::MemoryBufferStream fs(&s_model); - int rank = collective::GetRank(); - if (rank == 0) { - local_tree->Save(&fs); - } - fs.Seek(0); - collective::Broadcast(&s_model, 0); - RegTree reference_tree{}; // rank 0 tree - reference_tree.Load(&fs); - CHECK(*local_tree == reference_tree); - } - void UpdateTree(TrainParam const* param, HostDeviceVector* gpair, DMatrix* p_fmat, RegTree* p_tree, HostDeviceVector* p_out_position) { monitor_.Start("InitData"); @@ -868,7 +861,7 @@ class GPUHistMaker : public TreeUpdater { MetaInfo* info_{}; // NOLINT - std::unique_ptr> maker; // NOLINT + std::unique_ptr maker; // NOLINT [[nodiscard]] char const* Name() const override { return "grow_gpu_hist"; } [[nodiscard]] bool HasNodePosition() const override { return true; } @@ -883,6 +876,7 @@ class GPUHistMaker : public TreeUpdater { ObjInfo const* task_{nullptr}; common::Monitor monitor_; + std::shared_ptr column_sampler_; }; #if !defined(GTEST_TEST) @@ -892,4 +886,131 @@ XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist") return new GPUHistMaker(ctx, task); }); #endif // !defined(GTEST_TEST) + +class GPUGlobalApproxMaker : public TreeUpdater { + public: + explicit GPUGlobalApproxMaker(Context const* ctx, ObjInfo const* task) + : TreeUpdater(ctx), task_{task} {}; + void Configure(Args const& args) override { + // Used in test to count how many configurations are performed + LOG(DEBUG) << "[GPU Approx]: Configure"; + hist_maker_param_.UpdateAllowUnknown(args); + dh::CheckComputeCapability(); + initialised_ = false; + + monitor_.Init(this->Name()); + } + + void LoadConfig(Json const& in) override { + auto const& config = get(in); + FromJson(config.at("approx_train_param"), &this->hist_maker_param_); + initialised_ = false; + } + void SaveConfig(Json* p_out) const override { + auto& out = *p_out; + out["approx_train_param"] = ToJson(hist_maker_param_); + } + ~GPUGlobalApproxMaker() override { dh::GlobalMemoryLogger().Log(); } + + void Update(TrainParam const* param, HostDeviceVector* gpair, DMatrix* p_fmat, + common::Span> out_position, + const std::vector& trees) override { + monitor_.Start("Update"); + + this->InitDataOnce(p_fmat); + // build tree + hess_.resize(gpair->Size()); + auto hess = dh::ToSpan(hess_); + + gpair->SetDevice(ctx_->Device()); + auto d_gpair = gpair->ConstDeviceSpan(); + auto cuctx = ctx_->CUDACtx(); + thrust::transform(cuctx->CTP(), dh::tcbegin(d_gpair), dh::tcend(d_gpair), dh::tbegin(hess), + [=] XGBOOST_DEVICE(GradientPair const& g) { return g.GetHess(); }); + + auto const& info = p_fmat->Info(); + info.feature_types.SetDevice(ctx_->Device()); + auto batch = BatchParam{param->max_bin, hess, !task_->const_hess}; + maker_ = std::make_unique( + ctx_, !p_fmat->SingleColBlock(), info.feature_types.ConstDeviceSpan(), info.num_row_, + *param, column_sampler_, info.num_col_, batch); + + std::size_t t_idx{0}; + for (xgboost::RegTree* tree : trees) { + this->UpdateTree(gpair, p_fmat, tree, &out_position[t_idx]); + this->hist_maker_param_.CheckTreesSynchronized(tree); + ++t_idx; + } + + monitor_.Stop("Update"); + } + + void InitDataOnce(DMatrix* p_fmat) { + if (this->initialised_) { + return; + } + + monitor_.Start(__func__); + CHECK(ctx_->IsCUDA()) << error::InvalidCUDAOrdinal(); + // Synchronise the column sampling seed + uint32_t column_sampling_seed = common::GlobalRandom()(); + collective::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0); + this->column_sampler_ = std::make_shared(column_sampling_seed); + + p_last_fmat_ = p_fmat; + initialised_ = true; + monitor_.Stop(__func__); + } + + void InitData(DMatrix* p_fmat, RegTree const* p_tree) { + this->InitDataOnce(p_fmat); + p_last_tree_ = p_tree; + } + + void UpdateTree(HostDeviceVector* gpair, DMatrix* p_fmat, RegTree* p_tree, + HostDeviceVector* p_out_position) { + monitor_.Start("InitData"); + this->InitData(p_fmat, p_tree); + monitor_.Stop("InitData"); + + gpair->SetDevice(ctx_->gpu_id); + maker_->UpdateTree(gpair, p_fmat, task_, p_tree, p_out_position); + } + + bool UpdatePredictionCache(const DMatrix* data, + linalg::MatrixView p_out_preds) override { + if (maker_ == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) { + return false; + } + monitor_.Start("UpdatePredictionCache"); + bool result = maker_->UpdatePredictionCache(p_out_preds, p_last_tree_); + monitor_.Stop("UpdatePredictionCache"); + return result; + } + + [[nodiscard]] char const* Name() const override { return "grow_gpu_approx"; } + [[nodiscard]] bool HasNodePosition() const override { return true; } + + private: + bool initialised_{false}; + + GPUHistMakerTrainParam hist_maker_param_; + dh::device_vector hess_; + std::shared_ptr column_sampler_; + std::unique_ptr maker_; + + DMatrix* p_last_fmat_{nullptr}; + RegTree const* p_last_tree_{nullptr}; + ObjInfo const* task_{nullptr}; + + common::Monitor monitor_; +}; + +#if !defined(GTEST_TEST) +XGBOOST_REGISTER_TREE_UPDATER(GPUApproxMaker, "grow_gpu_approx") + .describe("Grow tree with GPU.") + .set_body([](Context const* ctx, ObjInfo const* task) { + return new GPUGlobalApproxMaker(ctx, task); + }); +#endif // !defined(GTEST_TEST) } // namespace xgboost::tree diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index b250cd2ab3f3..2bd47d42c688 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -13,10 +13,7 @@ #include "../../../src/common/common.h" #include "../../../src/data/ellpack_page.cuh" // for EllpackPageImpl #include "../../../src/data/ellpack_page.h" // for EllpackPage -#include "../../../src/data/sparse_page_source.h" -#include "../../../src/tree/constraints.cuh" #include "../../../src/tree/param.h" // for TrainParam -#include "../../../src/tree/updater_gpu_common.cuh" #include "../../../src/tree/updater_gpu_hist.cu" #include "../filesystem.h" // dmlc::TemporaryDirectory #include "../helpers.h" @@ -94,8 +91,9 @@ void TestBuildHist(bool use_shared_memory_histograms) { auto page = BuildEllpackPage(kNRows, kNCols); BatchParam batch_param{}; Context ctx{MakeCUDACtx(0)}; - GPUHistMakerDevice maker(&ctx, /*is_external_memory=*/false, {}, kNRows, param, - kNCols, kNCols, batch_param); + auto cs = std::make_shared(0); + GPUHistMakerDevice maker(&ctx, /*is_external_memory=*/false, {}, kNRows, param, cs, kNCols, + batch_param); xgboost::SimpleLCG gen; xgboost::SimpleRealUniformDistribution dist(0.0f, 1.0f); HostDeviceVector gpair(kNRows); diff --git a/tests/cpp/tree/test_prediction_cache.cc b/tests/cpp/tree/test_prediction_cache.cc index e60d9cd8add7..333f1eccc9df 100644 --- a/tests/cpp/tree/test_prediction_cache.cc +++ b/tests/cpp/tree/test_prediction_cache.cc @@ -24,15 +24,11 @@ class TestPredictionCache : public ::testing::Test { Xy_ = RandomDataGenerator{n_samples_, n_features, 0}.Targets(n_targets).GenerateDMatrix(true); } - void RunLearnerTest(std::string updater_name, float subsample, std::string const& grow_policy, - std::string const& strategy) { + void RunLearnerTest(Context const* ctx, std::string updater_name, float subsample, + std::string const& grow_policy, std::string const& strategy) { std::unique_ptr learner{Learner::Create({Xy_})}; - if (updater_name == "grow_gpu_hist") { - // gpu_id setup - learner->SetParam("tree_method", "gpu_hist"); - } else { - learner->SetParam("updater", updater_name); - } + learner->SetParam("device", ctx->DeviceName()); + learner->SetParam("updater", updater_name); learner->SetParam("multi_strategy", strategy); learner->SetParam("grow_policy", grow_policy); learner->SetParam("subsample", std::to_string(subsample)); @@ -65,20 +61,14 @@ class TestPredictionCache : public ::testing::Test { } } - void RunTest(std::string const& updater_name, std::string const& strategy) { + void RunTest(Context* ctx, std::string const& updater_name, std::string const& strategy) { { - Context ctx; - ctx.InitAllowUnknown(Args{{"nthread", "8"}}); - if (updater_name == "grow_gpu_hist") { - ctx = ctx.MakeCUDA(0); - } else { - ctx = ctx.MakeCPU(); - } + ctx->InitAllowUnknown(Args{{"nthread", "8"}}); ObjInfo task{ObjInfo::kRegression}; - std::unique_ptr updater{TreeUpdater::Create(updater_name, &ctx, &task)}; + std::unique_ptr updater{TreeUpdater::Create(updater_name, ctx, &task)}; RegTree tree; - std::vector trees{&tree}; + std::vector trees{&tree}; auto gpair = GenerateRandomGradients(n_samples_); tree::TrainParam param; param.UpdateAllowUnknown(Args{{"max_bin", "64"}}); @@ -86,33 +76,46 @@ class TestPredictionCache : public ::testing::Test { std::vector> position(1); updater->Update(¶m, &gpair, Xy_.get(), position, trees); HostDeviceVector out_prediction_cached; - out_prediction_cached.SetDevice(ctx.gpu_id); + out_prediction_cached.SetDevice(ctx->Device()); out_prediction_cached.Resize(n_samples_); auto cache = - linalg::MakeTensorView(&ctx, &out_prediction_cached, out_prediction_cached.Size(), 1); + linalg::MakeTensorView(ctx, &out_prediction_cached, out_prediction_cached.Size(), 1); ASSERT_TRUE(updater->UpdatePredictionCache(Xy_.get(), cache)); } for (auto policy : {"depthwise", "lossguide"}) { for (auto subsample : {1.0f, 0.4f}) { - this->RunLearnerTest(updater_name, subsample, policy, strategy); - this->RunLearnerTest(updater_name, subsample, policy, strategy); + this->RunLearnerTest(ctx, updater_name, subsample, policy, strategy); + this->RunLearnerTest(ctx, updater_name, subsample, policy, strategy); } } } }; -TEST_F(TestPredictionCache, Approx) { this->RunTest("grow_histmaker", "one_output_per_tree"); } +TEST_F(TestPredictionCache, Approx) { + Context ctx; + this->RunTest(&ctx, "grow_histmaker", "one_output_per_tree"); +} TEST_F(TestPredictionCache, Hist) { - this->RunTest("grow_quantile_histmaker", "one_output_per_tree"); + Context ctx; + this->RunTest(&ctx, "grow_quantile_histmaker", "one_output_per_tree"); } TEST_F(TestPredictionCache, HistMulti) { - this->RunTest("grow_quantile_histmaker", "multi_output_tree"); + Context ctx; + this->RunTest(&ctx, "grow_quantile_histmaker", "multi_output_tree"); } #if defined(XGBOOST_USE_CUDA) -TEST_F(TestPredictionCache, GpuHist) { this->RunTest("grow_gpu_hist", "one_output_per_tree"); } +TEST_F(TestPredictionCache, GpuHist) { + auto ctx = MakeCUDACtx(0); + this->RunTest(&ctx, "grow_gpu_hist", "one_output_per_tree"); +} + +TEST_F(TestPredictionCache, GpuApprox) { + auto ctx = MakeCUDACtx(0); + this->RunTest(&ctx, "grow_gpu_approx", "one_output_per_tree"); +} #endif // defined(XGBOOST_USE_CUDA) } // namespace xgboost diff --git a/tests/cpp/tree/test_regen.cc b/tests/cpp/tree/test_regen.cc index d0fe5b449ee2..837159329761 100644 --- a/tests/cpp/tree/test_regen.cc +++ b/tests/cpp/tree/test_regen.cc @@ -62,8 +62,10 @@ class RegenTest : public ::testing::Test { auto constexpr Iter() const { return 4; } template - size_t TestTreeMethod(std::string tree_method, std::string obj, bool reset = true) const { + size_t TestTreeMethod(Context const* ctx, std::string tree_method, std::string obj, + bool reset = true) const { auto learner = std::unique_ptr{Learner::Create({p_fmat_})}; + learner->SetParam("device", ctx->DeviceName()); learner->SetParam("tree_method", tree_method); learner->SetParam("objective", obj); learner->Configure(); @@ -87,40 +89,71 @@ class RegenTest : public ::testing::Test { } // anonymous namespace TEST_F(RegenTest, Approx) { - auto n = this->TestTreeMethod("approx", "reg:squarederror"); + Context ctx; + auto n = this->TestTreeMethod(&ctx, "approx", "reg:squarederror"); ASSERT_EQ(n, 1); - n = this->TestTreeMethod("approx", "reg:logistic"); + n = this->TestTreeMethod(&ctx, "approx", "reg:logistic"); ASSERT_EQ(n, this->Iter()); } TEST_F(RegenTest, Hist) { - auto n = this->TestTreeMethod("hist", "reg:squarederror"); + Context ctx; + auto n = this->TestTreeMethod(&ctx, "hist", "reg:squarederror"); ASSERT_EQ(n, 1); - n = this->TestTreeMethod("hist", "reg:logistic"); + n = this->TestTreeMethod(&ctx, "hist", "reg:logistic"); ASSERT_EQ(n, 1); } TEST_F(RegenTest, Mixed) { - auto n = this->TestTreeMethod("hist", "reg:squarederror", false); + Context ctx; + auto n = this->TestTreeMethod(&ctx, "hist", "reg:squarederror", false); ASSERT_EQ(n, 1); - n = this->TestTreeMethod("approx", "reg:logistic", true); + n = this->TestTreeMethod(&ctx, "approx", "reg:logistic", true); ASSERT_EQ(n, this->Iter() + 1); - n = this->TestTreeMethod("approx", "reg:logistic", false); + n = this->TestTreeMethod(&ctx, "approx", "reg:logistic", false); ASSERT_EQ(n, this->Iter()); - n = this->TestTreeMethod("hist", "reg:squarederror", true); + n = this->TestTreeMethod(&ctx, "hist", "reg:squarederror", true); ASSERT_EQ(n, this->Iter() + 1); } #if defined(XGBOOST_USE_CUDA) +TEST_F(RegenTest, GpuApprox) { + auto ctx = MakeCUDACtx(0); + auto n = this->TestTreeMethod(&ctx, "approx", "reg:squarederror", true); + ASSERT_EQ(n, 1); + n = this->TestTreeMethod(&ctx, "approx", "reg:logistic", false); + ASSERT_EQ(n, this->Iter()); + + n = this->TestTreeMethod(&ctx, "approx", "reg:logistic", true); + ASSERT_EQ(n, this->Iter() * 2); +} + TEST_F(RegenTest, GpuHist) { - auto n = this->TestTreeMethod("gpu_hist", "reg:squarederror"); + auto ctx = MakeCUDACtx(0); + auto n = this->TestTreeMethod(&ctx, "hist", "reg:squarederror", true); ASSERT_EQ(n, 1); - n = this->TestTreeMethod("gpu_hist", "reg:logistic", false); + n = this->TestTreeMethod(&ctx, "hist", "reg:logistic", false); ASSERT_EQ(n, 1); - n = this->TestTreeMethod("hist", "reg:logistic"); - ASSERT_EQ(n, 2); + { + Context ctx; + n = this->TestTreeMethod(&ctx, "hist", "reg:logistic"); + ASSERT_EQ(n, 2); + } +} + +TEST_F(RegenTest, GpuMixed) { + auto ctx = MakeCUDACtx(0); + auto n = this->TestTreeMethod(&ctx, "hist", "reg:squarederror", false); + ASSERT_EQ(n, 1); + n = this->TestTreeMethod(&ctx, "approx", "reg:logistic", true); + ASSERT_EQ(n, this->Iter() + 1); + + n = this->TestTreeMethod(&ctx, "approx", "reg:logistic", false); + ASSERT_EQ(n, this->Iter()); + n = this->TestTreeMethod(&ctx, "hist", "reg:squarederror", true); + ASSERT_EQ(n, this->Iter() + 1); } #endif // defined(XGBOOST_USE_CUDA) } // namespace xgboost diff --git a/tests/cpp/tree/test_tree_policy.cc b/tests/cpp/tree/test_tree_policy.cc index 15f4cd31bc99..50563be1d218 100644 --- a/tests/cpp/tree/test_tree_policy.cc +++ b/tests/cpp/tree/test_tree_policy.cc @@ -20,10 +20,11 @@ class TestGrowPolicy : public ::testing::Test { true); } - std::unique_ptr TrainOneIter(std::string tree_method, std::string policy, - int32_t max_leaves, int32_t max_depth) { + std::unique_ptr TrainOneIter(Context const* ctx, std::string tree_method, + std::string policy, int32_t max_leaves, int32_t max_depth) { std::unique_ptr learner{Learner::Create({this->Xy_})}; learner->SetParam("tree_method", tree_method); + learner->SetParam("device", ctx->DeviceName()); if (max_leaves >= 0) { learner->SetParam("max_leaves", std::to_string(max_leaves)); } @@ -63,7 +64,7 @@ class TestGrowPolicy : public ::testing::Test { if (max_leaves == 0 && max_depth == 0) { // unconstrainted - if (tree_method != "gpu_hist") { + if (ctx->IsCPU()) { // GPU pre-allocates for all nodes. learner->UpdateOneIter(0, Xy_); } @@ -86,23 +87,23 @@ class TestGrowPolicy : public ::testing::Test { return learner; } - void TestCombination(std::string tree_method) { + void TestCombination(Context const* ctx, std::string tree_method) { for (auto policy : {"depthwise", "lossguide"}) { // -1 means default for (auto leaves : {-1, 0, 3}) { for (auto depth : {-1, 0, 3}) { - this->TrainOneIter(tree_method, policy, leaves, depth); + this->TrainOneIter(ctx, tree_method, policy, leaves, depth); } } } } - void TestTreeGrowPolicy(std::string tree_method, std::string policy) { + void TestTreeGrowPolicy(Context const* ctx, std::string tree_method, std::string policy) { { /** * max_leaves */ - auto learner = this->TrainOneIter(tree_method, policy, 16, -1); + auto learner = this->TrainOneIter(ctx, tree_method, policy, 16, -1); Json model{Object{}}; learner->SaveModel(&model); @@ -115,7 +116,7 @@ class TestGrowPolicy : public ::testing::Test { /** * max_depth */ - auto learner = this->TrainOneIter(tree_method, policy, -1, 3); + auto learner = this->TrainOneIter(ctx, tree_method, policy, -1, 3); Json model{Object{}}; learner->SaveModel(&model); @@ -133,25 +134,36 @@ class TestGrowPolicy : public ::testing::Test { }; TEST_F(TestGrowPolicy, Approx) { - this->TestTreeGrowPolicy("approx", "depthwise"); - this->TestTreeGrowPolicy("approx", "lossguide"); + Context ctx; + this->TestTreeGrowPolicy(&ctx, "approx", "depthwise"); + this->TestTreeGrowPolicy(&ctx, "approx", "lossguide"); - this->TestCombination("approx"); + this->TestCombination(&ctx, "approx"); } TEST_F(TestGrowPolicy, Hist) { - this->TestTreeGrowPolicy("hist", "depthwise"); - this->TestTreeGrowPolicy("hist", "lossguide"); + Context ctx; + this->TestTreeGrowPolicy(&ctx, "hist", "depthwise"); + this->TestTreeGrowPolicy(&ctx, "hist", "lossguide"); - this->TestCombination("hist"); + this->TestCombination(&ctx, "hist"); } #if defined(XGBOOST_USE_CUDA) TEST_F(TestGrowPolicy, GpuHist) { - this->TestTreeGrowPolicy("gpu_hist", "depthwise"); - this->TestTreeGrowPolicy("gpu_hist", "lossguide"); + auto ctx = MakeCUDACtx(0); + this->TestTreeGrowPolicy(&ctx, "hist", "depthwise"); + this->TestTreeGrowPolicy(&ctx, "hist", "lossguide"); - this->TestCombination("gpu_hist"); + this->TestCombination(&ctx, "hist"); +} + +TEST_F(TestGrowPolicy, GpuApprox) { + auto ctx = MakeCUDACtx(0); + this->TestTreeGrowPolicy(&ctx, "approx", "depthwise"); + this->TestTreeGrowPolicy(&ctx, "approx", "lossguide"); + + this->TestCombination(&ctx, "approx"); } #endif // defined(XGBOOST_USE_CUDA) } // namespace xgboost diff --git a/tests/cpp/tree/test_tree_stat.cc b/tests/cpp/tree/test_tree_stat.cc index fb64e3a7838e..d125c84d55b0 100644 --- a/tests/cpp/tree/test_tree_stat.cc +++ b/tests/cpp/tree/test_tree_stat.cc @@ -135,7 +135,7 @@ class TestMinSplitLoss : public ::testing::Test { gpair_ = GenerateRandomGradients(kRows); } - std::int32_t Update(std::string updater, float gamma) { + std::int32_t Update(Context const* ctx, std::string updater, float gamma) { Args args{{"max_depth", "1"}, {"max_leaves", "0"}, @@ -154,8 +154,7 @@ class TestMinSplitLoss : public ::testing::Test { param.UpdateAllowUnknown(args); ObjInfo task{ObjInfo::kRegression}; - Context ctx{MakeCUDACtx(updater == "grow_gpu_hist" ? 0 : Context::kCpuId)}; - auto up = std::unique_ptr{TreeUpdater::Create(updater, &ctx, &task)}; + auto up = std::unique_ptr{TreeUpdater::Create(updater, ctx, &task)}; up->Configure({}); RegTree tree; @@ -167,16 +166,16 @@ class TestMinSplitLoss : public ::testing::Test { } public: - void RunTest(std::string updater) { + void RunTest(Context const* ctx, std::string updater) { { - int32_t n_nodes = Update(updater, 0.01); + int32_t n_nodes = Update(ctx, updater, 0.01); // This is not strictly verified, meaning the numeber `2` is whatever GPU_Hist retured // when writing this test, and only used for testing larger gamma (below) does prevent // building tree. ASSERT_EQ(n_nodes, 2); } { - int32_t n_nodes = Update(updater, 100.0); + int32_t n_nodes = Update(ctx, updater, 100.0); // No new nodes with gamma == 100. ASSERT_EQ(n_nodes, static_cast(0)); } @@ -185,10 +184,25 @@ class TestMinSplitLoss : public ::testing::Test { /* Exact tree method requires a pruner as an additional updater, so not tested here. */ -TEST_F(TestMinSplitLoss, Approx) { this->RunTest("grow_histmaker"); } +TEST_F(TestMinSplitLoss, Approx) { + Context ctx; + this->RunTest(&ctx, "grow_histmaker"); +} + +TEST_F(TestMinSplitLoss, Hist) { + Context ctx; + this->RunTest(&ctx, "grow_quantile_histmaker"); +} -TEST_F(TestMinSplitLoss, Hist) { this->RunTest("grow_quantile_histmaker"); } #if defined(XGBOOST_USE_CUDA) -TEST_F(TestMinSplitLoss, GpuHist) { this->RunTest("grow_gpu_hist"); } +TEST_F(TestMinSplitLoss, GpuHist) { + auto ctx = MakeCUDACtx(0); + this->RunTest(&ctx, "grow_gpu_hist"); +} + +TEST_F(TestMinSplitLoss, GpuApprox) { + auto ctx = MakeCUDACtx(0); + this->RunTest(&ctx, "grow_gpu_approx"); +} #endif // defined(XGBOOST_USE_CUDA) } // namespace xgboost diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index 7fea42f608d1..653a99f3a725 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -7,11 +7,18 @@ import xgboost as xgb from xgboost import testing as tm -from xgboost.testing.params import cat_parameter_strategy, hist_parameter_strategy +from xgboost.testing.params import ( + cat_parameter_strategy, + exact_parameter_strategy, + hist_parameter_strategy, +) from xgboost.testing.updater import ( + check_categorical_missing, + check_categorical_ohe, check_get_quantile_cut, check_init_estimation, check_quantile_loss, + train_result, ) sys.path.append("tests/python") @@ -20,22 +27,6 @@ pytestmark = tm.timeout(30) -def train_result(param, dmat: xgb.DMatrix, num_rounds: int) -> dict: - result: xgb.callback.TrainingCallback.EvalsLog = {} - booster = xgb.train( - param, - dmat, - num_rounds, - [(dmat, "train")], - verbose_eval=False, - evals_result=result, - ) - assert booster.num_features() == dmat.num_col() - assert booster.num_boosted_rounds() == num_rounds - - return result - - class TestGPUUpdatersMulti: @given( hist_parameter_strategy, strategies.integers(1, 20), tm.multi_dataset_strategy @@ -53,14 +44,45 @@ class TestGPUUpdaters: cputest = test_up.TestTreeMethod() @given( - hist_parameter_strategy, strategies.integers(1, 20), tm.make_dataset_strategy() + exact_parameter_strategy, + hist_parameter_strategy, + strategies.integers(1, 20), + tm.make_dataset_strategy(), ) @settings(deadline=None, max_examples=50, print_blob=True) - def test_gpu_hist(self, param, num_rounds, dataset): - param["tree_method"] = "gpu_hist" + def test_gpu_hist( + self, + param: Dict[str, Any], + hist_param: Dict[str, Any], + num_rounds: int, + dataset: tm.TestDataset, + ) -> None: + param.update({"tree_method": "hist", "device": "cuda"}) + param.update(hist_param) param = dataset.set_params(param) result = train_result(param, dataset.get_dmat(), num_rounds) - note(result) + note(str(result)) + assert tm.non_increasing(result["train"][dataset.metric]) + + @given( + exact_parameter_strategy, + hist_parameter_strategy, + strategies.integers(1, 20), + tm.make_dataset_strategy(), + ) + @settings(deadline=None, print_blob=True) + def test_gpu_approx( + self, + param: Dict[str, Any], + hist_param: Dict[str, Any], + num_rounds: int, + dataset: tm.TestDataset, + ) -> None: + param.update({"tree_method": "approx", "device": "cuda"}) + param.update(hist_param) + param = dataset.set_params(param) + result = train_result(param, dataset.get_dmat(), num_rounds) + note(str(result)) assert tm.non_increasing(result["train"][dataset.metric]) @given(tm.sparse_datasets_strategy) @@ -69,23 +91,27 @@ def test_sparse(self, dataset): param = {"tree_method": "hist", "max_bin": 64} hist_result = train_result(param, dataset.get_dmat(), 16) note(hist_result) - assert tm.non_increasing(hist_result['train'][dataset.metric]) + assert tm.non_increasing(hist_result["train"][dataset.metric]) param = {"tree_method": "gpu_hist", "max_bin": 64} gpu_hist_result = train_result(param, dataset.get_dmat(), 16) note(gpu_hist_result) - assert tm.non_increasing(gpu_hist_result['train'][dataset.metric]) + assert tm.non_increasing(gpu_hist_result["train"][dataset.metric]) np.testing.assert_allclose( hist_result["train"]["rmse"], gpu_hist_result["train"]["rmse"], rtol=1e-2 ) - @given(strategies.integers(10, 400), strategies.integers(3, 8), - strategies.integers(1, 2), strategies.integers(4, 7)) + @given( + strategies.integers(10, 400), + strategies.integers(3, 8), + strategies.integers(1, 2), + strategies.integers(4, 7), + ) @settings(deadline=None, max_examples=20, print_blob=True) @pytest.mark.skipif(**tm.no_pandas()) def test_categorical_ohe(self, rows, cols, rounds, cats): - self.cputest.run_categorical_ohe(rows, cols, rounds, cats, "gpu_hist") + check_categorical_ohe(rows, cols, rounds, cats, "cuda", "hist") @given( tm.categorical_dataset_strategy, @@ -95,7 +121,7 @@ def test_categorical_ohe(self, rows, cols, rounds, cats): ) @settings(deadline=None, max_examples=20, print_blob=True) @pytest.mark.skipif(**tm.no_pandas()) - def test_categorical( + def test_categorical_hist( self, dataset: tm.TestDataset, hist_parameters: Dict[str, Any], @@ -103,7 +129,30 @@ def test_categorical( n_rounds: int, ) -> None: cat_parameters.update(hist_parameters) - cat_parameters["tree_method"] = "gpu_hist" + cat_parameters["tree_method"] = "hist" + cat_parameters["device"] = "cuda" + + results = train_result(cat_parameters, dataset.get_dmat(), n_rounds) + tm.non_increasing(results["train"]["rmse"]) + + @given( + tm.categorical_dataset_strategy, + hist_parameter_strategy, + cat_parameter_strategy, + strategies.integers(4, 32), + ) + @settings(deadline=None, max_examples=20, print_blob=True) + @pytest.mark.skipif(**tm.no_pandas()) + def test_categorical_approx( + self, + dataset: tm.TestDataset, + hist_parameters: Dict[str, Any], + cat_parameters: Dict[str, Any], + n_rounds: int, + ) -> None: + cat_parameters.update(hist_parameters) + cat_parameters["tree_method"] = "approx" + cat_parameters["device"] = "cuda" results = train_result(cat_parameters, dataset.get_dmat(), n_rounds) tm.non_increasing(results["train"]["rmse"]) @@ -129,24 +178,25 @@ def test_categorical_ames_housing( @given( strategies.integers(10, 400), strategies.integers(3, 8), - strategies.integers(4, 7) + strategies.integers(4, 7), ) @settings(deadline=None, max_examples=20, print_blob=True) @pytest.mark.skipif(**tm.no_pandas()) def test_categorical_missing(self, rows, cols, cats): - self.cputest.run_categorical_missing(rows, cols, cats, "gpu_hist") + check_categorical_missing(rows, cols, cats, "cuda", "approx") + check_categorical_missing(rows, cols, cats, "cuda", "hist") @pytest.mark.skipif(**tm.no_pandas()) def test_max_cat(self) -> None: self.cputest.run_max_cat("gpu_hist") def test_categorical_32_cat(self): - '''32 hits the bound of integer bitset, so special test''' + """32 hits the bound of integer bitset, so special test""" rows = 1000 cols = 10 cats = 32 rounds = 4 - self.cputest.run_categorical_ohe(rows, cols, rounds, cats, "gpu_hist") + check_categorical_ohe(rows, cols, rounds, cats, "cuda", "hist") @pytest.mark.skipif(**tm.no_cupy()) def test_invalid_category(self): @@ -164,15 +214,15 @@ def test_gpu_hist_device_dmatrix( ) -> None: # We cannot handle empty dataset yet assume(len(dataset.y) > 0) - param['tree_method'] = 'gpu_hist' + param["tree_method"] = "gpu_hist" param = dataset.set_params(param) result = train_result( param, dataset.get_device_dmat(max_bin=param.get("max_bin", None)), - num_rounds + num_rounds, ) note(result) - assert tm.non_increasing(result['train'][dataset.metric], tolerance=1e-3) + assert tm.non_increasing(result["train"][dataset.metric], tolerance=1e-3) @given( hist_parameter_strategy, @@ -185,12 +235,12 @@ def test_external_memory(self, param, num_rounds, dataset): return # We cannot handle empty dataset yet assume(len(dataset.y) > 0) - param['tree_method'] = 'gpu_hist' + param["tree_method"] = "gpu_hist" param = dataset.set_params(param) m = dataset.get_external_dmat() external_result = train_result(param, m, num_rounds) del m - assert tm.non_increasing(external_result['train'][dataset.metric]) + assert tm.non_increasing(external_result["train"][dataset.metric]) def test_empty_dmatrix_prediction(self): # FIXME(trivialfis): This should be done with all updaters @@ -207,7 +257,7 @@ def test_empty_dmatrix_prediction(self): dtrain, verbose_eval=True, num_boost_round=6, - evals=[(dtrain, 'Train')] + evals=[(dtrain, "Train")], ) kRows = 100 @@ -222,10 +272,10 @@ def test_empty_dmatrix_prediction(self): @given(tm.make_dataset_strategy(), strategies.integers(0, 10)) @settings(deadline=None, max_examples=10, print_blob=True) def test_specified_gpu_id_gpu_update(self, dataset, gpu_id): - param = {'tree_method': 'gpu_hist', 'gpu_id': gpu_id} + param = {"tree_method": "gpu_hist", "gpu_id": gpu_id} param = dataset.set_params(param) result = train_result(param, dataset.get_dmat(), 10) - assert tm.non_increasing(result['train'][dataset.metric]) + assert tm.non_increasing(result["train"][dataset.metric]) @pytest.mark.skipif(**tm.no_sklearn()) @pytest.mark.parametrize("weighted", [True, False]) diff --git a/tests/python/test_updaters.py b/tests/python/test_updaters.py index 029911bf0deb..5374a2891382 100644 --- a/tests/python/test_updaters.py +++ b/tests/python/test_updaters.py @@ -1,6 +1,6 @@ import json from string import ascii_lowercase -from typing import Any, Dict, List +from typing import Any, Dict import numpy as np import pytest @@ -15,30 +15,15 @@ hist_parameter_strategy, ) from xgboost.testing.updater import ( + check_categorical_missing, + check_categorical_ohe, check_get_quantile_cut, check_init_estimation, check_quantile_loss, + train_result, ) -def train_result(param, dmat, num_rounds): - result = {} - booster = xgb.train( - param, - dmat, - num_rounds, - evals=[(dmat, "train")], - verbose_eval=False, - evals_result=result, - ) - assert booster.num_features() == dmat.num_col() - assert booster.num_boosted_rounds() == num_rounds - assert booster.feature_names == dmat.feature_names - assert booster.feature_types == dmat.feature_types - - return result - - class TestTreeMethodMulti: @given( exact_parameter_strategy, strategies.integers(1, 20), tm.multi_dataset_strategy @@ -281,115 +266,6 @@ def run_max_cat(self, tree_method: str) -> None: def test_max_cat(self, tree_method) -> None: self.run_max_cat(tree_method) - def run_categorical_missing( - self, rows: int, cols: int, cats: int, tree_method: str - ) -> None: - parameters: Dict[str, Any] = {"tree_method": tree_method} - cat, label = tm.make_categorical( - rows, n_features=cols, n_categories=cats, onehot=False, sparsity=0.5 - ) - Xy = xgb.DMatrix(cat, label, enable_categorical=True) - - def run(max_cat_to_onehot: int): - # Test with onehot splits - parameters["max_cat_to_onehot"] = max_cat_to_onehot - - evals_result: Dict[str, Dict] = {} - booster = xgb.train( - parameters, - Xy, - num_boost_round=16, - evals=[(Xy, "Train")], - evals_result=evals_result - ) - assert tm.non_increasing(evals_result["Train"]["rmse"]) - y_predt = booster.predict(Xy) - - rmse = tm.root_mean_square(label, y_predt) - np.testing.assert_allclose( - rmse, evals_result["Train"]["rmse"][-1], rtol=2e-5 - ) - - # Test with OHE split - run(self.USE_ONEHOT) - - # Test with partition-based split - run(self.USE_PART) - - def run_categorical_ohe( - self, rows: int, cols: int, rounds: int, cats: int, tree_method: str - ) -> None: - onehot, label = tm.make_categorical(rows, cols, cats, True) - cat, _ = tm.make_categorical(rows, cols, cats, False) - - by_etl_results: Dict[str, Dict[str, List[float]]] = {} - by_builtin_results: Dict[str, Dict[str, List[float]]] = {} - - parameters: Dict[str, Any] = { - "tree_method": tree_method, - # Use one-hot exclusively - "max_cat_to_onehot": self.USE_ONEHOT - } - - m = xgb.DMatrix(onehot, label, enable_categorical=False) - xgb.train( - parameters, - m, - num_boost_round=rounds, - evals=[(m, "Train")], - evals_result=by_etl_results, - ) - - m = xgb.DMatrix(cat, label, enable_categorical=True) - xgb.train( - parameters, - m, - num_boost_round=rounds, - evals=[(m, "Train")], - evals_result=by_builtin_results, - ) - - # There are guidelines on how to specify tolerance based on considering output - # as random variables. But in here the tree construction is extremely sensitive - # to floating point errors. An 1e-5 error in a histogram bin can lead to an - # entirely different tree. So even though the test is quite lenient, hypothesis - # can still pick up falsifying examples from time to time. - np.testing.assert_allclose( - np.array(by_etl_results["Train"]["rmse"]), - np.array(by_builtin_results["Train"]["rmse"]), - rtol=1e-3, - ) - assert tm.non_increasing(by_builtin_results["Train"]["rmse"]) - - by_grouping: Dict[str, Dict[str, List[float]]] = {} - # switch to partition-based splits - parameters["max_cat_to_onehot"] = self.USE_PART - parameters["reg_lambda"] = 0 - m = xgb.DMatrix(cat, label, enable_categorical=True) - xgb.train( - parameters, - m, - num_boost_round=rounds, - evals=[(m, "Train")], - evals_result=by_grouping, - ) - rmse_oh = by_builtin_results["Train"]["rmse"] - rmse_group = by_grouping["Train"]["rmse"] - # always better or equal to onehot when there's no regularization. - for a, b in zip(rmse_oh, rmse_group): - assert a >= b - - parameters["reg_lambda"] = 1.0 - by_grouping = {} - xgb.train( - parameters, - m, - num_boost_round=32, - evals=[(m, "Train")], - evals_result=by_grouping, - ) - assert tm.non_increasing(by_grouping["Train"]["rmse"]), by_grouping - @given(strategies.integers(10, 400), strategies.integers(3, 8), strategies.integers(1, 2), strategies.integers(4, 7)) @settings(deadline=None, print_blob=True) @@ -397,8 +273,8 @@ def run_categorical_ohe( def test_categorical_ohe( self, rows: int, cols: int, rounds: int, cats: int ) -> None: - self.run_categorical_ohe(rows, cols, rounds, cats, "approx") - self.run_categorical_ohe(rows, cols, rounds, cats, "hist") + check_categorical_ohe(rows, cols, rounds, cats, "cpu", "approx") + check_categorical_ohe(rows, cols, rounds, cats, "cpu", "hist") @given( tm.categorical_dataset_strategy, @@ -454,8 +330,8 @@ def test_categorical_ames_housing( @settings(deadline=None, print_blob=True) @pytest.mark.skipif(**tm.no_pandas()) def test_categorical_missing(self, rows, cols, cats): - self.run_categorical_missing(rows, cols, cats, "approx") - self.run_categorical_missing(rows, cols, cats, "hist") + check_categorical_missing(rows, cols, cats, "cpu", "approx") + check_categorical_missing(rows, cols, cats, "cpu", "hist") def run_adaptive(self, tree_method, weighted) -> None: rng = np.random.RandomState(1994) diff --git a/tests/test_distributed/test_gpu_with_dask/test_gpu_with_dask.py b/tests/test_distributed/test_gpu_with_dask/test_gpu_with_dask.py index 9386486de238..4cc9345799a0 100644 --- a/tests/test_distributed/test_gpu_with_dask/test_gpu_with_dask.py +++ b/tests/test_distributed/test_gpu_with_dask/test_gpu_with_dask.py @@ -154,7 +154,6 @@ def run_gpu_hist( DMatrixT: Type, client: Client, ) -> None: - params["tree_method"] = "hist" params["device"] = "cuda" params = dataset.set_params(params) # It doesn't make sense to distribute a completely @@ -275,8 +274,31 @@ def test_gpu_hist( dmatrix_type: type, local_cuda_client: Client, ) -> None: + params["tree_method"] = "hist" run_gpu_hist(params, num_rounds, dataset, dmatrix_type, local_cuda_client) + @given( + params=hist_parameter_strategy, + num_rounds=strategies.integers(1, 20), + dataset=tm.make_dataset_strategy(), + ) + @settings( + deadline=duration(seconds=120), + max_examples=20, + suppress_health_check=suppress, + print_blob=True, + ) + @pytest.mark.skipif(**tm.no_cupy()) + def test_gpu_approx( + self, + params: Dict, + num_rounds: int, + dataset: tm.TestDataset, + local_cuda_client: Client, + ) -> None: + params["tree_method"] = "approx" + run_gpu_hist(params, num_rounds, dataset, dxgb.DaskDMatrix, local_cuda_client) + def test_empty_quantile_dmatrix(self, local_cuda_client: Client) -> None: client = local_cuda_client X, y = make_categorical(client, 1, 30, 13)