From 4d1801286e6963917f564a75587f04fe213fbea5 Mon Sep 17 00:00:00 2001 From: egorsmir Date: Sun, 26 Jan 2020 22:59:53 +0300 Subject: [PATCH] Optimized ApplySplit and UpdatePredictionCache functions --- src/common/column_matrix.h | 1 + src/common/hist_util.cc | 105 +++++- src/common/hist_util.h | 28 +- src/common/row_set.h | 143 +++++++- src/tree/updater_quantile_hist.cc | 401 ++++++++++++--------- src/tree/updater_quantile_hist.h | 43 +-- tests/cpp/common/test_partition_builder.cc | 76 ++++ 7 files changed, 551 insertions(+), 246 deletions(-) create mode 100755 tests/cpp/common/test_partition_builder.cc diff --git a/src/common/column_matrix.h b/src/common/column_matrix.h index 40136c5c7991..762177212d19 100644 --- a/src/common/column_matrix.h +++ b/src/common/column_matrix.h @@ -37,6 +37,7 @@ class Column { size_t Size() const { return len_; } uint32_t GetGlobalBinIdx(size_t idx) const { return index_base_ + index_[idx]; } uint32_t GetFeatureBinIdx(size_t idx) const { return index_[idx]; } + const uint32_t* GetFeatureBinIdxPtr() const { return index_; } // column.GetFeatureBinIdx(idx) + column.GetBaseIdx(idx) == // column.GetGlobalBinIdx(idx) uint32_t GetBaseIdx() const { return index_base_; } diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index 9815a3b46aba..babf3c980f06 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -662,8 +662,8 @@ void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat, /*! * \brief fill a histogram by zeroes */ -void InitilizeHistByZeroes(GHistRow hist) { - memset(hist.data(), '\0', hist.size()*sizeof(tree::GradStats)); +void InitilizeHistByZeroes(GHistRow hist, size_t begin, size_t end) { + memset(hist.data() + begin, '\0', (end-begin)*sizeof(tree::GradStats)); } /*! @@ -707,40 +707,107 @@ void SubtractionHist(GHistRow dst, const GHistRow src1, const GHistRow src2, } } +template +void BuildHistDenseKernel(const size_t* rid, const float* pgh, const uint32_t* index, + FPType* hist_data, size_t ibegin, size_t iend, size_t n_features, + size_t prefetch_offset, size_t prefetch_step) { + for (size_t i = ibegin; i < iend; ++i) { + const size_t icol_start = rid[i] * n_features; + const size_t idx_gh = 2*rid[i]; + + if (do_prefetch) { + const size_t icol_start_prefetch = rid[i+prefetch_offset] * n_features; + + PREFETCH_READ_T0(pgh + 2*rid[i + prefetch_offset]); + for (size_t j = icol_start_prefetch; j < icol_start_prefetch + n_features; + j += prefetch_step) { + PREFETCH_READ_T0(index + j); + } + } + + for (size_t j = icol_start; j < icol_start + n_features; ++j) { + const uint32_t idx_bin = 2*index[j]; + + hist_data[idx_bin] += pgh[idx_gh]; + hist_data[idx_bin+1] += pgh[idx_gh+1]; + } + } +} + +template +void BuildHistSparseKernel(const size_t* rid, const float* pgh, const uint32_t* index, + FPType* hist_data, const size_t* row_ptr, size_t ibegin, size_t iend, + size_t prefetch_offset, size_t prefetch_step) { + for (size_t i = ibegin; i < iend; ++i) { + const size_t icol_start = row_ptr[rid[i]]; + const size_t icol_end = row_ptr[rid[i]+1]; + const size_t idx_gh = 2*rid[i]; + + if (do_prefetch) { + const size_t icol_start_prftch = row_ptr[rid[i+prefetch_offset]]; + const size_t icol_end_prefect = row_ptr[rid[i+prefetch_offset]+1]; + + PREFETCH_READ_T0(pgh + 2*rid[i + prefetch_offset]); + for (size_t j = icol_start_prftch; j < icol_end_prefect; j+=prefetch_step) { + PREFETCH_READ_T0(index + j); + } + } + + for (size_t j = icol_start; j < icol_end; ++j) { + const uint32_t idx_bin = 2*index[j]; + hist_data[idx_bin] += pgh[idx_gh]; + hist_data[idx_bin+1] += pgh[idx_gh+1]; + } + } +} + +template +void BuildHistKernel(const size_t* rid, const float* pgh, const uint32_t* index, + FPType* hist_data, const size_t* row_ptr, size_t ibegin, size_t iend, + size_t prefetch_offset, size_t prefetch_step, bool isDense) { + if (isDense) { + const size_t n_features = row_ptr[rid[0]+1] - row_ptr[rid[0]]; + BuildHistDenseKernel(rid, pgh, index, hist_data, + ibegin, iend, n_features, prefetch_offset, prefetch_step); + } else { + BuildHistSparseKernel(rid, pgh, index, hist_data, row_ptr, + ibegin, iend, prefetch_offset, prefetch_step); + } +} void GHistBuilder::BuildHist(const std::vector& gpair, const RowSetCollection::Elem row_indices, const GHistIndexMatrix& gmat, - GHistRow hist) { + GHistRow hist, + bool isDense) { const size_t* rid = row_indices.begin; const size_t nrows = row_indices.Size(); const uint32_t* index = gmat.index.data(); const size_t* row_ptr = gmat.row_ptr.data(); const float* pgh = reinterpret_cast(gpair.data()); - double* hist_data = reinterpret_cast(hist.data()); + using FPType = decltype(tree::GradStats::sum_grad); + FPType* hist_data = reinterpret_cast(hist.data()); const size_t cache_line_size = 64; const size_t prefetch_offset = 10; size_t no_prefetch_size = prefetch_offset + cache_line_size/sizeof(*rid); no_prefetch_size = no_prefetch_size > nrows ? nrows : no_prefetch_size; + const size_t prefetch_step = cache_line_size / sizeof(*index); - for (size_t i = 0; i < nrows; ++i) { - const size_t icol_start = row_ptr[rid[i]]; - const size_t icol_end = row_ptr[rid[i]+1]; - - if (i < nrows - no_prefetch_size) { - PREFETCH_READ_T0(row_ptr + rid[i + prefetch_offset]); - PREFETCH_READ_T0(pgh + 2*rid[i + prefetch_offset]); - } - - for (size_t j = icol_start; j < icol_end; ++j) { - const uint32_t idx_bin = 2*index[j]; - const size_t idx_gh = 2*rid[i]; + // if need to work with all rows from bin-matrix (e.g. root node) + const bool contiguousBlock = (rid[row_indices.Size()-1] - rid[0]) == (row_indices.Size() - 1); - hist_data[idx_bin] += pgh[idx_gh]; - hist_data[idx_bin+1] += pgh[idx_gh+1]; - } + if (contiguousBlock) { + // contiguous memory access, built-in HW prefetching is enough + BuildHistKernel(rid, pgh, index, hist_data, row_ptr, + 0, nrows, prefetch_offset, prefetch_step, isDense); + } else { + BuildHistKernel(rid, pgh, index, hist_data, row_ptr, + 0, nrows - no_prefetch_size, prefetch_offset, prefetch_step, isDense); + // no prefetching to avoid loading extra memory + BuildHistKernel(rid, pgh, index, hist_data, row_ptr, + nrows - no_prefetch_size, nrows, prefetch_offset, prefetch_step, isDense); } } diff --git a/src/common/hist_util.h b/src/common/hist_util.h index affdca17754b..d8a10c65c06a 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -355,7 +355,7 @@ using GHistRow = Span; /*! * \brief fill a histogram by zeros */ -void InitilizeHistByZeroes(GHistRow hist); +void InitilizeHistByZeroes(GHistRow hist, size_t begin, size_t end); /*! * \brief Increment hist as dst += add in range [begin, end) @@ -454,6 +454,7 @@ class ParallelGHistBuilder { tid_nid_to_hist_.clear(); hist_memory_.clear(); threads_to_nids_map_.clear(); + targeted_hists_ = targeted_hists; CHECK_EQ(nodes, targeted_hists.size()); @@ -478,7 +479,7 @@ class ParallelGHistBuilder { GHistRow hist = hist_memory_[idx]; if (!hist_was_used_[tid * nodes_ + nid]) { - InitilizeHistByZeroes(hist); + InitilizeHistByZeroes(hist, 0, hist.size()); hist_was_used_[tid * nodes_ + nid] = static_cast(true); } @@ -492,16 +493,23 @@ class ParallelGHistBuilder { GHistRow dst = targeted_hists_[nid]; + bool is_updated = false; for (size_t tid = 0; tid < nthreads_; ++tid) { if (hist_was_used_[tid * nodes_ + nid]) { + is_updated = true; const size_t idx = tid_nid_to_hist_.at({tid, nid}); GHistRow src = hist_memory_[idx]; if (dst.data() != src.data()) { IncrementHist(dst, src, begin, end); - } // else src is already targeted hist + } } } + if (!is_updated) { + // In distributed mode - some tree nodes can be empty on local machines, + // So we need just set local hist by zeros in this case + InitilizeHistByZeroes(dst, begin, end); + } } protected: @@ -531,7 +539,7 @@ class ParallelGHistBuilder { size_t hist_allocated_additionally = 0; for (size_t nid = 0; nid < nodes_; ++nid) { - size_t nthreads_for_nid = 0; + int nthreads_for_nid = 0; for (size_t tid = 0; tid < nthreads_; ++tid) { if (threads_to_nids_map_[tid * nodes_ + nid]) { @@ -539,10 +547,11 @@ class ParallelGHistBuilder { } } - CHECK_GT(nthreads_for_nid, 0); - // -1 means that we have one histogram per node already allocated externally, - // which should store final result for the node - hist_allocated_additionally += (nthreads_for_nid - 1); + // In distributed mode - some tree nodes can be empty on local machines, + // set nthreads_for_nid to 0 in this case. + // In another case - allocate additional (nthreads_for_nid - 1) histograms, + // because one is already allocated externally (will store final result for the node). + hist_allocated_additionally += std::max(0, nthreads_for_nid - 1); } for (size_t i = 0; i < hist_allocated_additionally; ++i) { @@ -613,7 +622,8 @@ class GHistBuilder { void BuildHist(const std::vector& gpair, const RowSetCollection::Elem row_indices, const GHistIndexMatrix& gmat, - GHistRow hist); + GHistRow hist, + bool isDense); // same, with feature grouping void BuildBlockHist(const std::vector& gpair, const RowSetCollection::Elem row_indices, diff --git a/src/common/row_set.h b/src/common/row_set.h index 285988b159c3..dd2c876cc2de 100644 --- a/src/common/row_set.h +++ b/src/common/row_set.h @@ -10,6 +10,7 @@ #include #include #include +#include namespace xgboost { namespace common { @@ -57,6 +58,13 @@ class RowSetCollection { << "access element that is not in the set"; return e; } + + /*! \brief return corresponding element set given the node_id */ + inline Elem& operator[](unsigned node_id) { + Elem& e = elem_of_each_node_[node_id]; + return e; + } + // clear up things inline void Clear() { elem_of_each_node_.clear(); @@ -83,25 +91,18 @@ class RowSetCollection { } // split rowset into two inline void AddSplit(unsigned node_id, - const std::vector& row_split_tloc, unsigned left_node_id, - unsigned right_node_id) { + unsigned right_node_id, + size_t n_left, + size_t n_right) { const Elem e = elem_of_each_node_[node_id]; - const auto nthread = static_cast(row_split_tloc.size()); CHECK(e.begin != nullptr); size_t* all_begin = dmlc::BeginPtr(row_indices_); size_t* begin = all_begin + (e.begin - all_begin); - size_t* it = begin; - for (bst_omp_uint tid = 0; tid < nthread; ++tid) { - std::copy(row_split_tloc[tid].left.begin(), row_split_tloc[tid].left.end(), it); - it += row_split_tloc[tid].left.size(); - } - size_t* split_pt = it; - for (bst_omp_uint tid = 0; tid < nthread; ++tid) { - std::copy(row_split_tloc[tid].right.begin(), row_split_tloc[tid].right.end(), it); - it += row_split_tloc[tid].right.size(); - } + CHECK_EQ(n_left + n_right, e.Size()); + CHECK_LE(begin + n_left, e.end); + CHECK_EQ(begin + n_left + n_right, e.end); if (left_node_id >= elem_of_each_node_.size()) { elem_of_each_node_.resize(left_node_id + 1, Elem(nullptr, nullptr, -1)); @@ -110,8 +111,8 @@ class RowSetCollection { elem_of_each_node_.resize(right_node_id + 1, Elem(nullptr, nullptr, -1)); } - elem_of_each_node_[left_node_id] = Elem(begin, split_pt, left_node_id); - elem_of_each_node_[right_node_id] = Elem(split_pt, e.end, right_node_id); + elem_of_each_node_[left_node_id] = Elem(begin, begin + n_left, left_node_id); + elem_of_each_node_[right_node_id] = Elem(begin + n_left, e.end, right_node_id); elem_of_each_node_[node_id] = Elem(nullptr, nullptr, -1); } @@ -123,6 +124,118 @@ class RowSetCollection { std::vector elem_of_each_node_; }; + +template +class PartitionBuilder { + public: + template + void Init(const size_t n_tasks, size_t n_nodes, Func func) { + node_sizes_.resize(n_nodes); + nodes_.resize(n_nodes+1); + + nodes_[0] = 0; + for (size_t i = 1; i < n_nodes+1; ++i) { + nodes_[i] = nodes_[i-1] + func(i-1); + } + + if (n_tasks > max_n_tasks_) { + blocks_.resize(n_tasks); + max_n_tasks_ = n_tasks; + } + } + + size_t* GetLeftBuffer(int nid, size_t begin, size_t end) { + const size_t task_idx = GetTaskIdx(nid, begin); + CHECK_LE(task_idx, blocks_.size()); + return blocks_[task_idx].left(); + } + + size_t* GetRightBuffer(int nid, size_t begin, size_t end) { + const size_t task_idx = GetTaskIdx(nid, begin); + CHECK_LE(task_idx, blocks_.size()); + return blocks_[task_idx].right(); + } + + void SetNLeftElems(int nid, size_t begin, size_t end, size_t n_left) { + size_t task_idx = GetTaskIdx(nid, begin); + CHECK_LE(task_idx, blocks_.size()); + blocks_[task_idx].n_left = n_left; + } + + void SetNRightElems(int nid, size_t begin, size_t end, size_t n_right) { + size_t task_idx = GetTaskIdx(nid, begin); + CHECK_LE(task_idx, blocks_.size()); + blocks_[task_idx].n_right = n_right; + } + + + size_t GetNLeftElems(int nid) const { + return node_sizes_[nid].first; + } + + size_t GetNRightElems(int nid) const { + return node_sizes_[nid].second; + } + + void CalculateRowOffsets() { + for (size_t i = 0; i < nodes_.size()-1; ++i) { + size_t n_left = 0; + for (size_t j = nodes_[i]; j < nodes_[i+1]; ++j) { + blocks_[j].n_offset_left = n_left; + n_left += blocks_[j].n_left; + } + size_t n_right = 0; + for (size_t j = nodes_[i]; j < nodes_[i+1]; ++j) { + blocks_[j].n_offset_right = n_left + n_right; + n_right += blocks_[j].n_right; + } + node_sizes_[i] = {n_left, n_right}; + } + } + + void MergeToArray(int nid, size_t begin, size_t* rows_indexes) { + size_t task_idx = GetTaskIdx(nid, begin); + + size_t* left_result = rows_indexes + blocks_[task_idx].n_offset_left; + size_t* right_result = rows_indexes + blocks_[task_idx].n_offset_right; + + const size_t* left = blocks_[task_idx].left(); + const size_t* right = blocks_[task_idx].right(); + + std::copy_n(left, blocks_[task_idx].n_left, left_result); + std::copy_n(right, blocks_[task_idx].n_right, right_result); + } + + protected: + size_t GetTaskIdx(int nid, size_t begin) { + return nodes_[nid] + begin/BlockSize; + } + + struct BlockInfo{ + size_t n_left; + size_t n_right; + + size_t n_offset_left; + size_t n_offset_right; + + size_t* left() { + return &left_data_[0]; + } + + size_t* right() { + return &right_data_[0]; + } + private: + alignas(128) size_t left_data_[BlockSize]; + alignas(128) size_t right_data_[BlockSize]; + }; + std::vector> node_sizes_; + std::vector nodes_; + std::vector blocks_; + size_t max_n_tasks_ = 0; +}; + + } // namespace common } // namespace xgboost diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index d1d206587c0b..b00eb1f2b40f 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -239,17 +239,15 @@ void QuantileHistMaker::Builder::BuildNodeStats( builder_monitor_.Stop("BuildNodeStats"); } -void QuantileHistMaker::Builder::EvaluateSplits( - const GHistIndexMatrix &gmat, - const ColumnMatrix &column_matrix, - DMatrix *p_fmat, - RegTree *p_tree, - int *num_leaves, - int depth, - unsigned *timestamp, - std::vector *temp_qexpand_depth) { - EvaluateSplit(qexpand_depth_wise_, gmat, hist_, *p_fmat, *p_tree); - +void QuantileHistMaker::Builder::AddSplitsToTree( + const GHistIndexMatrix &gmat, + DMatrix *p_fmat, + RegTree *p_tree, + int *num_leaves, + int depth, + unsigned *timestamp, + std::vector* nodes_for_apply_split, + std::vector* temp_qexpand_depth) { for (auto const& entry : qexpand_depth_wise_) { int nid = entry.nid; @@ -258,7 +256,17 @@ void QuantileHistMaker::Builder::EvaluateSplits( (param_.max_leaves > 0 && (*num_leaves) == param_.max_leaves)) { (*p_tree)[nid].SetLeaf(snode_[nid].weight * param_.learning_rate); } else { - this->ApplySplit(nid, gmat, column_matrix, hist_, *p_fmat, p_tree); + nodes_for_apply_split->push_back(entry); + + NodeEntry& e = snode_[nid]; + bst_float left_leaf_weight = + spliteval_->ComputeWeight(nid, e.best.left_sum) * param_.learning_rate; + bst_float right_leaf_weight = + spliteval_->ComputeWeight(nid, e.best.right_sum) * param_.learning_rate; + p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value, + e.best.DefaultLeft(), e.weight, left_leaf_weight, + right_leaf_weight, e.best.loss_chg, e.stats.sum_hess); + int left_id = (*p_tree)[nid].LeftChild(); int right_id = (*p_tree)[nid].RightChild(); temp_qexpand_depth->push_back(ExpandEntry(left_id, right_id, @@ -271,6 +279,25 @@ void QuantileHistMaker::Builder::EvaluateSplits( } } + +void QuantileHistMaker::Builder::EvaluateSplits( + const GHistIndexMatrix &gmat, + const ColumnMatrix &column_matrix, + DMatrix *p_fmat, + RegTree *p_tree, + int *num_leaves, + int depth, + unsigned *timestamp, + std::vector *temp_qexpand_depth) { + EvaluateSplit(qexpand_depth_wise_, gmat, hist_, *p_fmat, *p_tree); + + std::vector nodes_for_apply_split; + AddSplitsToTree(gmat, p_fmat, p_tree, num_leaves, depth, timestamp, + &nodes_for_apply_split, temp_qexpand_depth); + + ApplySplit(nodes_for_apply_split, gmat, column_matrix, hist_, *p_fmat, p_tree); +} + // Split nodes to 2 sets depending on amount of rows in each node // Histograms for small nodes will be built explicitly // Histograms for big nodes will be built by 'Subtraction Trick' @@ -382,7 +409,16 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide( || (param_.max_leaves > 0 && num_leaves == param_.max_leaves) ) { (*p_tree)[nid].SetLeaf(snode_[nid].weight * param_.learning_rate); } else { - this->ApplySplit(nid, gmat, column_matrix, hist_, *p_fmat, p_tree); + NodeEntry& e = snode_[nid]; + bst_float left_leaf_weight = + spliteval_->ComputeWeight(nid, e.best.left_sum) * param_.learning_rate; + bst_float right_leaf_weight = + spliteval_->ComputeWeight(nid, e.best.right_sum) * param_.learning_rate; + p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value, + e.best.DefaultLeft(), e.weight, left_leaf_weight, + right_leaf_weight, e.best.loss_chg, e.stats.sum_hess); + + this->ApplySplit({candidate}, gmat, column_matrix, hist_, *p_fmat, p_tree); const int cleft = (*p_tree)[nid].LeftChild(); const int cright = (*p_tree)[nid].RightChild(); @@ -473,7 +509,14 @@ bool QuantileHistMaker::Builder::UpdatePredictionCache( CHECK_GT(out_preds.size(), 0U); - for (const RowSetCollection::Elem rowset : row_set_collection_) { + size_t n_nodes = row_set_collection_.end() - row_set_collection_.begin(); + + common::BlockedSpace2d space(n_nodes, [&](size_t node) { + return row_set_collection_[node].Size(); + }, 1024); + + common::ParallelFor2d(space, this->nthread_, [&](size_t node, common::Range1d r) { + const RowSetCollection::Elem rowset = row_set_collection_[node]; if (rowset.begin != nullptr && rowset.end != nullptr) { int nid = rowset.node_id; bst_float leaf_value; @@ -487,11 +530,11 @@ bool QuantileHistMaker::Builder::UpdatePredictionCache( } leaf_value = (*p_last_tree_)[nid].LeafValue(); - for (const size_t* it = rowset.begin; it < rowset.end; ++it) { + for (const size_t* it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) { out_preds[*it] += leaf_value; } } - } + }); builder_monitor_.Stop("UpdatePredictionCache"); return true; @@ -732,193 +775,191 @@ void QuantileHistMaker::Builder::EvaluateSplit(const std::vector& n builder_monitor_.Stop("EvaluateSplit"); } -void QuantileHistMaker::Builder::ApplySplit(int nid, - const GHistIndexMatrix& gmat, - const ColumnMatrix& column_matrix, - const HistCollection& hist, - const DMatrix& fmat, - RegTree* p_tree) { - builder_monitor_.Start("ApplySplit"); - // TODO(hcho3): support feature sampling by levels - - /* 1. Create child nodes */ - NodeEntry& e = snode_[nid]; - bst_float left_leaf_weight = - spliteval_->ComputeWeight(nid, e.best.left_sum) * param_.learning_rate; - bst_float right_leaf_weight = - spliteval_->ComputeWeight(nid, e.best.right_sum) * param_.learning_rate; - p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value, - e.best.DefaultLeft(), e.weight, left_leaf_weight, - right_leaf_weight, e.best.loss_chg, e.stats.sum_hess); - - /* 2. Categorize member rows */ - const auto nthread = static_cast(this->nthread_); - row_split_tloc_.resize(nthread); - for (bst_omp_uint i = 0; i < nthread; ++i) { - row_split_tloc_[i].left.clear(); - row_split_tloc_[i].right.clear(); - } - const bool default_left = (*p_tree)[nid].DefaultLeft(); - const bst_uint fid = (*p_tree)[nid].SplitIndex(); - const bst_float split_pt = (*p_tree)[nid].SplitCond(); - const uint32_t lower_bound = gmat.cut.Ptrs()[fid]; - const uint32_t upper_bound = gmat.cut.Ptrs()[fid + 1]; - int32_t split_cond = -1; - // convert floating-point split_pt into corresponding bin_id - // split_cond = -1 indicates that split_pt is less than all known cut points - CHECK_LT(upper_bound, - static_cast(std::numeric_limits::max())); - for (uint32_t i = lower_bound; i < upper_bound; ++i) { - if (split_pt == gmat.cut.Values()[i]) { - split_cond = static_cast(i); - } - } +template +inline std::pair PartitionDenseKernel(const size_t* rid, + const uint32_t* idx, const uint32_t offset, const int32_t split_cond, + const size_t istart, const size_t iend, size_t* p_left, size_t* p_right) { + size_t ileft = 0; + size_t iright = 0; - const auto& rowset = row_set_collection_[nid]; + const uint32_t max_val = std::numeric_limits::max(); - Column column = column_matrix.GetColumn(fid); - if (column.GetType() == xgboost::common::kDenseColumn) { - ApplySplitDenseData(rowset, gmat, &row_split_tloc_, column, split_cond, - default_left); - } else { - ApplySplitSparseData(rowset, gmat, &row_split_tloc_, column, lower_bound, - upper_bound, split_cond, default_left); + for (size_t i = istart; i < iend; i++) { + if (idx[rid[i]] == max_val) { + if (default_left) { + p_left[ileft++] = rid[i]; + } else { + p_right[iright++] = rid[i]; + } + } else { + if (static_cast(idx[rid[i]] + offset) <= split_cond) { + p_left[ileft++] = rid[i]; + } else { + p_right[iright++] = rid[i]; + } + } } - row_set_collection_.AddSplit( - nid, row_split_tloc_, (*p_tree)[nid].LeftChild(), (*p_tree)[nid].RightChild()); - builder_monitor_.Stop("ApplySplit"); + return { ileft, iright }; } -void QuantileHistMaker::Builder::ApplySplitDenseData( - const RowSetCollection::Elem rowset, - const GHistIndexMatrix& gmat, - std::vector* p_row_split_tloc, - const Column& column, - bst_int split_cond, - bool default_left) { - std::vector& row_split_tloc = *p_row_split_tloc; - constexpr int kUnroll = 8; // loop unrolling factor - const size_t nrows = rowset.end - rowset.begin; - const size_t rest = nrows % kUnroll; - -#pragma omp parallel for num_threads(nthread_) schedule(static) - for (bst_omp_uint i = 0; i < nrows - rest; i += kUnroll) { - const bst_uint tid = omp_get_thread_num(); - auto& left = row_split_tloc[tid].left; - auto& right = row_split_tloc[tid].right; - size_t rid[kUnroll]; - uint32_t rbin[kUnroll]; - for (int k = 0; k < kUnroll; ++k) { - rid[k] = rowset.begin[i + k]; - } - for (int k = 0; k < kUnroll; ++k) { - rbin[k] = column.GetFeatureBinIdx(rid[k]); - } - for (int k = 0; k < kUnroll; ++k) { // NOLINT - if (rbin[k] == std::numeric_limits::max()) { // missing value - if (default_left) { - left.push_back(rid[k]); + +template +inline std::pair PartitionSparseKernel(const size_t* rid, + const uint32_t* idx, const uint32_t offset, const int32_t split_cond, + const size_t istart, const size_t iend, size_t* p_left, size_t* p_right, + bst_uint lower_bound, bst_uint upper_bound, const Column& column) { + + size_t ileft = 0; + size_t iright = 0; + + if (istart < iend) { // ensure that [istart, iend) is nonempty range + // search first nonzero row with index >= rowset[istart] + const size_t* p = std::lower_bound(column.GetRowData(), + column.GetRowData() + column.Size(), + rid[istart]); + + if (p != column.GetRowData() + column.Size() && *p <= rid[iend - 1]) { + size_t cursor = p - column.GetRowData(); + + for (size_t i = istart; i < iend; ++i) { + while (cursor < column.Size() + && column.GetRowIdx(cursor) < rid[i] + && column.GetRowIdx(cursor) <= rid[iend - 1]) { + ++cursor; + } + if (cursor < column.Size() && column.GetRowIdx(cursor) == rid[i]) { + const uint32_t rbin = column.GetFeatureBinIdx(cursor); + if (static_cast(rbin + column.GetBaseIdx()) <= split_cond) { + p_left[ileft++] = rid[i]; + } else { + p_right[iright++] = rid[i]; + } + ++cursor; } else { - right.push_back(rid[k]); + // missing value + if (default_left) { + p_left[ileft++] = rid[i]; + } else { + p_right[iright++] = rid[i]; + } + } + } + } else { // all rows in [istart, iend) have missing values + if (default_left) { + for (size_t i = istart; i < iend; ++i) { + p_left[ileft++] = rid[i]; } } else { - if (static_cast(rbin[k] + column.GetBaseIdx()) <= split_cond) { - left.push_back(rid[k]); - } else { - right.push_back(rid[k]); + for (size_t i = istart; i < iend; ++i) { + p_right[iright++] = rid[i]; } } } } - for (size_t i = nrows - rest; i < nrows; ++i) { - auto& left = row_split_tloc[nthread_-1].left; - auto& right = row_split_tloc[nthread_-1].right; - const size_t rid = rowset.begin[i]; - const uint32_t rbin = column.GetFeatureBinIdx(rid); - if (rbin == std::numeric_limits::max()) { // missing value + + return {ileft, iright}; +} + + +void QuantileHistMaker::Builder::ApplySplit(std::vector nodes, + const GHistIndexMatrix& gmat, + const ColumnMatrix& column_matrix, + const HistCollection& hist, + const DMatrix& fmat, + RegTree* p_tree) { + builder_monitor_.Start("ApplySplit"); + size_t n_nodes = nodes.size(); + std::vector split_conditions(n_nodes); + + for (size_t i = 0; i < nodes.size(); ++i) { + int32_t nid = nodes[i].nid; + const bst_uint fid = (*p_tree)[nid].SplitIndex(); + const bst_float split_pt = (*p_tree)[nid].SplitCond(); + const uint32_t lower_bound = gmat.cut.Ptrs()[fid]; + const uint32_t upper_bound = gmat.cut.Ptrs()[fid + 1]; + int32_t split_cond = -1; + // convert floating-point split_pt into corresponding bin_id + // split_cond = -1 indicates that split_pt is less than all known cut points + CHECK_LT(upper_bound, + static_cast(std::numeric_limits::max())); + for (uint32_t i = lower_bound; i < upper_bound; ++i) { + if (split_pt == gmat.cut.Values()[i]) { + split_cond = static_cast(i); + } + } + split_conditions[i] = split_cond; + } + + common::BlockedSpace2d space(n_nodes, [&](size_t node) { + int32_t nid = nodes[node].nid; + return row_set_collection_[nid].Size(); + }, kPartitionBlockSize); + + part_builder_.Init(space.Size(), n_nodes, [&](size_t node) { + int32_t nid = nodes[node].nid; + size_t size = row_set_collection_[nid].Size(); + size_t n_tasks = size / kPartitionBlockSize + !!(size % kPartitionBlockSize); + return n_tasks; + }); + + common::ParallelFor2d(space, this->nthread_, [&](size_t node, common::Range1d r) { + const int32_t nid = nodes[node].nid; + const size_t* rid = row_set_collection_[nid].begin; + size_t* p_left = part_builder_.GetLeftBuffer(node, r.begin(), r.end()); + size_t* p_right = part_builder_.GetRightBuffer(node, r.begin(), r.end()); + + const bst_uint fid = (*p_tree)[nid].SplitIndex(); + const bool default_left = (*p_tree)[nid].DefaultLeft(); + const auto column = column_matrix.GetColumn(fid); + const uint32_t* idx = column.GetFeatureBinIdxPtr(); + const uint32_t offset = column.GetBaseIdx(); + + std::pair pair; + + if (column.GetType() == xgboost::common::kDenseColumn) { if (default_left) { - left.push_back(rid); + pair = PartitionDenseKernel(rid, idx, offset, split_conditions[node], + r.begin(), r.end(), p_left, p_right); } else { - right.push_back(rid); + pair = PartitionDenseKernel(rid, idx, offset, split_conditions[node], + r.begin(), r.end(), p_left, p_right); } } else { - if (static_cast(rbin + column.GetBaseIdx()) <= split_cond) { - left.push_back(rid); + const uint32_t lower_bound = gmat.cut.Ptrs()[fid]; + const uint32_t upper_bound = gmat.cut.Ptrs()[fid + 1]; + + if (default_left) { + pair = PartitionSparseKernel(rid, idx, offset, split_conditions[node], + r.begin(), r.end(), p_left, p_right, lower_bound, upper_bound, column); } else { - right.push_back(rid); + pair = PartitionSparseKernel(rid, idx, offset, split_conditions[node], + r.begin(), r.end(), p_left, p_right, lower_bound, upper_bound, column); } } - } -} -void QuantileHistMaker::Builder::ApplySplitSparseData( - const RowSetCollection::Elem rowset, - const GHistIndexMatrix& gmat, - std::vector* p_row_split_tloc, - const Column& column, - bst_uint lower_bound, - bst_uint upper_bound, - bst_int split_cond, - bool default_left) { - std::vector& row_split_tloc = *p_row_split_tloc; - const size_t nrows = rowset.end - rowset.begin; - -#pragma omp parallel num_threads(nthread_) - { - const auto tid = static_cast(omp_get_thread_num()); - const size_t ibegin = tid * nrows / nthread_; - const size_t iend = (tid + 1) * nrows / nthread_; - if (ibegin < iend) { // ensure that [ibegin, iend) is nonempty range - // search first nonzero row with index >= rowset[ibegin] - const size_t* p = std::lower_bound(column.GetRowData(), - column.GetRowData() + column.Size(), - rowset.begin[ibegin]); - - auto& left = row_split_tloc[tid].left; - auto& right = row_split_tloc[tid].right; - if (p != column.GetRowData() + column.Size() && *p <= rowset.begin[iend - 1]) { - size_t cursor = p - column.GetRowData(); + part_builder_.SetNLeftElems(node, r.begin(), r.end(), pair.first); + part_builder_.SetNRightElems(node, r.begin(), r.end(), pair.second); + }); - for (size_t i = ibegin; i < iend; ++i) { - const size_t rid = rowset.begin[i]; - while (cursor < column.Size() - && column.GetRowIdx(cursor) < rid - && column.GetRowIdx(cursor) <= rowset.begin[iend - 1]) { - ++cursor; - } - if (cursor < column.Size() && column.GetRowIdx(cursor) == rid) { - const uint32_t rbin = column.GetFeatureBinIdx(cursor); - if (static_cast(rbin + column.GetBaseIdx()) <= split_cond) { - left.push_back(rid); - } else { - right.push_back(rid); - } - ++cursor; - } else { - // missing value - if (default_left) { - left.push_back(rid); - } else { - right.push_back(rid); - } - } - } - } else { // all rows in [ibegin, iend) have missing values - if (default_left) { - for (size_t i = ibegin; i < iend; ++i) { - const size_t rid = rowset.begin[i]; - left.push_back(rid); - } - } else { - for (size_t i = ibegin; i < iend; ++i) { - const size_t rid = rowset.begin[i]; - right.push_back(rid); - } - } - } - } + part_builder_.CalculateRowOffsets(); + + common::ParallelFor2d(space, this->nthread_, [&](size_t node, common::Range1d r) { + int32_t nid = nodes[node].nid; + part_builder_.MergeToArray(node, r.begin(), + const_cast(row_set_collection_[nid].begin)); + }); + + for (size_t i = 0; i < n_nodes; ++i) { + int32_t nid = nodes[i].nid; + + size_t n_left = part_builder_.GetNLeftElems(i); + size_t n_right = part_builder_.GetNRightElems(i); + + row_set_collection_.AddSplit(nid, (*p_tree)[nid].LeftChild(), + (*p_tree)[nid].RightChild(), n_left, n_right); } + builder_monitor_.Stop("ApplySplit"); } void QuantileHistMaker::Builder::InitNewNode(int nid, diff --git a/src/tree/updater_quantile_hist.h b/src/tree/updater_quantile_hist.h index 18dd4ef1baa7..4a0707393a74 100644 --- a/src/tree/updater_quantile_hist.h +++ b/src/tree/updater_quantile_hist.h @@ -161,7 +161,7 @@ class QuantileHistMaker: public TreeUpdater { if (param_.enable_feature_grouping > 0) { hist_builder_.BuildBlockHist(gpair, row_indices, gmatb, hist); } else { - hist_builder_.BuildHist(gpair, row_indices, gmat, hist); + hist_builder_.BuildHist(gpair, row_indices, gmat, hist, data_layout_ != kSparseData); } } @@ -200,28 +200,12 @@ class QuantileHistMaker: public TreeUpdater { const DMatrix& fmat, const RegTree& tree); - void ApplySplit(int nid, - const GHistIndexMatrix& gmat, - const ColumnMatrix& column_matrix, - const HistCollection& hist, - const DMatrix& fmat, - RegTree* p_tree); - - void ApplySplitDenseData(const RowSetCollection::Elem rowset, - const GHistIndexMatrix& gmat, - std::vector* p_row_split_tloc, - const Column& column, - bst_int split_cond, - bool default_left); - - void ApplySplitSparseData(const RowSetCollection::Elem rowset, - const GHistIndexMatrix& gmat, - std::vector* p_row_split_tloc, - const Column& column, - bst_uint lower_bound, - bst_uint upper_bound, - bst_int split_cond, - bool default_left); + void ApplySplit(std::vector nodes, + const GHistIndexMatrix& gmat, + const ColumnMatrix& column_matrix, + const HistCollection& hist, + const DMatrix& fmat, + RegTree* p_tree); void InitNewNode(int nid, const GHistIndexMatrix& gmat, @@ -295,6 +279,16 @@ class QuantileHistMaker: public TreeUpdater { unsigned *timestamp, std::vector *temp_qexpand_depth); + void AddSplitsToTree( + const GHistIndexMatrix &gmat, + DMatrix *p_fmat, + RegTree *p_tree, + int *num_leaves, + int depth, + unsigned *timestamp, + std::vector* nodes_for_apply_split, + std::vector* temp_qexpand_depth); + void ExpandWithLossGuide(const GHistIndexMatrix& gmat, const GHistIndexBlockMatrix& gmatb, const ColumnMatrix& column_matrix, @@ -335,6 +329,9 @@ class QuantileHistMaker: public TreeUpdater { std::unique_ptr spliteval_; FeatureInteractionConstraintHost interaction_constraints_; + static constexpr size_t kPartitionBlockSize = 2048; + common::PartitionBuilder part_builder_; + // back pointers to tree and data matrix const RegTree* p_last_tree_; DMatrix const* const p_last_fmat_; diff --git a/tests/cpp/common/test_partition_builder.cc b/tests/cpp/common/test_partition_builder.cc new file mode 100755 index 000000000000..6b51f2eaadc5 --- /dev/null +++ b/tests/cpp/common/test_partition_builder.cc @@ -0,0 +1,76 @@ +#include +#include +#include +#include + +#include "../../../src/common/row_set.h" +#include "../helpers.h" + +namespace xgboost { +namespace common { + +TEST(PartitionBuilder, BasicTest) { + constexpr size_t kBlockSize = 16; + constexpr size_t kNodes = 5; + constexpr size_t kTasks = 3 + 5 + 10 + 1 + 2; + + std::vector tasks = { 3, 5, 10, 1, 2 }; + + PartitionBuilder builder; + builder.Init(kTasks, kNodes, [&](size_t i) { + return tasks[i]; + }); + + std::vector rows_for_left_node = { 2, 12, 0, 16, 8 }; + + for(size_t nid = 0; nid < kNodes; ++nid) { + size_t value_left = 0; + size_t value_right = 0; + + size_t left_total = tasks[nid] * rows_for_left_node[nid]; + + for(size_t j = 0; j < tasks[nid]; ++j) { + size_t begin = kBlockSize*j; + size_t end = kBlockSize*(j+1); + + auto left = builder.GetLeftBuffer(nid, begin, end); + auto right = builder.GetRightBuffer(nid, begin, end); + + size_t n_left = rows_for_left_node[nid]; + size_t n_right = kBlockSize - rows_for_left_node[nid]; + + for(size_t i = 0; i < n_left; i++) { + left[i] = value_left++; + } + + for(size_t i = 0; i < n_right; i++) { + right[i] = left_total + value_right++; + } + + builder.SetNLeftElems(nid, begin, end, n_left); + builder.SetNRightElems(nid, begin, end, n_right); + } + } + builder.CalculateRowOffsets(); + + std::vector v(*std::max_element(tasks.begin(), tasks.end()) * kBlockSize); + + for(size_t nid = 0; nid < kNodes; ++nid) { + + for(size_t j = 0; j < tasks[nid]; ++j) { + builder.MergeToArray(nid, kBlockSize*j, v.data()); + } + + for(size_t j = 0; j < tasks[nid] * kBlockSize; ++j) { + ASSERT_EQ(v[j], j); + } + size_t n_left = builder.GetNLeftElems(nid); + size_t n_right = builder.GetNRightElems(nid); + + ASSERT_EQ(n_left, rows_for_left_node[nid] * tasks[nid]); + ASSERT_EQ(n_right, (kBlockSize - rows_for_left_node[nid]) * tasks[nid]); + } +} + +} // namespace common +} // namespace xgboost