Skip to content

Commit

Permalink
Initial GPU support for the approx tree method. (#9414)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Jul 31, 2023
1 parent 8f0efb4 commit 912e341
Show file tree
Hide file tree
Showing 23 changed files with 639 additions and 360 deletions.
3 changes: 2 additions & 1 deletion doc/parameter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ Parameters for Tree Booster
- ``grow_colmaker``: non-distributed column-based construction of trees.
- ``grow_histmaker``: distributed tree construction with row-based data splitting based on global proposal of histogram counting.
- ``grow_quantile_histmaker``: Grow tree using quantized histogram.
- ``grow_gpu_hist``: Grow tree with GPU. Enabled when ``tree_method`` is set to ``hist`` along with ``device=cuda``.
- ``grow_gpu_hist``: Enabled when ``tree_method`` is set to ``hist`` along with ``device=cuda``.
- ``grow_gpu_approx``: Enabled when ``tree_method`` is set to ``approx`` along with ``device=cuda``.
- ``sync``: synchronizes trees in all distributed nodes.
- ``refresh``: refreshes tree's statistics and/or leaf values based on the current data. Note that no random subsampling of data rows is performed.
- ``prune``: prunes the splits where loss < min_split_loss (or gamma) and nodes that have depth greater than ``max_depth``.
Expand Down
34 changes: 17 additions & 17 deletions doc/treemethod.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,23 +123,23 @@ Feature Matrix
Following table summarizes some differences in supported features between 4 tree methods,
`T` means supported while `F` means unsupported.

+------------------+-----------+---------------------+---------------------+------------------------+
| | Exact | Approx | Hist | Hist (GPU) |
+==================+===========+=====================+=====================+========================+
| grow_policy | Depthwise | depthwise/lossguide | depthwise/lossguide | depthwise/lossguide |
+------------------+-----------+---------------------+---------------------+------------------------+
| max_leaves | F | T | T | T |
+------------------+-----------+---------------------+---------------------+------------------------+
| sampling method | uniform | uniform | uniform | gradient_based/uniform |
+------------------+-----------+---------------------+---------------------+------------------------+
| categorical data | F | T | T | T |
+------------------+-----------+---------------------+---------------------+------------------------+
| External memory | F | T | T | P |
+------------------+-----------+---------------------+---------------------+------------------------+
| Distributed | F | T | T | T |
+------------------+-----------+---------------------+---------------------+------------------------+

Features/parameters that are not mentioned here are universally supported for all 4 tree
+------------------+-----------+---------------------+------------------------+---------------------+------------------------+
| | Exact | Approx | Approx (GPU) | Hist | Hist (GPU) |
+==================+===========+=====================+========================+=====================+========================+
| grow_policy | Depthwise | depthwise/lossguide | depthwise/lossguide | depthwise/lossguide | depthwise/lossguide |
+------------------+-----------+---------------------+------------------------+---------------------+------------------------+
| max_leaves | F | T | T | T | T |
+------------------+-----------+---------------------+------------------------+---------------------+------------------------+
| sampling method | uniform | uniform | gradient_based/uniform | uniform | gradient_based/uniform |
+------------------+-----------+---------------------+------------------------+---------------------+------------------------+
| categorical data | F | T | T | T | T |
+------------------+-----------+---------------------+------------------------+---------------------+------------------------+
| External memory | F | T | P | T | P |
+------------------+-----------+---------------------+------------------------+---------------------+------------------------+
| Distributed | F | T | T | T | T |
+------------------+-----------+---------------------+------------------------+---------------------+------------------------+

Features/parameters that are not mentioned here are universally supported for all 3 tree
methods (for instance, column sampling and constraints). The `P` in external memory means
special handling. Please note that both categorical data and external memory are
experimental.
140 changes: 139 additions & 1 deletion python-package/xgboost/testing/updater.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Tests for updaters."""
import json
from functools import partial, update_wrapper
from typing import Any, Dict
from typing import Any, Dict, List

import numpy as np

Expand Down Expand Up @@ -256,3 +256,141 @@ def check_get_quantile_cut(tree_method: str) -> None:
check_get_quantile_cut_device(tree_method, False)
if use_cupy:
check_get_quantile_cut_device(tree_method, True)


USE_ONEHOT = np.iinfo(np.int32).max
USE_PART = 1


def check_categorical_ohe( # pylint: disable=too-many-arguments
rows: int, cols: int, rounds: int, cats: int, device: str, tree_method: str
) -> None:
"Test for one-hot encoding with categorical data."

onehot, label = tm.make_categorical(rows, cols, cats, True)
cat, _ = tm.make_categorical(rows, cols, cats, False)

by_etl_results: Dict[str, Dict[str, List[float]]] = {}
by_builtin_results: Dict[str, Dict[str, List[float]]] = {}

parameters: Dict[str, Any] = {
"tree_method": tree_method,
# Use one-hot exclusively
"max_cat_to_onehot": USE_ONEHOT,
"device": device,
}

m = xgb.DMatrix(onehot, label, enable_categorical=False)
xgb.train(
parameters,
m,
num_boost_round=rounds,
evals=[(m, "Train")],
evals_result=by_etl_results,
)

m = xgb.DMatrix(cat, label, enable_categorical=True)
xgb.train(
parameters,
m,
num_boost_round=rounds,
evals=[(m, "Train")],
evals_result=by_builtin_results,
)

# There are guidelines on how to specify tolerance based on considering output
# as random variables. But in here the tree construction is extremely sensitive
# to floating point errors. An 1e-5 error in a histogram bin can lead to an
# entirely different tree. So even though the test is quite lenient, hypothesis
# can still pick up falsifying examples from time to time.
np.testing.assert_allclose(
np.array(by_etl_results["Train"]["rmse"]),
np.array(by_builtin_results["Train"]["rmse"]),
rtol=1e-3,
)
assert tm.non_increasing(by_builtin_results["Train"]["rmse"])

by_grouping: Dict[str, Dict[str, List[float]]] = {}
# switch to partition-based splits
parameters["max_cat_to_onehot"] = USE_PART
parameters["reg_lambda"] = 0
m = xgb.DMatrix(cat, label, enable_categorical=True)
xgb.train(
parameters,
m,
num_boost_round=rounds,
evals=[(m, "Train")],
evals_result=by_grouping,
)
rmse_oh = by_builtin_results["Train"]["rmse"]
rmse_group = by_grouping["Train"]["rmse"]
# always better or equal to onehot when there's no regularization.
for a, b in zip(rmse_oh, rmse_group):
assert a >= b

parameters["reg_lambda"] = 1.0
by_grouping = {}
xgb.train(
parameters,
m,
num_boost_round=32,
evals=[(m, "Train")],
evals_result=by_grouping,
)
assert tm.non_increasing(by_grouping["Train"]["rmse"]), by_grouping


def check_categorical_missing(
rows: int, cols: int, cats: int, device: str, tree_method: str
) -> None:
"""Check categorical data with missing values."""
parameters: Dict[str, Any] = {"tree_method": tree_method, "device": device}
cat, label = tm.make_categorical(
rows, n_features=cols, n_categories=cats, onehot=False, sparsity=0.5
)
Xy = xgb.DMatrix(cat, label, enable_categorical=True)

def run(max_cat_to_onehot: int) -> None:
# Test with onehot splits
parameters["max_cat_to_onehot"] = max_cat_to_onehot

evals_result: Dict[str, Dict] = {}
booster = xgb.train(
parameters,
Xy,
num_boost_round=16,
evals=[(Xy, "Train")],
evals_result=evals_result,
)
assert tm.non_increasing(evals_result["Train"]["rmse"])
y_predt = booster.predict(Xy)

rmse = tm.root_mean_square(label, y_predt)
np.testing.assert_allclose(rmse, evals_result["Train"]["rmse"][-1], rtol=2e-5)

# Test with OHE split
run(USE_ONEHOT)

# Test with partition-based split
run(USE_PART)


def train_result(
param: Dict[str, Any], dmat: xgb.DMatrix, num_rounds: int
) -> Dict[str, Any]:
"""Get training result from parameters and data."""
result: Dict[str, Any] = {}
booster = xgb.train(
param,
dmat,
num_rounds,
evals=[(dmat, "train")],
verbose_eval=False,
evals_result=result,
)
assert booster.num_features() == dmat.num_col()
assert booster.num_boosted_rounds() == num_rounds
assert booster.feature_names == dmat.feature_names
assert booster.feature_types == dmat.feature_types

return result
5 changes: 5 additions & 0 deletions src/common/error_msg.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,10 @@ void WarnDeprecatedGPUId();
void WarnEmptyDataset();

std::string DeprecatedFunc(StringView old, StringView since, StringView replacement);

constexpr StringView InvalidCUDAOrdinal() {
return "Invalid device. `device` is required to be CUDA and there must be at least one GPU "
"available for using GPU.";
}
} // namespace xgboost::error
#endif // XGBOOST_COMMON_ERROR_MSG_H_
6 changes: 3 additions & 3 deletions src/common/ranking_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include <vector> // for vector

#include "dmlc/parameter.h" // for FieldEntry, DMLC_DECLARE_FIELD
#include "error_msg.h" // for GroupWeight, GroupSize
#include "error_msg.h" // for GroupWeight, GroupSize, InvalidCUDAOrdinal
#include "xgboost/base.h" // for XGBOOST_DEVICE, bst_group_t
#include "xgboost/context.h" // for Context
#include "xgboost/data.h" // for MetaInfo
Expand Down Expand Up @@ -240,15 +240,15 @@ class RankingCache {
// The function simply returns a uninitialized buffer as this is only used by the
// objective for creating pairs.
common::Span<std::size_t> SortedIdxY(Context const* ctx, std::size_t n_samples) {
CHECK(ctx->IsCUDA());
CHECK(ctx->IsCUDA()) << error::InvalidCUDAOrdinal();
if (y_sorted_idx_cache_.Empty()) {
y_sorted_idx_cache_.SetDevice(ctx->gpu_id);
y_sorted_idx_cache_.Resize(n_samples);
}
return y_sorted_idx_cache_.DeviceSpan();
}
common::Span<float> RankedY(Context const* ctx, std::size_t n_samples) {
CHECK(ctx->IsCUDA());
CHECK(ctx->IsCUDA()) << error::InvalidCUDAOrdinal();
if (y_ranked_by_model_.Empty()) {
y_ranked_by_model_.SetDevice(ctx->gpu_id);
y_ranked_by_model_.Resize(n_samples);
Expand Down
7 changes: 5 additions & 2 deletions src/data/ellpack_page.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include "../common/categorical.h"
#include "../common/cuda_context.cuh"
#include "../common/hist_util.cuh"
#include "../common/random.h"
#include "../common/transform_iterator.h" // MakeIndexTransformIter
#include "./ellpack_page.cuh"
#include "device_adapter.cuh" // for HasInfInData
Expand Down Expand Up @@ -131,7 +130,11 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx, DMatrix* dmat, const BatchP
monitor_.Start("Quantiles");
// Create the quantile sketches for the dmatrix and initialize HistogramCuts.
row_stride = GetRowStride(dmat);
cuts_ = common::DeviceSketch(ctx, dmat, param.max_bin);
if (!param.hess.empty()) {
cuts_ = common::DeviceSketchWithHessian(ctx, dmat, param.max_bin, param.hess);
} else {
cuts_ = common::DeviceSketch(ctx, dmat, param.max_bin);
}
monitor_.Stop("Quantiles");

monitor_.Start("InitCompressedData");
Expand Down
5 changes: 2 additions & 3 deletions src/data/gradient_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@
#include <algorithm>
#include <limits>
#include <memory>
#include <utility> // std::forward
#include <utility> // for forward

#include "../common/column_matrix.h"
#include "../common/hist_util.h"
#include "../common/numeric.h"
#include "../common/threading_utils.h"
#include "../common/transform_iterator.h" // MakeIndexTransformIter
#include "../common/transform_iterator.h" // for MakeIndexTransformIter

namespace xgboost {

Expand Down
6 changes: 3 additions & 3 deletions src/data/simple_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

#include <algorithm>
#include <limits>
#include <numeric> // for accumulate
#include <type_traits>
#include <vector>

#include "../common/error_msg.h" // for InconsistentMaxBin
#include "../common/random.h"
#include "../common/threading_utils.h"
#include "../collective/communicator-inl.h" // for GetWorldSize, GetRank, Allgather
#include "../common/error_msg.h" // for InconsistentMaxBin
#include "./simple_batch_iterator.h"
#include "adapter.h"
#include "batch_utils.h" // for CheckEmpty, RegenGHist
Expand Down
1 change: 0 additions & 1 deletion src/data/sparse_page_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include "./sparse_page_dmatrix.h"

#include "../collective/communicator-inl.h"
#include "./simple_batch_iterator.h"
#include "batch_utils.h" // for RegenGHist
#include "gradient_index.h"

Expand Down
23 changes: 15 additions & 8 deletions src/data/sparse_page_dmatrix.cu
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
/**
* Copyright 2021-2023 by XGBoost contributors
*/
#include <memory>
#include <memory> // for unique_ptr

#include "../common/hist_util.cuh"
#include "batch_utils.h" // for CheckEmpty, RegenGHist
#include "../common/hist_util.h" // for HistogramCuts
#include "batch_utils.h" // for CheckEmpty, RegenGHist
#include "ellpack_page.cuh"
#include "sparse_page_dmatrix.h"
#include "sparse_page_source.h"
#include "xgboost/context.h" // for Context
#include "xgboost/data.h" // for BatchParam

namespace xgboost::data {
BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
Expand All @@ -25,8 +27,13 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
cache_info_.erase(id);
MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_);
std::unique_ptr<common::HistogramCuts> cuts;
cuts =
std::make_unique<common::HistogramCuts>(common::DeviceSketch(ctx, this, param.max_bin, 0));
if (!param.hess.empty()) {
cuts = std::make_unique<common::HistogramCuts>(
common::DeviceSketchWithHessian(ctx, this, param.max_bin, param.hess));
} else {
cuts =
std::make_unique<common::HistogramCuts>(common::DeviceSketch(ctx, this, param.max_bin));
}
this->InitializeSparsePage(ctx); // reset after use.

row_stride = GetRowStride(this);
Expand All @@ -35,10 +42,10 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
batch_param_ = param;

auto ft = this->info_.feature_types.ConstDeviceSpan();
ellpack_page_source_.reset(); // release resources.
ellpack_page_source_.reset(new EllpackPageSource(
ellpack_page_source_.reset(); // make sure resource is released before making new ones.
ellpack_page_source_ = std::make_shared<EllpackPageSource>(
this->missing_, ctx->Threads(), this->Info().num_col_, this->n_batches_, cache_info_.at(id),
param, std::move(cuts), this->IsDense(), row_stride, ft, sparse_page_source_, ctx->gpu_id));
param, std::move(cuts), this->IsDense(), row_stride, ft, sparse_page_source_, ctx->gpu_id);
} else {
CHECK(sparse_page_source_);
ellpack_page_source_->Reset();
Expand Down
7 changes: 4 additions & 3 deletions src/gbm/gbtree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,16 @@ std::string MapTreeMethodToUpdaters(Context const* ctx, TreeMethod tree_method)
if (ctx->IsCUDA()) {
common::AssertGPUSupport();
}

switch (tree_method) {
case TreeMethod::kAuto: // Use hist as default in 2.0
case TreeMethod::kHist: {
return ctx->DispatchDevice([] { return "grow_quantile_histmaker"; },
[] { return "grow_gpu_hist"; });
}
case TreeMethod::kApprox:
CHECK(ctx->IsCPU()) << "The `approx` tree method is not supported on GPU.";
return "grow_histmaker";
case TreeMethod::kApprox: {
return ctx->DispatchDevice([] { return "grow_histmaker"; }, [] { return "grow_gpu_approx"; });
}
case TreeMethod::kExact:
CHECK(ctx->IsCPU()) << "The `exact` tree method is not supported on GPU.";
return "grow_colmaker,prune";
Expand Down
8 changes: 3 additions & 5 deletions src/tree/constraints.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright 2018-2019 by Contributors
/**
* Copyright 2018-2023 by Contributors
*/
#ifndef XGBOOST_TREE_CONSTRAINTS_H_
#define XGBOOST_TREE_CONSTRAINTS_H_
Expand All @@ -8,10 +8,8 @@
#include <unordered_set>
#include <vector>

#include "xgboost/span.h"
#include "xgboost/base.h"

#include "param.h"
#include "xgboost/base.h"

namespace xgboost {
/*!
Expand Down
Loading

0 comments on commit 912e341

Please sign in to comment.