Skip to content

Commit

Permalink
Rewrite approx.
Browse files Browse the repository at this point in the history
Save cuts.

Prototype on fetching.

Copy the code.

Simple test.

Add gpair to batch parameter.

Add hessian to batch parameter.

Move.

Pass hessian into sketching.

Extract a push page function.

Make private.

Lint.

Revert debug.

Simple DMatrix.

Regenerate the index.

ama.

Clang tidy.

Retain page.

Fix.

Lint.

Tidy.

Integer backed enum.

Convert to uint32_t.

Prototype for saving gidx.

Save cuts.

Prototype on fetching.

Copy the code.

Simple test.

Add gpair to batch parameter.

Add hessian to batch parameter.

Move.

Pass hessian into sketching.

Extract a push page function.

Make private.

Lint.

Revert debug.

Simple DMatrix.

Initial port.

Pass in hessian.

Init column sampler.

Unused code.

Use ctx.

Merge sampling.

Use ctx in partition.

Fix init root.

Force regenerate the sketch.

Create a ctx.

Get it compile.

Don't use const method.

Use page id.

Pass in base row id.

Pass the cut instead.

Small fixes.

Debug.

Fix bin size.

Debug.

Fixes.

Debug.

Fix empty partition.

Remove comment.

Lint.

Fix tests compilation.

Remove check.

Merge some fixes.

fix.

Fix fetching.

lint.

Extract expand entry.

Lint.

Fix unittests.

Fix windows build.

Fix comparison.

Make const.

Note.

const.

Fix reduce hist.

Fix sparse data.

Avoid implicit conversion.

private.

mem leak.

Remove skip initialization.

Use maximum space.

demo.

lint.

File link tags.

ama.

Fix redefinition.

Fix ranking.

use npy.

Comment.

Tune it down.

Specify the tree method.

Get rid of the duplicated partitioner.

Allocate task.

Tests.

make batches.

Log.

Remove span.

Revert "make batches."

This reverts commit 33f7072.

small cleanup.

Lint.

Revert demo.

Better make batches.

Demo.

Test for grow policy.

Test feature weights.

small cleanup.

Remove iterator in evaluation.

Fix dask test.

Pass n_threads.

Start implementation for categorical data.

Fix.

Add apply split.

Enumerate splits.

Enable sklearn.

Works.

d_step.

update.

Pass feature types into index.

Search cut.

Add test.

As cat.

Fix cut.

Extract some tests.

Fix.

Interesting case.

Add Python tests.

Cleanup.

Revert "Interesting case."

This reverts commit 6bbaac2.

Bin.

Fix.

Dispatch.

Remove subtraction trick.

Lint

Use multiple buffers.

Revert "Use multiple buffers."

This reverts commit 2849f57.

Test for external memory.

Format.

Partition based categorical split.

Remove debug code.

Fix.

Lint.

Fix test.

Fix demo.

Fix.

Add test.

Remove use of omp func.

name.

Fix.

test.

Make LCG impl compliant to std.

Fix test.

Constexpr.

Use unsigned type.

osx

More test.

Rebase error.

Rebase error.
  • Loading branch information
trivialfis committed Nov 7, 2021
1 parent d7d1b6e commit aed9d27
Show file tree
Hide file tree
Showing 29 changed files with 1,403 additions and 470 deletions.
1 change: 1 addition & 0 deletions amalgamation/xgboost-all0.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
#include "../src/tree/updater_refresh.cc"
#include "../src/tree/updater_sync.cc"
#include "../src/tree/updater_histmaker.cc"
#include "../src/tree/updater_approx.cc"
#include "../src/tree/constraints.cc"

