Skip to content

Commit

Permalink
Use Booster context in DMatrix. (#8896)
Browse files Browse the repository at this point in the history
- Pass context from booster to DMatrix.
- Use context instead of integer for `n_threads`.
- Check the consistency configuration for `max_bin`.
- Test for all combinations of initialization options.
  • Loading branch information
trivialfis authored Apr 28, 2023
1 parent 1f9a57d commit 08ce495
Show file tree
Hide file tree
Showing 67 changed files with 1,281 additions and 933 deletions.
12 changes: 12 additions & 0 deletions include/xgboost/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,19 @@ struct Context : public XGBoostParameter<Context> {

bool IsCPU() const { return gpu_id == kCpuId; }
bool IsCUDA() const { return !IsCPU(); }

CUDAContext const* CUDACtx() const;
// Make a CUDA context based on the current context.
Context MakeCUDA(std::int32_t device = 0) const {
Context ctx = *this;
ctx.gpu_id = device;
return ctx;
}
Context MakeCPU() const {
Context ctx = *this;
ctx.gpu_id = kCpuId;
return ctx;
}

// declare parameters
DMLC_DECLARE_PARAMETER(Context) {
Expand Down
130 changes: 83 additions & 47 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright (c) 2015-2022 by XGBoost Contributors
/**
* Copyright 2015-2023 by XGBoost Contributors
* \file data.h
* \brief The input data structure of xgboost.
* \author Tianqi Chen
Expand Down Expand Up @@ -238,44 +238,72 @@ struct Entry {
}
};

/*!
* \brief Parameters for constructing batches.
/**
* \brief Parameters for constructing histogram index batches.
*/
struct BatchParam {
/*! \brief The GPU device to use. */
int gpu_id {-1};
/*! \brief Maximum number of bins per feature for histograms. */
/**
* \brief Maximum number of bins per feature for histograms.
*/
bst_bin_t max_bin{0};
/*! \brief Hessian, used for sketching with future approx implementation. */
/**
* \brief Hessian, used for sketching with future approx implementation.
*/
common::Span<float> hess;
/*! \brief Whether should DMatrix regenerate the batch. Only used for GHistIndex. */
bool regen {false};
/*! \brief Parameter used to generate column matrix for hist. */
/**
* \brief Whether should we force DMatrix to regenerate the batch. Only used for
* GHistIndex.
*/
bool regen{false};
/**
* \brief Forbid regenerating the gradient index. Used for internal validation.
*/
bool forbid_regen{false};
/**
* \brief Parameter used to generate column matrix for hist.
*/
double sparse_thresh{std::numeric_limits<double>::quiet_NaN()};

/**
* \brief Exact or others that don't need histogram.
*/
BatchParam() = default;
// GPU Hist
BatchParam(int32_t device, bst_bin_t max_bin)
: gpu_id{device}, max_bin{max_bin} {}
// Hist
/**
* \brief Used by the hist tree method.
*/
BatchParam(bst_bin_t max_bin, double sparse_thresh)
: max_bin{max_bin}, sparse_thresh{sparse_thresh} {}
// Approx
/**
* \brief Get batch with sketch weighted by hessian. The batch will be regenerated if
* the span is changed, so caller should keep the span for each iteration.
* \brief Used by the approx tree method.
*
* Get batch with sketch weighted by hessian. The batch will be regenerated if the
* span is changed, so caller should keep the span for each iteration.
*/
BatchParam(bst_bin_t max_bin, common::Span<float> hessian, bool regenerate)
: max_bin{max_bin}, hess{hessian}, regen{regenerate} {}

bool operator!=(BatchParam const& other) const {
if (hess.empty() && other.hess.empty()) {
return gpu_id != other.gpu_id || max_bin != other.max_bin;
}
return gpu_id != other.gpu_id || max_bin != other.max_bin || hess.data() != other.hess.data();
bool ParamNotEqual(BatchParam const& other) const {
// Check non-floating parameters.
bool cond = max_bin != other.max_bin;
// Check sparse thresh.
bool l_nan = std::isnan(sparse_thresh);
bool r_nan = std::isnan(other.sparse_thresh);
bool st_chg = (l_nan != r_nan) || (!l_nan && !r_nan && (sparse_thresh != other.sparse_thresh));
cond |= st_chg;

return cond;
}
bool operator==(BatchParam const& other) const {
return !(*this != other);
bool Initialized() const { return max_bin != 0; }
/**
* \brief Make a copy of self for DMatrix to describe how its existing index was generated.
*/
BatchParam MakeCache() const {
auto p = *this;
// These parameters have nothing to do with how the gradient index was generated in the
// first place.
p.regen = false;
p.forbid_regen = false;
return p;
}
};

Expand Down Expand Up @@ -435,7 +463,7 @@ class EllpackPage {
* This is used in the in-memory case. The ELLPACK page is constructed from an existing DMatrix
* in CSR format.
*/
explicit EllpackPage(DMatrix* dmat, const BatchParam& param);
explicit EllpackPage(Context const* ctx, DMatrix* dmat, const BatchParam& param);

/*! \brief Destructor. */
~EllpackPage();
Expand Down Expand Up @@ -551,7 +579,9 @@ class DMatrix {
template <typename T>
BatchSet<T> GetBatches();
template <typename T>
BatchSet<T> GetBatches(const BatchParam& param);
BatchSet<T> GetBatches(Context const* ctx);
template <typename T>
BatchSet<T> GetBatches(Context const* ctx, const BatchParam& param);
template <typename T>
bool PageExists() const;

Expand Down Expand Up @@ -658,18 +688,19 @@ class DMatrix {

protected:
virtual BatchSet<SparsePage> GetRowBatches() = 0;
virtual BatchSet<CSCPage> GetColumnBatches() = 0;
virtual BatchSet<SortedCSCPage> GetSortedColumnBatches() = 0;
virtual BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) = 0;
virtual BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam& param) = 0;
virtual BatchSet<ExtSparsePage> GetExtBatches(BatchParam const& param) = 0;
virtual BatchSet<CSCPage> GetColumnBatches(Context const* ctx) = 0;
virtual BatchSet<SortedCSCPage> GetSortedColumnBatches(Context const* ctx) = 0;
virtual BatchSet<EllpackPage> GetEllpackBatches(Context const* ctx, BatchParam const& param) = 0;
virtual BatchSet<GHistIndexMatrix> GetGradientIndex(Context const* ctx,
BatchParam const& param) = 0;
virtual BatchSet<ExtSparsePage> GetExtBatches(Context const* ctx, BatchParam const& param) = 0;

virtual bool EllpackExists() const = 0;
virtual bool GHistIndexExists() const = 0;
virtual bool SparsePageExists() const = 0;
};

template<>
template <>
inline BatchSet<SparsePage> DMatrix::GetBatches() {
return GetRowBatches();
}
Expand All @@ -684,34 +715,39 @@ inline bool DMatrix::PageExists<GHistIndexMatrix>() const {
return this->GHistIndexExists();
}

template<>
template <>
inline bool DMatrix::PageExists<SparsePage>() const {
return this->SparsePageExists();
}

template<>
inline BatchSet<CSCPage> DMatrix::GetBatches() {
return GetColumnBatches();
template <>
inline BatchSet<SparsePage> DMatrix::GetBatches(Context const*) {
return GetRowBatches();
}

template<>
inline BatchSet<SortedCSCPage> DMatrix::GetBatches() {
return GetSortedColumnBatches();
template <>
inline BatchSet<CSCPage> DMatrix::GetBatches(Context const* ctx) {
return GetColumnBatches(ctx);
}

template<>
inline BatchSet<EllpackPage> DMatrix::GetBatches(const BatchParam& param) {
return GetEllpackBatches(param);
template <>
inline BatchSet<SortedCSCPage> DMatrix::GetBatches(Context const* ctx) {
return GetSortedColumnBatches(ctx);
}

template <>
inline BatchSet<EllpackPage> DMatrix::GetBatches(Context const* ctx, BatchParam const& param) {
return GetEllpackBatches(ctx, param);
}

template <>
inline BatchSet<GHistIndexMatrix> DMatrix::GetBatches(const BatchParam& param) {
return GetGradientIndex(param);
inline BatchSet<GHistIndexMatrix> DMatrix::GetBatches(Context const* ctx, BatchParam const& param) {
return GetGradientIndex(ctx, param);
}

template <>
inline BatchSet<ExtSparsePage> DMatrix::GetBatches() {
return GetExtBatches(BatchParam{});
inline BatchSet<ExtSparsePage> DMatrix::GetBatches(Context const* ctx, BatchParam const& param) {
return GetExtBatches(ctx, param);
}
} // namespace xgboost

Expand Down
6 changes: 4 additions & 2 deletions python-package/xgboost/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,13 +317,15 @@ def get_dmat(self) -> xgb.DMatrix:
enable_categorical=True,
)

def get_device_dmat(self) -> xgb.QuantileDMatrix:
def get_device_dmat(self, max_bin: Optional[int]) -> xgb.QuantileDMatrix:
import cupy as cp

w = None if self.w is None else cp.array(self.w)
X = cp.array(self.X, dtype=np.float32)
y = cp.array(self.y, dtype=np.float32)
return xgb.QuantileDMatrix(X, y, weight=w, base_margin=self.margin)
return xgb.QuantileDMatrix(
X, y, weight=w, base_margin=self.margin, max_bin=max_bin
)

def get_external_dmat(self) -> xgb.DMatrix:
n_samples = self.X.shape[0]
Expand Down
74 changes: 47 additions & 27 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,50 @@
*/
#include "xgboost/c_api.h"

#include <rabit/c_api.h>

#include <cstring>
#include <fstream>
#include <memory>
#include <string>
#include <vector>

#include "../collective/communicator-inl.h"
#include "../common/api_entry.h" // XGBAPIThreadLocalEntry
#include "../common/charconv.h"
#include "../common/io.h"
#include "../data/adapter.h"
#include "../data/simple_dmatrix.h"
#include "c_api_utils.h"
#include "xgboost/base.h"
#include "xgboost/data.h"
#include "xgboost/global_config.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/json.h"
#include "xgboost/learner.h"
#include "xgboost/logging.h"
#include "xgboost/string_view.h" // StringView
#include "xgboost/version_config.h"
#include <algorithm> // for copy
#include <cinttypes> // for strtoimax
#include <cmath> // for nan
#include <cstring> // for strcmp
#include <fstream> // for operator<<, basic_ostream, ios, stringstream
#include <functional> // for less
#include <limits> // for numeric_limits
#include <map> // for operator!=, _Rb_tree_const_iterator, _Rb_tre...
#include <memory> // for shared_ptr, allocator, __shared_ptr_access
#include <string> // for char_traits, basic_string, operator==, string
#include <system_error> // for errc
#include <utility> // for pair
#include <vector> // for vector

#include "../collective/communicator-inl.h" // for Allreduce, Broadcast, Finalize, GetProcessor...
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
#include "../common/charconv.h" // for from_chars, to_chars, NumericLimits, from_ch...
#include "../common/io.h" // for FileExtension, LoadSequentialFile, MemoryBuf...
#include "../common/threading_utils.h" // for OmpGetNumThreads, ParallelFor
#include "../data/adapter.h" // for ArrayAdapter, DenseAdapter, RecordBatchesIte...
#include "../data/proxy_dmatrix.h" // for DMatrixProxy
#include "../data/simple_dmatrix.h" // for SimpleDMatrix
#include "c_api_error.h" // for xgboost_CHECK_C_ARG_PTR, API_END, API_BEGIN
#include "c_api_utils.h" // for RequiredArg, OptionalArg, GetMissing, CastDM...
#include "dmlc/base.h" // for BeginPtr, DMLC_ATTRIBUTE_UNUSED
#include "dmlc/io.h" // for Stream
#include "dmlc/parameter.h" // for FieldAccessEntry, FieldEntry, ParamManager
#include "dmlc/thread_local.h" // for ThreadLocalStore
#include "rabit/c_api.h" // for RabitLinkTag
#include "rabit/rabit.h" // for CheckPoint, LoadCheckPoint
#include "xgboost/base.h" // for bst_ulong, bst_float, GradientPair, bst_feat...
#include "xgboost/context.h" // for Context
#include "xgboost/data.h" // for DMatrix, MetaInfo, DataType, ExtSparsePage
#include "xgboost/feature_map.h" // for FeatureMap
#include "xgboost/global_config.h" // for GlobalConfiguration, GlobalConfigThreadLocal...
#include "xgboost/host_device_vector.h" // for HostDeviceVector
#include "xgboost/intrusive_ptr.h" // for xgboost
#include "xgboost/json.h" // for Json, get, Integer, IsA, Boolean, String
#include "xgboost/learner.h" // for Learner, PredictionType
#include "xgboost/logging.h" // for LOG_FATAL, LogMessageFatal, CHECK, LogCheck_EQ
#include "xgboost/predictor.h" // for PredictionCacheEntry
#include "xgboost/span.h" // for Span
#include "xgboost/string_view.h" // for StringView, operator<<
#include "xgboost/version_config.h" // for XGBOOST_VER_MAJOR, XGBOOST_VER_MINOR, XGBOOS...

#if defined(XGBOOST_USE_FEDERATED)
#include "../../plugin/federated/federated_server.h"
Expand Down Expand Up @@ -341,10 +361,10 @@ XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHand
API_END();
}

XGB_DLL int XGProxyDMatrixCreate(DMatrixHandle* out) {
XGB_DLL int XGProxyDMatrixCreate(DMatrixHandle *out) {
API_BEGIN();
xgboost_CHECK_C_ARG_PTR(out);
*out = new std::shared_ptr<xgboost::DMatrix>(new xgboost::data::DMatrixProxy);;
*out = new std::shared_ptr<xgboost::DMatrix>(new xgboost::data::DMatrixProxy);
API_END();
}

Expand Down Expand Up @@ -746,7 +766,7 @@ XGB_DLL int XGDMatrixGetDataAsCSR(DMatrixHandle const handle, char const *config

CHECK_LE(p_m->Info().num_col_, std::numeric_limits<unsigned>::max());

for (auto const &page : p_m->GetBatches<ExtSparsePage>()) {
for (auto const &page : p_m->GetBatches<ExtSparsePage>(p_m->Ctx(), BatchParam{})) {
CHECK(page.page);
auto const &h_offset = page.page->offset.ConstHostVector();
std::copy(h_offset.cbegin(), h_offset.cend(), out_indptr);
Expand Down
5 changes: 5 additions & 0 deletions src/common/error_msg.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,10 @@ constexpr StringView InfInData() {
constexpr StringView NoF128() {
return "128-bit floating point is not supported on current platform.";
}

constexpr StringView InconsistentMaxBin() {
return "Inconsistent `max_bin`. `max_bin` should be the same across different QuantileDMatrix, "
"and consistent with the Booster being trained.";
}
} // namespace xgboost::error
#endif // XGBOOST_COMMON_ERROR_MSG_H_
Loading

0 comments on commit 08ce495

Please sign in to comment.