Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Predict on Ellpack. #5327

Merged
merged 5 commits into from
Feb 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,19 @@ struct BatchParam {
/*! \brief The GPU device to use. */
int gpu_id;
/*! \brief Maximum number of bins per feature for histograms. */
int max_bin;
int max_bin { 0 };
/*! \brief Number of rows in a GPU batch, used for finding quantiles on GPU. */
int gpu_batch_nrows;
/*! \brief Page size for external memory mode. */
size_t gpu_page_size;

BatchParam() = default;
BatchParam(int32_t device, int32_t max_bin, int32_t gpu_batch_nrows,
size_t gpu_page_size = 0) :
gpu_id{device},
max_bin{max_bin},
gpu_batch_nrows{gpu_batch_nrows},
gpu_page_size{gpu_page_size}
{}
inline bool operator!=(const BatchParam& other) const {
return gpu_id != other.gpu_id ||
max_bin != other.max_bin ||
Expand Down Expand Up @@ -438,6 +445,9 @@ class DMatrix {
*/
template<typename T>
BatchSet<T> GetBatches(const BatchParam& param = {});
template <typename T>
bool PageExists() const;

// the following are column meta data, should be able to answer them fast.
/*! \return Whether the data columns single column block. */
virtual bool SingleColBlock() const = 0;
Expand Down Expand Up @@ -493,13 +503,26 @@ class DMatrix {
virtual BatchSet<CSCPage> GetColumnBatches() = 0;
virtual BatchSet<SortedCSCPage> GetSortedColumnBatches() = 0;
virtual BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) = 0;

virtual bool EllpackExists() const = 0;
virtual bool SparsePageExists() const = 0;
};

template<>
inline BatchSet<SparsePage> DMatrix::GetBatches(const BatchParam&) {
return GetRowBatches();
}

template<>
inline bool DMatrix::PageExists<EllpackPage>() const {
return this->EllpackExists();
}

template<>
inline bool DMatrix::PageExists<SparsePage>() const {
return this->SparsePageExists();
}

template<>
inline BatchSet<CSCPage> DMatrix::GetBatches(const BatchParam&) {
return GetColumnBatches();
Expand Down
8 changes: 4 additions & 4 deletions include/xgboost/tree_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class RegTree : public Model {
/*! \brief tree node */
class Node {
public:
Node() {
XGBOOST_DEVICE Node() {
// assert compact alignment
static_assert(sizeof(Node) == 4 * sizeof(int) + sizeof(Info),
"Node: 64 bit align");
Expand Down Expand Up @@ -422,7 +422,7 @@ class RegTree : public Model {
* \param i feature index.
* \return the i-th feature value
*/
bst_float Fvalue(size_t i) const;
bst_float GetFvalue(size_t i) const;
/*!
* \brief check whether i-th entry is missing
* \param i feature index.
Expand Down Expand Up @@ -565,7 +565,7 @@ inline size_t RegTree::FVec::Size() const {
return data_.size();
}

inline bst_float RegTree::FVec::Fvalue(size_t i) const {
inline bst_float RegTree::FVec::GetFvalue(size_t i) const {
return data_[i].fvalue;
}

Expand All @@ -577,7 +577,7 @@ inline int RegTree::GetLeafIndex(const RegTree::FVec& feat) const {
bst_node_t nid = 0;
while (!(*this)[nid].IsLeaf()) {
unsigned split_index = (*this)[nid].SplitIndex();
nid = this->GetNext(nid, feat.Fvalue(split_index), feat.IsMissing(split_index));
nid = this->GetNext(nid, feat.GetFvalue(split_index), feat.IsMissing(split_index));
}
return nid;
}
Expand Down
4 changes: 2 additions & 2 deletions src/data/ellpack_page.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ __global__ void CompressBinEllpackKernel(
common::CompressedByteT* __restrict__ buffer, // gidx_buffer
const size_t* __restrict__ row_ptrs, // row offset of input data
const Entry* __restrict__ entries, // One batch of input data
const float* __restrict__ cuts, // HistogramCuts::cut
const uint32_t* __restrict__ cut_rows, // HistogramCuts::row_ptrs
const float* __restrict__ cuts, // HistogramCuts::cut_values_
const uint32_t* __restrict__ cut_rows, // HistogramCuts::cut_ptrs_
size_t base_row, // batch_row_begin
size_t n_rows,
size_t row_stride,
Expand Down
9 changes: 8 additions & 1 deletion src/data/ellpack_page.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ struct EllpackInfo {
size_t NumSymbols() const {
return n_bins + 1;
}
size_t NumFeatures() const {
return min_fvalue.size();
}
};

/** \brief Struct for accessing and manipulating an ellpack matrix on the
Expand All @@ -89,7 +92,7 @@ struct EllpackMatrix {

// Get a matrix element, uses binary search for look up Return NaN if missing
// Given a row index and a feature index, returns the corresponding cut value
__device__ bst_float GetElement(size_t ridx, size_t fidx) const {
__device__ int32_t GetBinIndex(size_t ridx, size_t fidx) const {
ridx -= base_rowid;
auto row_begin = info.row_stride * ridx;
auto row_end = row_begin + info.row_stride;
Expand All @@ -103,6 +106,10 @@ struct EllpackMatrix {
info.feature_segments[fidx],
info.feature_segments[fidx + 1]);
}
return gidx;
}
__device__ bst_float GetFvalue(size_t ridx, size_t fidx) const {
auto gidx = GetBinIndex(ridx, fidx);
if (gidx == -1) {
return nan("");
}
Expand Down
10 changes: 7 additions & 3 deletions src/data/simple_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,15 @@ BatchSet<SortedCSCPage> SimpleDMatrix::GetSortedColumnBatches() {
}

BatchSet<EllpackPage> SimpleDMatrix::GetEllpackBatches(const BatchParam& param) {
CHECK_GE(param.gpu_id, 0);
CHECK_GE(param.max_bin, 2);
// ELLPACK page doesn't exist, generate it
if (!ellpack_page_) {
if (!(batch_param_ != BatchParam{})) {
CHECK(param != BatchParam{}) << "Batch parameter is not initialized.";
}
if (!ellpack_page_ || (batch_param_ != param && param != BatchParam{})) {
CHECK_GE(param.gpu_id, 0);
CHECK_GE(param.max_bin, 2);
ellpack_page_.reset(new EllpackPage(this, param));
batch_param_ = param;
}
auto begin_iter =
BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_page_.get()));
Expand Down
8 changes: 8 additions & 0 deletions src/data/simple_dmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ class SimpleDMatrix : public DMatrix {
std::unique_ptr<CSCPage> column_page_;
std::unique_ptr<SortedCSCPage> sorted_column_page_;
std::unique_ptr<EllpackPage> ellpack_page_;
BatchParam batch_param_;

bool EllpackExists() const override {
return static_cast<bool>(ellpack_page_);
}
bool SparsePageExists() const override {
return true;
}
};
} // namespace data
} // namespace xgboost
Expand Down
4 changes: 2 additions & 2 deletions src/data/sparse_page_dmatrix.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright 2014 by Contributors
* Copyright 2014-2020 by Contributors
* \file sparse_page_dmatrix.cc
* \brief The external memory version of Page Iterator.
* \author Tianqi Chen
Expand Down Expand Up @@ -47,7 +47,7 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(const BatchParam& par
CHECK_GE(param.gpu_id, 0);
CHECK_GE(param.max_bin, 2);
// Lazily instantiate
if (!ellpack_source_ || batch_param_ != param) {
if (!ellpack_source_ || (batch_param_ != param && param != BatchParam{})) {
ellpack_source_.reset(new EllpackPageSource(this, cache_info_, param));
batch_param_ = param;
}
Expand Down
7 changes: 7 additions & 0 deletions src/data/sparse_page_dmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ class SparsePageDMatrix : public DMatrix {
std::string cache_info_;
// Store column densities to avoid recalculating
std::vector<float> col_density_;

bool EllpackExists() const override {
return static_cast<bool>(ellpack_source_);
}
bool SparsePageExists() const override {
return static_cast<bool>(row_source_);
}
};
} // namespace data
} // namespace xgboost
Expand Down
Loading