Skip to content

Commit

Permalink
Use Booster context in DMatrix.
Browse files Browse the repository at this point in the history
CPU build.

fix.

Cleanup the batch param.

tidy.

subtle.

comments.

check.

More tests.

Clarify the check.

Extract.

Extract.

fixes.

Fix rebase.

Fix rebase.

Code comment.
  • Loading branch information
trivialfis committed Apr 28, 2023
1 parent fb94126 commit 9b0f676
Show file tree
Hide file tree
Showing 67 changed files with 1,281 additions and 933 deletions.
12 changes: 12 additions & 0 deletions include/xgboost/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,19 @@ struct Context : public XGBoostParameter<Context> {

bool IsCPU() const { return gpu_id == kCpuId; }
bool IsCUDA() const { return !IsCPU(); }

CUDAContext const* CUDACtx() const;
// Make a CUDA context based on the current context.
Context MakeCUDA(std::int32_t device = 0) const {
Context ctx = *this;
ctx.gpu_id = device;
return ctx;
}
Context MakeCPU() const {
Context ctx = *this;
ctx.gpu_id = kCpuId;
return ctx;
}

// declare parameters
DMLC_DECLARE_PARAMETER(Context) {
Expand Down
130 changes: 83 additions & 47 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright (c) 2015-2022 by XGBoost Contributors
/**
* Copyright 2015-2023 by XGBoost Contributors
* \file data.h
* \brief The input data structure of xgboost.
* \author Tianqi Chen
Expand Down Expand Up @@ -238,44 +238,72 @@ struct Entry {
}
};

/*!
* \brief Parameters for constructing batches.
/**
* \brief Parameters for constructing histogram index batches.
*/
struct BatchParam {
/*! \brief The GPU device to use. */
int gpu_id {-1};
/*! \brief Maximum number of bins per feature for histograms. */
/**
* \brief Maximum number of bins per feature for histograms.
*/
bst_bin_t max_bin{0};
/*! \brief Hessian, used for sketching with future approx implementation. */
/**
* \brief Hessian, used for sketching with future approx implementation.
*/
common::Span<float> hess;
/*! \brief Whether should DMatrix regenerate the batch. Only used for GHistIndex. */
bool regen {false};
/*! \brief Parameter used to generate column matrix for hist. */
/**
* \brief Whether should we force DMatrix to regenerate the batch. Only used for
* GHistIndex.
*/
bool regen{false};
/**
* \brief Forbid regenerating the gradient index. Used for internal validation.
*/
bool forbid_regen{false};
/**
* \brief Parameter used to generate column matrix for hist.
*/
double sparse_thresh{std::numeric_limits<double>::quiet_NaN()};

/**
* \brief Exact or others that don't need histogram.
*/
BatchParam() = default;
// GPU Hist
BatchParam(int32_t device, bst_bin_t max_bin)
: gpu_id{device}, max_bin{max_bin} {}
// Hist
/**
* \brief Used by the hist tree method.
*/
BatchParam(bst_bin_t max_bin, double sparse_thresh)
: max_bin{max_bin}, sparse_thresh{sparse_thresh} {}
// Approx
/**
* \brief Get batch with sketch weighted by hessian. The batch will be regenerated if
* the span is changed, so caller should keep the span for each iteration.
* \brief Used by the approx tree method.
*
* Get batch with sketch weighted by hessian. The batch will be regenerated if the
* span is changed, so caller should keep the span for each iteration.
*/
BatchParam(bst_bin_t max_bin, common::Span<float> hessian, bool regenerate)
: max_bin{max_bin}, hess{hessian}, regen{regenerate} {}

bool operator!=(BatchParam const& other) const {
if (hess.empty() && other.hess.empty()) {
return gpu_id != other.gpu_id || max_bin != other.max_bin;
}
return gpu_id != other.gpu_id || max_bin != other.max_bin || hess.data() != other.hess.data();
bool ParamNotEqual(BatchParam const& other) const {
// Check non-floating parameters.
bool cond = max_bin != other.max_bin;
// Check sparse thresh.
bool l_nan = std::isnan(sparse_thresh);
bool r_nan = std::isnan(other.sparse_thresh);
bool st_chg = (l_nan != r_nan) || (!l_nan && !r_nan && (sparse_thresh != other.sparse_thresh));
cond |= st_chg;

return cond;
}
bool operator==(BatchParam const& other) const {
return !(*this != other);
bool Initialized() const { return max_bin != 0; }
/**
* \brief Make a copy of self for DMatrix to describe how its existing index was generated.
*/
BatchParam MakeCache() const {
auto p = *this;
// These parameters have nothing to do with how the gradient index was generated in the
// first place.
p.regen = false;
p.forbid_regen = false;
return p;
}
};

Expand Down Expand Up @@ -435,7 +463,7 @@ class EllpackPage {
* This is used in the in-memory case. The ELLPACK page is constructed from an existing DMatrix
* in CSR format.
*/
explicit EllpackPage(DMatrix* dmat, const BatchParam& param);
explicit EllpackPage(Context const* ctx, DMatrix* dmat, const BatchParam& param);

/*! \brief Destructor. */
~EllpackPage();
Expand Down Expand Up @@ -551,7 +579,9 @@ class DMatrix {
template <typename T>
BatchSet<T> GetBatches();
template <typename T>
BatchSet<T> GetBatches(const BatchParam& param);
BatchSet<T> GetBatches(Context const* ctx);
template <typename T>
BatchSet<T> GetBatches(Context const* ctx, const BatchParam& param);
template <typename T>
bool PageExists() const;

Expand Down Expand Up @@ -662,18 +692,19 @@ class DMatrix {

protected:
virtual BatchSet<SparsePage> GetRowBatches() = 0;
virtual BatchSet<CSCPage> GetColumnBatches() = 0;
virtual BatchSet<SortedCSCPage> GetSortedColumnBatches() = 0;
virtual BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) = 0;
virtual BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam& param) = 0;
virtual BatchSet<ExtSparsePage> GetExtBatches(BatchParam const& param) = 0;
virtual BatchSet<CSCPage> GetColumnBatches(Context const* ctx) = 0;
virtual BatchSet<SortedCSCPage> GetSortedColumnBatches(Context const* ctx) = 0;
virtual BatchSet<EllpackPage> GetEllpackBatches(Context const* ctx, BatchParam const& param) = 0;
virtual BatchSet<GHistIndexMatrix> GetGradientIndex(Context const* ctx,
BatchParam const& param) = 0;
virtual BatchSet<ExtSparsePage> GetExtBatches(Context const* ctx, BatchParam const& param) = 0;

virtual bool EllpackExists() const = 0;
virtual bool GHistIndexExists() const = 0;
virtual bool SparsePageExists() const = 0;
};

template<>
template <>
inline BatchSet<SparsePage> DMatrix::GetBatches() {
return GetRowBatches();
}
Expand All @@ -688,34 +719,39 @@ inline bool DMatrix::PageExists<GHistIndexMatrix>() const {
return this->GHistIndexExists();
}

template<>
template <>
inline bool DMatrix::PageExists<SparsePage>() const {
return this->SparsePageExists();
}

template<>
inline BatchSet<CSCPage> DMatrix::GetBatches() {
return GetColumnBatches();
template <>
inline BatchSet<SparsePage> DMatrix::GetBatches(Context const*) {
return GetRowBatches();
}

template<>
inline BatchSet<SortedCSCPage> DMatrix::GetBatches() {
return GetSortedColumnBatches();
template <>
inline BatchSet<CSCPage> DMatrix::GetBatches(Context const* ctx) {
return GetColumnBatches(ctx);
}

template<>
inline BatchSet<EllpackPage> DMatrix::GetBatches(const BatchParam& param) {
return GetEllpackBatches(param);
template <>
inline BatchSet<SortedCSCPage> DMatrix::GetBatches(Context const* ctx) {
return GetSortedColumnBatches(ctx);
}

template <>
inline BatchSet<EllpackPage> DMatrix::GetBatches(Context const* ctx, BatchParam const& param) {
return GetEllpackBatches(ctx, param);
}

template <>
inline BatchSet<GHistIndexMatrix> DMatrix::GetBatches(const BatchParam& param) {
return GetGradientIndex(param);
inline BatchSet<GHistIndexMatrix> DMatrix::GetBatches(Context const* ctx, BatchParam const& param) {
return GetGradientIndex(ctx, param);
}

template <>
inline BatchSet<ExtSparsePage> DMatrix::GetBatches() {
return GetExtBatches(BatchParam{});
inline BatchSet<ExtSparsePage> DMatrix::GetBatches(Context const* ctx, BatchParam const& param) {
return GetExtBatches(ctx, param);
}
} // namespace xgboost

Expand Down
6 changes: 4 additions & 2 deletions python-package/xgboost/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,13 +317,15 @@ def get_dmat(self) -> xgb.DMatrix:
enable_categorical=True,
)

def get_device_dmat(self) -> xgb.QuantileDMatrix:
def get_device_dmat(self, max_bin: Optional[int]) -> xgb.QuantileDMatrix:
import cupy as cp

w = None if self.w is None else cp.array(self.w)
X = cp.array(self.X, dtype=np.float32)
y = cp.array(self.y, dtype=np.float32)
return xgb.QuantileDMatrix(X, y, weight=w, base_margin=self.margin)
return xgb.QuantileDMatrix(
X, y, weight=w, base_margin=self.margin, max_bin=max_bin
)

def get_external_dmat(self) -> xgb.DMatrix:
n_samples = self.X.shape[0]
Expand Down
74 changes: 47 additions & 27 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,50 @@
*/
#include "xgboost/c_api.h"

#include <rabit/c_api.h>

#include <cstring>
#include <fstream>
#include <memory>
#include <string>
#include <vector>

#include "../collective/communicator-inl.h"
#include "../common/api_entry.h" // XGBAPIThreadLocalEntry
#include "../common/charconv.h"
#include "../common/io.h"
#include "../data/adapter.h"
#include "../data/simple_dmatrix.h"
#include "c_api_utils.h"
#include "xgboost/base.h"
#include "xgboost/data.h"
#include "xgboost/global_config.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/json.h"
#include "xgboost/learner.h"
#include "xgboost/logging.h"
#include "xgboost/string_view.h" // StringView
#include "xgboost/version_config.h"
#include <algorithm> // for copy
#include <cinttypes> // for strtoimax
#include <cmath> // for nan
#include <cstring> // for strcmp
#include <fstream> // for operator<<, basic_ostream, ios, stringstream
#include <functional> // for less
#include <limits> // for numeric_limits
#include <map> // for operator!=, _Rb_tree_const_iterator, _Rb_tre...
#include <memory> // for shared_ptr, allocator, __shared_ptr_access
#include <string> // for char_traits, basic_string, operator==, string
#include <system_error> // for errc
#include <utility> // for pair
#include <vector> // for vector

#include "../collective/communicator-inl.h" // for Allreduce, Broadcast, Finalize, GetProcessor...
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
#include "../common/charconv.h" // for from_chars, to_chars, NumericLimits, from_ch...
#include "../common/io.h" // for FileExtension, LoadSequentialFile, MemoryBuf...
#include "../common/threading_utils.h" // for OmpGetNumThreads, ParallelFor
#include "../data/adapter.h" // for ArrayAdapter, DenseAdapter, RecordBatchesIte...
#include "../data/proxy_dmatrix.h" // for DMatrixProxy
#include "../data/simple_dmatrix.h" // for SimpleDMatrix
#include "c_api_error.h" // for xgboost_CHECK_C_ARG_PTR, API_END, API_BEGIN
#include "c_api_utils.h" // for RequiredArg, OptionalArg, GetMissing, CastDM...
#include "dmlc/base.h" // for BeginPtr, DMLC_ATTRIBUTE_UNUSED
#include "dmlc/io.h" // for Stream
#include "dmlc/parameter.h" // for FieldAccessEntry, FieldEntry, ParamManager
#include "dmlc/thread_local.h" // for ThreadLocalStore
#include "rabit/c_api.h" // for RabitLinkTag
#include "rabit/rabit.h" // for CheckPoint, LoadCheckPoint
#include "xgboost/base.h" // for bst_ulong, bst_float, GradientPair, bst_feat...
#include "xgboost/context.h" // for Context
#include "xgboost/data.h" // for DMatrix, MetaInfo, DataType, ExtSparsePage
#include "xgboost/feature_map.h" // for FeatureMap
#include "xgboost/global_config.h" // for GlobalConfiguration, GlobalConfigThreadLocal...
#include "xgboost/host_device_vector.h" // for HostDeviceVector
#include "xgboost/intrusive_ptr.h" // for xgboost
#include "xgboost/json.h" // for Json, get, Integer, IsA, Boolean, String
#include "xgboost/learner.h" // for Learner, PredictionType
#include "xgboost/logging.h" // for LOG_FATAL, LogMessageFatal, CHECK, LogCheck_EQ
#include "xgboost/predictor.h" // for PredictionCacheEntry
#include "xgboost/span.h" // for Span
#include "xgboost/string_view.h" // for StringView, operator<<
#include "xgboost/version_config.h" // for XGBOOST_VER_MAJOR, XGBOOST_VER_MINOR, XGBOOS...

#if defined(XGBOOST_USE_FEDERATED)
#include "../../plugin/federated/federated_server.h"
Expand Down Expand Up @@ -341,10 +361,10 @@ XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHand
API_END();
}

XGB_DLL int XGProxyDMatrixCreate(DMatrixHandle* out) {
XGB_DLL int XGProxyDMatrixCreate(DMatrixHandle *out) {
API_BEGIN();
xgboost_CHECK_C_ARG_PTR(out);
*out = new std::shared_ptr<xgboost::DMatrix>(new xgboost::data::DMatrixProxy);;
*out = new std::shared_ptr<xgboost::DMatrix>(new xgboost::data::DMatrixProxy);
API_END();
}

Expand Down Expand Up @@ -746,7 +766,7 @@ XGB_DLL int XGDMatrixGetDataAsCSR(DMatrixHandle const handle, char const *config

CHECK_LE(p_m->Info().num_col_, std::numeric_limits<unsigned>::max());

for (auto const &page : p_m->GetBatches<ExtSparsePage>()) {
for (auto const &page : p_m->GetBatches<ExtSparsePage>(p_m->Ctx(), BatchParam{})) {
CHECK(page.page);
auto const &h_offset = page.page->offset.ConstHostVector();
std::copy(h_offset.cbegin(), h_offset.cend(), out_indptr);
Expand Down
5 changes: 5 additions & 0 deletions src/common/error_msg.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,10 @@ constexpr StringView LabelScoreSize() {
constexpr StringView InfInData() {
return "Input data contains `inf` or a value too large, while `missing` is not set to `inf`";
}

constexpr StringView InconsistentMaxBin() {
return "Inconsistent `max_bin`. `max_bin` should be the same across different QuantileDMatrix, "
"and consistent with the Booster being trained.";
}
} // namespace xgboost::error
#endif // XGBOOST_COMMON_ERROR_MSG_H_
Loading

0 comments on commit 9b0f676

Please sign in to comment.