Skip to content

Commit

Permalink
Support optimal partitioning for GPU hist. (#7652)
Browse files Browse the repository at this point in the history
* Implement `MaxCategory` in quantile.
* Implement partition-based split for GPU evaluation.  Currently, it's based on the existing evaluation function.
* Extract an evaluator from GPU Hist to store the needed states.
* Added some CUDA stream/event utilities.
* Update document with references.
* Fixed a bug in approx evaluator where the number of data points is less than the number of categories.
  • Loading branch information
trivialfis authored Feb 14, 2022
1 parent 2369d55 commit 0d0abe1
Show file tree
Hide file tree
Showing 26 changed files with 1,088 additions and 528 deletions.
17 changes: 11 additions & 6 deletions demo/guide-python/cat_in_the_dat.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,12 @@ def load_cat_in_the_dat() -> tuple[pd.DataFrame, pd.Series]:
return X, y


params = {"tree_method": "gpu_hist", "use_label_encoder": False, "n_estimators": 32}
params = {
"tree_method": "gpu_hist",
"use_label_encoder": False,
"n_estimators": 32,
"colsample_bylevel": 0.7,
}


def categorical_model(X: pd.DataFrame, y: pd.Series, output_dir: str) -> None:
Expand All @@ -70,13 +75,13 @@ def categorical_model(X: pd.DataFrame, y: pd.Series, output_dir: str) -> None:
X, y, random_state=1994, test_size=0.2
)
# Specify `enable_categorical`.
clf = xgb.XGBClassifier(**params, enable_categorical=True)
clf.fit(
X_train,
y_train,
eval_set=[(X_test, y_test), (X_train, y_train)],
clf = xgb.XGBClassifier(
**params,
eval_metric="auc",
enable_categorical=True,
max_cat_to_onehot=1, # We use optimal partitioning exclusively
)
clf.fit(X_train, y_train, eval_set=[(X_test, y_test), (X_train, y_train)])
clf.save_model(os.path.join(output_dir, "categorical.json"))

y_score = clf.predict_proba(X_test)[:, 1] # proba of positive samples
Expand Down
17 changes: 10 additions & 7 deletions demo/guide-python/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
=====================================
Experimental support for categorical data. After 1.5 XGBoost `gpu_hist` tree method has
experimental support for one-hot encoding based tree split, and in 1.6 `approx` supported
experimental support for one-hot encoding based tree split, and in 1.6 `approx` support
was added.
In before, users need to run an encoder themselves before passing the data into XGBoost,
which creates a sparse matrix and potentially increase memory usage. This demo showcases
the experimental categorical data support, more advanced features are planned.
Also, see :doc:`the tutorial </tutorials/categorical>` for using XGBoost with categorical data
which creates a sparse matrix and potentially increase memory usage. This demo
showcases the experimental categorical data support, more advanced features are planned.
Also, see :doc:`the tutorial </tutorials/categorical>` for using XGBoost with
categorical data.
.. versionadded:: 1.5.0
Expand Down Expand Up @@ -55,8 +55,11 @@ def main() -> None:
# For scikit-learn interface, the input data must be pandas DataFrame or cudf
# DataFrame with categorical features
X, y = make_categorical(100, 10, 4, False)
# Specify `enable_categorical` to True.
reg = xgb.XGBRegressor(tree_method="gpu_hist", enable_categorical=True)
# Specify `enable_categorical` to True, also we use onehot encoding based split
# here for demonstration. For details see the document of `max_cat_to_onehot`.
reg = xgb.XGBRegressor(
tree_method="gpu_hist", enable_categorical=True, max_cat_to_onehot=5
)
reg.fit(X, y, eval_set=[(X, y)])

# Pass in already encoded data
Expand Down
7 changes: 4 additions & 3 deletions doc/parameter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,8 @@ Additional parameters for ``hist``, ``gpu_hist`` and ``approx`` tree method

- Use single precision to build histograms instead of double precision.

Additional parameters for ``approx`` tree method
================================================
Additional parameters for ``approx`` and ``gpu_hist`` tree method
=================================================================

* ``max_cat_to_onehot``

Expand All @@ -257,7 +257,8 @@ Additional parameters for ``approx`` tree method
- A threshold for deciding whether XGBoost should use one-hot encoding based split for
categorical data. When number of categories is lesser than the threshold then one-hot
encoding is chosen, otherwise the categories will be partitioned into children nodes.
Only relevant for regression and binary classification with `approx` tree method.
Only relevant for regression and binary classification. Also, `approx` or `gpu_hist`
tree method is required.

Additional parameters for Dart Booster (``booster=dart``)
=========================================================
Expand Down
44 changes: 40 additions & 4 deletions doc/tutorials/categorical.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
Categorical Data
################

.. note::

As of XGBoost 1.6, the feature is highly experimental and has limited features

Starting from version 1.5, XGBoost has experimental support for categorical data available
for public testing. At the moment, the support is implemented as one-hot encoding based
categorical tree splits. For numerical data, the split condition is defined as
Expand Down Expand Up @@ -107,6 +111,28 @@ For numerical data, the feature type can be ``"q"`` or ``"float"``, while for ca
feature it's specified as ``"c"``. The Dask module in XGBoost has the same interface so
:class:`dask.Array <dask.Array>` can also be used as categorical data.

********************
Optimal Partitioning
********************

.. versionadded:: 1.6

Optimal partitioning is a technique for partitioning the categorical predictors for each
node split, the proof of optimality for numerical objectives like ``RMSE`` was first
introduced by `[1] <#references>`__. The algorithm is used in decision trees for handling
regression and binary classification tasks `[2] <#references>`__, later LightGBM `[3]
<#references>`__ brought it to the context of gradient boosting trees and now is also
adopted in XGBoost as an optional feature for handling categorical splits. More
specifically, the proof by Fisher `[1] <#references>`__ states that, when trying to
partition a set of discrete values into groups based on the distances between a measure of
these values, one only needs to look at sorted partitions instead of enumerating all
possible permutations. In the context of decision trees, the discrete values are
categories, and the measure is the output leaf value. Intuitively, we want to group the
categories that output similar leaf values. During split finding, we first sort the
gradient histogram to prepare the contiguous partitions then enumerate the splits
according to these sorted values. One of the related parameters for XGBoost is
``max_cat_to_one_hot``, which controls whether one-hot encoding or partitioning should be
used for each feature, see :doc:`/parameter` for details.

*************
Miscellaneous
Expand All @@ -120,10 +146,20 @@ actual number of unique categories. During training this is validated but for p
it's treated as the same as missing value for performance reasons. Lastly, missing values
are treated as the same as numerical features (using the learned split direction).


**********
Next Steps
References
**********

As of XGBoost 1.5, the feature is highly experimental and have limited features like CPU
training is not yet supported. Please see `this issue
<https://github.com/dmlc/xgboost/issues/6503>`_ for progress.
[1] Walter D. Fisher. "`On Grouping for Maximum Homogeneity`_." Journal of the American Statistical Association. Vol. 53, No. 284 (Dec., 1958), pp. 789-798.

[2] Trevor Hastie, Robert Tibshirani, Jerome Friedman. "`The Elements of Statistical Learning`_". Springer Series in Statistics Springer New York Inc. (2001).

[3] Guolin Ke, Qi Meng, Thomas Finley, Taifeng Wang, Wei Chen, Weidong Ma, Qiwei Ye, Tie-Yan Liu. "`LightGBM\: A Highly Efficient Gradient Boosting Decision Tree`_." Advances in Neural Information Processing Systems 30 (NIPS 2017), pp. 3149-3157.


.. _On Grouping for Maximum Homogeneity: https://www.tandfonline.com/doi/abs/10.1080/01621459.1958.10501479

.. _The Elements of Statistical Learning: https://link.springer.com/book/10.1007/978-0-387-84858-7

.. _LightGBM\: A Highly Efficient Gradient Boosting Decision Tree: https://papers.nips.cc/paper/6907-lightgbm-a-highly-efficient-gradient-boosting-decision-tree.pdf
6 changes: 5 additions & 1 deletion include/xgboost/task.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright 2021 by XGBoost Contributors
* Copyright 2021-2022 by XGBoost Contributors
*/
#ifndef XGBOOST_TASK_H_
#define XGBOOST_TASK_H_
Expand Down Expand Up @@ -34,6 +34,10 @@ struct ObjInfo {

explicit ObjInfo(Task t) : task{t} {}
ObjInfo(Task t, bool khess) : task{t}, const_hess{khess} {}

constexpr bool UseOneHot() const {
return (task != ObjInfo::kRegression && task != ObjInfo::kBinary);
}
};
} // namespace xgboost
#endif // XGBOOST_TASK_H_
8 changes: 4 additions & 4 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,10 +581,10 @@ def __init__(
.. versionadded:: 1.3.0
Experimental support of specializing for categorical features. Do not set to
True unless you are interested in development. Currently it's only available
for `gpu_hist` tree method with 1 vs rest (one hot) categorical split. Also,
JSON serialization format is required.
Experimental support of specializing for categorical features. Do not set
to True unless you are interested in development. Currently it's only
available for `gpu_hist` and `approx` tree methods. Also, JSON/UBJSON
serialization format is required. (XGBoost 1.6 for approx)
"""
if group is not None and qid is not None:
Expand Down
13 changes: 8 additions & 5 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,9 @@ def inner(y_score: np.ndarray, dmatrix: DMatrix) -> Tuple[str, float]:
.. versionadded:: 1.5.0
Experimental support for categorical data. Do not set to true unless you are
interested in development. Only valid when `gpu_hist` and dataframe are used.
interested in development. Only valid when `gpu_hist` or `approx` is used along
with dataframe as input. Also, JSON/UBJSON serialization format is
required. (XGBoost 1.6 for approx)
max_cat_to_onehot : Optional[int]
Expand All @@ -216,10 +218,11 @@ def inner(y_score: np.ndarray, dmatrix: DMatrix) -> Tuple[str, float]:
.. note:: This parameter is experimental
A threshold for deciding whether XGBoost should use one-hot encoding based split
for categorical data. When number of categories is lesser than the threshold then
one-hot encoding is chosen, otherwise the categories will be partitioned into
children nodes. Only relevant for regression and binary classification and
`approx` tree method.
for categorical data. When number of categories is lesser than the threshold
then one-hot encoding is chosen, otherwise the categories will be partitioned
into children nodes. Only relevant for regression and binary
classification. Also, ``approx`` or ``gpu_hist`` tree method is required. See
:doc:`Categorical Data </tutorials/categorical>` for details.
eval_metric : Optional[Union[str, List[str], Callable]]
Expand Down
17 changes: 11 additions & 6 deletions src/common/categorical.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@

namespace xgboost {
namespace common {

using CatBitField = LBitField32;
using KCatBitField = CLBitField32;

// Cast the categorical type.
template <typename T>
XGBOOST_DEVICE bst_cat_t AsCat(T const& v) {
Expand Down Expand Up @@ -57,6 +61,11 @@ inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, float cat
if (XGBOOST_EXPECT(validate && (InvalidCat(cat) || cat >= s_cats.Size()), false)) {
return dft_left;
}

auto pos = KCatBitField::ToBitPos(cat);
if (pos.int_pos >= cats.size()) {
return true;
}
return !s_cats.Check(AsCat(cat));
}

Expand All @@ -73,18 +82,14 @@ inline void InvalidCategory() {
/*!
* \brief Whether should we use onehot encoding for categorical data.
*/
inline bool UseOneHot(uint32_t n_cats, uint32_t max_cat_to_onehot, ObjInfo task) {
bool use_one_hot = n_cats < max_cat_to_onehot ||
(task.task != ObjInfo::kRegression && task.task != ObjInfo::kBinary);
XGBOOST_DEVICE inline bool UseOneHot(uint32_t n_cats, uint32_t max_cat_to_onehot, ObjInfo task) {
bool use_one_hot = n_cats < max_cat_to_onehot || task.UseOneHot();
return use_one_hot;
}

struct IsCatOp {
XGBOOST_DEVICE bool operator()(FeatureType ft) { return ft == FeatureType::kCategorical; }
};

using CatBitField = LBitField32;
using KCatBitField = CLBitField32;
} // namespace common
} // namespace xgboost

Expand Down
77 changes: 71 additions & 6 deletions src/common/device_helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -952,22 +952,22 @@ thrust::device_ptr<T const> tcend(xgboost::HostDeviceVector<T> const& vector) {
}

template <typename T>
thrust::device_ptr<T> tbegin(xgboost::common::Span<T>& span) { // NOLINT
XGBOOST_DEVICE thrust::device_ptr<T> tbegin(xgboost::common::Span<T>& span) { // NOLINT
return thrust::device_ptr<T>(span.data());
}

template <typename T>
thrust::device_ptr<T> tbegin(xgboost::common::Span<T> const& span) { // NOLINT
XGBOOST_DEVICE thrust::device_ptr<T> tbegin(xgboost::common::Span<T> const& span) { // NOLINT
return thrust::device_ptr<T>(span.data());
}

template <typename T>
thrust::device_ptr<T> tend(xgboost::common::Span<T>& span) { // NOLINT
XGBOOST_DEVICE thrust::device_ptr<T> tend(xgboost::common::Span<T>& span) { // NOLINT
return tbegin(span) + span.size();
}

template <typename T>
thrust::device_ptr<T> tend(xgboost::common::Span<T> const& span) { // NOLINT
XGBOOST_DEVICE thrust::device_ptr<T> tend(xgboost::common::Span<T> const& span) { // NOLINT
return tbegin(span) + span.size();
}

Expand All @@ -982,12 +982,12 @@ XGBOOST_DEVICE auto trend(xgboost::common::Span<T> &span) { // NOLINT
}

template <typename T>
thrust::device_ptr<T const> tcbegin(xgboost::common::Span<T> const& span) { // NOLINT
XGBOOST_DEVICE thrust::device_ptr<T const> tcbegin(xgboost::common::Span<T> const& span) { // NOLINT
return thrust::device_ptr<T const>(span.data());
}

template <typename T>
thrust::device_ptr<T const> tcend(xgboost::common::Span<T> const& span) { // NOLINT
XGBOOST_DEVICE thrust::device_ptr<T const> tcend(xgboost::common::Span<T> const& span) { // NOLINT
return tcbegin(span) + span.size();
}

Expand Down Expand Up @@ -1536,4 +1536,69 @@ void SegmentedArgSort(xgboost::common::Span<U> values,
safe_cuda(cudaMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(),
sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice));
}

class CUDAStreamView;

class CUDAEvent {
cudaEvent_t event_{nullptr};

public:
CUDAEvent() { dh::safe_cuda(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); }
~CUDAEvent() {
if (event_) {
dh::safe_cuda(cudaEventDestroy(event_));
}
}

CUDAEvent(CUDAEvent const &that) = delete;
CUDAEvent &operator=(CUDAEvent const &that) = delete;

inline void Record(CUDAStreamView stream); // NOLINT

operator cudaEvent_t() const { return event_; } // NOLINT
};

class CUDAStreamView {
cudaStream_t stream_{nullptr};

public:
explicit CUDAStreamView(cudaStream_t s) : stream_{s} {}
void Wait(CUDAEvent const &e) {
#if defined(__CUDACC_VER_MAJOR__)
#if __CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ == 0
// CUDA == 11.0
dh::safe_cuda(cudaStreamWaitEvent(stream_, cudaEvent_t{e}, 0));
#else
// CUDA > 11.0
dh::safe_cuda(cudaStreamWaitEvent(stream_, cudaEvent_t{e}, cudaEventWaitDefault));
#endif // __CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ == 0:
#else // clang
dh::safe_cuda(cudaStreamWaitEvent(stream_, cudaEvent_t{e}, cudaEventWaitDefault));
#endif // defined(__CUDACC_VER_MAJOR__)
}
operator cudaStream_t() const { // NOLINT
return stream_;
}
void Sync() { dh::safe_cuda(cudaStreamSynchronize(stream_)); }
};

inline void CUDAEvent::Record(CUDAStreamView stream) { // NOLINT
dh::safe_cuda(cudaEventRecord(event_, cudaStream_t{stream}));
}

inline CUDAStreamView DefaultStream() { return CUDAStreamView{cudaStreamLegacy}; }

class CUDAStream {
cudaStream_t stream_;

public:
CUDAStream() {
dh::safe_cuda(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking));
}
~CUDAStream() {
dh::safe_cuda(cudaStreamDestroy(stream_));
}

CUDAStreamView View() const { return CUDAStreamView{stream_}; }
};
} // namespace dh
Loading

0 comments on commit 0d0abe1

Please sign in to comment.