// linear
Expand Down
4 changes: 2 additions & 2 deletions demo/guide-python/categorical.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Experimental support for categorical data. After 1.5 XGBoost `gpu_hist` tree method
has experimental support for one-hot encoding based tree split.
"""Experimental support for categorical data. After 1.6 XGBoost `gpu_hist` and `approx`
tree method have experimental support for one-hot encoding based tree split.
In before, users need to run an encoder themselves before passing the data into XGBoost,
which creates a sparse matrix and potentially increase memory usage. This demo showcases
Expand Down
16 changes: 13 additions & 3 deletions doc/parameter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -232,15 +232,25 @@ Parameters for Tree Booster
- Constraints for interaction representing permitted interactions. The constraints must
be specified in the form of a nest list, e.g. ``[[0, 1], [2, 3, 4]]``, where each inner
list is a group of indices of features that are allowed to interact with each other.
See tutorial for more information
See tutorial for more information.

Additional parameters for ``hist`` and ``gpu_hist`` tree method
================================================================
Additional parameters for ``hist`` and ``gpu_hist`` and ``approx`` tree method
==============================================================================

* ``single_precision_histogram``, [default=``false``]

- Use single precision to build histograms instead of double precision.

Additional parameters for ``approx`` tree method
================================================

* ``max_cat_to_onehot``

- A threshold for deciding whether XGBoost should use one-hot encoding based split for
categorical data. When number of categories is lesser than the threshold then one-hot
encoding is chosen, otherwise the categories will be partitioned into children nodes.
Only relevant for regression and binary classification and `approx` tree method.

Additional parameters for Dart Booster (``booster=dart``)
=========================================================

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,12 @@ abstract class XGBoostRegressorSuiteBase extends FunSuite with PerTest {
}
}

class XGBoostCpuRegressorSuite extends XGBoostRegressorSuiteBase {
class XGBoostCpuRegressorSuiteApprox extends XGBoostRegressorSuiteBase {
override protected val treeMethod: String = "approx"
}

class XGBoostCpuRegressorSuiteHist extends XGBoostRegressorSuiteBase {
override protected val treeMethod: String = "hist"
}

@GpuTestSuite
Expand Down
16 changes: 14 additions & 2 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,16 @@ def inner(y_score: np.ndarray, dmatrix: DMatrix) -> Tuple[str, float]:
callbacks = [xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
save_best=True)]
max_cat_to_onehot : bool
.. versionadded:: 1.6.0
A threshold for deciding whether XGBoost should use one-hot encoding based split
for categorical data. When number of categories is lesser than the threshold then
one-hot encoding is chosen, otherwise the categories will be partitioned into
children nodes. Only relevant for regression and binary classification and
`approx` tree method.
kwargs : dict, optional
Keyword arguments for XGBoost Booster object. Full documentation of
parameters can be found here:
Expand Down Expand Up @@ -484,6 +494,7 @@ def __init__(
eval_metric: Optional[Union[str, List[str], Callable]] = None,
early_stopping_rounds: Optional[int] = None,
callbacks: Optional[List[TrainingCallback]] = None,
max_cat_to_onehot: Optional[int] = None,
**kwargs: Any
) -> None:
if not SKLEARN_INSTALLED:
Expand Down Expand Up @@ -523,6 +534,7 @@ def __init__(
self.eval_metric = eval_metric
self.early_stopping_rounds = early_stopping_rounds
self.callbacks = callbacks
self.max_cat_to_onehot = max_cat_to_onehot
if kwargs:
self.kwargs = kwargs

Expand Down Expand Up @@ -801,8 +813,8 @@ def _duplicated(parameter: str) -> None:
_duplicated("callbacks")
callbacks = self.callbacks if self.callbacks is not None else callbacks

# lastly check categorical data support.
if self.enable_categorical and params.get("tree_method", None) != "gpu_hist":
tree_method = params.get("tree_method", None)
if self.enable_categorical and tree_method not in ("gpu_hist", "approx"):
raise ValueError(
"Experimental support for categorical data is not implemented for"
" current tree method yet."
Expand Down
8 changes: 8 additions & 0 deletions src/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,14 @@ std::vector<Idx> ArgSort(Container const &array, Comp comp = std::less<V>{}) {
XGBOOST_PARALLEL_STABLE_SORT(result.begin(), result.end(), op);
return result;
}

/**
* Last index of a group in a CSR style of index pointer.
*/
template <typename Idx, typename Indptr>
XGBOOST_DEVICE size_t LastOf(Idx group, common::Span<Indptr> indptr) {
return indptr[group + 1] - 1;
}
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_COMMON_H_
172 changes: 92 additions & 80 deletions src/common/hist_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#include "random.h"
#include "column_matrix.h"
#include "quantile.h"
#include "./../tree/updater_quantile_hist.h"
#include "../data/gradient_index.h"

#if defined(XGBOOST_MM_PREFETCH_PRESENT)
Expand Down Expand Up @@ -133,148 +132,161 @@ struct Prefetch {

constexpr size_t Prefetch::kNoPrefetchSize;


template<typename FPType, bool do_prefetch, typename BinIdxType, bool any_missing = true>
void BuildHistKernel(const std::vector<GradientPair>& gpair,
template <typename FPType, bool do_prefetch, typename BinIdxType,
bool first_page, bool any_missing = true>
void BuildHistKernel(const std::vector<GradientPair> &gpair,
const RowSetCollection::Elem row_indices,
const GHistIndexMatrix& gmat,
GHistRow<FPType> hist) {
const GHistIndexMatrix &gmat, GHistRow<FPType> hist) {
const size_t size = row_indices.Size();
const size_t* rid = row_indices.begin;
const float* pgh = reinterpret_cast<const float*>(gpair.data());
const BinIdxType* gradient_index = gmat.index.data<BinIdxType>();
const size_t* row_ptr = gmat.row_ptr.data();
const uint32_t* offsets = gmat.index.Offset();
const size_t n_features = row_ptr[row_indices.begin[0]+1] - row_ptr[row_indices.begin[0]];
FPType* hist_data = reinterpret_cast<FPType*>(hist.data());
const uint32_t two {2}; // Each element from 'gpair' and 'hist' contains
// 2 FP values: gradient and hessian.
// So we need to multiply each row-index/bin-index by 2
// to work with gradient pairs as a singe row FP array
const size_t *rid = row_indices.begin;
auto const *pgh = reinterpret_cast<const float *>(gpair.data());
const BinIdxType *gradient_index = gmat.index.data<BinIdxType>();

auto const &row_ptr = gmat.row_ptr.data();
auto base_rowid = gmat.base_rowid;
const uint32_t *offsets = gmat.index.Offset();
auto get_row_ptr = [&](size_t ridx) {
return first_page ? row_ptr[ridx] : row_ptr[ridx - base_rowid];
};
auto get_rid = [&](size_t ridx) {
return first_page ? ridx : (ridx - base_rowid);
};

const size_t n_features =
get_row_ptr(row_indices.begin[0] + 1) - get_row_ptr(row_indices.begin[0]);
auto hist_data = reinterpret_cast<FPType *>(hist.data());
const uint32_t two{2}; // Each element from 'gpair' and 'hist' contains
// 2 FP values: gradient and hessian.
// So we need to multiply each row-index/bin-index by 2
// to work with gradient pairs as a singe row FP array

for (size_t i = 0; i < size; ++i) {
const size_t icol_start = any_missing ? row_ptr[rid[i]] : rid[i] * n_features;
const size_t icol_end = any_missing ? row_ptr[rid[i]+1] : icol_start + n_features;
const size_t icol_start =
any_missing ? get_row_ptr(rid[i]) : get_rid(rid[i]) * n_features;
const size_t icol_end =
any_missing ? get_row_ptr(rid[i] + 1) : icol_start + n_features;

const size_t row_size = icol_end - icol_start;
const size_t idx_gh = two * rid[i];

if (do_prefetch) {
const size_t icol_start_prefetch = any_missing ? row_ptr[rid[i+Prefetch::kPrefetchOffset]] :
rid[i + Prefetch::kPrefetchOffset] * n_features;
const size_t icol_end_prefetch = any_missing ? row_ptr[rid[i+Prefetch::kPrefetchOffset]+1] :
icol_start_prefetch + n_features;
const size_t icol_start_prefetch =
any_missing
? get_row_ptr(rid[i + Prefetch::kPrefetchOffset])
: get_rid(rid[i + Prefetch::kPrefetchOffset]) * n_features;
const size_t icol_end_prefetch =
any_missing ? get_row_ptr(rid[i + Prefetch::kPrefetchOffset] + 1)
: icol_start_prefetch + n_features;

PREFETCH_READ_T0(pgh + two * rid[i + Prefetch::kPrefetchOffset]);
for (size_t j = icol_start_prefetch; j < icol_end_prefetch;
j+=Prefetch::GetPrefetchStep<uint32_t>()) {
j += Prefetch::GetPrefetchStep<uint32_t>()) {
PREFETCH_READ_T0(gradient_index + j);
}
}
const BinIdxType* gr_index_local = gradient_index + icol_start;
const BinIdxType *gr_index_local = gradient_index + icol_start;

for (size_t j = 0; j < row_size; ++j) {
const uint32_t idx_bin = two * (static_cast<uint32_t>(gr_index_local[j]) + (
any_missing ? 0 : offsets[j]));

hist_data[idx_bin] += pgh[idx_gh];
hist_data[idx_bin+1] += pgh[idx_gh+1];
const uint32_t idx_bin = two * (static_cast<uint32_t>(gr_index_local[j]) +
(any_missing ? 0 : offsets[j]));
hist_data[idx_bin] += pgh[idx_gh];
hist_data[idx_bin + 1] += pgh[idx_gh + 1];
}
}
}

template<typename FPType, bool do_prefetch, bool any_missing>
void BuildHistDispatch(const std::vector<GradientPair>& gpair,
template <typename FPType, bool do_prefetch, bool any_missing>
void BuildHistDispatch(const std::vector<GradientPair> &gpair,
const RowSetCollection::Elem row_indices,
const GHistIndexMatrix& gmat, GHistRow<FPType> hist) {
switch (gmat.index.GetBinTypeSize()) {
const GHistIndexMatrix &gmat, GHistRow<FPType> hist) {
auto first_page = gmat.base_rowid == 0;
if (first_page) {
switch (gmat.index.GetBinTypeSize()) {
case kUint8BinsTypeSize:
BuildHistKernel<FPType, do_prefetch, uint8_t, any_missing>(gpair, row_indices,
gmat, hist);
BuildHistKernel<FPType, do_prefetch, uint8_t, true, any_missing>(
gpair, row_indices, gmat, hist);
break;
case kUint16BinsTypeSize:
BuildHistKernel<FPType, do_prefetch, uint16_t, any_missing>(gpair, row_indices,
gmat, hist);
BuildHistKernel<FPType, do_prefetch, uint16_t, true, any_missing>(
gpair, row_indices, gmat, hist);
break;
case kUint32BinsTypeSize:
BuildHistKernel<FPType, do_prefetch, uint32_t, any_missing>(gpair, row_indices,
gmat, hist);
BuildHistKernel<FPType, do_prefetch, uint32_t, true, any_missing>(
gpair, row_indices, gmat, hist);
break;
default:
CHECK(false); // no default behavior
}
} else {
switch (gmat.index.GetBinTypeSize()) {
case kUint8BinsTypeSize:
BuildHistKernel<FPType, do_prefetch, uint8_t, false, any_missing>(
gpair, row_indices, gmat, hist);
break;
case kUint16BinsTypeSize:
BuildHistKernel<FPType, do_prefetch, uint16_t, false, any_missing>(
gpair, row_indices, gmat, hist);
break;
case kUint32BinsTypeSize:
BuildHistKernel<FPType, do_prefetch, uint32_t, false, any_missing>(
gpair, row_indices, gmat, hist);
break;
default:
CHECK(false); // no default behavior
}
}
}

template <typename GradientSumT>
template <bool any_missing>
void GHistBuilder<GradientSumT>::BuildHist(
const std::vector<GradientPair> &gpair,
const RowSetCollection::Elem row_indices,
const GHistIndexMatrix &gmat,
GHistRowT hist) {
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
GHistRowT hist) const {
const size_t nrows = row_indices.Size();
const size_t no_prefetch_size = Prefetch::NoPrefetchSize(nrows);

// if need to work with all rows from bin-matrix (e.g. root node)
const bool contiguousBlock = (row_indices.begin[nrows - 1] - row_indices.begin[0]) == (nrows - 1);
const bool contiguousBlock =
(row_indices.begin[nrows - 1] - row_indices.begin[0]) == (nrows - 1);

if (contiguousBlock) {
// contiguous memory access, built-in HW prefetching is enough
BuildHistDispatch<GradientSumT, false, any_missing>(gpair, row_indices, gmat, hist);
BuildHistDispatch<GradientSumT, false, any_missing>(gpair, row_indices,
gmat, hist);
} else {
const RowSetCollection::Elem span1(row_indices.begin, row_indices.end - no_prefetch_size);
const RowSetCollection::Elem span2(row_indices.end - no_prefetch_size, row_indices.end);
const RowSetCollection::Elem span1(row_indices.begin,
row_indices.end - no_prefetch_size);
const RowSetCollection::Elem span2(row_indices.end - no_prefetch_size,
row_indices.end);

BuildHistDispatch<GradientSumT, true, any_missing>(gpair, span1, gmat, hist);
BuildHistDispatch<GradientSumT, true, any_missing>(gpair, span1, gmat,
hist);
// no prefetching to avoid loading extra memory
BuildHistDispatch<GradientSumT, false, any_missing>(gpair, span2, gmat, hist);
BuildHistDispatch<GradientSumT, false, any_missing>(gpair, span2, gmat,
hist);
}
}

template void
GHistBuilder<float>::BuildHist<true>(const std::vector<GradientPair> &gpair,
const RowSetCollection::Elem row_indices,
const GHistIndexMatrix &gmat,
GHistRow<float> hist);
GHistRow<float> hist) const;
template void
GHistBuilder<float>::BuildHist<false>(const std::vector<GradientPair> &gpair,
const RowSetCollection::Elem row_indices,
const GHistIndexMatrix &gmat,
GHistRow<float> hist);
GHistRow<float> hist) const;
template void
GHistBuilder<double>::BuildHist<true>(const std::vector<GradientPair> &gpair,
const RowSetCollection::Elem row_indices,
const GHistIndexMatrix &gmat,
GHistRow<double> hist);
GHistRow<double> hist) const;
template void
GHistBuilder<double>::BuildHist<false>(const std::vector<GradientPair> &gpair,
const RowSetCollection::Elem row_indices,
const GHistIndexMatrix &gmat,
GHistRow<double> hist);

template<typename GradientSumT>
void GHistBuilder<GradientSumT>::SubtractionTrick(GHistRowT self,
GHistRowT sibling,
GHistRowT parent) {
const size_t size = self.size();
CHECK_EQ(sibling.size(), size);
CHECK_EQ(parent.size(), size);

const size_t block_size = 1024; // aproximatly 1024 values per block
size_t n_blocks = size/block_size + !!(size%block_size);

ParallelFor(omp_ulong(n_blocks), [&](omp_ulong iblock) {
const size_t ibegin = iblock*block_size;
const size_t iend = (((iblock+1)*block_size > size) ? size : ibegin + block_size);
SubtractionHist(self, parent, sibling, ibegin, iend);
});
}
template
void GHistBuilder<float>::SubtractionTrick(GHistRow<float> self,
GHistRow<float> sibling,
GHistRow<float> parent);
template
void GHistBuilder<double>::SubtractionTrick(GHistRow<double> self,
GHistRow<double> sibling,
GHistRow<double> parent);

GHistRow<double> hist) const;
} // namespace common
} // namespace xgboost
Loading

0 comments on commit aed9d27

Please sign in to comment.