From 51b50dadabdf93809f03dffaa30b98855acf187c Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Wed, 20 Nov 2019 14:43:00 -0800 Subject: [PATCH 01/48] add skeleton gradient-based sampler --- src/tree/gpu_hist/gradient_based_sampler.cu | 15 +++++++++++ src/tree/gpu_hist/gradient_based_sampler.cuh | 27 +++++++++++++++++++ tests/cpp/helpers.h | 13 +++++++++ .../gpu_hist/test_gradient_based_sampler.cu | 25 +++++++++++++++++ tests/cpp/tree/test_gpu_hist.cu | 13 --------- 5 files changed, 80 insertions(+), 13 deletions(-) create mode 100644 src/tree/gpu_hist/gradient_based_sampler.cu create mode 100644 src/tree/gpu_hist/gradient_based_sampler.cuh create mode 100644 tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu new file mode 100644 index 000000000000..f98a7b8a6130 --- /dev/null +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -0,0 +1,15 @@ +/*! + * Copyright 2019 by XGBoost Contributors + */ +#include "gradient_based_sampler.cuh" + +namespace xgboost { +namespace tree { + +void GradientBasedSampler::Sample(HostDeviceVector* gpair, + DMatrix* dmat, + size_t sample_rows) { + +} +}; // namespace tree +}; // namespace xgboost diff --git a/src/tree/gpu_hist/gradient_based_sampler.cuh b/src/tree/gpu_hist/gradient_based_sampler.cuh new file mode 100644 index 000000000000..ec798262ca28 --- /dev/null +++ b/src/tree/gpu_hist/gradient_based_sampler.cuh @@ -0,0 +1,27 @@ +/*! + * Copyright 2019 by XGBoost Contributors + */ +#pragma once +#include + +namespace xgboost { +namespace tree { + +/*! \brief Draw a sample of rows from a DMatrix. + * + * Use Poisson sampling to draw a probability proportional to size (pps) sample of rows from a + * DMatrix, where "size" is the absolute value of the gradient. + * + * \see Ke, G., Meng, Q., Finley, T., Wang, T., Chen, W., Ma, W., ... & Liu, T. Y. (2017). + * Lightgbm: A highly efficient gradient boosting decision tree. In Advances in Neural Information + * Processing Systems (pp. 3146-3154). + * \see Zhu, R. (2016). Gradient-based sampling: An adaptive importance sampling for least-squares. + * In Advances in Neural Information Processing Systems (pp. 406-414). + * \see Ohlsson, E. (1998). Sequential poisson sampling. Journal of official Statistics, 14(2), 149. + */ +class GradientBasedSampler { + public: + void Sample(HostDeviceVector* gpair, DMatrix* dmat, size_t sample_rows); +}; +}; // namespace tree +}; // namespace xgboost diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index b5bbf0ed7b94..d5eb9eb49c7d 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -214,6 +214,19 @@ inline GenericParameter CreateEmptyGenericParam(int gpu_id) { return tparam; } +inline HostDeviceVector GenerateRandomGradients(const size_t n_rows) { + xgboost::SimpleLCG gen; + xgboost::SimpleRealUniformDistribution dist(0.0f, 1.0f); + std::vector h_gpair(n_rows); + for (auto &gpair : h_gpair) { + bst_float grad = dist(&gen); + bst_float hess = dist(&gen); + gpair = GradientPair(grad, hess); + } + HostDeviceVector gpair(h_gpair); + return gpair; +} + #if defined(__CUDACC__) namespace { class HistogramCutsWrapper : public common::HistogramCuts { diff --git a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu new file mode 100644 index 000000000000..baef4dde46c0 --- /dev/null +++ b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu @@ -0,0 +1,25 @@ +#include + +#include "../../../../src/tree/gpu_hist/gradient_based_sampler.cuh" +#include "../../helpers.h" + +namespace xgboost { +namespace tree { + +TEST(GradientBasedSampler, Sample) { + constexpr size_t kRows = 2048; + constexpr size_t kCols = 16; + constexpr size_t kSampleRows = 512; + constexpr size_t kPageSize = 1024; + + // Create a DMatrix with multiple batches. + dmlc::TemporaryDirectory tmpdir; + std::unique_ptr + dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir)); + auto gpair = GenerateRandomGradients(kRows); + + GradientBasedSampler sampler; + sampler.Sample(&gpair, dmat.get(), kSampleRows); +} +}; // namespace tree +}; // namespace xgboost diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 38d455c68fd3..3254d12fc61c 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -319,19 +319,6 @@ int32_t TestMinSplitLoss(DMatrix* dmat, float gamma, HostDeviceVector GenerateRandomGradients(const size_t n_rows) { - xgboost::SimpleLCG gen; - xgboost::SimpleRealUniformDistribution dist(0.0f, 1.0f); - std::vector h_gpair(n_rows); - for (auto &gpair : h_gpair) { - bst_float grad = dist(&gen); - bst_float hess = dist(&gen); - gpair = GradientPair(grad, hess); - } - HostDeviceVector gpair(h_gpair); - return gpair; -} - TEST(GpuHist, MinSplitLoss) { constexpr size_t kRows = 32; constexpr size_t kCols = 16; From ebd67c13cd06d1ad7698c628c8f8bdbb873ccb69 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Wed, 20 Nov 2019 16:10:54 -0800 Subject: [PATCH 02/48] change gpu_hist to use sampling in external memory mode --- src/tree/gpu_hist/gradient_based_sampler.cu | 10 +- src/tree/gpu_hist/gradient_based_sampler.cuh | 11 ++ src/tree/gpu_hist/row_partitioner.cuh | 1 - src/tree/updater_gpu_hist.cu | 111 ++++++++----------- tests/cpp/tree/test_gpu_hist.cu | 2 +- 5 files changed, 64 insertions(+), 71 deletions(-) diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index f98a7b8a6130..06a21ae56e4b 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -6,10 +6,16 @@ namespace xgboost { namespace tree { +GradientBasedSample GradientBasedSampler::Sample(HostDeviceVector* gpair, + DMatrix* dmat, + BatchParam batch_param) { + auto page = (*dmat->GetBatches(batch_param).begin()).Impl(); + return {page, gpair}; +} + void GradientBasedSampler::Sample(HostDeviceVector* gpair, DMatrix* dmat, - size_t sample_rows) { + size_t sample_rows) {} -} }; // namespace tree }; // namespace xgboost diff --git a/src/tree/gpu_hist/gradient_based_sampler.cuh b/src/tree/gpu_hist/gradient_based_sampler.cuh index ec798262ca28..02a96465462f 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cuh +++ b/src/tree/gpu_hist/gradient_based_sampler.cuh @@ -7,6 +7,13 @@ namespace xgboost { namespace tree { +struct GradientBasedSample { + /*!\brief The sample rows in ELLPACK format. */ + EllpackPageImpl* page; + /*!\brief Rescaled gradient pairs for the sampled rows. */ + HostDeviceVector* gpair; +}; + /*! \brief Draw a sample of rows from a DMatrix. * * Use Poisson sampling to draw a probability proportional to size (pps) sample of rows from a @@ -21,6 +28,10 @@ namespace tree { */ class GradientBasedSampler { public: + GradientBasedSample Sample(HostDeviceVector* gpair, + DMatrix* dmat, + BatchParam batch_param); + void Sample(HostDeviceVector* gpair, DMatrix* dmat, size_t sample_rows); }; }; // namespace tree diff --git a/src/tree/gpu_hist/row_partitioner.cuh b/src/tree/gpu_hist/row_partitioner.cuh index 2d9500faf4a7..4818d71abc9f 100644 --- a/src/tree/gpu_hist/row_partitioner.cuh +++ b/src/tree/gpu_hist/row_partitioner.cuh @@ -125,7 +125,6 @@ class RowPartitioner { idx += segment.begin; RowIndexT ridx = d_ridx[idx]; bst_node_t new_position = op(ridx); // new node id - if (new_position == kIgnoredTreePosition) return; KERNEL_CHECK(new_position == left_nidx || new_position == right_nidx); AtomicIncrement(d_left_count, new_position == left_nidx); d_position[idx] = new_position; diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index c06f646562b0..9a3e49d58480 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -28,6 +28,7 @@ #include "param.h" #include "updater_gpu_common.cuh" #include "constraints.cuh" +#include "gpu_hist/gradient_based_sampler.cuh" #include "gpu_hist/row_partitioner.cuh" namespace xgboost { @@ -479,6 +480,9 @@ struct GPUHistMakerDevice { std::function>; std::unique_ptr qexpand; + bool use_gradient_based_sampling {false}; + std::unique_ptr gradient_based_sampler; + GPUHistMakerDevice(int _device_id, EllpackPageImpl* _page, bst_uint _n_rows, @@ -493,7 +497,11 @@ struct GPUHistMakerDevice { prediction_cache_initialised(false), column_sampler(column_sampler_seed), interaction_constraints(param, n_features), - batch_param(_batch_param) { + batch_param(_batch_param), + use_gradient_based_sampling(_page->matrix.n_rows != _n_rows) { + if (use_gradient_based_sampling) { + gradient_based_sampler.reset(new GradientBasedSampler()); + } monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(device_id)); } @@ -527,7 +535,7 @@ struct GPUHistMakerDevice { // Reset values for each update iteration // Note that the column sampler must be passed by value because it is not // thread safe - void Reset(HostDeviceVector* dh_gpair, int64_t num_columns) { + void Reset(HostDeviceVector* dh_gpair, DMatrix* dmat, int64_t num_columns) { if (param.grow_policy == TrainParam::kLossGuide) { qexpand.reset(new ExpandQueue(LossGuide)); } else { @@ -539,13 +547,21 @@ struct GPUHistMakerDevice { this->interaction_constraints.Reset(); std::fill(node_sum_gradients.begin(), node_sum_gradients.end(), GradientPair()); + + if (use_gradient_based_sampling) { + auto sample = gradient_based_sampler->Sample(dh_gpair, dmat, batch_param); + page = sample.page; + gpair = sample.gpair->DeviceSpan(); + n_rows = page->matrix.n_rows; + } else { + dh::safe_cuda(cudaMemcpyAsync( + gpair.data(), dh_gpair->ConstDevicePointer(), + gpair.size() * sizeof(GradientPair), cudaMemcpyHostToHost)); + SubsampleGradientPair(device_id, gpair, param.subsample); + } + row_partitioner.reset(); // Release the device memory first before reallocating row_partitioner.reset(new RowPartitioner(device_id, n_rows)); - - dh::safe_cuda(cudaMemcpyAsync( - gpair.data(), dh_gpair->ConstDevicePointer(), - gpair.size() * sizeof(GradientPair), cudaMemcpyHostToHost)); - SubsampleGradientPair(device_id, gpair, param.subsample); hist.Reset(); } @@ -631,14 +647,6 @@ struct GPUHistMakerDevice { return std::vector(result_all.begin(), result_all.end()); } - // Build gradient histograms for a given node across all the batches in the DMatrix. - void BuildHistBatches(int nidx, DMatrix* p_fmat) { - for (auto& batch : p_fmat->GetBatches(batch_param)) { - page = batch.Impl(); - BuildHist(nidx); - } - } - void BuildHist(int nidx) { hist.AllocateHistogram(nidx); auto d_node_hist = hist.GetNodeHistogram(nidx); @@ -686,10 +694,7 @@ struct GPUHistMakerDevice { row_partitioner->UpdatePosition( nidx, split_node.LeftChild(), split_node.RightChild(), - [=] __device__(size_t ridx) { - if (!d_matrix.IsInRange(ridx)) { - return RowPartitioner::kIgnoredTreePosition; - } + [=] __device__(bst_uint ridx) { // given a row index, returns the node id it belongs to bst_float cut_value = d_matrix.GetElement(ridx, split_node.SplitIndex()); @@ -718,6 +723,10 @@ struct GPUHistMakerDevice { d_nodes.size() * sizeof(RegTree::Node), cudaMemcpyHostToDevice)); + if (row_partitioner->GetRows().size() != n_rows) { + row_partitioner.reset(); // Release the device memory first before reallocating + row_partitioner.reset(new RowPartitioner(device_id, p_fmat->Info().num_row_)); + } for (auto& batch : p_fmat->GetBatches(batch_param)) { page = batch.Impl(); auto d_matrix = page->matrix; @@ -796,7 +805,8 @@ struct GPUHistMakerDevice { /** * \brief Build GPU local histograms for the left and right child of some parent node */ - void BuildHistLeftRight(const ExpandEntry &candidate, int nidx_left, int nidx_right) { + void BuildHistLeftRight(const ExpandEntry &candidate, int nidx_left, + int nidx_right, dh::AllReducer* reducer) { auto build_hist_nidx = nidx_left; auto subtraction_trick_nidx = nidx_right; @@ -808,34 +818,6 @@ struct GPUHistMakerDevice { } this->BuildHist(build_hist_nidx); - - // Check whether we can use the subtraction trick to calculate the other - bool do_subtraction_trick = this->CanDoSubtractionTrick( - candidate.nid, build_hist_nidx, subtraction_trick_nidx); - - if (!do_subtraction_trick) { - // Calculate other histogram manually - this->BuildHist(subtraction_trick_nidx); - } - } - - /** - * \brief AllReduce GPU histograms for the left and right child of some parent node. - */ - void ReduceHistLeftRight(const ExpandEntry& candidate, - int nidx_left, - int nidx_right, - dh::AllReducer* reducer) { - auto build_hist_nidx = nidx_left; - auto subtraction_trick_nidx = nidx_right; - - // Decide whether to build the left histogram or right histogram - // Use sum of Hessian as a heuristic to select node with fewest training instances - bool fewer_right = candidate.split.right_sum.GetHess() < candidate.split.left_sum.GetHess(); - if (fewer_right) { - std::swap(build_hist_nidx, subtraction_trick_nidx); - } - this->AllReduceHist(build_hist_nidx, reducer); // Check whether we can use the subtraction trick to calculate the other @@ -848,6 +830,7 @@ struct GPUHistMakerDevice { subtraction_trick_nidx); } else { // Calculate other histogram manually + this->BuildHist(subtraction_trick_nidx); this->AllReduceHist(subtraction_trick_nidx, reducer); } } @@ -888,7 +871,7 @@ struct GPUHistMakerDevice { tree[candidate.nid].RightChild()); } - void InitRoot(RegTree* p_tree, HostDeviceVector* gpair_all, DMatrix* p_fmat, + void InitRoot(RegTree* p_tree, HostDeviceVector* gpair_all, dh::AllReducer* reducer, int64_t num_columns) { constexpr int kRootNIdx = 0; @@ -904,7 +887,7 @@ struct GPUHistMakerDevice { node_sum_gradients_d.data(), sizeof(GradientPair), cudaMemcpyDeviceToHost)); - this->BuildHistBatches(kRootNIdx, p_fmat); + this->BuildHist(kRootNIdx); this->AllReduceHist(kRootNIdx, reducer); // Remember root stats @@ -927,11 +910,11 @@ struct GPUHistMakerDevice { auto& tree = *p_tree; monitor.StartCuda("Reset"); - this->Reset(gpair_all, p_fmat->Info().num_col_); + this->Reset(gpair_all, p_fmat, p_fmat->Info().num_col_); monitor.StopCuda("Reset"); monitor.StartCuda("InitRoot"); - this->InitRoot(p_tree, gpair_all, p_fmat, reducer, p_fmat->Info().num_col_); + this->InitRoot(p_tree, gpair_all, reducer, p_fmat->Info().num_col_); monitor.StopCuda("InitRoot"); auto timestamp = qexpand->size(); @@ -950,21 +933,15 @@ struct GPUHistMakerDevice { int left_child_nidx = tree[candidate.nid].LeftChild(); int right_child_nidx = tree[candidate.nid].RightChild(); // Only create child entries if needed - if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx), num_leaves)) { - for (auto& batch : p_fmat->GetBatches(batch_param)) { - page = batch.Impl(); - - monitor.StartCuda("UpdatePosition"); - this->UpdatePosition(candidate.nid, (*p_tree)[candidate.nid]); - monitor.StopCuda("UpdatePosition"); - - monitor.StartCuda("BuildHist"); - this->BuildHistLeftRight(candidate, left_child_nidx, right_child_nidx); - monitor.StopCuda("BuildHist"); - } - monitor.StartCuda("ReduceHist"); - this->ReduceHistLeftRight(candidate, left_child_nidx, right_child_nidx, reducer); - monitor.StopCuda("ReduceHist"); + if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx), + num_leaves)) { + monitor.StartCuda("UpdatePosition"); + this->UpdatePosition(candidate.nid, (*p_tree)[candidate.nid]); + monitor.StopCuda("UpdatePosition"); + + monitor.StartCuda("BuildHist"); + this->BuildHistLeftRight(candidate, left_child_nidx, right_child_nidx, reducer); + monitor.StopCuda("BuildHist"); monitor.StartCuda("EvaluateSplits"); auto splits = this->EvaluateSplits({left_child_nidx, right_child_nidx}, diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 3254d12fc61c..c555b7941fe9 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -378,7 +378,7 @@ void UpdateTree(HostDeviceVector* gpair, hist_maker.UpdatePredictionCache(dmat, preds); } -TEST(GpuHist, ExternalMemory) { +TEST(GpuHist, DISABLED_ExternalMemory) { constexpr size_t kRows = 6; constexpr size_t kCols = 2; constexpr size_t kPageSize = 1; From a2c446c80cf781febd82d3ea377dd3218cd8cf1f Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Wed, 20 Nov 2019 16:51:09 -0800 Subject: [PATCH 03/48] add failing tests --- src/tree/gpu_hist/gradient_based_sampler.cu | 7 ++----- src/tree/gpu_hist/gradient_based_sampler.cuh | 5 ++--- .../gpu_hist/test_gradient_based_sampler.cu | 18 +++++++++++++++++- tests/cpp/tree/test_gpu_hist.cu | 2 +- 4 files changed, 22 insertions(+), 10 deletions(-) diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index 06a21ae56e4b..79dede6c471b 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -8,14 +8,11 @@ namespace tree { GradientBasedSample GradientBasedSampler::Sample(HostDeviceVector* gpair, DMatrix* dmat, - BatchParam batch_param) { + BatchParam batch_param, + size_t sample_rows) { auto page = (*dmat->GetBatches(batch_param).begin()).Impl(); return {page, gpair}; } -void GradientBasedSampler::Sample(HostDeviceVector* gpair, - DMatrix* dmat, - size_t sample_rows) {} - }; // namespace tree }; // namespace xgboost diff --git a/src/tree/gpu_hist/gradient_based_sampler.cuh b/src/tree/gpu_hist/gradient_based_sampler.cuh index 02a96465462f..9ec407d09b4b 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cuh +++ b/src/tree/gpu_hist/gradient_based_sampler.cuh @@ -30,9 +30,8 @@ class GradientBasedSampler { public: GradientBasedSample Sample(HostDeviceVector* gpair, DMatrix* dmat, - BatchParam batch_param); - - void Sample(HostDeviceVector* gpair, DMatrix* dmat, size_t sample_rows); + BatchParam batch_param, + size_t sample_rows = 0); }; }; // namespace tree }; // namespace xgboost diff --git a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu index baef4dde46c0..172bcf924d68 100644 --- a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu +++ b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu @@ -1,5 +1,6 @@ #include +#include "../../../../src/data/ellpack_page.cuh" #include "../../../../src/tree/gpu_hist/gradient_based_sampler.cuh" #include "../../helpers.h" @@ -19,7 +20,22 @@ TEST(GradientBasedSampler, Sample) { auto gpair = GenerateRandomGradients(kRows); GradientBasedSampler sampler; - sampler.Sample(&gpair, dmat.get(), kSampleRows); + BatchParam param{0, 256, 0, kPageSize}; + auto sample = sampler.Sample(&gpair, dmat.get(), param, kSampleRows); + auto page = sample.page; + auto scaled_gpair = sample.gpair; + EXPECT_NEAR(page->matrix.n_rows, kSampleRows, 5); + EXPECT_EQ(page->matrix.n_rows, scaled_gpair->Size()); + + float gradients = 0; + for (auto gp : gpair.ConstHostVector()) { + gradients += gp.GetGrad(); + } + float scaled_gradients = 0; + for (auto gp : scaled_gpair->ConstHostVector()) { + scaled_gradients += gp.GetGrad(); + } + EXPECT_FLOAT_EQ(gradients, scaled_gradients); } }; // namespace tree }; // namespace xgboost diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index c555b7941fe9..3254d12fc61c 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -378,7 +378,7 @@ void UpdateTree(HostDeviceVector* gpair, hist_maker.UpdatePredictionCache(dmat, preds); } -TEST(GpuHist, DISABLED_ExternalMemory) { +TEST(GpuHist, ExternalMemory) { constexpr size_t kRows = 6; constexpr size_t kCols = 2; constexpr size_t kPageSize = 1; From 9e91871607073927da42368e20b49bbd322483a0 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 21 Nov 2019 15:54:21 -0800 Subject: [PATCH 04/48] wip: poisson sampling --- src/tree/gpu_hist/gradient_based_sampler.cu | 36 +++++++++++++++++++ src/tree/updater_gpu_hist.cu | 10 ++---- .../gpu_hist/test_gradient_based_sampler.cu | 1 + 3 files changed, 40 insertions(+), 7 deletions(-) diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index 79dede6c471b..439f4c467c4d 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -1,15 +1,51 @@ /*! * Copyright 2019 by XGBoost Contributors */ +#include +#include +#include + +#include + +#include "../../common/device_helpers.cuh" #include "gradient_based_sampler.cuh" namespace xgboost { namespace tree { +/*! \brief A functor that returns the absolute value of gradient from a gradient pair. */ +struct abs_grad : public thrust::unary_function { + __device__ + float operator()(const GradientPair& gpair) const { + return fabsf(gpair.GetGrad()); + } +}; + +struct sample_and_scale : public thrust::unary_function { + const size_t expected_sample_rows; + const float sum_abs_gradient; + + sample_and_scale(size_t _expected_sample_rows, float _sum_abs_gradient) + : expected_sample_rows(_expected_sample_rows), sum_abs_gradient(_sum_abs_gradient) {} + + __device__ + GradientPair operator()(const GradientPair& gpair) { + return GradientPair(); + } +}; + GradientBasedSample GradientBasedSampler::Sample(HostDeviceVector* gpair, DMatrix* dmat, BatchParam batch_param, size_t sample_rows) { + float sum_abs_gradient = thrust::transform_reduce( + dh::tbegin(*gpair), dh::tend(*gpair), abs_grad(), 0.0f, thrust::plus()); + + HostDeviceVector scaled_gpair(gpair->Size()); + scaled_gpair.SetDevice(batch_param.gpu_id); + thrust::transform(dh::tbegin(*gpair), dh::tend(*gpair), dh::tbegin(scaled_gpair), + sample_and_scale(sample_rows, sum_abs_gradient)); + auto page = (*dmat->GetBatches(batch_param).begin()).Impl(); return {page, gpair}; } diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 9a3e49d58480..626dfa4d258c 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -871,14 +871,10 @@ struct GPUHistMakerDevice { tree[candidate.nid].RightChild()); } - void InitRoot(RegTree* p_tree, HostDeviceVector* gpair_all, - dh::AllReducer* reducer, int64_t num_columns) { + void InitRoot(RegTree* p_tree, dh::AllReducer* reducer, int64_t num_columns) { constexpr int kRootNIdx = 0; - const auto &gpair = gpair_all->DeviceSpan(); - - dh::SumReduction(temp_memory, gpair, node_sum_gradients_d, - gpair.size()); + dh::SumReduction(temp_memory, gpair, node_sum_gradients_d, gpair.size()); reducer->AllReduceSum( reinterpret_cast(node_sum_gradients_d.data()), reinterpret_cast(node_sum_gradients_d.data()), 2); @@ -914,7 +910,7 @@ struct GPUHistMakerDevice { monitor.StopCuda("Reset"); monitor.StartCuda("InitRoot"); - this->InitRoot(p_tree, gpair_all, reducer, p_fmat->Info().num_col_); + this->InitRoot(p_tree, reducer, p_fmat->Info().num_col_); monitor.StopCuda("InitRoot"); auto timestamp = qexpand->size(); diff --git a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu index 172bcf924d68..7f576ce38c8c 100644 --- a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu +++ b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu @@ -18,6 +18,7 @@ TEST(GradientBasedSampler, Sample) { std::unique_ptr dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir)); auto gpair = GenerateRandomGradients(kRows); + gpair.SetDevice(0); GradientBasedSampler sampler; BatchParam param{0, 256, 0, kPageSize}; From ed322cc6ea44c7f7581c13c5e2242d3c64a89e5b Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 22 Nov 2019 16:10:29 -0800 Subject: [PATCH 05/48] sample and scale gradient pairs --- include/xgboost/base.h | 14 +++++ src/tree/gpu_hist/gradient_based_sampler.cu | 59 ++++++++++++++----- src/tree/gpu_hist/gradient_based_sampler.cuh | 4 +- src/tree/updater_gpu_hist.cu | 9 +-- .../gpu_hist/test_gradient_based_sampler.cu | 19 +++--- 5 files changed, 75 insertions(+), 30 deletions(-) diff --git a/include/xgboost/base.h b/include/xgboost/base.h index 1a4df84c0f12..8b551ed325df 100644 --- a/include/xgboost/base.h +++ b/include/xgboost/base.h @@ -193,6 +193,20 @@ class GradientPairInternal { return g; } + XGBOOST_DEVICE GradientPairInternal operator*(float multiplier) const { + GradientPairInternal g; + g.grad_ = grad_ * multiplier; + g.hess_ = hess_ * multiplier; + return g; + } + + XGBOOST_DEVICE GradientPairInternal operator/(float divider) const { + GradientPairInternal g; + g.grad_ = grad_ / divider; + g.hess_ = hess_ / divider; + return g; + } + XGBOOST_DEVICE explicit GradientPairInternal(int value) { *this = GradientPairInternal(static_cast(value), static_cast(value)); diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index 439f4c467c4d..c035be3b06f7 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -1,13 +1,15 @@ /*! * Copyright 2019 by XGBoost Contributors */ -#include #include +#include +#include #include #include #include "../../common/device_helpers.cuh" +#include "../../common/random.h" #include "gradient_based_sampler.cuh" namespace xgboost { @@ -15,22 +17,44 @@ namespace tree { /*! \brief A functor that returns the absolute value of gradient from a gradient pair. */ struct abs_grad : public thrust::unary_function { - __device__ - float operator()(const GradientPair& gpair) const { + XGBOOST_DEVICE float operator()(const GradientPair& gpair) const { return fabsf(gpair.GetGrad()); } }; -struct sample_and_scale : public thrust::unary_function { - const size_t expected_sample_rows; +/*! \brief A functor that samples and scales a gradient pair. + * + * Sampling probability is proportional to the absolute value of the gradient. If selected, the + * gradient pair is re-scaled proportional to (1 / probability). + */ +struct sample_and_scale : public thrust::binary_function { + const size_t sample_rows; const float sum_abs_gradient; + const uint32_t seed; + + XGBOOST_DEVICE sample_and_scale(size_t _sample_rows, float _sum_abs_gradient, size_t _seed) + : sample_rows(_sample_rows), sum_abs_gradient(_sum_abs_gradient), seed(_seed) {} - sample_and_scale(size_t _expected_sample_rows, float _sum_abs_gradient) - : expected_sample_rows(_expected_sample_rows), sum_abs_gradient(_sum_abs_gradient) {} + XGBOOST_DEVICE GradientPair operator()(const GradientPair& gpair, size_t i) { + thrust::default_random_engine rng(seed); + thrust::uniform_real_distribution dist; + rng.discard(i); + float p = sample_rows * fabsf(gpair.GetGrad()) / sum_abs_gradient; + if (p > 1.0f) { + p = 1.0f; + } + if (dist(rng) <= p) { + return gpair / p; + } else { + return GradientPair(); + } + } +}; - __device__ - GradientPair operator()(const GradientPair& gpair) { - return GradientPair(); +/*! \brief A functor that returns true if the gradient pair is non-zero. */ +struct is_non_zero : public thrust::unary_function { + XGBOOST_DEVICE bool operator()(const GradientPair& gpair) const { + return gpair.GetGrad() != 0 || gpair.GetHess() != 0; } }; @@ -41,13 +65,18 @@ GradientBasedSample GradientBasedSampler::Sample(HostDeviceVector* float sum_abs_gradient = thrust::transform_reduce( dh::tbegin(*gpair), dh::tend(*gpair), abs_grad(), 0.0f, thrust::plus()); - HostDeviceVector scaled_gpair(gpair->Size()); - scaled_gpair.SetDevice(batch_param.gpu_id); - thrust::transform(dh::tbegin(*gpair), dh::tend(*gpair), dh::tbegin(scaled_gpair), - sample_and_scale(sample_rows, sum_abs_gradient)); + thrust::transform(dh::tbegin(*gpair), dh::tend(*gpair), + thrust::counting_iterator(0), + dh::tbegin(*gpair), + sample_and_scale(sample_rows, sum_abs_gradient, common::GlobalRandom()())); + + size_t out_size = thrust::count_if(dh::tbegin(*gpair), dh::tend(*gpair), is_non_zero()); + + HostDeviceVector out_gpair(out_size, GradientPair(), batch_param.gpu_id); + thrust::copy_if(dh::tbegin(*gpair), dh::tend(*gpair), dh::tbegin(out_gpair), is_non_zero()); auto page = (*dmat->GetBatches(batch_param).begin()).Impl(); - return {page, gpair}; + return {page, out_gpair}; } }; // namespace tree diff --git a/src/tree/gpu_hist/gradient_based_sampler.cuh b/src/tree/gpu_hist/gradient_based_sampler.cuh index 9ec407d09b4b..9273e2193e84 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cuh +++ b/src/tree/gpu_hist/gradient_based_sampler.cuh @@ -8,10 +8,10 @@ namespace xgboost { namespace tree { struct GradientBasedSample { - /*!\brief The sample rows in ELLPACK format. */ + /*!\brief Sampled rows in ELLPACK format. */ EllpackPageImpl* page; /*!\brief Rescaled gradient pairs for the sampled rows. */ - HostDeviceVector* gpair; + HostDeviceVector gpair; }; /*! \brief Draw a sample of rows from a DMatrix. diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 626dfa4d258c..ba6ac9e9d02f 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -548,15 +548,16 @@ struct GPUHistMakerDevice { std::fill(node_sum_gradients.begin(), node_sum_gradients.end(), GradientPair()); + dh::safe_cuda(cudaMemcpyAsync( + gpair.data(), dh_gpair->ConstDevicePointer(), + gpair.size() * sizeof(GradientPair), cudaMemcpyDeviceToDevice)); + if (use_gradient_based_sampling) { auto sample = gradient_based_sampler->Sample(dh_gpair, dmat, batch_param); page = sample.page; - gpair = sample.gpair->DeviceSpan(); + gpair = sample.gpair.DeviceSpan(); n_rows = page->matrix.n_rows; } else { - dh::safe_cuda(cudaMemcpyAsync( - gpair.data(), dh_gpair->ConstDevicePointer(), - gpair.size() * sizeof(GradientPair), cudaMemcpyHostToHost)); SubsampleGradientPair(device_id, gpair, param.subsample); } diff --git a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu index 7f576ce38c8c..5f8d98367948 100644 --- a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu +++ b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu @@ -24,19 +24,20 @@ TEST(GradientBasedSampler, Sample) { BatchParam param{0, 256, 0, kPageSize}; auto sample = sampler.Sample(&gpair, dmat.get(), param, kSampleRows); auto page = sample.page; - auto scaled_gpair = sample.gpair; - EXPECT_NEAR(page->matrix.n_rows, kSampleRows, 5); - EXPECT_EQ(page->matrix.n_rows, scaled_gpair->Size()); + auto sampled_gpair = sample.gpair; + EXPECT_NEAR(sampled_gpair.Size(), kSampleRows, 12); + EXPECT_NEAR(page->matrix.n_rows, kSampleRows, 12); + EXPECT_EQ(page->matrix.n_rows, sampled_gpair.Size()); - float gradients = 0; + float sum_gradients = 0; for (auto gp : gpair.ConstHostVector()) { - gradients += gp.GetGrad(); + sum_gradients += gp.GetGrad(); } - float scaled_gradients = 0; - for (auto gp : scaled_gpair->ConstHostVector()) { - scaled_gradients += gp.GetGrad(); + float sum_sampled_gradients = 0; + for (auto gp : sampled_gpair.ConstHostVector()) { + sum_sampled_gradients += gp.GetGrad(); } - EXPECT_FLOAT_EQ(gradients, scaled_gradients); + EXPECT_FLOAT_EQ(sum_gradients, sum_sampled_gradients); } }; // namespace tree }; // namespace xgboost From 93770142bdd4927d3584aa96ca3a9fa985a2a8f2 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Mon, 2 Dec 2019 16:09:32 -0800 Subject: [PATCH 06/48] calculate max number of sample rows --- src/common/compressed_iterator.h | 23 ++++++++++++++++++-- src/data/ellpack_page.cu | 9 +++----- src/data/ellpack_page.cuh | 4 ++++ src/tree/gpu_hist/gradient_based_sampler.cu | 9 ++++++++ src/tree/gpu_hist/gradient_based_sampler.cuh | 4 ++++ tests/cpp/common/test_compressed_iterator.cc | 21 ++++++++++++++++++ 6 files changed, 62 insertions(+), 8 deletions(-) diff --git a/src/common/compressed_iterator.h b/src/common/compressed_iterator.h index 0f6e93695436..df34ec15666d 100644 --- a/src/common/compressed_iterator.h +++ b/src/common/compressed_iterator.h @@ -63,7 +63,7 @@ class CompressedBufferWriter { * \fn static size_t CompressedBufferWriter::CalculateBufferSize(int * num_elements, int num_symbols) * - * \brief Calculates number of bytes requiredm for a given number of elements + * \brief Calculates number of bytes required for a given number of elements * and a symbol range. * * \author Rory @@ -74,7 +74,6 @@ class CompressedBufferWriter { * * \return The calculated buffer size. */ - static size_t CalculateBufferSize(size_t num_elements, size_t num_symbols) { const int bits_per_byte = 8; size_t compressed_size = static_cast(std::ceil( @@ -83,6 +82,26 @@ class CompressedBufferWriter { return compressed_size + detail::kPadding; } + /** + * \brief Calculates maximum number of rows that can fit in a given number of bytes. + * \param num_bytes Number of bytes. + * \param num_symbols Max number of symbols (alphabet size). + * \param row_stride Number of features per row. + * \param extra_bytes_per_row Extra number of bytes needed per row. + * \return The calculated number of rows. + */ + static size_t CalculateMaxRows(size_t num_bytes, + size_t num_symbols, + size_t row_stride, + size_t extra_bytes_per_row) { + const int bits_per_byte = 8; + size_t usable_bits = (num_bytes - detail::kPadding) * bits_per_byte; + size_t extra_bits = extra_bytes_per_row * bits_per_byte; + size_t symbol_bits = row_stride * detail::SymbolBits(num_symbols); + size_t num_rows = static_cast(std::floor(usable_bits / (extra_bits + symbol_bits))); + return num_rows; + } + template void WriteSymbol(CompressedByteT *buffer, T symbol, size_t offset) { const int bits_per_byte = 8; diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index a58c74e6f60d..e35866bb8436 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -123,7 +123,7 @@ void EllpackPageImpl::InitInfo(int device, // Initialize the buffer to stored compressed features. void EllpackPageImpl::InitCompressedData(int device, size_t num_rows) { - size_t num_symbols = matrix.info.n_bins + 1; + size_t num_symbols = matrix.info.NumSymbols(); // Required buffer size for storing data matrix in ELLPack format. size_t compressed_size_bytes = common::CompressedBufferWriter::CalculateBufferSize( @@ -149,7 +149,6 @@ void EllpackPageImpl::CreateHistIndices(int device, const auto& offset_vec = row_batch.offset.ConstHostVector(); - int num_symbols = matrix.info.n_bins + 1; // bin and compress entries in batches of rows size_t gpu_batch_nrows = std::min( dh::TotalMemory(device) / (16 * row_stride * sizeof(Entry)), @@ -193,7 +192,7 @@ void EllpackPageImpl::CreateHistIndices(int device, 1); dh::LaunchKernel {grid3, block3} ( CompressBinEllpackKernel, - common::CompressedBufferWriter(num_symbols), + common::CompressedBufferWriter(matrix.info.NumSymbols()), gidx_buffer.data(), row_ptrs.data().get(), entries_d.data().get(), @@ -254,11 +253,9 @@ void EllpackPageImpl::CompressSparsePage(int device) { // Return the memory cost for storing the compressed features. size_t EllpackPageImpl::MemCostBytes() const { - size_t num_symbols = matrix.info.n_bins + 1; - // Required buffer size for storing data matrix in ELLPack format. size_t compressed_size_bytes = common::CompressedBufferWriter::CalculateBufferSize( - matrix.info.row_stride * matrix.n_rows, num_symbols); + matrix.info.row_stride * matrix.n_rows, matrix.info.NumSymbols()); return compressed_size_bytes; } diff --git a/src/data/ellpack_page.cuh b/src/data/ellpack_page.cuh index 47fd98910a95..4bea9ca145ca 100644 --- a/src/data/ellpack_page.cuh +++ b/src/data/ellpack_page.cuh @@ -71,6 +71,10 @@ struct EllpackInfo { size_t row_stride, const common::HistogramCuts& hmat, dh::BulkAllocator* ba); + + inline size_t NumSymbols() const { + return n_bins + 1; + } }; /** \brief Struct for accessing and manipulating an ellpack matrix on the diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index c035be3b06f7..7fe572f1d271 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -8,6 +8,7 @@ #include +#include "../../common/compressed_iterator.h" #include "../../common/device_helpers.cuh" #include "../../common/random.h" #include "gradient_based_sampler.cuh" @@ -15,6 +16,14 @@ namespace xgboost { namespace tree { +size_t GradientBasedSampler::MaxSampleRows(int device, const EllpackInfo& info) { + size_t available_memory = dh::AvailableMemory(device); + size_t usable_memory = available_memory * 0.95; + size_t max_rows = common::CompressedBufferWriter::CalculateMaxRows( + usable_memory, info.NumSymbols(), info.row_stride, 2 * sizeof(float)); + return max_rows; +} + /*! \brief A functor that returns the absolute value of gradient from a gradient pair. */ struct abs_grad : public thrust::unary_function { XGBOOST_DEVICE float operator()(const GradientPair& gpair) const { diff --git a/src/tree/gpu_hist/gradient_based_sampler.cuh b/src/tree/gpu_hist/gradient_based_sampler.cuh index 9273e2193e84..1bb982c189f4 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cuh +++ b/src/tree/gpu_hist/gradient_based_sampler.cuh @@ -4,6 +4,8 @@ #pragma once #include +#include "../../data/ellpack_page.cuh" + namespace xgboost { namespace tree { @@ -28,6 +30,8 @@ struct GradientBasedSample { */ class GradientBasedSampler { public: + size_t MaxSampleRows(int device, const EllpackInfo& info); + GradientBasedSample Sample(HostDeviceVector* gpair, DMatrix* dmat, BatchParam batch_param, diff --git a/tests/cpp/common/test_compressed_iterator.cc b/tests/cpp/common/test_compressed_iterator.cc index 93243c0b336e..ea50e887e723 100644 --- a/tests/cpp/common/test_compressed_iterator.cc +++ b/tests/cpp/common/test_compressed_iterator.cc @@ -51,5 +51,26 @@ TEST(CompressedIterator, Test) { } } +TEST(CompressedIterator, CalculateMaxRows) { + const size_t num_bytes = 12652838912; + const size_t num_symbols = 256 * 100 + 1; + const size_t row_stride = 100; + const size_t extra_bytes = 8; + size_t num_rows = + CompressedBufferWriter::CalculateMaxRows(num_bytes, num_symbols, row_stride, extra_bytes); + EXPECT_EQ(num_rows, 64720403); + + // The calculated # rows should fit within the given number of bytes. + size_t buffer_size = CompressedBufferWriter::CalculateBufferSize(num_rows * 100, num_symbols); + size_t extras = extra_bytes * num_rows; + EXPECT_LE(buffer_size + extras, num_bytes); + + // An extra row wouldn't fit. + num_rows++; + buffer_size = CompressedBufferWriter::CalculateBufferSize(num_rows * 100, num_symbols); + extras = extra_bytes * num_rows; + EXPECT_GT(buffer_size + extras, num_bytes); +} + } // namespace common } // namespace xgboost From f26415c6d2b8579e4356bcab266de0e24554dce8 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Tue, 3 Dec 2019 17:04:00 -0800 Subject: [PATCH 07/48] add sampler constructor --- src/data/ellpack_page.cu | 15 ++++- src/data/ellpack_page.cuh | 8 +++ src/tree/gpu_hist/gradient_based_sampler.cu | 56 ++++++++++++++----- src/tree/gpu_hist/gradient_based_sampler.cuh | 30 ++++++++-- src/tree/updater_gpu_hist.cu | 17 +++--- tests/cpp/common/test_compressed_iterator.cc | 10 ++-- .../gpu_hist/test_gradient_based_sampler.cu | 20 ++++--- 7 files changed, 114 insertions(+), 42 deletions(-) diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index e35866bb8436..5b3844bd41f0 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -64,6 +64,20 @@ __global__ void CompressBinEllpackKernel( wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + ifeature); } +// Construct an empty ELLPACK matrix. +EllpackPageImpl::EllpackPageImpl(int device, EllpackInfo info, size_t n_rows) { + monitor_.Init("ellpack_page"); + dh::safe_cuda(cudaSetDevice(device)); + + matrix.info = info; + matrix.base_rowid = 0; + matrix.n_rows = n_rows; + + monitor_.StartCuda("InitCompressedData"); + InitCompressedData(device, n_rows); + monitor_.StopCuda("InitCompressedData"); +} + // Construct an ELLPACK matrix in memory. EllpackPageImpl::EllpackPageImpl(DMatrix* dmat, const BatchParam& param) { monitor_.Init("ellpack_page"); @@ -277,5 +291,4 @@ void EllpackPageImpl::InitDevice(int device, EllpackInfo info) { device_initialized_ = true; } - } // namespace xgboost diff --git a/src/data/ellpack_page.cuh b/src/data/ellpack_page.cuh index 4bea9ca145ca..989f61af10ce 100644 --- a/src/data/ellpack_page.cuh +++ b/src/data/ellpack_page.cuh @@ -204,6 +204,14 @@ class EllpackPageImpl { */ EllpackPageImpl() = default; + /*! + * \brief Constructor from an existing EllpackInfo. + * + * This is used in the sampling case. The ELLPACK page is constructed from an existing EllpackInfo + * and the given number of rows. + */ + explicit EllpackPageImpl(int device, EllpackInfo info, size_t n_rows); + /*! * \brief Constructor from an existing DMatrix. * diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index 7fe572f1d271..31c41d18f801 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -9,18 +9,41 @@ #include #include "../../common/compressed_iterator.h" -#include "../../common/device_helpers.cuh" #include "../../common/random.h" #include "gradient_based_sampler.cuh" namespace xgboost { namespace tree { -size_t GradientBasedSampler::MaxSampleRows(int device, const EllpackInfo& info) { - size_t available_memory = dh::AvailableMemory(device); +GradientBasedSampler::GradientBasedSampler(BatchParam batch_param, + EllpackInfo info, + size_t n_rows, + size_t sample_rows) + : batch_param_(batch_param), info_(info), sample_rows_(sample_rows) { + monitor_.Init("gradient_based_sampler"); + + if (sample_rows_ == 0) { + sample_rows_ = MaxSampleRows(); + } + if (sample_rows_ >= n_rows) { + is_sampling_ = false; + sample_rows_ = n_rows; + } else { + is_sampling_ = true; + } + + page_.reset(new EllpackPageImpl(batch_param.gpu_id, info, sample_rows_)); + if (is_sampling_) { + ba_.Allocate(batch_param.gpu_id, &gpair_, sample_rows_); + } +} + +size_t GradientBasedSampler::MaxSampleRows() { + size_t available_memory = dh::AvailableMemory(batch_param_.gpu_id); size_t usable_memory = available_memory * 0.95; + size_t gpair_bytes = sizeof(GradientPair); size_t max_rows = common::CompressedBufferWriter::CalculateMaxRows( - usable_memory, info.NumSymbols(), info.row_stride, 2 * sizeof(float)); + usable_memory, info_.NumSymbols(), info_.row_stride, gpair_bytes); return max_rows; } @@ -68,24 +91,29 @@ struct is_non_zero : public thrust::unary_function { }; GradientBasedSample GradientBasedSampler::Sample(HostDeviceVector* gpair, - DMatrix* dmat, - BatchParam batch_param, - size_t sample_rows) { + DMatrix* dmat) { + if (!is_sampling_) { + auto page = (*dmat->GetBatches(batch_param_).begin()).Impl(); + auto out_gpair = gpair->DeviceSpan(); + return {sample_rows_, page, out_gpair}; + } + float sum_abs_gradient = thrust::transform_reduce( dh::tbegin(*gpair), dh::tend(*gpair), abs_grad(), 0.0f, thrust::plus()); thrust::transform(dh::tbegin(*gpair), dh::tend(*gpair), thrust::counting_iterator(0), dh::tbegin(*gpair), - sample_and_scale(sample_rows, sum_abs_gradient, common::GlobalRandom()())); - - size_t out_size = thrust::count_if(dh::tbegin(*gpair), dh::tend(*gpair), is_non_zero()); + sample_and_scale(sample_rows_, sum_abs_gradient, common::GlobalRandom()())); - HostDeviceVector out_gpair(out_size, GradientPair(), batch_param.gpu_id); - thrust::copy_if(dh::tbegin(*gpair), dh::tend(*gpair), dh::tbegin(out_gpair), is_non_zero()); + thrust::copy_if(thrust::device, + dh::tbegin(*gpair), + dh::tend(*gpair), + gpair_.begin(), + is_non_zero()); - auto page = (*dmat->GetBatches(batch_param).begin()).Impl(); - return {page, out_gpair}; + auto page = (*dmat->GetBatches(batch_param_).begin()).Impl(); + return {sample_rows_, page, gpair_}; } }; // namespace tree diff --git a/src/tree/gpu_hist/gradient_based_sampler.cuh b/src/tree/gpu_hist/gradient_based_sampler.cuh index 1bb982c189f4..53c0bb9c336e 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cuh +++ b/src/tree/gpu_hist/gradient_based_sampler.cuh @@ -2,18 +2,23 @@ * Copyright 2019 by XGBoost Contributors */ #pragma once +#include #include +#include +#include "../../common/device_helpers.cuh" #include "../../data/ellpack_page.cuh" namespace xgboost { namespace tree { struct GradientBasedSample { + /*!\brief Number of sampled rows. */ + size_t sample_rows; /*!\brief Sampled rows in ELLPACK format. */ EllpackPageImpl* page; /*!\brief Rescaled gradient pairs for the sampled rows. */ - HostDeviceVector gpair; + common::Span gpair; }; /*! \brief Draw a sample of rows from a DMatrix. @@ -30,12 +35,25 @@ struct GradientBasedSample { */ class GradientBasedSampler { public: - size_t MaxSampleRows(int device, const EllpackInfo& info); + explicit GradientBasedSampler(BatchParam batch_param, + EllpackInfo info, + size_t n_rows, + size_t sample_rows = 0); - GradientBasedSample Sample(HostDeviceVector* gpair, - DMatrix* dmat, - BatchParam batch_param, - size_t sample_rows = 0); + /*! \brief Returns the max number of rows that can fit in available GPU memory. */ + size_t MaxSampleRows(); + + GradientBasedSample Sample(HostDeviceVector* gpair, DMatrix* dmat); + + private: + common::Monitor monitor_; + dh::BulkAllocator ba_; + BatchParam batch_param_; + EllpackInfo info_; + bool is_sampling_; + size_t sample_rows_; + std::unique_ptr page_; + common::Span gpair_; }; }; // namespace tree }; // namespace xgboost diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index ba6ac9e9d02f..6e3586ac55df 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -481,7 +481,7 @@ struct GPUHistMakerDevice { std::unique_ptr qexpand; bool use_gradient_based_sampling {false}; - std::unique_ptr gradient_based_sampler; + std::unique_ptr sampler; GPUHistMakerDevice(int _device_id, EllpackPageImpl* _page, @@ -500,7 +500,7 @@ struct GPUHistMakerDevice { batch_param(_batch_param), use_gradient_based_sampling(_page->matrix.n_rows != _n_rows) { if (use_gradient_based_sampling) { - gradient_based_sampler.reset(new GradientBasedSampler()); + sampler.reset(new GradientBasedSampler(batch_param, page->matrix.info, n_rows)); } monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(device_id)); } @@ -548,16 +548,15 @@ struct GPUHistMakerDevice { std::fill(node_sum_gradients.begin(), node_sum_gradients.end(), GradientPair()); - dh::safe_cuda(cudaMemcpyAsync( - gpair.data(), dh_gpair->ConstDevicePointer(), - gpair.size() * sizeof(GradientPair), cudaMemcpyDeviceToDevice)); - if (use_gradient_based_sampling) { - auto sample = gradient_based_sampler->Sample(dh_gpair, dmat, batch_param); + auto sample = sampler->Sample(dh_gpair, dmat); + n_rows = sample.sample_rows; page = sample.page; - gpair = sample.gpair.DeviceSpan(); - n_rows = page->matrix.n_rows; + gpair = sample.gpair; } else { + dh::safe_cuda(cudaMemcpyAsync( + gpair.data(), dh_gpair->ConstDevicePointer(), + gpair.size() * sizeof(GradientPair), cudaMemcpyDeviceToDevice)); SubsampleGradientPair(device_id, gpair, param.subsample); } diff --git a/tests/cpp/common/test_compressed_iterator.cc b/tests/cpp/common/test_compressed_iterator.cc index ea50e887e723..bfdae6814d9b 100644 --- a/tests/cpp/common/test_compressed_iterator.cc +++ b/tests/cpp/common/test_compressed_iterator.cc @@ -53,23 +53,23 @@ TEST(CompressedIterator, Test) { TEST(CompressedIterator, CalculateMaxRows) { const size_t num_bytes = 12652838912; - const size_t num_symbols = 256 * 100 + 1; const size_t row_stride = 100; + const size_t num_symbols = 256 * row_stride + 1; const size_t extra_bytes = 8; size_t num_rows = CompressedBufferWriter::CalculateMaxRows(num_bytes, num_symbols, row_stride, extra_bytes); EXPECT_EQ(num_rows, 64720403); // The calculated # rows should fit within the given number of bytes. - size_t buffer_size = CompressedBufferWriter::CalculateBufferSize(num_rows * 100, num_symbols); + size_t buffer = CompressedBufferWriter::CalculateBufferSize(num_rows * row_stride, num_symbols); size_t extras = extra_bytes * num_rows; - EXPECT_LE(buffer_size + extras, num_bytes); + EXPECT_LE(buffer + extras, num_bytes); // An extra row wouldn't fit. num_rows++; - buffer_size = CompressedBufferWriter::CalculateBufferSize(num_rows * 100, num_symbols); + buffer = CompressedBufferWriter::CalculateBufferSize(num_rows * row_stride, num_symbols); extras = extra_bytes * num_rows; - EXPECT_GT(buffer_size + extras, num_bytes); + EXPECT_GT(buffer + extras, num_bytes); } } // namespace common diff --git a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu index 5f8d98367948..c07e6c42a97a 100644 --- a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu +++ b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu @@ -20,24 +20,30 @@ TEST(GradientBasedSampler, Sample) { auto gpair = GenerateRandomGradients(kRows); gpair.SetDevice(0); - GradientBasedSampler sampler; BatchParam param{0, 256, 0, kPageSize}; - auto sample = sampler.Sample(&gpair, dmat.get(), param, kSampleRows); - auto page = sample.page; + auto page = (*dmat->GetBatches(param).begin()).Impl(); + + GradientBasedSampler sampler(param, page->matrix.info, kRows, kSampleRows); + auto sample = sampler.Sample(&gpair, dmat.get()); + page = sample.page; auto sampled_gpair = sample.gpair; - EXPECT_NEAR(sampled_gpair.Size(), kSampleRows, 12); - EXPECT_NEAR(page->matrix.n_rows, kSampleRows, 12); - EXPECT_EQ(page->matrix.n_rows, sampled_gpair.Size()); + EXPECT_EQ(sampled_gpair.size(), kSampleRows); + EXPECT_EQ(page->matrix.n_rows, kSampleRows); + EXPECT_EQ(page->matrix.n_rows, sampled_gpair.size()); float sum_gradients = 0; for (auto gp : gpair.ConstHostVector()) { sum_gradients += gp.GetGrad(); } + float sum_sampled_gradients = 0; - for (auto gp : sampled_gpair.ConstHostVector()) { + std::vector sampled_gpair_h(sampled_gpair.size()); + thrust::copy(sampled_gpair.begin(), sampled_gpair.end(), sampled_gpair_h.begin()); + for (auto gp : sampled_gpair_h) { sum_sampled_gradients += gp.GetGrad(); } EXPECT_FLOAT_EQ(sum_gradients, sum_sampled_gradients); } + }; // namespace tree }; // namespace xgboost From f2dd9280030961bb9cdf5d55e8e2ea8428f9c9bc Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Wed, 4 Dec 2019 14:59:30 -0800 Subject: [PATCH 08/48] collect all pages in memory if they fit --- include/xgboost/base.h | 4 ++ src/common/compressed_iterator.h | 2 +- src/data/ellpack_page.cu | 28 +++++++++ src/data/ellpack_page.cuh | 10 ++++ src/tree/gpu_hist/gradient_based_sampler.cu | 18 +++++- src/tree/gpu_hist/gradient_based_sampler.cuh | 5 ++ .../gpu_hist/test_gradient_based_sampler.cu | 60 +++++++++++++++++-- 7 files changed, 119 insertions(+), 8 deletions(-) diff --git a/include/xgboost/base.h b/include/xgboost/base.h index 8b551ed325df..6eb8c6065458 100644 --- a/include/xgboost/base.h +++ b/include/xgboost/base.h @@ -207,6 +207,10 @@ class GradientPairInternal { return g; } + XGBOOST_DEVICE bool operator==(const GradientPairInternal &rhs) const { + return grad_ == rhs.grad_ && hess_ == rhs.hess_; + } + XGBOOST_DEVICE explicit GradientPairInternal(int value) { *this = GradientPairInternal(static_cast(value), static_cast(value)); diff --git a/src/common/compressed_iterator.h b/src/common/compressed_iterator.h index df34ec15666d..a414ca3d32a2 100644 --- a/src/common/compressed_iterator.h +++ b/src/common/compressed_iterator.h @@ -207,7 +207,7 @@ class CompressedIterator { public: CompressedIterator() : buffer_(nullptr), symbol_bits_(0), offset_(0) {} - CompressedIterator(CompressedByteT *buffer, int num_symbols) + CompressedIterator(CompressedByteT *buffer, size_t num_symbols) : buffer_(buffer), offset_(0) { symbol_bits_ = detail::SymbolBits(num_symbols); } diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index 5b3844bd41f0..6835b8618423 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -110,6 +110,34 @@ EllpackPageImpl::EllpackPageImpl(DMatrix* dmat, const BatchParam& param) { monitor_.StopCuda("BinningCompression"); } +struct CopyPageFunction { + common::CompressedBufferWriter cbw; + common::CompressedByteT* dst_data_d; + common::CompressedIterator src_iterator_d; + size_t offset; + + CopyPageFunction(EllpackPageImpl* dst, EllpackPageImpl* src, size_t offset) + : cbw{dst->matrix.info.NumSymbols()}, + dst_data_d{dst->gidx_buffer.data()}, + src_iterator_d{src->gidx_buffer.data(), src->matrix.info.NumSymbols()}, + offset(offset) {} + + __device__ void operator()(size_t i) { + cbw.AtomicWriteSymbol(dst_data_d, src_iterator_d[i], i + offset); + } +}; + +size_t EllpackPageImpl::Copy(int device, EllpackPageImpl* page, size_t offset) { + monitor_.StartCuda("Copy"); + size_t num_elements = page->matrix.n_rows * page->matrix.info.row_stride; + CHECK_EQ(matrix.info.row_stride, page->matrix.info.row_stride); + CHECK_EQ(matrix.info.NumSymbols(), page->matrix.info.NumSymbols()); + CHECK_GE(matrix.n_rows * matrix.info.row_stride, offset + num_elements); + dh::LaunchN(device, num_elements, CopyPageFunction(this, page, offset)); + monitor_.StopCuda("Copy"); + return num_elements; +} + // Construct an EllpackInfo based on histogram cuts of features. EllpackInfo::EllpackInfo(int device, bool is_dense, diff --git a/src/data/ellpack_page.cuh b/src/data/ellpack_page.cuh index 989f61af10ce..f77974312c4e 100644 --- a/src/data/ellpack_page.cuh +++ b/src/data/ellpack_page.cuh @@ -220,6 +220,16 @@ class EllpackPageImpl { */ explicit EllpackPageImpl(DMatrix* dmat, const BatchParam& parm); + + /*! \brief Copy the elements of the given ELLPACK page into this page. + * + * @param device The GPU device to use. + * @param page The ELLPACK page to copy from. + * @param offset The number of elements to skip before copying. + * @returns The number of elements copied. + */ + size_t Copy(int device, EllpackPageImpl* page, size_t offset); + /*! * \brief Initialize the EllpackInfo contained in the EllpackMatrix. * diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index 31c41d18f801..0fa1b06c4189 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -93,9 +93,9 @@ struct is_non_zero : public thrust::unary_function { GradientBasedSample GradientBasedSampler::Sample(HostDeviceVector* gpair, DMatrix* dmat) { if (!is_sampling_) { - auto page = (*dmat->GetBatches(batch_param_).begin()).Impl(); + CollectPages(dmat); auto out_gpair = gpair->DeviceSpan(); - return {sample_rows_, page, out_gpair}; + return {sample_rows_, page_.get(), out_gpair}; } float sum_abs_gradient = thrust::transform_reduce( @@ -116,5 +116,19 @@ GradientBasedSample GradientBasedSampler::Sample(HostDeviceVector* return {sample_rows_, page, gpair_}; } +void GradientBasedSampler::CollectPages(DMatrix* dmat) { + if (page_collected_) { + return; + } + + size_t offset = 0; + for (auto& batch : dmat->GetBatches(batch_param_)) { + auto page = batch.Impl(); + size_t num_elements = page_->Copy(batch_param_.gpu_id, page, offset); + offset += num_elements; + } + page_collected_ = true; +} + }; // namespace tree }; // namespace xgboost diff --git a/src/tree/gpu_hist/gradient_based_sampler.cuh b/src/tree/gpu_hist/gradient_based_sampler.cuh index 53c0bb9c336e..c4f180bc5205 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cuh +++ b/src/tree/gpu_hist/gradient_based_sampler.cuh @@ -43,8 +43,12 @@ class GradientBasedSampler { /*! \brief Returns the max number of rows that can fit in available GPU memory. */ size_t MaxSampleRows(); + /*! \brief Sample from a DMatrix based on the given gradients. */ GradientBasedSample Sample(HostDeviceVector* gpair, DMatrix* dmat); + /*! \brief Collect all the rows from a DMatrix into a single ELLPACK page. */ + void CollectPages(DMatrix* dmat); + private: common::Monitor monitor_; dh::BulkAllocator ba_; @@ -54,6 +58,7 @@ class GradientBasedSampler { size_t sample_rows_; std::unique_ptr page_; common::Span gpair_; + bool page_collected_{false}; }; }; // namespace tree }; // namespace xgboost diff --git a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu index c07e6c42a97a..7b2c9291f764 100644 --- a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu +++ b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu @@ -1,4 +1,5 @@ #include +#include #include "../../../../src/data/ellpack_page.cuh" #include "../../../../src/tree/gpu_hist/gradient_based_sampler.cuh" @@ -7,7 +8,7 @@ namespace xgboost { namespace tree { -TEST(GradientBasedSampler, Sample) { +TEST(GradientBasedSampler, DISABLED_Sample) { constexpr size_t kRows = 2048; constexpr size_t kCols = 16; constexpr size_t kSampleRows = 512; @@ -25,11 +26,11 @@ TEST(GradientBasedSampler, Sample) { GradientBasedSampler sampler(param, page->matrix.info, kRows, kSampleRows); auto sample = sampler.Sample(&gpair, dmat.get()); - page = sample.page; + auto sampled_page = sample.page; auto sampled_gpair = sample.gpair; + EXPECT_EQ(sample.sample_rows, kSampleRows); + EXPECT_EQ(sampled_page->matrix.n_rows, kSampleRows); EXPECT_EQ(sampled_gpair.size(), kSampleRows); - EXPECT_EQ(page->matrix.n_rows, kSampleRows); - EXPECT_EQ(page->matrix.n_rows, sampled_gpair.size()); float sum_gradients = 0; for (auto gp : gpair.ConstHostVector()) { @@ -38,12 +39,61 @@ TEST(GradientBasedSampler, Sample) { float sum_sampled_gradients = 0; std::vector sampled_gpair_h(sampled_gpair.size()); - thrust::copy(sampled_gpair.begin(), sampled_gpair.end(), sampled_gpair_h.begin()); + dh::CopyDeviceSpanToVector(&sampled_gpair_h, sampled_gpair); for (auto gp : sampled_gpair_h) { sum_sampled_gradients += gp.GetGrad(); } EXPECT_FLOAT_EQ(sum_gradients, sum_sampled_gradients); } +TEST(GradientBasedSampler, FullSample) { + constexpr size_t kRows = 1024; + constexpr size_t kCols = 4; + constexpr size_t kSampleRows = 4096; + constexpr size_t kPageSize = 1024; + + // Create a DMatrix with multiple batches. + dmlc::TemporaryDirectory tmpdir; + std::unique_ptr + dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir)); + auto gpair = GenerateRandomGradients(kRows); + gpair.SetDevice(0); + + BatchParam param{0, 256, 0, kPageSize}; + auto page = (*dmat->GetBatches(param).begin()).Impl(); + + GradientBasedSampler sampler(param, page->matrix.info, kRows, kSampleRows); + auto sample = sampler.Sample(&gpair, dmat.get()); + auto sampled_page = sample.page; + auto sampled_gpair = sample.gpair; + EXPECT_EQ(sample.sample_rows, kRows); + EXPECT_EQ(sampled_gpair.size(), kRows); + EXPECT_EQ(sampled_page->matrix.n_rows, kRows); + + auto gpair_h = gpair.ConstHostVector(); + std::vector sampled_gpair_h(sampled_gpair.size()); + dh::CopyDeviceSpanToVector(&sampled_gpair_h, sampled_gpair); + EXPECT_EQ(gpair_h, sampled_gpair_h); + + std::vector buffer(sampled_page->gidx_buffer.size()); + dh::CopyDeviceSpanToVector(&buffer, sampled_page->gidx_buffer); + common::CompressedIterator + ci(buffer.data(), sampled_page->matrix.info.NumSymbols()); + + size_t offset = 0; + for (auto& batch : dmat->GetBatches(param)) { + auto page = batch.Impl(); + std::vector page_buffer(page->gidx_buffer.size()); + dh::CopyDeviceSpanToVector(&page_buffer, page->gidx_buffer); + common::CompressedIterator + page_ci(page_buffer.data(), page->matrix.info.NumSymbols()); + size_t num_elements = page->matrix.n_rows * page->matrix.info.row_stride; + for (size_t i = 0; i < num_elements; i++) { + EXPECT_EQ(ci[i + offset], page_ci[i]); + } + offset += num_elements; + } +} + }; // namespace tree }; // namespace xgboost From 2e5494c38bda408458563e356a549c4dc83c321b Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Wed, 4 Dec 2019 16:10:50 -0800 Subject: [PATCH 09/48] optimize finalize position --- src/tree/updater_gpu_hist.cu | 61 ++++++++++++++++++++---------------- 1 file changed, 34 insertions(+), 27 deletions(-) diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 6e3586ac55df..ee772044c46a 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -723,39 +723,46 @@ struct GPUHistMakerDevice { d_nodes.size() * sizeof(RegTree::Node), cudaMemcpyHostToDevice)); - if (row_partitioner->GetRows().size() != n_rows) { + if (row_partitioner->GetRows().size() != p_fmat->Info().num_row_) { row_partitioner.reset(); // Release the device memory first before reallocating row_partitioner.reset(new RowPartitioner(device_id, p_fmat->Info().num_row_)); } - for (auto& batch : p_fmat->GetBatches(batch_param)) { - page = batch.Impl(); - auto d_matrix = page->matrix; - row_partitioner->FinalisePosition( - [=] __device__(size_t row_id, int position) { - if (!d_matrix.IsInRange(row_id)) { - return RowPartitioner::kIgnoredTreePosition; - } - auto node = d_nodes[position]; - - while (!node.IsLeaf()) { - bst_float element = d_matrix.GetElement(row_id, node.SplitIndex()); - // Missing value - if (isnan(element)) { - position = node.DefaultChild(); - } else { - if (element <= node.SplitCond()) { - position = node.LeftChild(); - } else { - position = node.RightChild(); - } - } - node = d_nodes[position]; - } - return position; - }); + if (page->matrix.n_rows == p_fmat->Info().num_row_) { + FinalisePositionInPage(page, d_nodes); + } else { + for (auto& batch : p_fmat->GetBatches(batch_param)) { + FinalisePositionInPage(batch.Impl(), d_nodes); + } } } + void FinalisePositionInPage(EllpackPageImpl* page, const common::Span d_nodes) { + auto d_matrix = page->matrix; + row_partitioner->FinalisePosition( + [=] __device__(size_t row_id, int position) { + if (!d_matrix.IsInRange(row_id)) { + return RowPartitioner::kIgnoredTreePosition; + } + auto node = d_nodes[position]; + + while (!node.IsLeaf()) { + bst_float element = d_matrix.GetElement(row_id, node.SplitIndex()); + // Missing value + if (isnan(element)) { + position = node.DefaultChild(); + } else { + if (element <= node.SplitCond()) { + position = node.LeftChild(); + } else { + position = node.RightChild(); + } + } + node = d_nodes[position]; + } + return position; + }); + } + void UpdatePredictionCache(bst_float* out_preds_d) { dh::safe_cuda(cudaSetDevice(device_id)); if (!prediction_cache_initialised) { From 7fca606b7248d543b0bf3178bd28a9457dec2e93 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 5 Dec 2019 16:23:18 -0800 Subject: [PATCH 10/48] done with sampling --- src/data/ellpack_page.cu | 37 ++++++++ src/data/ellpack_page.cuh | 9 +- src/tree/gpu_hist/gradient_based_sampler.cu | 87 +++++++++++++++---- src/tree/gpu_hist/gradient_based_sampler.cuh | 4 +- .../gpu_hist/test_gradient_based_sampler.cu | 4 +- 5 files changed, 119 insertions(+), 22 deletions(-) diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index 6835b8618423..9fee95cf6009 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -138,6 +138,43 @@ size_t EllpackPageImpl::Copy(int device, EllpackPageImpl* page, size_t offset) { return num_elements; } +struct CompactPageFunction { + common::CompressedBufferWriter cbw; + common::CompressedByteT* dst_data_d; + common::CompressedIterator src_iterator_d; + common::Span row_indexes; + size_t base_rowid; + size_t row_stride; + + CompactPageFunction(EllpackPageImpl* dst, EllpackPageImpl* src, common::Span row_indexes) + : cbw{dst->matrix.info.NumSymbols()}, + dst_data_d{dst->gidx_buffer.data()}, + src_iterator_d{src->gidx_buffer.data(), src->matrix.info.NumSymbols()}, + row_indexes(row_indexes), + base_rowid{src->matrix.base_rowid}, + row_stride{src->matrix.info.row_stride} {} + + __device__ void operator()(size_t i) { + size_t row = base_rowid + i; + size_t row_index = row_indexes[row]; + if (row_index == SIZE_MAX) return; + size_t dst_offset = row_index * row_stride; + size_t src_offset = i * row_stride; + for (size_t j = 0; j < row_stride; j++) { + cbw.AtomicWriteSymbol(dst_data_d, src_iterator_d[src_offset], dst_offset + j); + } + } +}; + +void EllpackPageImpl::Compact(int device, EllpackPageImpl* page, common::Span row_indexes) { + monitor_.StartCuda("Compact"); + CHECK_EQ(matrix.info.row_stride, page->matrix.info.row_stride); + CHECK_EQ(matrix.info.NumSymbols(), page->matrix.info.NumSymbols()); + CHECK_LE(page->matrix.base_rowid + page->matrix.n_rows, row_indexes.size()); + dh::LaunchN(device, page->matrix.n_rows, CompactPageFunction(this, page, row_indexes)); + monitor_.StopCuda("Compact"); +} + // Construct an EllpackInfo based on histogram cuts of features. EllpackInfo::EllpackInfo(int device, bool is_dense, diff --git a/src/data/ellpack_page.cuh b/src/data/ellpack_page.cuh index f77974312c4e..2ff46f1a9380 100644 --- a/src/data/ellpack_page.cuh +++ b/src/data/ellpack_page.cuh @@ -220,7 +220,6 @@ class EllpackPageImpl { */ explicit EllpackPageImpl(DMatrix* dmat, const BatchParam& parm); - /*! \brief Copy the elements of the given ELLPACK page into this page. * * @param device The GPU device to use. @@ -230,6 +229,14 @@ class EllpackPageImpl { */ size_t Copy(int device, EllpackPageImpl* page, size_t offset); + /*! \brief Compact the given ELLPACK page into the current page. + * + * @param device The GPU device to use. + * @param page The ELLPACK page to compact from. + * @param row_indexes Row indexes for the compacted page. + */ + void Compact(int device, EllpackPageImpl* page, common::Span row_indexes); + /*! * \brief Initialize the EllpackInfo contained in the EllpackMatrix. * diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index 0fa1b06c4189..331d11e49e81 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -7,6 +7,7 @@ #include #include +#include #include "../../common/compressed_iterator.h" #include "../../common/random.h" @@ -28,22 +29,27 @@ GradientBasedSampler::GradientBasedSampler(BatchParam batch_param, if (sample_rows_ >= n_rows) { is_sampling_ = false; sample_rows_ = n_rows; + LOG(CONSOLE) << "Keeping " << sample_rows_ << " in GPU memory, not sampling"; } else { is_sampling_ = true; + LOG(CONSOLE) << "Sampling " << sample_rows_ << " rows"; } page_.reset(new EllpackPageImpl(batch_param.gpu_id, info, sample_rows_)); if (is_sampling_) { - ba_.Allocate(batch_param.gpu_id, &gpair_, sample_rows_); + gpair_.SetDevice(batch_param_.gpu_id); + gpair_.Resize(sample_rows_, GradientPair()); + sample_row_index_.SetDevice(batch_param_.gpu_id); + sample_row_index_.Resize(n_rows, 0); } } size_t GradientBasedSampler::MaxSampleRows() { size_t available_memory = dh::AvailableMemory(batch_param_.gpu_id); size_t usable_memory = available_memory * 0.95; - size_t gpair_bytes = sizeof(GradientPair); + size_t extra_bytes = sizeof(GradientPair) + sizeof(size_t); size_t max_rows = common::CompressedBufferWriter::CalculateMaxRows( - usable_memory, info_.NumSymbols(), info_.row_stride, gpair_bytes); + usable_memory, info_.NumSymbols(), info_.row_stride, extra_bytes); return max_rows; } @@ -54,17 +60,16 @@ struct abs_grad : public thrust::unary_function { } }; -/*! \brief A functor that samples and scales a gradient pair. +/*! \brief A functor that samples a gradient pair. * - * Sampling probability is proportional to the absolute value of the gradient. If selected, the - * gradient pair is re-scaled proportional to (1 / probability). + * Sampling probability is proportional to the absolute value of the gradient. */ -struct sample_and_scale : public thrust::binary_function { +struct sample_gradient : public thrust::binary_function { const size_t sample_rows; const float sum_abs_gradient; const uint32_t seed; - XGBOOST_DEVICE sample_and_scale(size_t _sample_rows, float _sum_abs_gradient, size_t _seed) + XGBOOST_DEVICE sample_gradient(size_t _sample_rows, float _sum_abs_gradient, size_t _seed) : sample_rows(_sample_rows), sum_abs_gradient(_sum_abs_gradient), seed(_seed) {} XGBOOST_DEVICE GradientPair operator()(const GradientPair& gpair, size_t i) { @@ -76,7 +81,7 @@ struct sample_and_scale : public thrust::binary_function { } }; +/*! \brief A functor that clears the row indexes with empty gradient. */ +struct clear_empty_rows : public thrust::binary_function { + const size_t max_rows; + + XGBOOST_DEVICE clear_empty_rows(size_t max_rows) : max_rows(max_rows) {} + + XGBOOST_DEVICE size_t operator()(const GradientPair& gpair, size_t row_index) const { + if ((gpair.GetGrad() != 0 || gpair.GetHess() != 0) && row_index < max_rows) { + return row_index; + } else { + return SIZE_MAX; + } + } +}; + +/*! \brief A functor that trims extra sampled rows. */ +struct trim_extra_rows : public thrust::binary_function { + XGBOOST_DEVICE GradientPair operator()(const GradientPair& gpair, size_t row_index) const { + if (row_index == SIZE_MAX) { + return GradientPair(); + } else { + return gpair; + } + } +}; + GradientBasedSample GradientBasedSampler::Sample(HostDeviceVector* gpair, DMatrix* dmat) { + // If there is enough space for all rows, just collect them in a single ELLPACK page and return. if (!is_sampling_) { CollectPages(dmat); auto out_gpair = gpair->DeviceSpan(); return {sample_rows_, page_.get(), out_gpair}; } + // Sum the absolute value of gradients as the denominator to normalize the probability. float sum_abs_gradient = thrust::transform_reduce( dh::tbegin(*gpair), dh::tend(*gpair), abs_grad(), 0.0f, thrust::plus()); + // Poisson sampling of the gradient pairs based on the absolute value of the gradient. thrust::transform(dh::tbegin(*gpair), dh::tend(*gpair), thrust::counting_iterator(0), dh::tbegin(*gpair), - sample_and_scale(sample_rows_, sum_abs_gradient, common::GlobalRandom()())); + sample_gradient(sample_rows_, sum_abs_gradient, common::GlobalRandom()())); + + // Map the original row index to the sample row index. + sample_row_index_.Fill(0); + thrust::transform(dh::tbegin(*gpair), dh::tend(*gpair), + dh::tbegin(sample_row_index_), + is_non_zero()); + thrust::exclusive_scan(dh::tbegin(sample_row_index_), dh::tend(sample_row_index_), + dh::tbegin(sample_row_index_)); + thrust::transform(dh::tbegin(*gpair), dh::tend(*gpair), + dh::tbegin(sample_row_index_), + dh::tbegin(sample_row_index_), + clear_empty_rows(sample_rows_)); - thrust::copy_if(thrust::device, - dh::tbegin(*gpair), - dh::tend(*gpair), - gpair_.begin(), - is_non_zero()); + // Zero out the gradient pairs if there are more rows than desired. + thrust::transform(dh::tbegin(*gpair), dh::tend(*gpair), + dh::tbegin(sample_row_index_), + dh::tbegin(*gpair), + trim_extra_rows()); + + // Compact the non-zero gradient pairs. + thrust::copy_if(dh::tbegin(*gpair), dh::tend(*gpair), dh::tbegin(gpair_), is_non_zero()); + + // Compact the ELLPACK pages into the single sample page. + for (auto& batch : dmat->GetBatches(batch_param_)) { + page_->Compact(batch_param_.gpu_id, batch.Impl(), sample_row_index_.DeviceSpan()); + } - auto page = (*dmat->GetBatches(batch_param_).begin()).Impl(); - return {sample_rows_, page, gpair_}; + return {sample_rows_, page_.get(), gpair_.DeviceSpan()}; } void GradientBasedSampler::CollectPages(DMatrix* dmat) { diff --git a/src/tree/gpu_hist/gradient_based_sampler.cuh b/src/tree/gpu_hist/gradient_based_sampler.cuh index c4f180bc5205..c9f4c04c310b 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cuh +++ b/src/tree/gpu_hist/gradient_based_sampler.cuh @@ -51,13 +51,13 @@ class GradientBasedSampler { private: common::Monitor monitor_; - dh::BulkAllocator ba_; BatchParam batch_param_; EllpackInfo info_; bool is_sampling_; size_t sample_rows_; std::unique_ptr page_; - common::Span gpair_; + HostDeviceVector gpair_; + HostDeviceVector sample_row_index_; bool page_collected_{false}; }; }; // namespace tree diff --git a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu index 7b2c9291f764..3c47485dfe27 100644 --- a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu +++ b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu @@ -8,10 +8,10 @@ namespace xgboost { namespace tree { -TEST(GradientBasedSampler, DISABLED_Sample) { +TEST(GradientBasedSampler, Sample) { constexpr size_t kRows = 2048; constexpr size_t kCols = 16; - constexpr size_t kSampleRows = 512; + constexpr size_t kSampleRows = 1024; constexpr size_t kPageSize = 1024; // Create a DMatrix with multiple batches. From f661df0d6cbcf94c72fd2e1d2f56dc46c3c18285 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 5 Dec 2019 16:50:38 -0800 Subject: [PATCH 11/48] add some docs --- src/data/ellpack_page.cu | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index 9fee95cf6009..de3ce492460c 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -64,7 +64,7 @@ __global__ void CompressBinEllpackKernel( wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + ifeature); } -// Construct an empty ELLPACK matrix. +// Construct an ELLPACK matrix with the given number of empty rows. EllpackPageImpl::EllpackPageImpl(int device, EllpackInfo info, size_t n_rows) { monitor_.Init("ellpack_page"); dh::safe_cuda(cudaSetDevice(device)); @@ -110,10 +110,12 @@ EllpackPageImpl::EllpackPageImpl(DMatrix* dmat, const BatchParam& param) { monitor_.StopCuda("BinningCompression"); } +// A functor that copies the data from one EllpackPage to another. struct CopyPageFunction { common::CompressedBufferWriter cbw; common::CompressedByteT* dst_data_d; common::CompressedIterator src_iterator_d; + // The number of elements to skip. size_t offset; CopyPageFunction(EllpackPageImpl* dst, EllpackPageImpl* src, size_t offset) @@ -127,6 +129,7 @@ struct CopyPageFunction { } }; +// Copy the data from the given EllpackPage to the current page. size_t EllpackPageImpl::Copy(int device, EllpackPageImpl* page, size_t offset) { monitor_.StartCuda("Copy"); size_t num_elements = page->matrix.n_rows * page->matrix.info.row_stride; @@ -138,6 +141,7 @@ size_t EllpackPageImpl::Copy(int device, EllpackPageImpl* page, size_t offset) { return num_elements; } +// A functor that compacts the rows from one EllpackPage into another. struct CompactPageFunction { common::CompressedBufferWriter cbw; common::CompressedByteT* dst_data_d; @@ -166,6 +170,7 @@ struct CompactPageFunction { } }; +// Compacts the data from the given EllpackPage into the current page. void EllpackPageImpl::Compact(int device, EllpackPageImpl* page, common::Span row_indexes) { monitor_.StartCuda("Compact"); CHECK_EQ(matrix.info.row_stride, page->matrix.info.row_stride); From 14361af50cfe7d0890a457c82e3e81752851a1c8 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 6 Dec 2019 11:12:10 -0800 Subject: [PATCH 12/48] formatting --- src/tree/gpu_hist/gradient_based_sampler.cu | 34 ++++++++++++--------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index 331d11e49e81..0ba2df779147 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -5,10 +5,11 @@ #include #include #include - #include #include +#include + #include "../../common/compressed_iterator.h" #include "../../common/random.h" #include "gradient_based_sampler.cuh" @@ -54,7 +55,7 @@ size_t GradientBasedSampler::MaxSampleRows() { } /*! \brief A functor that returns the absolute value of gradient from a gradient pair. */ -struct abs_grad : public thrust::unary_function { +struct AbsoluteGradient : public thrust::unary_function { XGBOOST_DEVICE float operator()(const GradientPair& gpair) const { return fabsf(gpair.GetGrad()); } @@ -64,12 +65,12 @@ struct abs_grad : public thrust::unary_function { * * Sampling probability is proportional to the absolute value of the gradient. */ -struct sample_gradient : public thrust::binary_function { +struct PoissonSampling : public thrust::binary_function { const size_t sample_rows; const float sum_abs_gradient; const uint32_t seed; - XGBOOST_DEVICE sample_gradient(size_t _sample_rows, float _sum_abs_gradient, size_t _seed) + XGBOOST_DEVICE PoissonSampling(size_t _sample_rows, float _sum_abs_gradient, size_t _seed) : sample_rows(_sample_rows), sum_abs_gradient(_sum_abs_gradient), seed(_seed) {} XGBOOST_DEVICE GradientPair operator()(const GradientPair& gpair, size_t i) { @@ -89,17 +90,17 @@ struct sample_gradient : public thrust::binary_function { +struct IsNonZero : public thrust::unary_function { XGBOOST_DEVICE bool operator()(const GradientPair& gpair) const { return gpair.GetGrad() != 0 || gpair.GetHess() != 0; } }; /*! \brief A functor that clears the row indexes with empty gradient. */ -struct clear_empty_rows : public thrust::binary_function { +struct ClearEmptyRows : public thrust::binary_function { const size_t max_rows; - XGBOOST_DEVICE clear_empty_rows(size_t max_rows) : max_rows(max_rows) {} + XGBOOST_DEVICE ClearEmptyRows(size_t max_rows) : max_rows(max_rows) {} XGBOOST_DEVICE size_t operator()(const GradientPair& gpair, size_t row_index) const { if ((gpair.GetGrad() != 0 || gpair.GetHess() != 0) && row_index < max_rows) { @@ -111,7 +112,7 @@ struct clear_empty_rows : public thrust::binary_function { +struct TrimExtraRows : public thrust::binary_function { XGBOOST_DEVICE GradientPair operator()(const GradientPair& gpair, size_t row_index) const { if (row_index == SIZE_MAX) { return GradientPair(); @@ -131,35 +132,38 @@ GradientBasedSample GradientBasedSampler::Sample(HostDeviceVector* } // Sum the absolute value of gradients as the denominator to normalize the probability. - float sum_abs_gradient = thrust::transform_reduce( - dh::tbegin(*gpair), dh::tend(*gpair), abs_grad(), 0.0f, thrust::plus()); + float sum_abs_gradient = thrust::transform_reduce(dh::tbegin(*gpair), + dh::tend(*gpair), + AbsoluteGradient(), + 0.0f, + thrust::plus()); // Poisson sampling of the gradient pairs based on the absolute value of the gradient. thrust::transform(dh::tbegin(*gpair), dh::tend(*gpair), thrust::counting_iterator(0), dh::tbegin(*gpair), - sample_gradient(sample_rows_, sum_abs_gradient, common::GlobalRandom()())); + PoissonSampling(sample_rows_, sum_abs_gradient, common::GlobalRandom()())); // Map the original row index to the sample row index. sample_row_index_.Fill(0); thrust::transform(dh::tbegin(*gpair), dh::tend(*gpair), dh::tbegin(sample_row_index_), - is_non_zero()); + IsNonZero()); thrust::exclusive_scan(dh::tbegin(sample_row_index_), dh::tend(sample_row_index_), dh::tbegin(sample_row_index_)); thrust::transform(dh::tbegin(*gpair), dh::tend(*gpair), dh::tbegin(sample_row_index_), dh::tbegin(sample_row_index_), - clear_empty_rows(sample_rows_)); + ClearEmptyRows(sample_rows_)); // Zero out the gradient pairs if there are more rows than desired. thrust::transform(dh::tbegin(*gpair), dh::tend(*gpair), dh::tbegin(sample_row_index_), dh::tbegin(*gpair), - trim_extra_rows()); + TrimExtraRows()); // Compact the non-zero gradient pairs. - thrust::copy_if(dh::tbegin(*gpair), dh::tend(*gpair), dh::tbegin(gpair_), is_non_zero()); + thrust::copy_if(dh::tbegin(*gpair), dh::tend(*gpair), dh::tbegin(gpair_), IsNonZero()); // Compact the ELLPACK pages into the single sample page. for (auto& batch : dmat->GetBatches(batch_param_)) { From 70276f116e06d71a2a59a365cc00251f3a934c80 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 6 Dec 2019 11:23:58 -0800 Subject: [PATCH 13/48] explicit constructor --- src/tree/gpu_hist/gradient_based_sampler.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index 0ba2df779147..24add7ba1f50 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -100,7 +100,7 @@ struct IsNonZero : public thrust::unary_function { struct ClearEmptyRows : public thrust::binary_function { const size_t max_rows; - XGBOOST_DEVICE ClearEmptyRows(size_t max_rows) : max_rows(max_rows) {} + XGBOOST_DEVICE explicit ClearEmptyRows(size_t max_rows) : max_rows(max_rows) {} XGBOOST_DEVICE size_t operator()(const GradientPair& gpair, size_t row_index) const { if ((gpair.GetGrad() != 0 || gpair.GetHess() != 0) && row_index < max_rows) { From d7770b4b7677e24d8295079e4daba27ab5204557 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 6 Dec 2019 12:16:09 -0800 Subject: [PATCH 14/48] no need for gmock --- tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu index 3c47485dfe27..09b91675d6d1 100644 --- a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu +++ b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu @@ -1,5 +1,4 @@ #include -#include #include "../../../../src/data/ellpack_page.cuh" #include "../../../../src/tree/gpu_hist/gradient_based_sampler.cuh" From 6a29c3895675dcd7fb86c42e4e8a9a6a7924bd45 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 6 Dec 2019 14:15:28 -0800 Subject: [PATCH 15/48] test ellpackpage copy and compact --- tests/cpp/data/test_ellpack_page.cu | 115 ++++++++++++++++++++++++++++ 1 file changed, 115 insertions(+) diff --git a/tests/cpp/data/test_ellpack_page.cu b/tests/cpp/data/test_ellpack_page.cu index 6dd97ef7ca05..8b1d8b6c76d7 100644 --- a/tests/cpp/data/test_ellpack_page.cu +++ b/tests/cpp/data/test_ellpack_page.cu @@ -81,4 +81,119 @@ TEST(EllpackPage, BuildGidxSparse) { } } +struct ReadRowFunction { + EllpackMatrix matrix; + int row; + bst_float* row_data_d; + ReadRowFunction(EllpackMatrix matrix, int row, bst_float* row_data_d) + : matrix(std::move(matrix)), row(row), row_data_d(row_data_d) {} + + __device__ void operator()(size_t col) { + auto value = matrix.GetElement(row, col); + if (isnan(value)) { + value = -1; + } + row_data_d[col] = value; + } +}; + +TEST(EllpackPage, Copy) { + constexpr size_t kRows = 1024; + constexpr size_t kCols = 16; + constexpr size_t kPageSize = 1024; + + // Create a DMatrix with multiple batches. + dmlc::TemporaryDirectory tmpdir; + std::unique_ptr + dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir)); + BatchParam param{0, 256, 0, kPageSize}; + auto page = (*dmat->GetBatches(param).begin()).Impl(); + + // Create an empty result page. + EllpackPageImpl result(0, page->matrix.info, kRows); + + // Copy batch pages into the result page. + size_t offset = 0; + for (auto& batch : dmat->GetBatches(param)) { + size_t num_elements = result.Copy(0, batch.Impl(), offset); + offset += num_elements; + } + + size_t current_row = 0; + thrust::device_vector row_d(kCols); + thrust::device_vector row_result_d(kCols); + std::vector row(kCols); + std::vector row_result(kCols); + for (auto& page : dmat->GetBatches(param)) { + auto impl = page.Impl(); + EXPECT_EQ(impl->matrix.base_rowid, current_row); + + for (size_t i = 0; i < impl->Size(); i++) { + dh::LaunchN(0, kCols, ReadRowFunction(impl->matrix, current_row, row_d.data().get())); + thrust::copy(row_d.begin(), row_d.end(), row.begin()); + + dh::LaunchN(0, kCols, ReadRowFunction(result.matrix, current_row, row_result_d.data().get())); + thrust::copy(row_result_d.begin(), row_result_d.end(), row_result.begin()); + + EXPECT_EQ(row, row_result); + current_row++; + } + } +} + +TEST(EllpackPage, Compact) { + constexpr size_t kRows = 16; + constexpr size_t kCols = 16; + constexpr size_t kPageSize = 16; + constexpr size_t kCompactedRows = 8; + + // Create a DMatrix with multiple batches. + dmlc::TemporaryDirectory tmpdir; + std::unique_ptr + dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir)); + BatchParam param{0, 256, 0, kPageSize}; + auto page = (*dmat->GetBatches(param).begin()).Impl(); + + // Create an empty result page. + EllpackPageImpl result(0, page->matrix.info, kCompactedRows); + + // Compact batch pages into the result page. + std::vector row_indexes_h { + SIZE_MAX, 0, 1, 2, SIZE_MAX, 3, SIZE_MAX, 4, 5, SIZE_MAX, 6, SIZE_MAX, 7, SIZE_MAX, SIZE_MAX, + SIZE_MAX}; + thrust::device_vector row_indexes_d = row_indexes_h; + common::Span row_indexes_span(row_indexes_d.data().get(), kRows); + for (auto& batch : dmat->GetBatches(param)) { + result.Compact(0, batch.Impl(), row_indexes_span); + } + + size_t current_row = 0; + thrust::device_vector row_d(kCols); + thrust::device_vector row_result_d(kCols); + std::vector row(kCols); + std::vector row_result(kCols); + for (auto& page : dmat->GetBatches(param)) { + auto impl = page.Impl(); + EXPECT_EQ(impl->matrix.base_rowid, current_row); + + for (size_t i = 0; i < impl->Size(); i++) { + size_t compacted_row = row_indexes_h[current_row]; + if (compacted_row == SIZE_MAX) { + current_row++; + continue; + } + + dh::LaunchN(0, kCols, ReadRowFunction(impl->matrix, current_row, row_d.data().get())); + thrust::copy(row_d.begin(), row_d.end(), row.begin()); + + dh::LaunchN(0, kCols, + ReadRowFunction(result.matrix, compacted_row, row_result_d.data().get())); + thrust::copy(row_result_d.begin(), row_result_d.end(), row_result.begin()); + + EXPECT_EQ(row, row_result); + current_row++; + } + } +} + } // namespace xgboost From 967ff16333e0fa9f9aa1e4ea17f5b26a8d423148 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 6 Dec 2019 15:56:04 -0800 Subject: [PATCH 16/48] use subsample to control gradient sampler --- src/common/device_helpers.cuh | 20 ++++++++++++ src/tree/gpu_hist/gradient_based_sampler.cu | 31 ++++++++++--------- src/tree/gpu_hist/gradient_based_sampler.cuh | 4 +-- src/tree/updater_gpu_hist.cu | 21 ++++++++----- .../gpu_hist/test_gradient_based_sampler.cu | 12 +++---- tests/cpp/tree/test_gpu_hist.cu | 9 ++++-- 6 files changed, 63 insertions(+), 34 deletions(-) diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 42f027d10e72..7b08770fdce2 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -1198,6 +1198,26 @@ thrust::device_ptr tcend(xgboost::HostDeviceVector const& vector) { return tcbegin(vector) + vector.Size(); } +template +thrust::device_ptr tbegin(xgboost::common::Span& span) { // NOLINT + return thrust::device_ptr(span.data()); +} + +template +thrust::device_ptr tend(xgboost::common::Span& span) { // // NOLINT + return tbegin(span) + span.size(); +} + +template +thrust::device_ptr tcbegin(xgboost::common::Span const& span) { + return thrust::device_ptr(span.data()); +} + +template +thrust::device_ptr tcend(xgboost::common::Span const& span) { + return tcbegin(span) + span.size(); +} + template class LauncherItr { public: diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index 24add7ba1f50..8f7b0e1f4509 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -20,12 +20,14 @@ namespace tree { GradientBasedSampler::GradientBasedSampler(BatchParam batch_param, EllpackInfo info, size_t n_rows, - size_t sample_rows) - : batch_param_(batch_param), info_(info), sample_rows_(sample_rows) { + float subsample) + : batch_param_(batch_param), info_(info) { monitor_.Init("gradient_based_sampler"); - if (sample_rows_ == 0) { + if (subsample == 0.0f || subsample == 1.0f) { sample_rows_ = MaxSampleRows(); + } else { + sample_rows_ = n_rows * subsample; } if (sample_rows_ >= n_rows) { is_sampling_ = false; @@ -122,48 +124,47 @@ struct TrimExtraRows : public thrust::binary_function* gpair, +GradientBasedSample GradientBasedSampler::Sample(common::Span gpair, DMatrix* dmat) { // If there is enough space for all rows, just collect them in a single ELLPACK page and return. if (!is_sampling_) { CollectPages(dmat); - auto out_gpair = gpair->DeviceSpan(); - return {sample_rows_, page_.get(), out_gpair}; + return {sample_rows_, page_.get(), gpair}; } // Sum the absolute value of gradients as the denominator to normalize the probability. - float sum_abs_gradient = thrust::transform_reduce(dh::tbegin(*gpair), - dh::tend(*gpair), + float sum_abs_gradient = thrust::transform_reduce(dh::tbegin(gpair), + dh::tend(gpair), AbsoluteGradient(), 0.0f, thrust::plus()); // Poisson sampling of the gradient pairs based on the absolute value of the gradient. - thrust::transform(dh::tbegin(*gpair), dh::tend(*gpair), + thrust::transform(dh::tbegin(gpair), dh::tend(gpair), thrust::counting_iterator(0), - dh::tbegin(*gpair), + dh::tbegin(gpair), PoissonSampling(sample_rows_, sum_abs_gradient, common::GlobalRandom()())); // Map the original row index to the sample row index. sample_row_index_.Fill(0); - thrust::transform(dh::tbegin(*gpair), dh::tend(*gpair), + thrust::transform(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(sample_row_index_), IsNonZero()); thrust::exclusive_scan(dh::tbegin(sample_row_index_), dh::tend(sample_row_index_), dh::tbegin(sample_row_index_)); - thrust::transform(dh::tbegin(*gpair), dh::tend(*gpair), + thrust::transform(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(sample_row_index_), dh::tbegin(sample_row_index_), ClearEmptyRows(sample_rows_)); // Zero out the gradient pairs if there are more rows than desired. - thrust::transform(dh::tbegin(*gpair), dh::tend(*gpair), + thrust::transform(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(sample_row_index_), - dh::tbegin(*gpair), + dh::tbegin(gpair), TrimExtraRows()); // Compact the non-zero gradient pairs. - thrust::copy_if(dh::tbegin(*gpair), dh::tend(*gpair), dh::tbegin(gpair_), IsNonZero()); + thrust::copy_if(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(gpair_), IsNonZero()); // Compact the ELLPACK pages into the single sample page. for (auto& batch : dmat->GetBatches(batch_param_)) { diff --git a/src/tree/gpu_hist/gradient_based_sampler.cuh b/src/tree/gpu_hist/gradient_based_sampler.cuh index c9f4c04c310b..9abf3c4151a1 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cuh +++ b/src/tree/gpu_hist/gradient_based_sampler.cuh @@ -38,13 +38,13 @@ class GradientBasedSampler { explicit GradientBasedSampler(BatchParam batch_param, EllpackInfo info, size_t n_rows, - size_t sample_rows = 0); + float subsample = 1.0f); /*! \brief Returns the max number of rows that can fit in available GPU memory. */ size_t MaxSampleRows(); /*! \brief Sample from a DMatrix based on the given gradients. */ - GradientBasedSample Sample(HostDeviceVector* gpair, DMatrix* dmat); + GradientBasedSample Sample(common::Span gpair, DMatrix* dmat); /*! \brief Collect all the rows from a DMatrix into a single ELLPACK page. */ void CollectPages(DMatrix* dmat); diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 7f04bd0426fc..98bf39f09ce5 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -453,6 +453,7 @@ struct GPUHistMakerDevice { /*! \brief Gradient pair for each row. */ common::Span gpair; + common::Span sampled_gpair; common::Span monotone_constraints; common::Span prediction_cache; @@ -501,7 +502,10 @@ struct GPUHistMakerDevice { batch_param(_batch_param), use_gradient_based_sampling(_page->matrix.n_rows != _n_rows) { if (use_gradient_based_sampling) { - sampler.reset(new GradientBasedSampler(batch_param, page->matrix.info, n_rows)); + sampler.reset(new GradientBasedSampler(batch_param, + page->matrix.info, + n_rows, + param.subsample)); } monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(device_id)); } @@ -549,16 +553,17 @@ struct GPUHistMakerDevice { std::fill(node_sum_gradients.begin(), node_sum_gradients.end(), GradientPair()); + dh::safe_cuda(cudaMemcpyAsync( + gpair.data(), dh_gpair->ConstDevicePointer(), + gpair.size() * sizeof(GradientPair), cudaMemcpyDeviceToDevice)); if (use_gradient_based_sampling) { - auto sample = sampler->Sample(dh_gpair, dmat); + auto sample = sampler->Sample(gpair, dmat); n_rows = sample.sample_rows; page = sample.page; - gpair = sample.gpair; + sampled_gpair = sample.gpair; } else { - dh::safe_cuda(cudaMemcpyAsync( - gpair.data(), dh_gpair->ConstDevicePointer(), - gpair.size() * sizeof(GradientPair), cudaMemcpyDeviceToDevice)); SubsampleGradientPair(device_id, gpair, param.subsample); + sampled_gpair = gpair; } row_partitioner.reset(); // Release the device memory first before reallocating @@ -652,7 +657,7 @@ struct GPUHistMakerDevice { hist.AllocateHistogram(nidx); auto d_node_hist = hist.GetNodeHistogram(nidx); auto d_ridx = row_partitioner->GetRows(nidx); - auto d_gpair = gpair.data(); + auto d_gpair = sampled_gpair.data(); auto n_elements = d_ridx.size() * page->matrix.info.row_stride; @@ -882,7 +887,7 @@ struct GPUHistMakerDevice { void InitRoot(RegTree* p_tree, dh::AllReducer* reducer, int64_t num_columns) { constexpr int kRootNIdx = 0; - dh::SumReduction(temp_memory, gpair, node_sum_gradients_d, gpair.size()); + dh::SumReduction(temp_memory, sampled_gpair, node_sum_gradients_d, sampled_gpair.size()); reducer->AllReduceSum( reinterpret_cast(node_sum_gradients_d.data()), reinterpret_cast(node_sum_gradients_d.data()), 2); diff --git a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu index 09b91675d6d1..3c5466a438fe 100644 --- a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu +++ b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu @@ -10,7 +10,8 @@ namespace tree { TEST(GradientBasedSampler, Sample) { constexpr size_t kRows = 2048; constexpr size_t kCols = 16; - constexpr size_t kSampleRows = 1024; + constexpr float kSubsample = 0.5; + constexpr size_t kSampleRows = kRows * kSubsample; constexpr size_t kPageSize = 1024; // Create a DMatrix with multiple batches. @@ -23,8 +24,8 @@ TEST(GradientBasedSampler, Sample) { BatchParam param{0, 256, 0, kPageSize}; auto page = (*dmat->GetBatches(param).begin()).Impl(); - GradientBasedSampler sampler(param, page->matrix.info, kRows, kSampleRows); - auto sample = sampler.Sample(&gpair, dmat.get()); + GradientBasedSampler sampler(param, page->matrix.info, kRows, kSubsample); + auto sample = sampler.Sample(gpair.DeviceSpan(), dmat.get()); auto sampled_page = sample.page; auto sampled_gpair = sample.gpair; EXPECT_EQ(sample.sample_rows, kSampleRows); @@ -48,7 +49,6 @@ TEST(GradientBasedSampler, Sample) { TEST(GradientBasedSampler, FullSample) { constexpr size_t kRows = 1024; constexpr size_t kCols = 4; - constexpr size_t kSampleRows = 4096; constexpr size_t kPageSize = 1024; // Create a DMatrix with multiple batches. @@ -61,8 +61,8 @@ TEST(GradientBasedSampler, FullSample) { BatchParam param{0, 256, 0, kPageSize}; auto page = (*dmat->GetBatches(param).begin()).Impl(); - GradientBasedSampler sampler(param, page->matrix.info, kRows, kSampleRows); - auto sample = sampler.Sample(&gpair, dmat.get()); + GradientBasedSampler sampler(param, page->matrix.info, kRows); + auto sample = sampler.Sample(gpair.DeviceSpan(), dmat.get()); auto sampled_page = sample.page; auto sampled_gpair = sample.gpair; EXPECT_EQ(sample.sample_rows, kRows); diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index ac9f8c1cdada..9a82174688bc 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -346,7 +346,8 @@ void UpdateTree(HostDeviceVector* gpair, DMatrix* dmat, size_t gpu_page_size, RegTree* tree, - HostDeviceVector* preds) { + HostDeviceVector* preds, + float subsample = 1.0f) { constexpr size_t kMaxBin = 2; if (gpu_page_size > 0) { @@ -367,7 +368,8 @@ void UpdateTree(HostDeviceVector* gpair, {"max_bin", std::to_string(kMaxBin)}, {"min_child_weight", "0.0"}, {"reg_alpha", "0"}, - {"reg_lambda", "0"} + {"reg_lambda", "0"}, + {"subsample", std::to_string(subsample)}, }; tree::GPUHistMakerSpecialised hist_maker; @@ -383,6 +385,7 @@ TEST(GpuHist, ExternalMemory) { constexpr size_t kRows = 6; constexpr size_t kCols = 2; constexpr size_t kPageSize = 1; + constexpr float kSubsample = 0.99; // Create an in-memory DMatrix. std::unique_ptr dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true)); @@ -402,7 +405,7 @@ TEST(GpuHist, ExternalMemory) { // Build another tree using multiple ELLPACK pages. RegTree tree_ext; HostDeviceVector preds_ext(kRows, 0.0, 0); - UpdateTree(&gpair, dmat_ext.get(), kPageSize, &tree_ext, &preds_ext); + UpdateTree(&gpair, dmat_ext.get(), kPageSize, &tree_ext, &preds_ext, kSubsample); // Make sure the predictions are the same. auto preds_h = preds.ConstHostVector(); From 73ba5af5c408ab7a8d012f945326879aaeae5668 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Tue, 10 Dec 2019 13:40:49 -0800 Subject: [PATCH 17/48] implement sequential poisson sampling --- src/tree/gpu_hist/gradient_based_sampler.cu | 175 +++++++++++++----- src/tree/gpu_hist/gradient_based_sampler.cuh | 26 ++- .../gpu_hist/test_gradient_based_sampler.cu | 81 ++++---- 3 files changed, 196 insertions(+), 86 deletions(-) diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index 8f7b0e1f4509..cc3f266a9fac 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -21,7 +21,7 @@ GradientBasedSampler::GradientBasedSampler(BatchParam batch_param, EllpackInfo info, size_t n_rows, float subsample) - : batch_param_(batch_param), info_(info) { + : batch_param_(batch_param), info_(info), sampling_method_(kDefaultSamplingMethod) { monitor_.Init("gradient_based_sampler"); if (subsample == 0.0f || subsample == 1.0f) { @@ -29,28 +29,32 @@ GradientBasedSampler::GradientBasedSampler(BatchParam batch_param, } else { sample_rows_ = n_rows * subsample; } + if (sample_rows_ >= n_rows) { - is_sampling_ = false; + sampling_method_ = kNoSampling; sample_rows_ = n_rows; LOG(CONSOLE) << "Keeping " << sample_rows_ << " in GPU memory, not sampling"; } else { - is_sampling_ = true; LOG(CONSOLE) << "Sampling " << sample_rows_ << " rows"; } page_.reset(new EllpackPageImpl(batch_param.gpu_id, info, sample_rows_)); - if (is_sampling_) { - gpair_.SetDevice(batch_param_.gpu_id); - gpair_.Resize(sample_rows_, GradientPair()); - sample_row_index_.SetDevice(batch_param_.gpu_id); - sample_row_index_.Resize(n_rows, 0); + if (sampling_method_ != kNoSampling) { + ba_.Allocate(batch_param_.gpu_id, + &gpair_, sample_rows_, + &row_weight_, n_rows, + &row_index_, n_rows, + &sample_row_index_, n_rows); + thrust::copy(thrust::counting_iterator(0), + thrust::counting_iterator(n_rows), + dh::tbegin(row_index_)); } } size_t GradientBasedSampler::MaxSampleRows() { size_t available_memory = dh::AvailableMemory(batch_param_.gpu_id); size_t usable_memory = available_memory * 0.95; - size_t extra_bytes = sizeof(GradientPair) + sizeof(size_t); + size_t extra_bytes = sizeof(GradientPair) + sizeof(float) + 2 * sizeof(size_t); size_t max_rows = common::CompressedBufferWriter::CalculateMaxRows( usable_memory, info_.NumSymbols(), info_.row_stride, extra_bytes); return max_rows; @@ -63,16 +67,66 @@ struct AbsoluteGradient : public thrust::unary_function { } }; +/*! \brief A functor that returns true if the gradient pair is non-zero. */ +struct IsNonZero : public thrust::unary_function { + XGBOOST_DEVICE bool operator()(const GradientPair& gpair) const { + return gpair.GetGrad() != 0 || gpair.GetHess() != 0; + } +}; + +GradientBasedSample GradientBasedSampler::Sample(common::Span gpair, + DMatrix* dmat, + SamplingMethod sampling_method) { + if (sampling_method_ != kNoSampling) { + sampling_method_ = sampling_method; + } + + switch (sampling_method_) { + case kNoSampling: + return NoSampling(gpair, dmat); + case kPoissonSampling: + return PoissonSampling(gpair, dmat); + case kSequentialPoissonSampling: + return SequentialPoissonSampling(gpair, dmat); + case kUniformSampling: + return UniformSampling(gpair, dmat); + default: + LOG(FATAL) << "unknown sampling method"; + return {sample_rows_, page_.get(), gpair}; + } +} + +GradientBasedSample GradientBasedSampler::NoSampling(common::Span gpair, + DMatrix* dmat) { + CollectPages(dmat); + return {sample_rows_, page_.get(), gpair}; +} + +void GradientBasedSampler::CollectPages(DMatrix* dmat) { + if (page_collected_) { + return; + } + + size_t offset = 0; + for (auto& batch : dmat->GetBatches(batch_param_)) { + auto page = batch.Impl(); + size_t num_elements = page_->Copy(batch_param_.gpu_id, page, offset); + offset += num_elements; + } + page_collected_ = true; +} + /*! \brief A functor that samples a gradient pair. * * Sampling probability is proportional to the absolute value of the gradient. */ -struct PoissonSampling : public thrust::binary_function { +struct PoissonSamplingFunction + : public thrust::binary_function { const size_t sample_rows; const float sum_abs_gradient; const uint32_t seed; - XGBOOST_DEVICE PoissonSampling(size_t _sample_rows, float _sum_abs_gradient, size_t _seed) + XGBOOST_DEVICE PoissonSamplingFunction(size_t _sample_rows, float _sum_abs_gradient, size_t _seed) : sample_rows(_sample_rows), sum_abs_gradient(_sum_abs_gradient), seed(_seed) {} XGBOOST_DEVICE GradientPair operator()(const GradientPair& gpair, size_t i) { @@ -91,13 +145,6 @@ struct PoissonSampling : public thrust::binary_function { - XGBOOST_DEVICE bool operator()(const GradientPair& gpair) const { - return gpair.GetGrad() != 0 || gpair.GetHess() != 0; - } -}; - /*! \brief A functor that clears the row indexes with empty gradient. */ struct ClearEmptyRows : public thrust::binary_function { const size_t max_rows; @@ -124,29 +171,23 @@ struct TrimExtraRows : public thrust::binary_function gpair, - DMatrix* dmat) { - // If there is enough space for all rows, just collect them in a single ELLPACK page and return. - if (!is_sampling_) { - CollectPages(dmat); - return {sample_rows_, page_.get(), gpair}; - } - +GradientBasedSample GradientBasedSampler::PoissonSampling(common::Span gpair, + DMatrix* dmat) { // Sum the absolute value of gradients as the denominator to normalize the probability. - float sum_abs_gradient = thrust::transform_reduce(dh::tbegin(gpair), - dh::tend(gpair), + float sum_abs_gradient = thrust::transform_reduce(dh::tbegin(gpair), dh::tend(gpair), AbsoluteGradient(), - 0.0f, - thrust::plus()); + 0.0f, thrust::plus()); // Poisson sampling of the gradient pairs based on the absolute value of the gradient. thrust::transform(dh::tbegin(gpair), dh::tend(gpair), thrust::counting_iterator(0), dh::tbegin(gpair), - PoissonSampling(sample_rows_, sum_abs_gradient, common::GlobalRandom()())); + PoissonSamplingFunction(sample_rows_, + sum_abs_gradient, + common::GlobalRandom()())); // Map the original row index to the sample row index. - sample_row_index_.Fill(0); + thrust::fill(dh::tbegin(sample_row_index_), dh::tend(sample_row_index_), 0); thrust::transform(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(sample_row_index_), IsNonZero()); @@ -168,25 +209,75 @@ GradientBasedSample GradientBasedSampler::Sample(common::Span gpai // Compact the ELLPACK pages into the single sample page. for (auto& batch : dmat->GetBatches(batch_param_)) { - page_->Compact(batch_param_.gpu_id, batch.Impl(), sample_row_index_.DeviceSpan()); + page_->Compact(batch_param_.gpu_id, batch.Impl(), sample_row_index_); } - return {sample_rows_, page_.get(), gpair_.DeviceSpan()}; + return {sample_rows_, page_.get(), gpair_}; } -void GradientBasedSampler::CollectPages(DMatrix* dmat) { - if (page_collected_) { - return; +/*! \brief A functor that samples gradient pairs using sequential Poisson sampling. + * + * Sampling probability is proportional to the absolute value of the gradient. + */ +struct SequentialPoissonSamplingFunction + : public thrust::binary_function { + const uint32_t seed; + + XGBOOST_DEVICE explicit SequentialPoissonSamplingFunction(size_t _seed) : seed(_seed) {} + + XGBOOST_DEVICE float operator()(const GradientPair& gpair, size_t i) { + if (gpair.GetGrad() == 0) { + return FLT_MAX; + } + thrust::default_random_engine rng(seed); + thrust::uniform_real_distribution dist; + rng.discard(i); + return dist(rng) / fabsf(gpair.GetGrad()); } +}; - size_t offset = 0; +GradientBasedSample GradientBasedSampler::SequentialPoissonSampling( + common::Span gpair, DMatrix* dmat) { + // Transform the gradient to weight = random(0, 1) / abs(grad). + thrust::transform(dh::tbegin(gpair), dh::tend(gpair), + thrust::counting_iterator(0), + dh::tbegin(row_weight_), + SequentialPoissonSamplingFunction(common::GlobalRandom()())); + + // Sort the gradient pairs and row indexes by weight. + thrust::sort_by_key(dh::tbegin(row_weight_), dh::tend(row_weight_), + thrust::make_zip_iterator(thrust::make_tuple(dh::tbegin(gpair), + dh::tbegin(row_index_)))); + + // Clear the gradient pairs not in the sample. + thrust::fill(dh::tbegin(gpair) + sample_rows_, dh::tend(gpair), GradientPair()); + + // Index the sample rows. + thrust::copy(thrust::counting_iterator(0), + thrust::counting_iterator(sample_rows_), + dh::tbegin(sample_row_index_)); + thrust::fill(dh::tbegin(sample_row_index_) + sample_rows_, dh::tend(sample_row_index_), SIZE_MAX); + + // Sort the gradient pairs and sample row indexed by the original row index. + thrust::sort_by_key(dh::tbegin(row_index_), dh::tend(row_index_), + thrust::make_zip_iterator(thrust::make_tuple(dh::tbegin(gpair), + dh::tbegin(sample_row_index_)))); + + // Compact the non-zero gradient pairs. + thrust::copy_if(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(gpair_), IsNonZero()); + + // Compact the ELLPACK pages into the single sample page. for (auto& batch : dmat->GetBatches(batch_param_)) { - auto page = batch.Impl(); - size_t num_elements = page_->Copy(batch_param_.gpu_id, page, offset); - offset += num_elements; + page_->Compact(batch_param_.gpu_id, batch.Impl(), sample_row_index_); } - page_collected_ = true; + + return {sample_rows_, page_.get(), gpair_}; } +GradientBasedSample GradientBasedSampler::UniformSampling(common::Span gpair, + DMatrix* dmat) { + // TODO(rongou): implement this. + return {sample_rows_, page_.get(), gpair_}; +} }; // namespace tree }; // namespace xgboost diff --git a/src/tree/gpu_hist/gradient_based_sampler.cuh b/src/tree/gpu_hist/gradient_based_sampler.cuh index 9abf3c4151a1..db4d538c9971 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cuh +++ b/src/tree/gpu_hist/gradient_based_sampler.cuh @@ -35,6 +35,13 @@ struct GradientBasedSample { */ class GradientBasedSampler { public: + enum SamplingMethod { + kNoSampling, + kPoissonSampling, + kSequentialPoissonSampling, + kUniformSampling, + }; + explicit GradientBasedSampler(BatchParam batch_param, EllpackInfo info, size_t n_rows, @@ -44,20 +51,31 @@ class GradientBasedSampler { size_t MaxSampleRows(); /*! \brief Sample from a DMatrix based on the given gradients. */ - GradientBasedSample Sample(common::Span gpair, DMatrix* dmat); + GradientBasedSample Sample(common::Span gpair, DMatrix* dmat, + SamplingMethod sampling_method = kDefaultSamplingMethod); /*! \brief Collect all the rows from a DMatrix into a single ELLPACK page. */ void CollectPages(DMatrix* dmat); private: + static const SamplingMethod kDefaultSamplingMethod = kSequentialPoissonSampling; + + GradientBasedSample NoSampling(common::Span gpair, DMatrix* dmat); + GradientBasedSample PoissonSampling(common::Span gpair, DMatrix* dmat); + GradientBasedSample SequentialPoissonSampling(common::Span gpair, DMatrix* dmat); + GradientBasedSample UniformSampling(common::Span gpair, DMatrix* dmat); + common::Monitor monitor_; + dh::BulkAllocator ba_; BatchParam batch_param_; EllpackInfo info_; - bool is_sampling_; + SamplingMethod sampling_method_; size_t sample_rows_; std::unique_ptr page_; - HostDeviceVector gpair_; - HostDeviceVector sample_row_index_; + common::Span gpair_; + common::Span row_weight_; + common::Span row_index_; + common::Span sample_row_index_; bool page_collected_{false}; }; }; // namespace tree diff --git a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu index 3c5466a438fe..cee59672c433 100644 --- a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu +++ b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu @@ -7,46 +7,7 @@ namespace xgboost { namespace tree { -TEST(GradientBasedSampler, Sample) { - constexpr size_t kRows = 2048; - constexpr size_t kCols = 16; - constexpr float kSubsample = 0.5; - constexpr size_t kSampleRows = kRows * kSubsample; - constexpr size_t kPageSize = 1024; - - // Create a DMatrix with multiple batches. - dmlc::TemporaryDirectory tmpdir; - std::unique_ptr - dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir)); - auto gpair = GenerateRandomGradients(kRows); - gpair.SetDevice(0); - - BatchParam param{0, 256, 0, kPageSize}; - auto page = (*dmat->GetBatches(param).begin()).Impl(); - - GradientBasedSampler sampler(param, page->matrix.info, kRows, kSubsample); - auto sample = sampler.Sample(gpair.DeviceSpan(), dmat.get()); - auto sampled_page = sample.page; - auto sampled_gpair = sample.gpair; - EXPECT_EQ(sample.sample_rows, kSampleRows); - EXPECT_EQ(sampled_page->matrix.n_rows, kSampleRows); - EXPECT_EQ(sampled_gpair.size(), kSampleRows); - - float sum_gradients = 0; - for (auto gp : gpair.ConstHostVector()) { - sum_gradients += gp.GetGrad(); - } - - float sum_sampled_gradients = 0; - std::vector sampled_gpair_h(sampled_gpair.size()); - dh::CopyDeviceSpanToVector(&sampled_gpair_h, sampled_gpair); - for (auto gp : sampled_gpair_h) { - sum_sampled_gradients += gp.GetGrad(); - } - EXPECT_FLOAT_EQ(sum_gradients, sum_sampled_gradients); -} - -TEST(GradientBasedSampler, FullSample) { +TEST(GradientBasedSampler, NoSampling) { constexpr size_t kRows = 1024; constexpr size_t kCols = 4; constexpr size_t kPageSize = 1024; @@ -94,5 +55,45 @@ TEST(GradientBasedSampler, FullSample) { } } +TEST(GradientBasedSampler, PoissonSampling) { + constexpr size_t kRows = 2048; + constexpr size_t kCols = 16; + constexpr float kSubsample = 0.5; + constexpr size_t kSampleRows = kRows * kSubsample; + constexpr size_t kPageSize = 1024; + + // Create a DMatrix with multiple batches. + dmlc::TemporaryDirectory tmpdir; + std::unique_ptr + dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir)); + auto gpair = GenerateRandomGradients(kRows); + gpair.SetDevice(0); + + BatchParam param{0, 256, 0, kPageSize}; + auto page = (*dmat->GetBatches(param).begin()).Impl(); + + GradientBasedSampler sampler(param, page->matrix.info, kRows, kSubsample); + auto sample = sampler.Sample(gpair.DeviceSpan(), dmat.get(), + GradientBasedSampler::kPoissonSampling); + auto sampled_page = sample.page; + auto sampled_gpair = sample.gpair; + EXPECT_EQ(sample.sample_rows, kSampleRows); + EXPECT_EQ(sampled_page->matrix.n_rows, kSampleRows); + EXPECT_EQ(sampled_gpair.size(), kSampleRows); + + float sum_gradients = 0; + for (auto gp : gpair.ConstHostVector()) { + sum_gradients += gp.GetGrad(); + } + + float sum_sampled_gradients = 0; + std::vector sampled_gpair_h(sampled_gpair.size()); + dh::CopyDeviceSpanToVector(&sampled_gpair_h, sampled_gpair); + for (auto gp : sampled_gpair_h) { + sum_sampled_gradients += gp.GetGrad(); + } + EXPECT_FLOAT_EQ(sum_gradients, sum_sampled_gradients); +} + }; // namespace tree }; // namespace xgboost From d2f2f698ddefe0a722a78361f7eab069312625a5 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Tue, 10 Dec 2019 16:08:29 -0800 Subject: [PATCH 18/48] fix compact bug --- src/data/ellpack_page.cu | 2 +- src/tree/gpu_hist/gradient_based_sampler.cu | 16 +++++++++++----- tests/cpp/data/test_ellpack_page.cu | 4 ++-- tests/cpp/tree/test_gpu_hist.cu | 2 +- 4 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index de3ce492460c..8c8404c1f604 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -165,7 +165,7 @@ struct CompactPageFunction { size_t dst_offset = row_index * row_stride; size_t src_offset = i * row_stride; for (size_t j = 0; j < row_stride; j++) { - cbw.AtomicWriteSymbol(dst_data_d, src_iterator_d[src_offset], dst_offset + j); + cbw.AtomicWriteSymbol(dst_data_d, src_iterator_d[src_offset + j], dst_offset + j); } } }; diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index cc3f266a9fac..f2842ae57125 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -252,11 +252,9 @@ GradientBasedSample GradientBasedSampler::SequentialPoissonSampling( // Clear the gradient pairs not in the sample. thrust::fill(dh::tbegin(gpair) + sample_rows_, dh::tend(gpair), GradientPair()); - // Index the sample rows. - thrust::copy(thrust::counting_iterator(0), - thrust::counting_iterator(sample_rows_), - dh::tbegin(sample_row_index_)); - thrust::fill(dh::tbegin(sample_row_index_) + sample_rows_, dh::tend(sample_row_index_), SIZE_MAX); + // Mask the sample rows. + thrust::fill(dh::tbegin(sample_row_index_), dh::tbegin(sample_row_index_) + sample_rows_, 1); + thrust::fill(dh::tbegin(sample_row_index_) + sample_rows_, dh::tend(sample_row_index_), 0); // Sort the gradient pairs and sample row indexed by the original row index. thrust::sort_by_key(dh::tbegin(row_index_), dh::tend(row_index_), @@ -266,6 +264,14 @@ GradientBasedSample GradientBasedSampler::SequentialPoissonSampling( // Compact the non-zero gradient pairs. thrust::copy_if(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(gpair_), IsNonZero()); + // Index the sample rows. + thrust::exclusive_scan(dh::tbegin(sample_row_index_), dh::tend(sample_row_index_), + dh::tbegin(sample_row_index_)); + thrust::transform(dh::tbegin(gpair), dh::tend(gpair), + dh::tbegin(sample_row_index_), + dh::tbegin(sample_row_index_), + ClearEmptyRows(sample_rows_)); + // Compact the ELLPACK pages into the single sample page. for (auto& batch : dmat->GetBatches(batch_param_)) { page_->Compact(batch_param_.gpu_id, batch.Impl(), sample_row_index_); diff --git a/tests/cpp/data/test_ellpack_page.cu b/tests/cpp/data/test_ellpack_page.cu index 8b1d8b6c76d7..6479e9feea93 100644 --- a/tests/cpp/data/test_ellpack_page.cu +++ b/tests/cpp/data/test_ellpack_page.cu @@ -143,8 +143,8 @@ TEST(EllpackPage, Copy) { TEST(EllpackPage, Compact) { constexpr size_t kRows = 16; - constexpr size_t kCols = 16; - constexpr size_t kPageSize = 16; + constexpr size_t kCols = 2; + constexpr size_t kPageSize = 1; constexpr size_t kCompactedRows = 8; // Create a DMatrix with multiple batches. diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 9a82174688bc..7cfa12d1f431 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -411,7 +411,7 @@ TEST(GpuHist, ExternalMemory) { auto preds_h = preds.ConstHostVector(); auto preds_ext_h = preds_ext.ConstHostVector(); for (int i = 0; i < kRows; i++) { - ASSERT_FLOAT_EQ(preds_h[i], preds_ext_h[i]); + EXPECT_NEAR(preds_h[i], preds_ext_h[i], 2e-6); } } From 827c9880d43987ec7049a87c06ad9f4ffc2e2617 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Tue, 10 Dec 2019 22:52:33 -0800 Subject: [PATCH 19/48] fix cpp test --- src/tree/gpu_hist/gradient_based_sampler.cu | 2 +- src/tree/updater_gpu_hist.cu | 15 +++++---------- tests/cpp/tree/test_gpu_hist.cu | 9 +++++---- 3 files changed, 11 insertions(+), 15 deletions(-) diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index f2842ae57125..4e0d022c1d1e 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -54,7 +54,7 @@ GradientBasedSampler::GradientBasedSampler(BatchParam batch_param, size_t GradientBasedSampler::MaxSampleRows() { size_t available_memory = dh::AvailableMemory(batch_param_.gpu_id); size_t usable_memory = available_memory * 0.95; - size_t extra_bytes = sizeof(GradientPair) + sizeof(float) + 2 * sizeof(size_t); + size_t extra_bytes = sizeof(float) + 2 * sizeof(size_t); size_t max_rows = common::CompressedBufferWriter::CalculateMaxRows( usable_memory, info_.NumSymbols(), info_.row_stride, extra_bytes); return max_rows; diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 98bf39f09ce5..fc1f829fec2e 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -453,7 +453,6 @@ struct GPUHistMakerDevice { /*! \brief Gradient pair for each row. */ common::Span gpair; - common::Span sampled_gpair; common::Span monotone_constraints; common::Span prediction_cache; @@ -553,17 +552,14 @@ struct GPUHistMakerDevice { std::fill(node_sum_gradients.begin(), node_sum_gradients.end(), GradientPair()); - dh::safe_cuda(cudaMemcpyAsync( - gpair.data(), dh_gpair->ConstDevicePointer(), - gpair.size() * sizeof(GradientPair), cudaMemcpyDeviceToDevice)); if (use_gradient_based_sampling) { - auto sample = sampler->Sample(gpair, dmat); + auto sample = sampler->Sample(dh_gpair->DeviceSpan(), dmat); n_rows = sample.sample_rows; page = sample.page; - sampled_gpair = sample.gpair; + gpair = sample.gpair; } else { + gpair = dh_gpair->DeviceSpan(); SubsampleGradientPair(device_id, gpair, param.subsample); - sampled_gpair = gpair; } row_partitioner.reset(); // Release the device memory first before reallocating @@ -657,7 +653,7 @@ struct GPUHistMakerDevice { hist.AllocateHistogram(nidx); auto d_node_hist = hist.GetNodeHistogram(nidx); auto d_ridx = row_partitioner->GetRows(nidx); - auto d_gpair = sampled_gpair.data(); + auto d_gpair = gpair.data(); auto n_elements = d_ridx.size() * page->matrix.info.row_stride; @@ -887,7 +883,7 @@ struct GPUHistMakerDevice { void InitRoot(RegTree* p_tree, dh::AllReducer* reducer, int64_t num_columns) { constexpr int kRootNIdx = 0; - dh::SumReduction(temp_memory, sampled_gpair, node_sum_gradients_d, sampled_gpair.size()); + dh::SumReduction(temp_memory, gpair, node_sum_gradients_d, gpair.size()); reducer->AllReduceSum( reinterpret_cast(node_sum_gradients_d.data()), reinterpret_cast(node_sum_gradients_d.data()), 2); @@ -982,7 +978,6 @@ inline void GPUHistMakerDevice::InitHistogram() { param.max_leaves > 0 ? param.max_leaves * 2 : MaxNodesDepth(param.max_depth); ba.Allocate(device_id, - &gpair, n_rows, &prediction_cache, n_rows, &node_sum_gradients_d, max_nodes, &monotone_constraints, param.monotone_constraints.size()); diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 7cfa12d1f431..e5761cf41fb6 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -89,12 +89,13 @@ void TestBuildHist(bool use_shared_memory_histograms) { xgboost::SimpleLCG gen; xgboost::SimpleRealUniformDistribution dist(0.0f, 1.0f); - std::vector h_gpair(kNRows); - for (auto &gpair : h_gpair) { + HostDeviceVector gpair(kNRows); + for (auto &gp : gpair.HostVector()) { bst_float grad = dist(&gen); bst_float hess = dist(&gen); - gpair = GradientPair(grad, hess); + gp = GradientPair(grad, hess); } + gpair.SetDevice(0); thrust::host_vector h_gidx_buffer (page->gidx_buffer.size()); @@ -105,7 +106,7 @@ void TestBuildHist(bool use_shared_memory_histograms) { maker.row_partitioner.reset(new RowPartitioner(0, kNRows)); maker.hist.AllocateHistogram(0); - dh::CopyVectorToDeviceSpan(maker.gpair, h_gpair); + maker.gpair = gpair.DeviceSpan(); maker.use_shared_memory_histograms = use_shared_memory_histograms; maker.BuildHist(0); From 857c9c73778d5058126d6b9f75209fcf0a0a26ab Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Wed, 11 Dec 2019 16:01:45 -0800 Subject: [PATCH 20/48] finally working --- src/data/ellpack_page.cu | 4 +--- src/tree/gpu_hist/gradient_based_sampler.cu | 3 ++- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index 8c8404c1f604..ab3380fc06f2 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -214,9 +214,7 @@ void EllpackPageImpl::InitCompressedData(int device, size_t num_rows) { matrix.info.row_stride * num_rows, num_symbols); ba_.Allocate(device, &gidx_buffer, compressed_size_bytes); - thrust::fill( - thrust::device_pointer_cast(gidx_buffer.data()), - thrust::device_pointer_cast(gidx_buffer.data() + gidx_buffer.size()), 0); + thrust::fill(dh::tbegin(gidx_buffer), dh::tend(gidx_buffer), 0); matrix.gidx_iter = common::CompressedIterator(gidx_buffer.data(), num_symbols); } diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index 4e0d022c1d1e..cf3b0ff0f47b 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -33,7 +33,7 @@ GradientBasedSampler::GradientBasedSampler(BatchParam batch_param, if (sample_rows_ >= n_rows) { sampling_method_ = kNoSampling; sample_rows_ = n_rows; - LOG(CONSOLE) << "Keeping " << sample_rows_ << " in GPU memory, not sampling"; + LOG(CONSOLE) << "Keeping " << sample_rows_ << " rows in GPU memory, not sampling"; } else { LOG(CONSOLE) << "Sampling " << sample_rows_ << " rows"; } @@ -273,6 +273,7 @@ GradientBasedSample GradientBasedSampler::SequentialPoissonSampling( ClearEmptyRows(sample_rows_)); // Compact the ELLPACK pages into the single sample page. + thrust::fill(dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0); for (auto& batch : dmat->GetBatches(batch_param_)) { page_->Compact(batch_param_.gpu_id, batch.Impl(), sample_row_index_); } From 7de162013fbd79fce0af0eed1c76a4d2c54f5dc2 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 13 Dec 2019 12:06:42 -0800 Subject: [PATCH 21/48] add uniform sampling --- src/tree/gpu_hist/gradient_based_sampler.cu | 181 +++++------------- src/tree/gpu_hist/gradient_based_sampler.cuh | 31 +-- .../gpu_hist/test_gradient_based_sampler.cu | 8 +- 3 files changed, 71 insertions(+), 149 deletions(-) diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index cf3b0ff0f47b..32910dfc0092 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -20,8 +20,9 @@ namespace tree { GradientBasedSampler::GradientBasedSampler(BatchParam batch_param, EllpackInfo info, size_t n_rows, - float subsample) - : batch_param_(batch_param), info_(info), sampling_method_(kDefaultSamplingMethod) { + float subsample, + SamplingMethod sampling_method) + : batch_param_(batch_param), info_(info), sampling_method_(sampling_method) { monitor_.Init("gradient_based_sampler"); if (subsample == 0.0f || subsample == 1.0f) { @@ -60,32 +61,11 @@ size_t GradientBasedSampler::MaxSampleRows() { return max_rows; } -/*! \brief A functor that returns the absolute value of gradient from a gradient pair. */ -struct AbsoluteGradient : public thrust::unary_function { - XGBOOST_DEVICE float operator()(const GradientPair& gpair) const { - return fabsf(gpair.GetGrad()); - } -}; - -/*! \brief A functor that returns true if the gradient pair is non-zero. */ -struct IsNonZero : public thrust::unary_function { - XGBOOST_DEVICE bool operator()(const GradientPair& gpair) const { - return gpair.GetGrad() != 0 || gpair.GetHess() != 0; - } -}; - GradientBasedSample GradientBasedSampler::Sample(common::Span gpair, - DMatrix* dmat, - SamplingMethod sampling_method) { - if (sampling_method_ != kNoSampling) { - sampling_method_ = sampling_method; - } - + DMatrix* dmat) { switch (sampling_method_) { case kNoSampling: return NoSampling(gpair, dmat); - case kPoissonSampling: - return PoissonSampling(gpair, dmat); case kSequentialPoissonSampling: return SequentialPoissonSampling(gpair, dmat); case kUniformSampling: @@ -96,12 +76,6 @@ GradientBasedSample GradientBasedSampler::Sample(common::Span gpai } } -GradientBasedSample GradientBasedSampler::NoSampling(common::Span gpair, - DMatrix* dmat) { - CollectPages(dmat); - return {sample_rows_, page_.get(), gpair}; -} - void GradientBasedSampler::CollectPages(DMatrix* dmat) { if (page_collected_) { return; @@ -116,32 +90,33 @@ void GradientBasedSampler::CollectPages(DMatrix* dmat) { page_collected_ = true; } -/*! \brief A functor that samples a gradient pair. - * - * Sampling probability is proportional to the absolute value of the gradient. - */ -struct PoissonSamplingFunction - : public thrust::binary_function { - const size_t sample_rows; - const float sum_abs_gradient; +GradientBasedSample GradientBasedSampler::NoSampling(common::Span gpair, + DMatrix* dmat) { + CollectPages(dmat); + return {sample_rows_, page_.get(), gpair}; +} + +/*! \brief A functor that calculate the weight of each row as random(0, 1) / abs(grad). */ +struct CalculateWeight : public thrust::binary_function { const uint32_t seed; - XGBOOST_DEVICE PoissonSamplingFunction(size_t _sample_rows, float _sum_abs_gradient, size_t _seed) - : sample_rows(_sample_rows), sum_abs_gradient(_sum_abs_gradient), seed(_seed) {} + XGBOOST_DEVICE explicit CalculateWeight(size_t _seed) : seed(_seed) {} - XGBOOST_DEVICE GradientPair operator()(const GradientPair& gpair, size_t i) { + XGBOOST_DEVICE float operator()(const GradientPair& gpair, size_t i) { + if (gpair.GetGrad() == 0) { + return FLT_MAX; + } thrust::default_random_engine rng(seed); thrust::uniform_real_distribution dist; rng.discard(i); - float p = sample_rows * fabsf(gpair.GetGrad()) / sum_abs_gradient; - if (p > 1.0f) { - p = 1.0f; - } - if (dist(rng) <= p) { - return gpair; - } else { - return GradientPair(); - } + return dist(rng) / fabsf(gpair.GetGrad()); + } +}; + +/*! \brief A functor that returns true if the gradient pair is non-zero. */ +struct IsNonZero : public thrust::unary_function { + XGBOOST_DEVICE bool operator()(const GradientPair& gpair) const { + return gpair.GetGrad() != 0 || gpair.GetHess() != 0; } }; @@ -160,90 +135,18 @@ struct ClearEmptyRows : public thrust::binary_function { - XGBOOST_DEVICE GradientPair operator()(const GradientPair& gpair, size_t row_index) const { - if (row_index == SIZE_MAX) { - return GradientPair(); - } else { - return gpair; - } - } -}; - -GradientBasedSample GradientBasedSampler::PoissonSampling(common::Span gpair, - DMatrix* dmat) { - // Sum the absolute value of gradients as the denominator to normalize the probability. - float sum_abs_gradient = thrust::transform_reduce(dh::tbegin(gpair), dh::tend(gpair), - AbsoluteGradient(), - 0.0f, thrust::plus()); - - // Poisson sampling of the gradient pairs based on the absolute value of the gradient. - thrust::transform(dh::tbegin(gpair), dh::tend(gpair), - thrust::counting_iterator(0), - dh::tbegin(gpair), - PoissonSamplingFunction(sample_rows_, - sum_abs_gradient, - common::GlobalRandom()())); - - // Map the original row index to the sample row index. - thrust::fill(dh::tbegin(sample_row_index_), dh::tend(sample_row_index_), 0); - thrust::transform(dh::tbegin(gpair), dh::tend(gpair), - dh::tbegin(sample_row_index_), - IsNonZero()); - thrust::exclusive_scan(dh::tbegin(sample_row_index_), dh::tend(sample_row_index_), - dh::tbegin(sample_row_index_)); - thrust::transform(dh::tbegin(gpair), dh::tend(gpair), - dh::tbegin(sample_row_index_), - dh::tbegin(sample_row_index_), - ClearEmptyRows(sample_rows_)); - - // Zero out the gradient pairs if there are more rows than desired. - thrust::transform(dh::tbegin(gpair), dh::tend(gpair), - dh::tbegin(sample_row_index_), - dh::tbegin(gpair), - TrimExtraRows()); - - // Compact the non-zero gradient pairs. - thrust::copy_if(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(gpair_), IsNonZero()); - - // Compact the ELLPACK pages into the single sample page. - for (auto& batch : dmat->GetBatches(batch_param_)) { - page_->Compact(batch_param_.gpu_id, batch.Impl(), sample_row_index_); - } - - return {sample_rows_, page_.get(), gpair_}; -} - -/*! \brief A functor that samples gradient pairs using sequential Poisson sampling. - * - * Sampling probability is proportional to the absolute value of the gradient. - */ -struct SequentialPoissonSamplingFunction - : public thrust::binary_function { - const uint32_t seed; - - XGBOOST_DEVICE explicit SequentialPoissonSamplingFunction(size_t _seed) : seed(_seed) {} - - XGBOOST_DEVICE float operator()(const GradientPair& gpair, size_t i) { - if (gpair.GetGrad() == 0) { - return FLT_MAX; - } - thrust::default_random_engine rng(seed); - thrust::uniform_real_distribution dist; - rng.discard(i); - return dist(rng) / fabsf(gpair.GetGrad()); - } -}; - GradientBasedSample GradientBasedSampler::SequentialPoissonSampling( common::Span gpair, DMatrix* dmat) { // Transform the gradient to weight = random(0, 1) / abs(grad). thrust::transform(dh::tbegin(gpair), dh::tend(gpair), thrust::counting_iterator(0), dh::tbegin(row_weight_), - SequentialPoissonSamplingFunction(common::GlobalRandom()())); + CalculateWeight(common::GlobalRandom()())); + return WeightedSampling(gpair, dmat); +} +GradientBasedSample GradientBasedSampler::WeightedSampling( + common::Span gpair, DMatrix* dmat) { // Sort the gradient pairs and row indexes by weight. thrust::sort_by_key(dh::tbegin(row_weight_), dh::tend(row_weight_), thrust::make_zip_iterator(thrust::make_tuple(dh::tbegin(gpair), @@ -256,7 +159,7 @@ GradientBasedSample GradientBasedSampler::SequentialPoissonSampling( thrust::fill(dh::tbegin(sample_row_index_), dh::tbegin(sample_row_index_) + sample_rows_, 1); thrust::fill(dh::tbegin(sample_row_index_) + sample_rows_, dh::tend(sample_row_index_), 0); - // Sort the gradient pairs and sample row indexed by the original row index. + // Sort the gradient pairs and sample row indexes by the original row index. thrust::sort_by_key(dh::tbegin(row_index_), dh::tend(row_index_), thrust::make_zip_iterator(thrust::make_tuple(dh::tbegin(gpair), dh::tbegin(sample_row_index_)))); @@ -281,10 +184,28 @@ GradientBasedSample GradientBasedSampler::SequentialPoissonSampling( return {sample_rows_, page_.get(), gpair_}; } +/*! \brief A functor that returns random weights. */ +struct RandomWeight : public thrust::unary_function { + const uint32_t seed; + + XGBOOST_DEVICE explicit RandomWeight(size_t _seed) : seed(_seed) {} + + XGBOOST_DEVICE float operator()(size_t i) { + thrust::default_random_engine rng(seed); + thrust::uniform_real_distribution dist; + rng.discard(i); + return dist(rng); + } +}; + GradientBasedSample GradientBasedSampler::UniformSampling(common::Span gpair, DMatrix* dmat) { - // TODO(rongou): implement this. - return {sample_rows_, page_.get(), gpair_}; + // Generate random weights. + thrust::transform(thrust::counting_iterator(0), + thrust::counting_iterator(0) + gpair.size(), + dh::tbegin(row_weight_), + RandomWeight(common::GlobalRandom()())); + return WeightedSampling(gpair, dmat); } }; // namespace tree }; // namespace xgboost diff --git a/src/tree/gpu_hist/gradient_based_sampler.cuh b/src/tree/gpu_hist/gradient_based_sampler.cuh index db4d538c9971..fcddb7bcfbb7 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cuh +++ b/src/tree/gpu_hist/gradient_based_sampler.cuh @@ -22,9 +22,6 @@ struct GradientBasedSample { }; /*! \brief Draw a sample of rows from a DMatrix. - * - * Use Poisson sampling to draw a probability proportional to size (pps) sample of rows from a - * DMatrix, where "size" is the absolute value of the gradient. * * \see Ke, G., Meng, Q., Finley, T., Wang, T., Chen, W., Ma, W., ... & Liu, T. Y. (2017). * Lightgbm: A highly efficient gradient boosting decision tree. In Advances in Neural Information @@ -36,35 +33,39 @@ struct GradientBasedSample { class GradientBasedSampler { public: enum SamplingMethod { + /*! \brief When all rows can fit in GPU memory, no sampling is needed. */ kNoSampling, - kPoissonSampling, + /*! \brief Fixed-sized random sampling, weighted by the absolute value of the gradient. */ kSequentialPoissonSampling, - kUniformSampling, + /*! \brief This is for comparison purposes only, not recommended for real use. */ + kUniformSampling }; explicit GradientBasedSampler(BatchParam batch_param, EllpackInfo info, size_t n_rows, - float subsample = 1.0f); - - /*! \brief Returns the max number of rows that can fit in available GPU memory. */ - size_t MaxSampleRows(); + float subsample = 1.0f, + SamplingMethod sampling_method = kDefaultSamplingMethod); /*! \brief Sample from a DMatrix based on the given gradients. */ - GradientBasedSample Sample(common::Span gpair, DMatrix* dmat, - SamplingMethod sampling_method = kDefaultSamplingMethod); - - /*! \brief Collect all the rows from a DMatrix into a single ELLPACK page. */ - void CollectPages(DMatrix* dmat); + GradientBasedSample Sample(common::Span gpair, DMatrix* dmat); private: static const SamplingMethod kDefaultSamplingMethod = kSequentialPoissonSampling; GradientBasedSample NoSampling(common::Span gpair, DMatrix* dmat); - GradientBasedSample PoissonSampling(common::Span gpair, DMatrix* dmat); GradientBasedSample SequentialPoissonSampling(common::Span gpair, DMatrix* dmat); GradientBasedSample UniformSampling(common::Span gpair, DMatrix* dmat); + /*! \brief Returns the max number of rows that can fit in available GPU memory. */ + size_t MaxSampleRows(); + + /*! \brief Collect all the rows from a DMatrix into a single ELLPACK page. */ + void CollectPages(DMatrix* dmat); + + /*! \brief Do weighted sampling after the row weights are calculated. */ + GradientBasedSample WeightedSampling(common::Span gpair, DMatrix* dmat); + common::Monitor monitor_; dh::BulkAllocator ba_; BatchParam batch_param_; diff --git a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu index cee59672c433..466457de1a0f 100644 --- a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu +++ b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu @@ -55,7 +55,7 @@ TEST(GradientBasedSampler, NoSampling) { } } -TEST(GradientBasedSampler, PoissonSampling) { +TEST(GradientBasedSampler, SequentialPoissonSampling) { constexpr size_t kRows = 2048; constexpr size_t kCols = 16; constexpr float kSubsample = 0.5; @@ -72,9 +72,9 @@ TEST(GradientBasedSampler, PoissonSampling) { BatchParam param{0, 256, 0, kPageSize}; auto page = (*dmat->GetBatches(param).begin()).Impl(); - GradientBasedSampler sampler(param, page->matrix.info, kRows, kSubsample); - auto sample = sampler.Sample(gpair.DeviceSpan(), dmat.get(), - GradientBasedSampler::kPoissonSampling); + GradientBasedSampler sampler(param, page->matrix.info, kRows, kSubsample, + GradientBasedSampler::kSequentialPoissonSampling); + auto sample = sampler.Sample(gpair.DeviceSpan(), dmat.get()); auto sampled_page = sample.page; auto sampled_gpair = sample.gpair; EXPECT_EQ(sample.sample_rows, kSampleRows); From d3a3dbf33d0f496a87da58398002c3433ab9eb86 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 13 Dec 2019 16:16:00 -0800 Subject: [PATCH 22/48] better estimate of sample rows --- src/tree/gpu_hist/gradient_based_sampler.cu | 10 ++++++---- src/tree/gpu_hist/gradient_based_sampler.cuh | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index 32910dfc0092..d642d11ace60 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -26,7 +26,7 @@ GradientBasedSampler::GradientBasedSampler(BatchParam batch_param, monitor_.Init("gradient_based_sampler"); if (subsample == 0.0f || subsample == 1.0f) { - sample_rows_ = MaxSampleRows(); + sample_rows_ = MaxSampleRows(n_rows); } else { sample_rows_ = n_rows * subsample; } @@ -52,10 +52,12 @@ GradientBasedSampler::GradientBasedSampler(BatchParam batch_param, } } -size_t GradientBasedSampler::MaxSampleRows() { +size_t GradientBasedSampler::MaxSampleRows(size_t n_rows) { size_t available_memory = dh::AvailableMemory(batch_param_.gpu_id); - size_t usable_memory = available_memory * 0.95; - size_t extra_bytes = sizeof(float) + 2 * sizeof(size_t); + // Subtract row_weight_, row_index_, and sample_row_index_. + available_memory -= n_rows * (sizeof(float) + 2 * sizeof(size_t)); + size_t usable_memory = available_memory * 0.5; + size_t extra_bytes = sizeof(GradientPair); size_t max_rows = common::CompressedBufferWriter::CalculateMaxRows( usable_memory, info_.NumSymbols(), info_.row_stride, extra_bytes); return max_rows; diff --git a/src/tree/gpu_hist/gradient_based_sampler.cuh b/src/tree/gpu_hist/gradient_based_sampler.cuh index fcddb7bcfbb7..e503e292242e 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cuh +++ b/src/tree/gpu_hist/gradient_based_sampler.cuh @@ -58,7 +58,7 @@ class GradientBasedSampler { GradientBasedSample UniformSampling(common::Span gpair, DMatrix* dmat); /*! \brief Returns the max number of rows that can fit in available GPU memory. */ - size_t MaxSampleRows(); + size_t MaxSampleRows(size_t n_rows); /*! \brief Collect all the rows from a DMatrix into a single ELLPACK page. */ void CollectPages(DMatrix* dmat); From f7286230b028dfa0f50f4ee9a47c36cb22cd3d74 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Mon, 16 Dec 2019 11:31:48 -0800 Subject: [PATCH 23/48] more agressive memory allocation --- src/tree/gpu_hist/gradient_based_sampler.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index d642d11ace60..cd4916e5963a 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -56,7 +56,7 @@ size_t GradientBasedSampler::MaxSampleRows(size_t n_rows) { size_t available_memory = dh::AvailableMemory(batch_param_.gpu_id); // Subtract row_weight_, row_index_, and sample_row_index_. available_memory -= n_rows * (sizeof(float) + 2 * sizeof(size_t)); - size_t usable_memory = available_memory * 0.5; + size_t usable_memory = available_memory * 0.7; size_t extra_bytes = sizeof(GradientPair); size_t max_rows = common::CompressedBufferWriter::CalculateMaxRows( usable_memory, info_.NumSymbols(), info_.row_stride, extra_bytes); From c206880e22d605cf27ab25a52e9e9a8f1025affb Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Mon, 16 Dec 2019 12:38:53 -0800 Subject: [PATCH 24/48] add some documentation --- include/xgboost/base.h | 14 -------- src/data/ellpack_page.cu | 37 +++++++++++-------- src/data/ellpack_page.cuh | 1 + src/tree/gpu_hist/gradient_based_sampler.cu | 38 +++++++++++++------- src/tree/gpu_hist/gradient_based_sampler.cuh | 2 +- 5 files changed, 50 insertions(+), 42 deletions(-) diff --git a/include/xgboost/base.h b/include/xgboost/base.h index 739b97ef60db..bb8e5dd155e6 100644 --- a/include/xgboost/base.h +++ b/include/xgboost/base.h @@ -193,20 +193,6 @@ class GradientPairInternal { return g; } - XGBOOST_DEVICE GradientPairInternal operator*(float multiplier) const { - GradientPairInternal g; - g.grad_ = grad_ * multiplier; - g.hess_ = hess_ * multiplier; - return g; - } - - XGBOOST_DEVICE GradientPairInternal operator/(float divider) const { - GradientPairInternal g; - g.grad_ = grad_ / divider; - g.hess_ = hess_ / divider; - return g; - } - XGBOOST_DEVICE bool operator==(const GradientPairInternal &rhs) const { return grad_ == rhs.grad_ && hess_ == rhs.hess_; } diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index ab3380fc06f2..760d47b06f8d 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -111,21 +111,21 @@ EllpackPageImpl::EllpackPageImpl(DMatrix* dmat, const BatchParam& param) { } // A functor that copies the data from one EllpackPage to another. -struct CopyPageFunction { +struct CopyPage { common::CompressedBufferWriter cbw; common::CompressedByteT* dst_data_d; common::CompressedIterator src_iterator_d; // The number of elements to skip. size_t offset; - CopyPageFunction(EllpackPageImpl* dst, EllpackPageImpl* src, size_t offset) + CopyPage(EllpackPageImpl* dst, EllpackPageImpl* src, size_t offset) : cbw{dst->matrix.info.NumSymbols()}, dst_data_d{dst->gidx_buffer.data()}, src_iterator_d{src->gidx_buffer.data(), src->matrix.info.NumSymbols()}, offset(offset) {} - __device__ void operator()(size_t i) { - cbw.AtomicWriteSymbol(dst_data_d, src_iterator_d[i], i + offset); + __device__ void operator()(size_t element_id) { + cbw.AtomicWriteSymbol(dst_data_d, src_iterator_d[element_id], element_id + offset); } }; @@ -136,21 +136,30 @@ size_t EllpackPageImpl::Copy(int device, EllpackPageImpl* page, size_t offset) { CHECK_EQ(matrix.info.row_stride, page->matrix.info.row_stride); CHECK_EQ(matrix.info.NumSymbols(), page->matrix.info.NumSymbols()); CHECK_GE(matrix.n_rows * matrix.info.row_stride, offset + num_elements); - dh::LaunchN(device, num_elements, CopyPageFunction(this, page, offset)); + dh::LaunchN(device, num_elements, CopyPage(this, page, offset)); monitor_.StopCuda("Copy"); return num_elements; } // A functor that compacts the rows from one EllpackPage into another. -struct CompactPageFunction { +struct CompactPage { common::CompressedBufferWriter cbw; common::CompressedByteT* dst_data_d; common::CompressedIterator src_iterator_d; + /*! \brief An array that maps the rows from the full DMatrix to the compacted page. + * + * The total size is the number of rows in the original, uncompacted DMatrix. Elements are the + * row ids in the compacted page. Rows not needed are set to SIZE_MAX. + * + * An example compacting 16 rows to 8 rows: + * [SIZE_MAX, 0, 1, SIZE_MAX, SIZE_MAX, 2, SIZE_MAX, 3, 4, 5, SIZE_MAX, 6, SIZE_MAX, 7, SIZE_MAX, + * SIZE_MAX] + */ common::Span row_indexes; size_t base_rowid; size_t row_stride; - CompactPageFunction(EllpackPageImpl* dst, EllpackPageImpl* src, common::Span row_indexes) + CompactPage(EllpackPageImpl* dst, EllpackPageImpl* src, common::Span row_indexes) : cbw{dst->matrix.info.NumSymbols()}, dst_data_d{dst->gidx_buffer.data()}, src_iterator_d{src->gidx_buffer.data(), src->matrix.info.NumSymbols()}, @@ -158,12 +167,12 @@ struct CompactPageFunction { base_rowid{src->matrix.base_rowid}, row_stride{src->matrix.info.row_stride} {} - __device__ void operator()(size_t i) { - size_t row = base_rowid + i; - size_t row_index = row_indexes[row]; - if (row_index == SIZE_MAX) return; - size_t dst_offset = row_index * row_stride; - size_t src_offset = i * row_stride; + __device__ void operator()(size_t row_id) { + size_t src_row = base_rowid + row_id; + size_t dst_row = row_indexes[src_row]; + if (dst_row == SIZE_MAX) return; + size_t dst_offset = dst_row * row_stride; + size_t src_offset = row_id * row_stride; for (size_t j = 0; j < row_stride; j++) { cbw.AtomicWriteSymbol(dst_data_d, src_iterator_d[src_offset + j], dst_offset + j); } @@ -176,7 +185,7 @@ void EllpackPageImpl::Compact(int device, EllpackPageImpl* page, common::Spanmatrix.info.row_stride); CHECK_EQ(matrix.info.NumSymbols(), page->matrix.info.NumSymbols()); CHECK_LE(page->matrix.base_rowid + page->matrix.n_rows, row_indexes.size()); - dh::LaunchN(device, page->matrix.n_rows, CompactPageFunction(this, page, row_indexes)); + dh::LaunchN(device, page->matrix.n_rows, CompactPage(this, page, row_indexes)); monitor_.StopCuda("Compact"); } diff --git a/src/data/ellpack_page.cuh b/src/data/ellpack_page.cuh index 2ff46f1a9380..c85738c4da39 100644 --- a/src/data/ellpack_page.cuh +++ b/src/data/ellpack_page.cuh @@ -72,6 +72,7 @@ struct EllpackInfo { const common::HistogramCuts& hmat, dh::BulkAllocator* ba); + /*! \brief Return the total number of symbols (total number of bins plus 1 for not found). */ inline size_t NumSymbols() const { return n_bins + 1; } diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index cd4916e5963a..ae6ebae43a16 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -4,7 +4,6 @@ #include #include #include -#include #include #include @@ -25,12 +24,15 @@ GradientBasedSampler::GradientBasedSampler(BatchParam batch_param, : batch_param_(batch_param), info_(info), sampling_method_(sampling_method) { monitor_.Init("gradient_based_sampler"); + // If `subsample` is not specified, try to figure out how many rows to sample based on available + // free GPU memory. if (subsample == 0.0f || subsample == 1.0f) { sample_rows_ = MaxSampleRows(n_rows); } else { sample_rows_ = n_rows * subsample; } + // If there is enough GPU memory to keep all the rows, don't sample. if (sample_rows_ >= n_rows) { sampling_method_ = kNoSampling; sample_rows_ = n_rows; @@ -39,7 +41,10 @@ GradientBasedSampler::GradientBasedSampler(BatchParam batch_param, LOG(CONSOLE) << "Sampling " << sample_rows_ << " rows"; } + // Create a new ELLPACK page with empty rows. page_.reset(new EllpackPageImpl(batch_param.gpu_id, info, sample_rows_)); + + // Allocate GPU memory for sampling. if (sampling_method_ != kNoSampling) { ba_.Allocate(batch_param_.gpu_id, &gpair_, sample_rows_, @@ -52,6 +57,7 @@ GradientBasedSampler::GradientBasedSampler(BatchParam batch_param, } } +// Determine the maximum number of rows that can fit into the available GPU memory. size_t GradientBasedSampler::MaxSampleRows(size_t n_rows) { size_t available_memory = dh::AvailableMemory(batch_param_.gpu_id); // Subtract row_weight_, row_index_, and sample_row_index_. @@ -63,21 +69,30 @@ size_t GradientBasedSampler::MaxSampleRows(size_t n_rows) { return max_rows; } +// Sample a DMatrix based on the given gradient pairs. GradientBasedSample GradientBasedSampler::Sample(common::Span gpair, DMatrix* dmat) { + monitor_.StartCuda("Sample"); + GradientBasedSample sample; switch (sampling_method_) { case kNoSampling: - return NoSampling(gpair, dmat); + sample = NoSampling(gpair, dmat); + break; case kSequentialPoissonSampling: - return SequentialPoissonSampling(gpair, dmat); + sample = SequentialPoissonSampling(gpair, dmat); + break; case kUniformSampling: - return UniformSampling(gpair, dmat); + sample = UniformSampling(gpair, dmat); + break; default: LOG(FATAL) << "unknown sampling method"; - return {sample_rows_, page_.get(), gpair}; + sample = {sample_rows_, page_.get(), gpair}; } + monitor_.StopCuda("Sample"); + return sample; } +// When not sampling, collect all the external memory ELLPACK pages into a single in-memory page. void GradientBasedSampler::CollectPages(DMatrix* dmat) { if (page_collected_) { return; @@ -98,7 +113,7 @@ GradientBasedSample GradientBasedSampler::NoSampling(common::Span return {sample_rows_, page_.get(), gpair}; } -/*! \brief A functor that calculate the weight of each row as random(0, 1) / abs(grad). */ +/*! \brief A functor that calculates the weight of each row as random(0, 1) / abs(grad). */ struct CalculateWeight : public thrust::binary_function { const uint32_t seed; @@ -124,12 +139,8 @@ struct IsNonZero : public thrust::unary_function { /*! \brief A functor that clears the row indexes with empty gradient. */ struct ClearEmptyRows : public thrust::binary_function { - const size_t max_rows; - - XGBOOST_DEVICE explicit ClearEmptyRows(size_t max_rows) : max_rows(max_rows) {} - XGBOOST_DEVICE size_t operator()(const GradientPair& gpair, size_t row_index) const { - if ((gpair.GetGrad() != 0 || gpair.GetHess() != 0) && row_index < max_rows) { + if (gpair.GetGrad() != 0 || gpair.GetHess() != 0) { return row_index; } else { return SIZE_MAX; @@ -147,6 +158,7 @@ GradientBasedSample GradientBasedSampler::SequentialPoissonSampling( return WeightedSampling(gpair, dmat); } +// Perform sampling after the weights are calculated. GradientBasedSample GradientBasedSampler::WeightedSampling( common::Span gpair, DMatrix* dmat) { // Sort the gradient pairs and row indexes by weight. @@ -154,7 +166,7 @@ GradientBasedSample GradientBasedSampler::WeightedSampling( thrust::make_zip_iterator(thrust::make_tuple(dh::tbegin(gpair), dh::tbegin(row_index_)))); - // Clear the gradient pairs not in the sample. + // Clear the gradient pairs not included in the sample. thrust::fill(dh::tbegin(gpair) + sample_rows_, dh::tend(gpair), GradientPair()); // Mask the sample rows. @@ -175,7 +187,7 @@ GradientBasedSample GradientBasedSampler::WeightedSampling( thrust::transform(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(sample_row_index_), dh::tbegin(sample_row_index_), - ClearEmptyRows(sample_rows_)); + ClearEmptyRows()); // Compact the ELLPACK pages into the single sample page. thrust::fill(dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0); diff --git a/src/tree/gpu_hist/gradient_based_sampler.cuh b/src/tree/gpu_hist/gradient_based_sampler.cuh index e503e292242e..baeb6b8f0313 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cuh +++ b/src/tree/gpu_hist/gradient_based_sampler.cuh @@ -17,7 +17,7 @@ struct GradientBasedSample { size_t sample_rows; /*!\brief Sampled rows in ELLPACK format. */ EllpackPageImpl* page; - /*!\brief Rescaled gradient pairs for the sampled rows. */ + /*!\brief Gradient pairs for the sampled rows. */ common::Span gpair; }; From 5652b3b07fa50fce28d3434d3d6b919e01de1993 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 20 Dec 2019 11:14:31 -0800 Subject: [PATCH 25/48] use mvs --- src/tree/gpu_hist/gradient_based_sampler.cu | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index ae6ebae43a16..630f20825ac8 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -113,20 +113,27 @@ GradientBasedSample GradientBasedSampler::NoSampling(common::Span return {sample_rows_, page_.get(), gpair}; } -/*! \brief A functor that calculates the weight of each row as random(0, 1) / abs(grad). */ +/*! \brief A functor that calculates the weight of each row. + * + * The approach here is based on Minimal Variance Sampling (MVS), with lambda set to 0.1. + * + * \see Ibragimov, B., & Gusev, G. (2019). Minimal Variance Sampling in Stochastic Gradient + * Boosting. In Advances in Neural Information Processing Systems (pp. 15061-15071). + */ struct CalculateWeight : public thrust::binary_function { const uint32_t seed; + const float lambda{0.1}; XGBOOST_DEVICE explicit CalculateWeight(size_t _seed) : seed(_seed) {} XGBOOST_DEVICE float operator()(const GradientPair& gpair, size_t i) { - if (gpair.GetGrad() == 0) { + if (gpair.GetGrad() == 0 && gpair.GetHess() == 0) { return FLT_MAX; } thrust::default_random_engine rng(seed); thrust::uniform_real_distribution dist; rng.discard(i); - return dist(rng) / fabsf(gpair.GetGrad()); + return dist(rng) / sqrtf(powf(gpair.GetGrad(), 2) + lambda * powf(gpair.GetHess(), 2)); } }; From a5b57b136045ac407152ecc51ffa307afd1d7f40 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 20 Dec 2019 11:23:30 -0800 Subject: [PATCH 26/48] fix windows --- src/tree/gpu_hist/gradient_based_sampler.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index 630f20825ac8..3620745fdb89 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -122,7 +122,7 @@ GradientBasedSample GradientBasedSampler::NoSampling(common::Span */ struct CalculateWeight : public thrust::binary_function { const uint32_t seed; - const float lambda{0.1}; + const float lambda{0.1f}; XGBOOST_DEVICE explicit CalculateWeight(size_t _seed) : seed(_seed) {} From 62a9eadd169a248e49bbbb0419fe0567fddbcaca Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 20 Dec 2019 14:13:43 -0800 Subject: [PATCH 27/48] address review comments --- src/data/ellpack_page.cuh | 2 +- src/tree/gpu_hist/gradient_based_sampler.cu | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/data/ellpack_page.cuh b/src/data/ellpack_page.cuh index c85738c4da39..fcf89ab8fe98 100644 --- a/src/data/ellpack_page.cuh +++ b/src/data/ellpack_page.cuh @@ -73,7 +73,7 @@ struct EllpackInfo { dh::BulkAllocator* ba); /*! \brief Return the total number of symbols (total number of bins plus 1 for not found). */ - inline size_t NumSymbols() const { + size_t NumSymbols() const { return n_bins + 1; } }; diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index 3620745fdb89..c94fd8aa5d9b 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -61,7 +61,12 @@ GradientBasedSampler::GradientBasedSampler(BatchParam batch_param, size_t GradientBasedSampler::MaxSampleRows(size_t n_rows) { size_t available_memory = dh::AvailableMemory(batch_param_.gpu_id); // Subtract row_weight_, row_index_, and sample_row_index_. - available_memory -= n_rows * (sizeof(float) + 2 * sizeof(size_t)); + size_t index_bytes = n_rows * (sizeof(float) + 2 * sizeof(size_t)); + CHECK_GT(available_memory, index_bytes) + << "not enough GPU memory for indexes, " + << "available: " << available_memory + << "indexes: " << index_bytes; + available_memory -= index_bytes; size_t usable_memory = available_memory * 0.7; size_t extra_bytes = sizeof(GradientPair); size_t max_rows = common::CompressedBufferWriter::CalculateMaxRows( @@ -223,7 +228,7 @@ GradientBasedSample GradientBasedSampler::UniformSampling(common::Span(0), - thrust::counting_iterator(0) + gpair.size(), + thrust::counting_iterator(gpair.size()), dh::tbegin(row_weight_), RandomWeight(common::GlobalRandom()())); return WeightedSampling(gpair, dmat); From 2409ecb0706e214bdc3693264fc7f8f505d04f25 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Wed, 8 Jan 2020 11:39:31 -0800 Subject: [PATCH 28/48] add sampling method param --- src/tree/param.h | 11 +++++++++++ src/tree/updater_basemaker-inl.h | 2 ++ src/tree/updater_colmaker.cc | 2 ++ src/tree/updater_quantile_hist.cc | 2 ++ 4 files changed, 17 insertions(+) diff --git a/src/tree/param.h b/src/tree/param.h index 3d991a6f1738..da63895a5ef8 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -50,6 +50,9 @@ struct TrainParam : public XGBoostParameter { float max_delta_step; // whether we want to do subsample float subsample; + // sampling method + enum SamplingMethod { kUniform = 0, kGradientBased = 1 }; + int sampling_method; // whether to subsample columns in each split (node) float colsample_bynode; // whether to subsample columns in each level @@ -144,6 +147,14 @@ struct TrainParam : public XGBoostParameter { .set_range(0.0f, 1.0f) .set_default(1.0f) .describe("Row subsample ratio of training instance."); + DMLC_DECLARE_FIELD(sampling_method) + .set_default(kUniform) + .add_enum("uniform", kUniform) + .add_enum("gradient_based", kGradientBased) + .describe( + "Sampling method. 0: select random training instances uniformly. " + "1: select random training instances with higher probability when the " + "gradient and hessian are larger. (cf. CatBoost)"); DMLC_DECLARE_FIELD(colsample_bynode) .set_range(0.0f, 1.0f) .set_default(1.0f) diff --git a/src/tree/updater_basemaker-inl.h b/src/tree/updater_basemaker-inl.h index f806628dd023..2ef24b11facc 100644 --- a/src/tree/updater_basemaker-inl.h +++ b/src/tree/updater_basemaker-inl.h @@ -148,6 +148,8 @@ class BaseMaker: public TreeUpdater { } // mark subsample if (param_.subsample < 1.0f) { + CHECK_EQ(param_.sampling_method, TrainParam::kUniform) + << "Only uniform sampling is supported"; std::bernoulli_distribution coin_flip(param_.subsample); auto& rnd = common::GlobalRandom(); for (size_t i = 0; i < position_.size(); ++i) { diff --git a/src/tree/updater_colmaker.cc b/src/tree/updater_colmaker.cc index 62a63dec37ac..02f8700c1c84 100644 --- a/src/tree/updater_colmaker.cc +++ b/src/tree/updater_colmaker.cc @@ -198,6 +198,8 @@ class ColMaker: public TreeUpdater { } // mark subsample if (param_.subsample < 1.0f) { + CHECK_EQ(param_.sampling_method, TrainParam::kUniform) + << "Only uniform sampling is supported"; std::bernoulli_distribution coin_flip(param_.subsample); auto& rnd = common::GlobalRandom(); for (size_t ridx = 0; ridx < position_.size(); ++ridx) { diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index b28c2f9707da..3c9d0b31eb94 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -441,6 +441,8 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat, // mark subsample and build list of member rows if (param_.subsample < 1.0f) { + CHECK_EQ(param_.sampling_method, TrainParam::kUniform) + << "Only uniform sampling is supported"; std::bernoulli_distribution coin_flip(param_.subsample); auto& rnd = common::GlobalRandom(); size_t j = 0; From 3b16e6609c3f8d5997abdf541edf1f9cb9f6b547 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 9 Jan 2020 13:54:36 -0800 Subject: [PATCH 29/48] gradient-based sampling in in-memory mode --- src/tree/gpu_hist/gradient_based_sampler.cu | 271 +++++++++++------- src/tree/gpu_hist/gradient_based_sampler.cuh | 38 +-- src/tree/updater_gpu_common.cuh | 36 --- src/tree/updater_gpu_hist.cu | 24 +- .../gpu_hist/test_gradient_based_sampler.cu | 8 +- 5 files changed, 187 insertions(+), 190 deletions(-) diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index c94fd8aa5d9b..e28e8bcba2fe 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -16,87 +17,73 @@ namespace xgboost { namespace tree { -GradientBasedSampler::GradientBasedSampler(BatchParam batch_param, - EllpackInfo info, +GradientBasedSampler::GradientBasedSampler(EllpackPageImpl* page, size_t n_rows, + BatchParam batch_param, float subsample, - SamplingMethod sampling_method) - : batch_param_(batch_param), info_(info), sampling_method_(sampling_method) { + int sampling_method) + : original_page_(page), + batch_param_(batch_param), + is_external_memory_(page->matrix.n_rows != n_rows), + subsample_(subsample), + is_sampling_(subsample < 1.0), + sampling_method_(sampling_method), + sample_rows_(n_rows * subsample) { monitor_.Init("gradient_based_sampler"); - // If `subsample` is not specified, try to figure out how many rows to sample based on available - // free GPU memory. - if (subsample == 0.0f || subsample == 1.0f) { - sample_rows_ = MaxSampleRows(n_rows); - } else { - sample_rows_ = n_rows * subsample; - } - - // If there is enough GPU memory to keep all the rows, don't sample. - if (sample_rows_ >= n_rows) { - sampling_method_ = kNoSampling; - sample_rows_ = n_rows; - LOG(CONSOLE) << "Keeping " << sample_rows_ << " rows in GPU memory, not sampling"; - } else { - LOG(CONSOLE) << "Sampling " << sample_rows_ << " rows"; - } - - // Create a new ELLPACK page with empty rows. - page_.reset(new EllpackPageImpl(batch_param.gpu_id, info, sample_rows_)); - - // Allocate GPU memory for sampling. - if (sampling_method_ != kNoSampling) { - ba_.Allocate(batch_param_.gpu_id, - &gpair_, sample_rows_, - &row_weight_, n_rows, - &row_index_, n_rows, - &sample_row_index_, n_rows); - thrust::copy(thrust::counting_iterator(0), - thrust::counting_iterator(n_rows), - dh::tbegin(row_index_)); + if (is_external_memory_) { + // Create a new ELLPACK page with empty rows. + page_.reset(new EllpackPageImpl(batch_param.gpu_id, + original_page_->matrix.info, + sample_rows_)); + // Allocate GPU memory for sampling. + if (is_sampling_) { + ba_.Allocate(batch_param_.gpu_id, + &gpair_, sample_rows_, + &row_weight_, n_rows, + &row_index_, n_rows, + &sample_row_index_, n_rows); + thrust::copy(thrust::counting_iterator(0), + thrust::counting_iterator(n_rows), + dh::tbegin(row_index_)); + } } } -// Determine the maximum number of rows that can fit into the available GPU memory. -size_t GradientBasedSampler::MaxSampleRows(size_t n_rows) { - size_t available_memory = dh::AvailableMemory(batch_param_.gpu_id); - // Subtract row_weight_, row_index_, and sample_row_index_. - size_t index_bytes = n_rows * (sizeof(float) + 2 * sizeof(size_t)); - CHECK_GT(available_memory, index_bytes) - << "not enough GPU memory for indexes, " - << "available: " << available_memory - << "indexes: " << index_bytes; - available_memory -= index_bytes; - size_t usable_memory = available_memory * 0.7; - size_t extra_bytes = sizeof(GradientPair); - size_t max_rows = common::CompressedBufferWriter::CalculateMaxRows( - usable_memory, info_.NumSymbols(), info_.row_stride, extra_bytes); - return max_rows; -} - // Sample a DMatrix based on the given gradient pairs. GradientBasedSample GradientBasedSampler::Sample(common::Span gpair, DMatrix* dmat) { monitor_.StartCuda("Sample"); GradientBasedSample sample; - switch (sampling_method_) { - case kNoSampling: - sample = NoSampling(gpair, dmat); - break; - case kSequentialPoissonSampling: - sample = SequentialPoissonSampling(gpair, dmat); - break; - case kUniformSampling: - sample = UniformSampling(gpair, dmat); - break; - default: - LOG(FATAL) << "unknown sampling method"; - sample = {sample_rows_, page_.get(), gpair}; + if (is_sampling_) { + switch (sampling_method_) { + case TrainParam::kUniform: + sample = UniformSampling(gpair, dmat); + break; + case TrainParam::kGradientBased: + sample = GradientBasedSampling(gpair, dmat); + break; + default: + LOG(FATAL) << "unknown sampling method"; + sample = {0, nullptr, gpair}; + } + } else { + sample = NoSampling(gpair, dmat); } monitor_.StopCuda("Sample"); return sample; } +GradientBasedSample GradientBasedSampler::NoSampling(common::Span gpair, + DMatrix* dmat) { + if (is_external_memory_) { + CollectPages(dmat); + return {dmat->Info().num_row_, page_.get(), gpair}; + } else { + return {dmat->Info().num_row_, original_page_, gpair}; + } +} + // When not sampling, collect all the external memory ELLPACK pages into a single in-memory page. void GradientBasedSampler::CollectPages(DMatrix* dmat) { if (page_collected_) { @@ -112,36 +99,131 @@ void GradientBasedSampler::CollectPages(DMatrix* dmat) { page_collected_ = true; } -GradientBasedSample GradientBasedSampler::NoSampling(common::Span gpair, - DMatrix* dmat) { - CollectPages(dmat); - return {sample_rows_, page_.get(), gpair}; +/*! \brief A functor that returns random weights. */ +struct RandomWeight : public thrust::unary_function { + uint32_t seed; + + XGBOOST_DEVICE explicit RandomWeight(size_t _seed) : seed(_seed) {} + + XGBOOST_DEVICE float operator()(size_t i) const { + thrust::default_random_engine rng(seed); + thrust::uniform_real_distribution dist; + rng.discard(i); + return dist(rng); + } +}; + +/*! \brief A functor that performs Bernoulli sampling on gradient pairs. */ +struct BernoulliSampling : public thrust::binary_function { + float p; + RandomWeight rnd; + + XGBOOST_DEVICE BernoulliSampling(float _p, RandomWeight _rnd) : p(_p), rnd(_rnd) {} + + XGBOOST_DEVICE GradientPair operator()(const GradientPair& gpair, size_t i) const { + if (rnd(i) <= p) { + return gpair; + } else { + return GradientPair(); + } + } +}; + +GradientBasedSample GradientBasedSampler::UniformSampling(common::Span gpair, + DMatrix* dmat) { + RandomWeight rnd(common::GlobalRandom()()); + if (is_external_memory_) { + // Generate random weights. + thrust::transform(thrust::counting_iterator(0), + thrust::counting_iterator(gpair.size()), + dh::tbegin(row_weight_), + rnd); + return SequentialPoissonSampling(gpair, dmat); + } else { + // Set gradient pair to 0 with p = 1 - subsample + thrust::transform(dh::tbegin(gpair), dh::tend(gpair), + thrust::counting_iterator(0), + dh::tbegin(gpair), + BernoulliSampling(subsample_, rnd)); + return {dmat->Info().num_row_, original_page_, gpair}; + } } -/*! \brief A functor that calculates the weight of each row. +/*! \brief A functor that combines the gradient pair into a single float. * * The approach here is based on Minimal Variance Sampling (MVS), with lambda set to 0.1. * * \see Ibragimov, B., & Gusev, G. (2019). Minimal Variance Sampling in Stochastic Gradient * Boosting. In Advances in Neural Information Processing Systems (pp. 15061-15071). */ +struct CombineGradientPair : public thrust::unary_function { + static constexpr float lambda = 0.1f; + + XGBOOST_DEVICE float operator()(const GradientPair& gpair) const { + return sqrtf(powf(gpair.GetGrad(), 2) + lambda * powf(gpair.GetHess(), 2)); + } +}; + +/*! \brief A functor that calculates the weight of each row. + */ struct CalculateWeight : public thrust::binary_function { - const uint32_t seed; - const float lambda{0.1f}; + RandomWeight rnd; + CombineGradientPair combine{}; - XGBOOST_DEVICE explicit CalculateWeight(size_t _seed) : seed(_seed) {} + XGBOOST_DEVICE explicit CalculateWeight(RandomWeight _rnd) : rnd(_rnd) {} XGBOOST_DEVICE float operator()(const GradientPair& gpair, size_t i) { if (gpair.GetGrad() == 0 && gpair.GetHess() == 0) { return FLT_MAX; } - thrust::default_random_engine rng(seed); - thrust::uniform_real_distribution dist; - rng.discard(i); - return dist(rng) / sqrtf(powf(gpair.GetGrad(), 2) + lambda * powf(gpair.GetHess(), 2)); + return rnd(i) / combine(gpair); } }; +/*! \brief A functor that performs Poisson sampling with probability proportional to the combined + * gradient pair. + */ +struct PoissonSampling : public thrust::binary_function { + size_t sample_rows; + float normalization; + const RandomWeight rnd; + const CombineGradientPair combine{}; + + XGBOOST_DEVICE PoissonSampling(size_t _sample_rows, float _normalization, RandomWeight _rnd) + : sample_rows(_sample_rows), normalization(_normalization), rnd(_rnd) {} + + XGBOOST_DEVICE GradientPair operator()(const GradientPair& gpair, size_t i) const { + if (rnd(i) <= sample_rows * combine(gpair) / normalization) { + return gpair; + } else { + return GradientPair(); + } + } +}; + +GradientBasedSample GradientBasedSampler::GradientBasedSampling( + common::Span gpair, DMatrix* dmat) { + RandomWeight rnd(common::GlobalRandom()()); + if (is_external_memory_) { + thrust::transform(dh::tbegin(gpair), dh::tend(gpair), + thrust::counting_iterator(0), + dh::tbegin(row_weight_), + CalculateWeight(rnd)); + return SequentialPoissonSampling(gpair, dmat); + } else { + float normalization = thrust::transform_reduce(dh::tbegin(gpair), dh::tend(gpair), + CombineGradientPair(), + 0.0f, + thrust::plus()); + // Set gradient pair to 0 with p = 1 - p_i + thrust::transform(dh::tbegin(gpair), dh::tend(gpair), + thrust::counting_iterator(0), + dh::tbegin(gpair), + PoissonSampling(sample_rows_, normalization, rnd)); + return {dmat->Info().num_row_, original_page_, gpair}; + } +} + /*! \brief A functor that returns true if the gradient pair is non-zero. */ struct IsNonZero : public thrust::unary_function { XGBOOST_DEVICE bool operator()(const GradientPair& gpair) const { @@ -160,19 +242,9 @@ struct ClearEmptyRows : public thrust::binary_function gpair, DMatrix* dmat) { - // Transform the gradient to weight = random(0, 1) / abs(grad). - thrust::transform(dh::tbegin(gpair), dh::tend(gpair), - thrust::counting_iterator(0), - dh::tbegin(row_weight_), - CalculateWeight(common::GlobalRandom()())); - return WeightedSampling(gpair, dmat); -} - // Perform sampling after the weights are calculated. -GradientBasedSample GradientBasedSampler::WeightedSampling( - common::Span gpair, DMatrix* dmat) { +GradientBasedSample GradientBasedSampler::SequentialPoissonSampling( + common::Span gpair, DMatrix* dmat) { // Sort the gradient pairs and row indexes by weight. thrust::sort_by_key(dh::tbegin(row_weight_), dh::tend(row_weight_), thrust::make_zip_iterator(thrust::make_tuple(dh::tbegin(gpair), @@ -210,28 +282,5 @@ GradientBasedSample GradientBasedSampler::WeightedSampling( return {sample_rows_, page_.get(), gpair_}; } -/*! \brief A functor that returns random weights. */ -struct RandomWeight : public thrust::unary_function { - const uint32_t seed; - - XGBOOST_DEVICE explicit RandomWeight(size_t _seed) : seed(_seed) {} - - XGBOOST_DEVICE float operator()(size_t i) { - thrust::default_random_engine rng(seed); - thrust::uniform_real_distribution dist; - rng.discard(i); - return dist(rng); - } -}; - -GradientBasedSample GradientBasedSampler::UniformSampling(common::Span gpair, - DMatrix* dmat) { - // Generate random weights. - thrust::transform(thrust::counting_iterator(0), - thrust::counting_iterator(gpair.size()), - dh::tbegin(row_weight_), - RandomWeight(common::GlobalRandom()())); - return WeightedSampling(gpair, dmat); -} }; // namespace tree }; // namespace xgboost diff --git a/src/tree/gpu_hist/gradient_based_sampler.cuh b/src/tree/gpu_hist/gradient_based_sampler.cuh index baeb6b8f0313..ff69f1bf122e 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cuh +++ b/src/tree/gpu_hist/gradient_based_sampler.cuh @@ -32,45 +32,35 @@ struct GradientBasedSample { */ class GradientBasedSampler { public: - enum SamplingMethod { - /*! \brief When all rows can fit in GPU memory, no sampling is needed. */ - kNoSampling, - /*! \brief Fixed-sized random sampling, weighted by the absolute value of the gradient. */ - kSequentialPoissonSampling, - /*! \brief This is for comparison purposes only, not recommended for real use. */ - kUniformSampling - }; - explicit GradientBasedSampler(BatchParam batch_param, - EllpackInfo info, - size_t n_rows, - float subsample = 1.0f, - SamplingMethod sampling_method = kDefaultSamplingMethod); + GradientBasedSampler(EllpackPageImpl* page, + size_t n_rows, + BatchParam batch_param, + float subsample, + int sampling_method); - /*! \brief Sample from a DMatrix based on the given gradients. */ + /*! \brief Sample from a DMatrix based on the given gradient pairs. */ GradientBasedSample Sample(common::Span gpair, DMatrix* dmat); private: - static const SamplingMethod kDefaultSamplingMethod = kSequentialPoissonSampling; - GradientBasedSample NoSampling(common::Span gpair, DMatrix* dmat); - GradientBasedSample SequentialPoissonSampling(common::Span gpair, DMatrix* dmat); GradientBasedSample UniformSampling(common::Span gpair, DMatrix* dmat); - - /*! \brief Returns the max number of rows that can fit in available GPU memory. */ - size_t MaxSampleRows(size_t n_rows); + GradientBasedSample GradientBasedSampling(common::Span gpair, DMatrix* dmat); /*! \brief Collect all the rows from a DMatrix into a single ELLPACK page. */ void CollectPages(DMatrix* dmat); - /*! \brief Do weighted sampling after the row weights are calculated. */ - GradientBasedSample WeightedSampling(common::Span gpair, DMatrix* dmat); + /*! \brief Fixed-size Poisson sampling after the row weights are calculated. */ + GradientBasedSample SequentialPoissonSampling(common::Span gpair, DMatrix* dmat); common::Monitor monitor_; dh::BulkAllocator ba_; + EllpackPageImpl* original_page_; + float subsample_; + bool is_external_memory_; + bool is_sampling_; BatchParam batch_param_; - EllpackInfo info_; - SamplingMethod sampling_method_; + int sampling_method_; size_t sample_rows_; std::unique_ptr page_; common::Span gpair_; diff --git a/src/tree/updater_gpu_common.cuh b/src/tree/updater_gpu_common.cuh index dacb32f0aac4..b64eeaae6e70 100644 --- a/src/tree/updater_gpu_common.cuh +++ b/src/tree/updater_gpu_common.cuh @@ -187,41 +187,5 @@ XGBOOST_DEVICE inline int MaxNodesDepth(int depth) { return (1 << (depth + 1)) - 1; } -/* - * Random - */ -struct BernoulliRng { - float p; - uint32_t seed; - - XGBOOST_DEVICE BernoulliRng(float p, size_t seed_) : p(p) { - seed = static_cast(seed_); - } - - XGBOOST_DEVICE bool operator()(const int i) const { - thrust::default_random_engine rng(seed); - thrust::uniform_real_distribution dist; - rng.discard(i); - return dist(rng) <= p; - } -}; - -// Set gradient pair to 0 with p = 1 - subsample -inline void SubsampleGradientPair(int device_idx, - common::Span d_gpair, - float subsample, int offset = 0) { - if (subsample == 1.0) { - return; - } - - BernoulliRng rng(subsample, common::GlobalRandom()()); - - dh::LaunchN(device_idx, d_gpair.size(), [=] XGBOOST_DEVICE(int i) { - if (!rng(i + offset)) { - d_gpair[i] = GradientPair(); - } - }); -} - } // namespace tree } // namespace xgboost diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 696ff0c4029d..67aa3cc6417e 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -500,12 +500,11 @@ struct GPUHistMakerDevice { interaction_constraints(param, n_features), batch_param(_batch_param), use_gradient_based_sampling(_page->matrix.n_rows != _n_rows) { - if (use_gradient_based_sampling) { - sampler.reset(new GradientBasedSampler(batch_param, - page->matrix.info, - n_rows, - param.subsample)); - } + sampler.reset(new GradientBasedSampler(page, + n_rows, + batch_param, + param.subsample, + param.sampling_method)); monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(device_id)); } @@ -552,15 +551,10 @@ struct GPUHistMakerDevice { std::fill(node_sum_gradients.begin(), node_sum_gradients.end(), GradientPair()); - if (use_gradient_based_sampling) { - auto sample = sampler->Sample(dh_gpair->DeviceSpan(), dmat); - n_rows = sample.sample_rows; - page = sample.page; - gpair = sample.gpair; - } else { - gpair = dh_gpair->DeviceSpan(); - SubsampleGradientPair(device_id, gpair, param.subsample); - } + auto sample = sampler->Sample(dh_gpair->DeviceSpan(), dmat); + n_rows = sample.sample_rows; + page = sample.page; + gpair = sample.gpair; row_partitioner.reset(); // Release the device memory first before reallocating row_partitioner.reset(new RowPartitioner(device_id, n_rows)); diff --git a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu index 466457de1a0f..f86f0ec7699a 100644 --- a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu +++ b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu @@ -10,6 +10,7 @@ namespace tree { TEST(GradientBasedSampler, NoSampling) { constexpr size_t kRows = 1024; constexpr size_t kCols = 4; + constexpr float kSubsample = 1.0; constexpr size_t kPageSize = 1024; // Create a DMatrix with multiple batches. @@ -22,7 +23,7 @@ TEST(GradientBasedSampler, NoSampling) { BatchParam param{0, 256, 0, kPageSize}; auto page = (*dmat->GetBatches(param).begin()).Impl(); - GradientBasedSampler sampler(param, page->matrix.info, kRows); + GradientBasedSampler sampler(page, kRows, param, kSubsample, TrainParam::kUniform); auto sample = sampler.Sample(gpair.DeviceSpan(), dmat.get()); auto sampled_page = sample.page; auto sampled_gpair = sample.gpair; @@ -55,7 +56,7 @@ TEST(GradientBasedSampler, NoSampling) { } } -TEST(GradientBasedSampler, SequentialPoissonSampling) { +TEST(GradientBasedSampler, GradientBasedSampling) { constexpr size_t kRows = 2048; constexpr size_t kCols = 16; constexpr float kSubsample = 0.5; @@ -72,8 +73,7 @@ TEST(GradientBasedSampler, SequentialPoissonSampling) { BatchParam param{0, 256, 0, kPageSize}; auto page = (*dmat->GetBatches(param).begin()).Impl(); - GradientBasedSampler sampler(param, page->matrix.info, kRows, kSubsample, - GradientBasedSampler::kSequentialPoissonSampling); + GradientBasedSampler sampler(page, kRows, param, kSubsample, TrainParam::kGradientBased); auto sample = sampler.Sample(gpair.DeviceSpan(), dmat.get()); auto sampled_page = sample.page; auto sampled_gpair = sample.gpair; From e41496bec3d7cac48c674877b73fd94a02e6dfae Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 9 Jan 2020 15:46:14 -0800 Subject: [PATCH 30/48] fix clang tidy warning --- src/tree/gpu_hist/gradient_based_sampler.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index e28e8bcba2fe..44f5f066da30 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -157,10 +157,10 @@ GradientBasedSample GradientBasedSampler::UniformSampling(common::Span { - static constexpr float lambda = 0.1f; + static constexpr float kLambda = 0.1f; XGBOOST_DEVICE float operator()(const GradientPair& gpair) const { - return sqrtf(powf(gpair.GetGrad(), 2) + lambda * powf(gpair.GetHess(), 2)); + return sqrtf(powf(gpair.GetGrad(), 2) + kLambda * powf(gpair.GetHess(), 2)); } }; From fd458c63e6e2ac6a9fd55d079d2ddd3c415fe1a7 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 9 Jan 2020 16:24:54 -0800 Subject: [PATCH 31/48] add tests for in-core --- src/tree/gpu_hist/gradient_based_sampler.cu | 6 +- .../gpu_hist/test_gradient_based_sampler.cu | 137 +++++++++++++++++- tests/cpp/tree/test_gpu_hist.cu | 1 + 3 files changed, 138 insertions(+), 6 deletions(-) diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index 44f5f066da30..c0aa4fad0fe5 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -168,7 +168,7 @@ struct CombineGradientPair : public thrust::unary_function */ struct CalculateWeight : public thrust::binary_function { RandomWeight rnd; - CombineGradientPair combine{}; + CombineGradientPair combine; XGBOOST_DEVICE explicit CalculateWeight(RandomWeight _rnd) : rnd(_rnd) {} @@ -186,8 +186,8 @@ struct CalculateWeight : public thrust::binary_function { size_t sample_rows; float normalization; - const RandomWeight rnd; - const CombineGradientPair combine{}; + RandomWeight rnd; + CombineGradientPair combine; XGBOOST_DEVICE PoissonSampling(size_t _sample_rows, float _normalization, RandomWeight _rnd) : sample_rows(_sample_rows), normalization(_normalization), rnd(_rnd) {} diff --git a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu index f86f0ec7699a..8b89629286e4 100644 --- a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu +++ b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu @@ -8,6 +8,28 @@ namespace xgboost { namespace tree { TEST(GradientBasedSampler, NoSampling) { + constexpr size_t kRows = 1024; + constexpr size_t kCols = 4; + constexpr float kSubsample = 1.0; + constexpr size_t kPageSize = 0; + + // Create a DMatrix with a single batche. + std::unique_ptr dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true)); + auto gpair = GenerateRandomGradients(kRows); + gpair.SetDevice(0); + + BatchParam param{0, 256, 0, kPageSize}; + auto page = (*dmat->GetBatches(param).begin()).Impl(); + + GradientBasedSampler sampler(page, kRows, param, kSubsample, TrainParam::kUniform); + auto sample = sampler.Sample(gpair.DeviceSpan(), dmat.get()); + EXPECT_EQ(sample.sample_rows, kRows); + EXPECT_EQ(sample.page, page); + EXPECT_EQ(sample.gpair.size(), gpair.Size()); + EXPECT_EQ(sample.gpair.data(), gpair.DevicePointer()); +} + +TEST(GradientBasedSampler, NoSampling_ExternalMemory) { constexpr size_t kRows = 1024; constexpr size_t kCols = 4; constexpr float kSubsample = 1.0; @@ -56,7 +78,42 @@ TEST(GradientBasedSampler, NoSampling) { } } -TEST(GradientBasedSampler, GradientBasedSampling) { +TEST(GradientBasedSampler, UniformSampling) { + constexpr size_t kRows = 2048; + constexpr size_t kCols = 16; + constexpr float kSubsample = 0.5; + constexpr size_t kSampleRows = kRows * kSubsample; + constexpr size_t kPageSize = 0; + + // Create a DMatrix with a single batche. + std::unique_ptr dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true)); + auto gpair = GenerateRandomGradients(kRows); + float sum_gradients = 0; + for (auto gp : gpair.ConstHostVector()) { + sum_gradients += gp.GetGrad(); + } + gpair.SetDevice(0); + + BatchParam param{0, 256, 0, kPageSize}; + auto page = (*dmat->GetBatches(param).begin()).Impl(); + + GradientBasedSampler sampler(page, kRows, param, kSubsample, TrainParam::kUniform); + auto sample = sampler.Sample(gpair.DeviceSpan(), dmat.get()); + auto sampled_gpair = sample.gpair; + EXPECT_EQ(sample.sample_rows, kRows); + EXPECT_EQ(sample.page, page); + EXPECT_EQ(sample.gpair.size(), kRows); + + float sum_sampled_gradients = 0; + std::vector sampled_gpair_h(sampled_gpair.size()); + dh::CopyDeviceSpanToVector(&sampled_gpair_h, sampled_gpair); + for (auto gp : sampled_gpair_h) { + sum_sampled_gradients += gp.GetGrad(); + } + EXPECT_NEAR(sum_gradients / kRows, sum_sampled_gradients / kSampleRows, 0.002); +} + +TEST(GradientBasedSampler, UniformSampling_ExternalMemory) { constexpr size_t kRows = 2048; constexpr size_t kCols = 16; constexpr float kSubsample = 0.5; @@ -68,12 +125,16 @@ TEST(GradientBasedSampler, GradientBasedSampling) { std::unique_ptr dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir)); auto gpair = GenerateRandomGradients(kRows); + float sum_gradients = 0; + for (auto gp : gpair.ConstHostVector()) { + sum_gradients += gp.GetGrad(); + } gpair.SetDevice(0); BatchParam param{0, 256, 0, kPageSize}; auto page = (*dmat->GetBatches(param).begin()).Impl(); - GradientBasedSampler sampler(page, kRows, param, kSubsample, TrainParam::kGradientBased); + GradientBasedSampler sampler(page, kRows, param, kSubsample, TrainParam::kUniform); auto sample = sampler.Sample(gpair.DeviceSpan(), dmat.get()); auto sampled_page = sample.page; auto sampled_gpair = sample.gpair; @@ -81,10 +142,79 @@ TEST(GradientBasedSampler, GradientBasedSampling) { EXPECT_EQ(sampled_page->matrix.n_rows, kSampleRows); EXPECT_EQ(sampled_gpair.size(), kSampleRows); + float sum_sampled_gradients = 0; + std::vector sampled_gpair_h(sampled_gpair.size()); + dh::CopyDeviceSpanToVector(&sampled_gpair_h, sampled_gpair); + for (auto gp : sampled_gpair_h) { + sum_sampled_gradients += gp.GetGrad(); + } + EXPECT_NEAR(sum_gradients / kRows, sum_sampled_gradients / kSampleRows, 0.01); +} + +TEST(GradientBasedSampler, GradientBasedSampling) { + constexpr size_t kRows = 2048; + constexpr size_t kCols = 16; + constexpr float kSubsample = 0.5; + constexpr size_t kSampleRows = kRows * kSubsample; + constexpr size_t kPageSize = 0; + + // Create a DMatrix with a single batche. + std::unique_ptr dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true)); + auto gpair = GenerateRandomGradients(kRows); float sum_gradients = 0; for (auto gp : gpair.ConstHostVector()) { sum_gradients += gp.GetGrad(); } + gpair.SetDevice(0); + + BatchParam param{0, 256, 0, kPageSize}; + auto page = (*dmat->GetBatches(param).begin()).Impl(); + + GradientBasedSampler sampler(page, kRows, param, kSubsample, TrainParam::kGradientBased); + auto sample = sampler.Sample(gpair.DeviceSpan(), dmat.get()); + auto sampled_gpair = sample.gpair; + EXPECT_EQ(sample.sample_rows, kRows); + EXPECT_EQ(sample.page, page); + EXPECT_EQ(sampled_gpair.size(), kRows); + + float sum_sampled_gradients = 0; + std::vector sampled_gpair_h(sampled_gpair.size()); + dh::CopyDeviceSpanToVector(&sampled_gpair_h, sampled_gpair); + for (auto gp : sampled_gpair_h) { + sum_sampled_gradients += gp.GetGrad(); + } + // TODO(rongou): gradient pairs need to be rescaled to get accurate statistics. + EXPECT_NEAR(sum_gradients / kRows, sum_sampled_gradients / kSampleRows, 0.15); +} + +TEST(GradientBasedSampler, GradientBasedSampling_ExternalMemory) { + constexpr size_t kRows = 2048; + constexpr size_t kCols = 16; + constexpr float kSubsample = 0.5; + constexpr size_t kSampleRows = kRows * kSubsample; + constexpr size_t kPageSize = 1024; + + // Create a DMatrix with multiple batches. + dmlc::TemporaryDirectory tmpdir; + std::unique_ptr + dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir)); + auto gpair = GenerateRandomGradients(kRows); + float sum_gradients = 0; + for (auto gp : gpair.ConstHostVector()) { + sum_gradients += gp.GetGrad(); + } + gpair.SetDevice(0); + + BatchParam param{0, 256, 0, kPageSize}; + auto page = (*dmat->GetBatches(param).begin()).Impl(); + + GradientBasedSampler sampler(page, kRows, param, kSubsample, TrainParam::kGradientBased); + auto sample = sampler.Sample(gpair.DeviceSpan(), dmat.get()); + auto sampled_page = sample.page; + auto sampled_gpair = sample.gpair; + EXPECT_EQ(sample.sample_rows, kSampleRows); + EXPECT_EQ(sampled_page->matrix.n_rows, kSampleRows); + EXPECT_EQ(sampled_gpair.size(), kSampleRows); float sum_sampled_gradients = 0; std::vector sampled_gpair_h(sampled_gpair.size()); @@ -92,7 +222,8 @@ TEST(GradientBasedSampler, GradientBasedSampling) { for (auto gp : sampled_gpair_h) { sum_sampled_gradients += gp.GetGrad(); } - EXPECT_FLOAT_EQ(sum_gradients, sum_sampled_gradients); + // TODO(rongou): gradient pairs need to be rescaled to get accurate statistics. + EXPECT_NEAR(sum_gradients / kRows, sum_sampled_gradients / kSampleRows, 0.15); } }; // namespace tree diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index ee7502d8180f..c4ef8d1b2a75 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -370,6 +370,7 @@ void UpdateTree(HostDeviceVector* gpair, {"reg_alpha", "0"}, {"reg_lambda", "0"}, {"subsample", std::to_string(subsample)}, + {"sampling_method", "gradient_based"}, }; tree::GPUHistMakerSpecialised hist_maker; From 6e3a7fa2fa8d11f72629832cc55a547af6dcab22 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 9 Jan 2020 16:31:23 -0800 Subject: [PATCH 32/48] remove unused code --- src/common/compressed_iterator.h | 20 ------------------- tests/cpp/common/test_compressed_iterator.cc | 21 -------------------- 2 files changed, 41 deletions(-) diff --git a/src/common/compressed_iterator.h b/src/common/compressed_iterator.h index a414ca3d32a2..13d70e9ab9ff 100644 --- a/src/common/compressed_iterator.h +++ b/src/common/compressed_iterator.h @@ -82,26 +82,6 @@ class CompressedBufferWriter { return compressed_size + detail::kPadding; } - /** - * \brief Calculates maximum number of rows that can fit in a given number of bytes. - * \param num_bytes Number of bytes. - * \param num_symbols Max number of symbols (alphabet size). - * \param row_stride Number of features per row. - * \param extra_bytes_per_row Extra number of bytes needed per row. - * \return The calculated number of rows. - */ - static size_t CalculateMaxRows(size_t num_bytes, - size_t num_symbols, - size_t row_stride, - size_t extra_bytes_per_row) { - const int bits_per_byte = 8; - size_t usable_bits = (num_bytes - detail::kPadding) * bits_per_byte; - size_t extra_bits = extra_bytes_per_row * bits_per_byte; - size_t symbol_bits = row_stride * detail::SymbolBits(num_symbols); - size_t num_rows = static_cast(std::floor(usable_bits / (extra_bits + symbol_bits))); - return num_rows; - } - template void WriteSymbol(CompressedByteT *buffer, T symbol, size_t offset) { const int bits_per_byte = 8; diff --git a/tests/cpp/common/test_compressed_iterator.cc b/tests/cpp/common/test_compressed_iterator.cc index bfdae6814d9b..93243c0b336e 100644 --- a/tests/cpp/common/test_compressed_iterator.cc +++ b/tests/cpp/common/test_compressed_iterator.cc @@ -51,26 +51,5 @@ TEST(CompressedIterator, Test) { } } -TEST(CompressedIterator, CalculateMaxRows) { - const size_t num_bytes = 12652838912; - const size_t row_stride = 100; - const size_t num_symbols = 256 * row_stride + 1; - const size_t extra_bytes = 8; - size_t num_rows = - CompressedBufferWriter::CalculateMaxRows(num_bytes, num_symbols, row_stride, extra_bytes); - EXPECT_EQ(num_rows, 64720403); - - // The calculated # rows should fit within the given number of bytes. - size_t buffer = CompressedBufferWriter::CalculateBufferSize(num_rows * row_stride, num_symbols); - size_t extras = extra_bytes * num_rows; - EXPECT_LE(buffer + extras, num_bytes); - - // An extra row wouldn't fit. - num_rows++; - buffer = CompressedBufferWriter::CalculateBufferSize(num_rows * row_stride, num_symbols); - extras = extra_bytes * num_rows; - EXPECT_GT(buffer + extras, num_bytes); -} - } // namespace common } // namespace xgboost From 4daffbfa7562dbf46141d37b7c7a142fb3c4b50c Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 10 Jan 2020 16:31:59 -0800 Subject: [PATCH 33/48] relax tests --- tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu index 8b89629286e4..9461796e3b50 100644 --- a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu +++ b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu @@ -110,7 +110,7 @@ TEST(GradientBasedSampler, UniformSampling) { for (auto gp : sampled_gpair_h) { sum_sampled_gradients += gp.GetGrad(); } - EXPECT_NEAR(sum_gradients / kRows, sum_sampled_gradients / kSampleRows, 0.002); + EXPECT_NEAR(sum_gradients / kRows, sum_sampled_gradients / kSampleRows, 0.02); } TEST(GradientBasedSampler, UniformSampling_ExternalMemory) { @@ -148,7 +148,7 @@ TEST(GradientBasedSampler, UniformSampling_ExternalMemory) { for (auto gp : sampled_gpair_h) { sum_sampled_gradients += gp.GetGrad(); } - EXPECT_NEAR(sum_gradients / kRows, sum_sampled_gradients / kSampleRows, 0.01); + EXPECT_NEAR(sum_gradients / kRows, sum_sampled_gradients / kSampleRows, 0.02); } TEST(GradientBasedSampler, GradientBasedSampling) { From 9aaa9ba26537847b7156dc57cb4eee2154957d65 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Wed, 15 Jan 2020 16:37:24 -0800 Subject: [PATCH 34/48] add test to verify sampling --- tests/cpp/tree/test_gpu_hist_sampling.cc | 128 +++++++++++++++++++++++ 1 file changed, 128 insertions(+) create mode 100644 tests/cpp/tree/test_gpu_hist_sampling.cc diff --git a/tests/cpp/tree/test_gpu_hist_sampling.cc b/tests/cpp/tree/test_gpu_hist_sampling.cc new file mode 100644 index 000000000000..b646f1afb21d --- /dev/null +++ b/tests/cpp/tree/test_gpu_hist_sampling.cc @@ -0,0 +1,128 @@ +/*! + * Copyright 2020 XGBoost contributors + */ +#include +#include +#include +#include +#include + +#include "gtest/gtest.h" + +namespace xgboost { +namespace tree { + +class GpuHistSamplingTest : public ::testing::Test { + protected: + static void SetUpTestCase() { + constexpr size_t kRows = 1000; + constexpr size_t kCols = 1; + constexpr size_t kPageSize = 1024; + + temp_dir = new dmlc::TemporaryDirectory(); + const std::string tmp_file = temp_dir->path + "/random.libsvm"; + { + std::ofstream fo(tmp_file.c_str()); + + std::mt19937 gen{2020}; // NOLINT + std::normal_distribution<> rnd{0, 1}; + + for (size_t i = 0; i < kRows; i++) { + std::stringstream row; + row << rnd(gen); + + for (size_t j = 0; j < kCols; j++) { + row << " " << j << ":" << rnd(gen); + } + fo << row.str() << "\n"; + } + } + + dmat = std::shared_ptr(DMatrix::Load(tmp_file, true, false)); + const std::string ext_mem_file = tmp_file + "#" + tmp_file + ".cache"; + dmat_ext = std::shared_ptr( + DMatrix::Load(ext_mem_file, true, false, "auto", kPageSize)); + } + + static void TearDownTestCase() { + dmat.reset(); + dmat_ext.reset(); + delete temp_dir; + } + + static void VerifyPredictionMean(const std::shared_ptr& dtrain, + float subsample = 1.0f, + const std::string& sampling_method = "uniform") { + std::vector> cache_mats{dtrain}; + std::unique_ptr learner(Learner::Create(cache_mats)); + Args args { + {"tree_method", "gpu_hist"}, + {"max_depth", "1"}, + {"subsample", std::to_string(subsample)}, + {"sampling_method", sampling_method}, + + {"learning_rate", "1"}, + {"reg_alpha", "0"}, + {"reg_lambda", "0"}, + }; + learner->SetParams(args); + + constexpr int kNumRound = 10; + for (int i = 0; i < kNumRound; ++i) { + learner->UpdateOneIter(i, dtrain.get()); + } + + HostDeviceVector preds; + learner->Predict(dtrain.get(), true, &preds); + auto h_preds = preds.ConstHostVector(); + float mean = std::accumulate(h_preds.begin(), h_preds.end(), 0.0f) / h_preds.size(); + EXPECT_NEAR(mean, 0.0f, 2e-2) << "subsample=" << subsample; + } + + static dmlc::TemporaryDirectory* temp_dir; + static std::shared_ptr dmat; + static std::shared_ptr dmat_ext; +}; + +dmlc::TemporaryDirectory* GpuHistSamplingTest::temp_dir; +std::shared_ptr GpuHistSamplingTest::dmat; +std::shared_ptr GpuHistSamplingTest::dmat_ext; + +TEST_F(GpuHistSamplingTest, NoSampling) { + VerifyPredictionMean(dmat); +} + +TEST_F(GpuHistSamplingTest, NoSampling_ExternalMemory) { + VerifyPredictionMean(dmat_ext); +} + +TEST_F(GpuHistSamplingTest, UniformSampling) { + for (int i = 1; i < 10; i++) { + float subsample = static_cast(i) / 10.0f; + VerifyPredictionMean(dmat, subsample); + } +} + +TEST_F(GpuHistSamplingTest, UniformSampling_ExternalMemory) { + for (int i = 1; i < 10; i++) { + float subsample = static_cast(i) / 10.0f; + VerifyPredictionMean(dmat_ext, subsample); + } +} + +TEST_F(GpuHistSamplingTest, GradientBasedSampling) { + for (int i = 1; i < 10; i++) { + float subsample = static_cast(i) / 10.0f; + VerifyPredictionMean(dmat, subsample, "gradient_based"); + } +} + +TEST_F(GpuHistSamplingTest, GradientBasedSampling_ExternalMemory) { + for (int i = 1; i < 10; i++) { + float subsample = static_cast(i) / 10.0f; + VerifyPredictionMean(dmat_ext, subsample, "gradient_based"); + } +} + +} // namespace tree +} // namespace xgboost From 10fce05574c704af518c89bbab8a80d023f1714c Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 17 Jan 2020 16:53:15 -0800 Subject: [PATCH 35/48] inverse probability weighting estimation --- include/xgboost/base.h | 26 +++ src/tree/gpu_hist/gradient_based_sampler.cu | 144 ++++++------ src/tree/gpu_hist/gradient_based_sampler.cuh | 6 +- .../gpu_hist/test_gradient_based_sampler.cu | 209 +++++------------- 4 files changed, 153 insertions(+), 232 deletions(-) diff --git a/include/xgboost/base.h b/include/xgboost/base.h index bb8e5dd155e6..0b3a1e0af5f1 100644 --- a/include/xgboost/base.h +++ b/include/xgboost/base.h @@ -193,6 +193,32 @@ class GradientPairInternal { return g; } + XGBOOST_DEVICE GradientPairInternal &operator*=(float multiplier) { + grad_ *= multiplier; + hess_ *= multiplier; + return *this; + } + + XGBOOST_DEVICE GradientPairInternal operator*(float multiplier) const { + GradientPairInternal g; + g.grad_ = grad_ * multiplier; + g.hess_ = hess_ * multiplier; + return g; + } + + XGBOOST_DEVICE GradientPairInternal &operator/=(float divisor) { + grad_ /= divisor; + hess_ /= divisor; + return *this; + } + + XGBOOST_DEVICE GradientPairInternal operator/(float divisor) const { + GradientPairInternal g; + g.grad_ = grad_ / divisor; + g.hess_ = hess_ / divisor; + return g; + } + XGBOOST_DEVICE bool operator==(const GradientPairInternal &rhs) const { return grad_ == rhs.grad_ && hess_ == rhs.hess_; } diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index c0aa4fad0fe5..bc42ed01bdd6 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -31,22 +31,22 @@ GradientBasedSampler::GradientBasedSampler(EllpackPageImpl* page, sample_rows_(n_rows * subsample) { monitor_.Init("gradient_based_sampler"); - if (is_external_memory_) { + if (is_sampling_ || is_external_memory_) { // Create a new ELLPACK page with empty rows. page_.reset(new EllpackPageImpl(batch_param.gpu_id, original_page_->matrix.info, sample_rows_)); - // Allocate GPU memory for sampling. - if (is_sampling_) { - ba_.Allocate(batch_param_.gpu_id, - &gpair_, sample_rows_, - &row_weight_, n_rows, - &row_index_, n_rows, - &sample_row_index_, n_rows); - thrust::copy(thrust::counting_iterator(0), - thrust::counting_iterator(n_rows), - dh::tbegin(row_index_)); - } + } + // Allocate GPU memory for sampling. + if (is_sampling_) { + ba_.Allocate(batch_param_.gpu_id, + &gpair_, sample_rows_, + &row_weight_, n_rows, + &row_index_, n_rows, + &sample_row_index_, n_rows); + thrust::copy(thrust::counting_iterator(0), + thrust::counting_iterator(n_rows), + dh::tbegin(row_index_)); } } @@ -77,16 +77,17 @@ GradientBasedSample GradientBasedSampler::Sample(common::Span gpai GradientBasedSample GradientBasedSampler::NoSampling(common::Span gpair, DMatrix* dmat) { if (is_external_memory_) { - CollectPages(dmat); + ConcatenatePages(dmat); return {dmat->Info().num_row_, page_.get(), gpair}; } else { return {dmat->Info().num_row_, original_page_, gpair}; } } -// When not sampling, collect all the external memory ELLPACK pages into a single in-memory page. -void GradientBasedSampler::CollectPages(DMatrix* dmat) { - if (page_collected_) { +// When not sampling, concatenate all the external memory ELLPACK pages into a single in-memory +// page. +void GradientBasedSampler::ConcatenatePages(DMatrix* dmat) { + if (page_concatenated_) { return; } @@ -96,7 +97,7 @@ void GradientBasedSampler::CollectPages(DMatrix* dmat) { size_t num_elements = page_->Copy(batch_param_.gpu_id, page, offset); offset += num_elements; } - page_collected_ = true; + page_concatenated_ = true; } /*! \brief A functor that returns random weights. */ @@ -113,40 +114,29 @@ struct RandomWeight : public thrust::unary_function { } }; -/*! \brief A functor that performs Bernoulli sampling on gradient pairs. */ -struct BernoulliSampling : public thrust::binary_function { +/*! \brief A functor that scales gradient pairs by 1/p. */ +struct FixedScaling : public thrust::unary_function { float p; - RandomWeight rnd; - XGBOOST_DEVICE BernoulliSampling(float _p, RandomWeight _rnd) : p(_p), rnd(_rnd) {} + XGBOOST_DEVICE explicit FixedScaling(float _p) : p(_p) {} - XGBOOST_DEVICE GradientPair operator()(const GradientPair& gpair, size_t i) const { - if (rnd(i) <= p) { - return gpair; - } else { - return GradientPair(); - } + XGBOOST_DEVICE GradientPair operator()(const GradientPair& gpair) const { + return gpair / p; } }; GradientBasedSample GradientBasedSampler::UniformSampling(common::Span gpair, DMatrix* dmat) { - RandomWeight rnd(common::GlobalRandom()()); - if (is_external_memory_) { - // Generate random weights. - thrust::transform(thrust::counting_iterator(0), - thrust::counting_iterator(gpair.size()), - dh::tbegin(row_weight_), - rnd); - return SequentialPoissonSampling(gpair, dmat); - } else { - // Set gradient pair to 0 with p = 1 - subsample - thrust::transform(dh::tbegin(gpair), dh::tend(gpair), - thrust::counting_iterator(0), - dh::tbegin(gpair), - BernoulliSampling(subsample_, rnd)); - return {dmat->Info().num_row_, original_page_, gpair}; - } + // Generate random weights. + thrust::transform(thrust::counting_iterator(0), + thrust::counting_iterator(gpair.size()), + dh::tbegin(row_weight_), + RandomWeight(common::GlobalRandom()())); + // Scale gradient pairs by 1/subsample. + thrust::transform(dh::tbegin(gpair), dh::tend(gpair), + dh::tbegin(gpair), + FixedScaling(subsample_)); + return SequentialPoissonSampling(gpair, dmat); } /*! \brief A functor that combines the gradient pair into a single float. @@ -167,61 +157,67 @@ struct CombineGradientPair : public thrust::unary_function /*! \brief A functor that calculates the weight of each row. */ struct CalculateWeight : public thrust::binary_function { + size_t sample_rows; + float normalization; RandomWeight rnd; CombineGradientPair combine; - XGBOOST_DEVICE explicit CalculateWeight(RandomWeight _rnd) : rnd(_rnd) {} + XGBOOST_DEVICE CalculateWeight(size_t _sample_rows, float _normalization, RandomWeight _rnd) + : sample_rows(_sample_rows), normalization(_normalization), rnd(_rnd) {} XGBOOST_DEVICE float operator()(const GradientPair& gpair, size_t i) { + // If the gradient and hessian are both empty, we should never select this row. if (gpair.GetGrad() == 0 && gpair.GetHess() == 0) { return FLT_MAX; } - return rnd(i) / combine(gpair); + float combined_gradient = combine(gpair); + float p = sample_rows * combined_gradient / normalization; + if (p >= 1) { + // Always select this row. + return 0.0f; + } else { + // Select this row randomly with probability proportional to the combined gradient. + return rnd(i) / combined_gradient; + } } }; -/*! \brief A functor that performs Poisson sampling with probability proportional to the combined - * gradient pair. - */ -struct PoissonSampling : public thrust::binary_function { +/*! \brief A functor that scales gradient pairs by 1/p_i. */ +struct WeightedScaling : public thrust::unary_function { size_t sample_rows; float normalization; - RandomWeight rnd; CombineGradientPair combine; - XGBOOST_DEVICE PoissonSampling(size_t _sample_rows, float _normalization, RandomWeight _rnd) - : sample_rows(_sample_rows), normalization(_normalization), rnd(_rnd) {} + XGBOOST_DEVICE WeightedScaling(size_t _sample_rows, float _normalization) + : sample_rows(_sample_rows), normalization(_normalization) {} - XGBOOST_DEVICE GradientPair operator()(const GradientPair& gpair, size_t i) const { - if (rnd(i) <= sample_rows * combine(gpair) / normalization) { + XGBOOST_DEVICE GradientPair operator()(const GradientPair& gpair) const { + float p = sample_rows * combine(gpair) / normalization; + if (p >= 1) { return gpair; } else { - return GradientPair(); + return gpair / p; } } }; + GradientBasedSample GradientBasedSampler::GradientBasedSampling( common::Span gpair, DMatrix* dmat) { - RandomWeight rnd(common::GlobalRandom()()); - if (is_external_memory_) { - thrust::transform(dh::tbegin(gpair), dh::tend(gpair), - thrust::counting_iterator(0), - dh::tbegin(row_weight_), - CalculateWeight(rnd)); - return SequentialPoissonSampling(gpair, dmat); - } else { - float normalization = thrust::transform_reduce(dh::tbegin(gpair), dh::tend(gpair), - CombineGradientPair(), - 0.0f, - thrust::plus()); - // Set gradient pair to 0 with p = 1 - p_i - thrust::transform(dh::tbegin(gpair), dh::tend(gpair), - thrust::counting_iterator(0), - dh::tbegin(gpair), - PoissonSampling(sample_rows_, normalization, rnd)); - return {dmat->Info().num_row_, original_page_, gpair}; - } + float normalization = thrust::transform_reduce(dh::tbegin(gpair), dh::tend(gpair), + CombineGradientPair(), + 0.0f, + thrust::plus()); + thrust::transform(dh::tbegin(gpair), dh::tend(gpair), + thrust::counting_iterator(0), + dh::tbegin(row_weight_), + CalculateWeight(sample_rows_, normalization, + RandomWeight(common::GlobalRandom()()))); + // Scale gradient pairs by 1/p. + thrust::transform(dh::tbegin(gpair), dh::tend(gpair), + dh::tbegin(gpair), + WeightedScaling(sample_rows_, normalization)); + return SequentialPoissonSampling(gpair, dmat); } /*! \brief A functor that returns true if the gradient pair is non-zero. */ diff --git a/src/tree/gpu_hist/gradient_based_sampler.cuh b/src/tree/gpu_hist/gradient_based_sampler.cuh index ff69f1bf122e..e4696327f080 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cuh +++ b/src/tree/gpu_hist/gradient_based_sampler.cuh @@ -47,8 +47,8 @@ class GradientBasedSampler { GradientBasedSample UniformSampling(common::Span gpair, DMatrix* dmat); GradientBasedSample GradientBasedSampling(common::Span gpair, DMatrix* dmat); - /*! \brief Collect all the rows from a DMatrix into a single ELLPACK page. */ - void CollectPages(DMatrix* dmat); + /*! \brief Concatenate all the rows from a DMatrix into a single ELLPACK page. */ + void ConcatenatePages(DMatrix* dmat); /*! \brief Fixed-size Poisson sampling after the row weights are calculated. */ GradientBasedSample SequentialPoissonSampling(common::Span gpair, DMatrix* dmat); @@ -67,7 +67,7 @@ class GradientBasedSampler { common::Span row_weight_; common::Span row_index_; common::Span sample_row_index_; - bool page_collected_{false}; + bool page_concatenated_{false}; }; }; // namespace tree }; // namespace xgboost diff --git a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu index 9461796e3b50..59319b06b41b 100644 --- a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu +++ b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu @@ -7,32 +7,55 @@ namespace xgboost { namespace tree { -TEST(GradientBasedSampler, NoSampling) { - constexpr size_t kRows = 1024; - constexpr size_t kCols = 4; - constexpr float kSubsample = 1.0; - constexpr size_t kPageSize = 0; +void VerifySampling(size_t page_size, float subsample, int sampling_method) { + constexpr size_t kRows = 2048; + constexpr size_t kCols = 1; + size_t sample_rows = kRows * subsample; - // Create a DMatrix with a single batche. - std::unique_ptr dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true)); + dmlc::TemporaryDirectory tmpdir; + std::unique_ptr dmat( + CreateSparsePageDMatrixWithRC(kRows, kCols, page_size, true, tmpdir)); auto gpair = GenerateRandomGradients(kRows); + GradientPair sum_gpair{}; + for (const auto& gp : gpair.ConstHostVector()) { + sum_gpair += gp; + } gpair.SetDevice(0); - BatchParam param{0, 256, 0, kPageSize}; + BatchParam param{0, 256, 0, page_size}; auto page = (*dmat->GetBatches(param).begin()).Impl(); + if (page_size != 0) { + EXPECT_NE(page->matrix.n_rows, kRows); + } - GradientBasedSampler sampler(page, kRows, param, kSubsample, TrainParam::kUniform); + GradientBasedSampler sampler(page, kRows, param, subsample, sampling_method); auto sample = sampler.Sample(gpair.DeviceSpan(), dmat.get()); - EXPECT_EQ(sample.sample_rows, kRows); - EXPECT_EQ(sample.page, page); - EXPECT_EQ(sample.gpair.size(), gpair.Size()); - EXPECT_EQ(sample.gpair.data(), gpair.DevicePointer()); + EXPECT_EQ(sample.sample_rows, sample_rows); + EXPECT_EQ(sample.page->matrix.n_rows, sample_rows); + EXPECT_EQ(sample.gpair.size(), sample_rows); + + GradientPair sum_sampled_gpair{}; + std::vector sampled_gpair_h(sample.gpair.size()); + dh::CopyDeviceSpanToVector(&sampled_gpair_h, sample.gpair); + for (const auto& gp : sampled_gpair_h) { + sum_sampled_gpair += gp; + } + EXPECT_NEAR(sum_gpair.GetGrad(), sum_sampled_gpair.GetGrad(), 0.01f * kRows); + EXPECT_NEAR(sum_gpair.GetHess(), sum_sampled_gpair.GetHess(), 0.01f * kRows); +} + +TEST(GradientBasedSampler, NoSampling) { + constexpr size_t kPageSize = 0; + constexpr float kSubsample = 1.0f; + constexpr int kSamplingMethod = TrainParam::kUniform; + VerifySampling(kPageSize, kSubsample, kSamplingMethod); } +// In external mode, when not sampling, we concatenate the pages together. TEST(GradientBasedSampler, NoSampling_ExternalMemory) { - constexpr size_t kRows = 1024; - constexpr size_t kCols = 4; - constexpr float kSubsample = 1.0; + constexpr size_t kRows = 2048; + constexpr size_t kCols = 1; + constexpr float kSubsample = 1.0f; constexpr size_t kPageSize = 1024; // Create a DMatrix with multiple batches. @@ -44,20 +67,16 @@ TEST(GradientBasedSampler, NoSampling_ExternalMemory) { BatchParam param{0, 256, 0, kPageSize}; auto page = (*dmat->GetBatches(param).begin()).Impl(); + EXPECT_NE(page->matrix.n_rows, kRows); GradientBasedSampler sampler(page, kRows, param, kSubsample, TrainParam::kUniform); auto sample = sampler.Sample(gpair.DeviceSpan(), dmat.get()); auto sampled_page = sample.page; - auto sampled_gpair = sample.gpair; EXPECT_EQ(sample.sample_rows, kRows); - EXPECT_EQ(sampled_gpair.size(), kRows); + EXPECT_EQ(sample.gpair.size(), gpair.Size()); + EXPECT_EQ(sample.gpair.data(), gpair.DevicePointer()); EXPECT_EQ(sampled_page->matrix.n_rows, kRows); - auto gpair_h = gpair.ConstHostVector(); - std::vector sampled_gpair_h(sampled_gpair.size()); - dh::CopyDeviceSpanToVector(&sampled_gpair_h, sampled_gpair); - EXPECT_EQ(gpair_h, sampled_gpair_h); - std::vector buffer(sampled_page->gidx_buffer.size()); dh::CopyDeviceSpanToVector(&buffer, sampled_page->gidx_buffer); common::CompressedIterator @@ -79,151 +98,31 @@ TEST(GradientBasedSampler, NoSampling_ExternalMemory) { } TEST(GradientBasedSampler, UniformSampling) { - constexpr size_t kRows = 2048; - constexpr size_t kCols = 16; - constexpr float kSubsample = 0.5; - constexpr size_t kSampleRows = kRows * kSubsample; constexpr size_t kPageSize = 0; - - // Create a DMatrix with a single batche. - std::unique_ptr dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true)); - auto gpair = GenerateRandomGradients(kRows); - float sum_gradients = 0; - for (auto gp : gpair.ConstHostVector()) { - sum_gradients += gp.GetGrad(); - } - gpair.SetDevice(0); - - BatchParam param{0, 256, 0, kPageSize}; - auto page = (*dmat->GetBatches(param).begin()).Impl(); - - GradientBasedSampler sampler(page, kRows, param, kSubsample, TrainParam::kUniform); - auto sample = sampler.Sample(gpair.DeviceSpan(), dmat.get()); - auto sampled_gpair = sample.gpair; - EXPECT_EQ(sample.sample_rows, kRows); - EXPECT_EQ(sample.page, page); - EXPECT_EQ(sample.gpair.size(), kRows); - - float sum_sampled_gradients = 0; - std::vector sampled_gpair_h(sampled_gpair.size()); - dh::CopyDeviceSpanToVector(&sampled_gpair_h, sampled_gpair); - for (auto gp : sampled_gpair_h) { - sum_sampled_gradients += gp.GetGrad(); - } - EXPECT_NEAR(sum_gradients / kRows, sum_sampled_gradients / kSampleRows, 0.02); + constexpr float kSubsample = 0.5; + constexpr int kSamplingMethod = TrainParam::kUniform; + VerifySampling(kPageSize, kSubsample, kSamplingMethod); } TEST(GradientBasedSampler, UniformSampling_ExternalMemory) { - constexpr size_t kRows = 2048; - constexpr size_t kCols = 16; - constexpr float kSubsample = 0.5; - constexpr size_t kSampleRows = kRows * kSubsample; constexpr size_t kPageSize = 1024; - - // Create a DMatrix with multiple batches. - dmlc::TemporaryDirectory tmpdir; - std::unique_ptr - dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir)); - auto gpair = GenerateRandomGradients(kRows); - float sum_gradients = 0; - for (auto gp : gpair.ConstHostVector()) { - sum_gradients += gp.GetGrad(); - } - gpair.SetDevice(0); - - BatchParam param{0, 256, 0, kPageSize}; - auto page = (*dmat->GetBatches(param).begin()).Impl(); - - GradientBasedSampler sampler(page, kRows, param, kSubsample, TrainParam::kUniform); - auto sample = sampler.Sample(gpair.DeviceSpan(), dmat.get()); - auto sampled_page = sample.page; - auto sampled_gpair = sample.gpair; - EXPECT_EQ(sample.sample_rows, kSampleRows); - EXPECT_EQ(sampled_page->matrix.n_rows, kSampleRows); - EXPECT_EQ(sampled_gpair.size(), kSampleRows); - - float sum_sampled_gradients = 0; - std::vector sampled_gpair_h(sampled_gpair.size()); - dh::CopyDeviceSpanToVector(&sampled_gpair_h, sampled_gpair); - for (auto gp : sampled_gpair_h) { - sum_sampled_gradients += gp.GetGrad(); - } - EXPECT_NEAR(sum_gradients / kRows, sum_sampled_gradients / kSampleRows, 0.02); + constexpr float kSubsample = 0.5; + constexpr int kSamplingMethod = TrainParam::kUniform; + VerifySampling(kPageSize, kSubsample, kSamplingMethod); } TEST(GradientBasedSampler, GradientBasedSampling) { - constexpr size_t kRows = 2048; - constexpr size_t kCols = 16; - constexpr float kSubsample = 0.5; - constexpr size_t kSampleRows = kRows * kSubsample; constexpr size_t kPageSize = 0; - - // Create a DMatrix with a single batche. - std::unique_ptr dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true)); - auto gpair = GenerateRandomGradients(kRows); - float sum_gradients = 0; - for (auto gp : gpair.ConstHostVector()) { - sum_gradients += gp.GetGrad(); - } - gpair.SetDevice(0); - - BatchParam param{0, 256, 0, kPageSize}; - auto page = (*dmat->GetBatches(param).begin()).Impl(); - - GradientBasedSampler sampler(page, kRows, param, kSubsample, TrainParam::kGradientBased); - auto sample = sampler.Sample(gpair.DeviceSpan(), dmat.get()); - auto sampled_gpair = sample.gpair; - EXPECT_EQ(sample.sample_rows, kRows); - EXPECT_EQ(sample.page, page); - EXPECT_EQ(sampled_gpair.size(), kRows); - - float sum_sampled_gradients = 0; - std::vector sampled_gpair_h(sampled_gpair.size()); - dh::CopyDeviceSpanToVector(&sampled_gpair_h, sampled_gpair); - for (auto gp : sampled_gpair_h) { - sum_sampled_gradients += gp.GetGrad(); - } - // TODO(rongou): gradient pairs need to be rescaled to get accurate statistics. - EXPECT_NEAR(sum_gradients / kRows, sum_sampled_gradients / kSampleRows, 0.15); + constexpr float kSubsample = 0.5; + constexpr int kSamplingMethod = TrainParam::kGradientBased; + VerifySampling(kPageSize, kSubsample, kSamplingMethod); } TEST(GradientBasedSampler, GradientBasedSampling_ExternalMemory) { - constexpr size_t kRows = 2048; - constexpr size_t kCols = 16; - constexpr float kSubsample = 0.5; - constexpr size_t kSampleRows = kRows * kSubsample; constexpr size_t kPageSize = 1024; - - // Create a DMatrix with multiple batches. - dmlc::TemporaryDirectory tmpdir; - std::unique_ptr - dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir)); - auto gpair = GenerateRandomGradients(kRows); - float sum_gradients = 0; - for (auto gp : gpair.ConstHostVector()) { - sum_gradients += gp.GetGrad(); - } - gpair.SetDevice(0); - - BatchParam param{0, 256, 0, kPageSize}; - auto page = (*dmat->GetBatches(param).begin()).Impl(); - - GradientBasedSampler sampler(page, kRows, param, kSubsample, TrainParam::kGradientBased); - auto sample = sampler.Sample(gpair.DeviceSpan(), dmat.get()); - auto sampled_page = sample.page; - auto sampled_gpair = sample.gpair; - EXPECT_EQ(sample.sample_rows, kSampleRows); - EXPECT_EQ(sampled_page->matrix.n_rows, kSampleRows); - EXPECT_EQ(sampled_gpair.size(), kSampleRows); - - float sum_sampled_gradients = 0; - std::vector sampled_gpair_h(sampled_gpair.size()); - dh::CopyDeviceSpanToVector(&sampled_gpair_h, sampled_gpair); - for (auto gp : sampled_gpair_h) { - sum_sampled_gradients += gp.GetGrad(); - } - // TODO(rongou): gradient pairs need to be rescaled to get accurate statistics. - EXPECT_NEAR(sum_gradients / kRows, sum_sampled_gradients / kSampleRows, 0.15); + constexpr float kSubsample = 0.5; + constexpr int kSamplingMethod = TrainParam::kGradientBased; + VerifySampling(kPageSize, kSubsample, kSamplingMethod); } }; // namespace tree From 55bbe7426c31dd7b04d7395299c42bf88190ab69 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Tue, 21 Jan 2020 11:27:51 -0800 Subject: [PATCH 36/48] combine weight calculation and gpair scaling --- src/tree/gpu_hist/gradient_based_sampler.cu | 43 ++++++--------------- 1 file changed, 11 insertions(+), 32 deletions(-) diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index bc42ed01bdd6..0f17ed49a3a6 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -154,9 +154,9 @@ struct CombineGradientPair : public thrust::unary_function } }; -/*! \brief A functor that calculates the weight of each row. - */ -struct CalculateWeight : public thrust::binary_function { +/*! \brief A functor that calculates the weight of each row, and scales gradient pairs by 1/p_i. */ +struct CalculateWeight + : public thrust::binary_function> { size_t sample_rows; float normalization; RandomWeight rnd; @@ -165,43 +165,25 @@ struct CalculateWeight : public thrust::binary_function operator()(const GradientPair& gpair, + size_t i) { // If the gradient and hessian are both empty, we should never select this row. if (gpair.GetGrad() == 0 && gpair.GetHess() == 0) { - return FLT_MAX; + return thrust::make_tuple(FLT_MAX, gpair); } float combined_gradient = combine(gpair); float p = sample_rows * combined_gradient / normalization; if (p >= 1) { // Always select this row. - return 0.0f; + return thrust::make_tuple(0.0f, gpair); } else { // Select this row randomly with probability proportional to the combined gradient. - return rnd(i) / combined_gradient; + // Scale gpair by 1/p. + return thrust::make_tuple(rnd(i) / combined_gradient, gpair / p); } } }; -/*! \brief A functor that scales gradient pairs by 1/p_i. */ -struct WeightedScaling : public thrust::unary_function { - size_t sample_rows; - float normalization; - CombineGradientPair combine; - - XGBOOST_DEVICE WeightedScaling(size_t _sample_rows, float _normalization) - : sample_rows(_sample_rows), normalization(_normalization) {} - - XGBOOST_DEVICE GradientPair operator()(const GradientPair& gpair) const { - float p = sample_rows * combine(gpair) / normalization; - if (p >= 1) { - return gpair; - } else { - return gpair / p; - } - } -}; - - GradientBasedSample GradientBasedSampler::GradientBasedSampling( common::Span gpair, DMatrix* dmat) { float normalization = thrust::transform_reduce(dh::tbegin(gpair), dh::tend(gpair), @@ -210,13 +192,10 @@ GradientBasedSample GradientBasedSampler::GradientBasedSampling( thrust::plus()); thrust::transform(dh::tbegin(gpair), dh::tend(gpair), thrust::counting_iterator(0), - dh::tbegin(row_weight_), + thrust::make_zip_iterator(thrust::make_tuple( + dh::tbegin(row_weight_), dh::tbegin(gpair))), CalculateWeight(sample_rows_, normalization, RandomWeight(common::GlobalRandom()()))); - // Scale gradient pairs by 1/p. - thrust::transform(dh::tbegin(gpair), dh::tend(gpair), - dh::tbegin(gpair), - WeightedScaling(sample_rows_, normalization)); return SequentialPoissonSampling(gpair, dmat); } From be163b7c645aac53ff03fc758660fec7d78413f9 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Tue, 21 Jan 2020 14:54:04 -0800 Subject: [PATCH 37/48] fix tests --- tests/cpp/tree/test_gpu_hist.cu | 39 ++++++- tests/cpp/tree/test_gpu_hist_sampling.cc | 128 ----------------------- 2 files changed, 36 insertions(+), 131 deletions(-) delete mode 100644 tests/cpp/tree/test_gpu_hist_sampling.cc diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index c4ef8d1b2a75..ebc6f0cde4f1 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -383,10 +383,9 @@ void UpdateTree(HostDeviceVector* gpair, } TEST(GpuHist, ExternalMemory) { - constexpr size_t kRows = 6; + constexpr size_t kRows = 4096; constexpr size_t kCols = 2; - constexpr size_t kPageSize = 1; - constexpr float kSubsample = 0.99; + constexpr size_t kPageSize = 1024; // Create an in-memory DMatrix. std::unique_ptr dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true)); @@ -403,6 +402,40 @@ TEST(GpuHist, ExternalMemory) { HostDeviceVector preds(kRows, 0.0, 0); UpdateTree(&gpair, dmat.get(), 0, &tree, &preds); + // Build another tree using multiple ELLPACK pages. + RegTree tree_ext; + HostDeviceVector preds_ext(kRows, 0.0, 0); + UpdateTree(&gpair, dmat_ext.get(), kPageSize, &tree_ext, &preds_ext); + + // Make sure the predictions are the same. + auto preds_h = preds.ConstHostVector(); + auto preds_ext_h = preds_ext.ConstHostVector(); + for (int i = 0; i < kRows; i++) { + EXPECT_NEAR(preds_h[i], preds_ext_h[i], 2e-6); + } +} + +TEST(GpuHist, ExternalMemoryWithSampling) { + constexpr size_t kRows = 4096; + constexpr size_t kCols = 2; + constexpr size_t kPageSize = 1024; + constexpr float kSubsample = 0.5; + + // Create an in-memory DMatrix. + std::unique_ptr dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true)); + + // Create a DMatrix with multiple batches. + dmlc::TemporaryDirectory tmpdir; + std::unique_ptr + dmat_ext(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir)); + + auto gpair = GenerateRandomGradients(kRows); + + // Build a tree using the in-memory DMatrix. + RegTree tree; + HostDeviceVector preds(kRows, 0.0, 0); + UpdateTree(&gpair, dmat.get(), 0, &tree, &preds, kSubsample); + // Build another tree using multiple ELLPACK pages. RegTree tree_ext; HostDeviceVector preds_ext(kRows, 0.0, 0); diff --git a/tests/cpp/tree/test_gpu_hist_sampling.cc b/tests/cpp/tree/test_gpu_hist_sampling.cc deleted file mode 100644 index b646f1afb21d..000000000000 --- a/tests/cpp/tree/test_gpu_hist_sampling.cc +++ /dev/null @@ -1,128 +0,0 @@ -/*! - * Copyright 2020 XGBoost contributors - */ -#include -#include -#include -#include -#include - -#include "gtest/gtest.h" - -namespace xgboost { -namespace tree { - -class GpuHistSamplingTest : public ::testing::Test { - protected: - static void SetUpTestCase() { - constexpr size_t kRows = 1000; - constexpr size_t kCols = 1; - constexpr size_t kPageSize = 1024; - - temp_dir = new dmlc::TemporaryDirectory(); - const std::string tmp_file = temp_dir->path + "/random.libsvm"; - { - std::ofstream fo(tmp_file.c_str()); - - std::mt19937 gen{2020}; // NOLINT - std::normal_distribution<> rnd{0, 1}; - - for (size_t i = 0; i < kRows; i++) { - std::stringstream row; - row << rnd(gen); - - for (size_t j = 0; j < kCols; j++) { - row << " " << j << ":" << rnd(gen); - } - fo << row.str() << "\n"; - } - } - - dmat = std::shared_ptr(DMatrix::Load(tmp_file, true, false)); - const std::string ext_mem_file = tmp_file + "#" + tmp_file + ".cache"; - dmat_ext = std::shared_ptr( - DMatrix::Load(ext_mem_file, true, false, "auto", kPageSize)); - } - - static void TearDownTestCase() { - dmat.reset(); - dmat_ext.reset(); - delete temp_dir; - } - - static void VerifyPredictionMean(const std::shared_ptr& dtrain, - float subsample = 1.0f, - const std::string& sampling_method = "uniform") { - std::vector> cache_mats{dtrain}; - std::unique_ptr learner(Learner::Create(cache_mats)); - Args args { - {"tree_method", "gpu_hist"}, - {"max_depth", "1"}, - {"subsample", std::to_string(subsample)}, - {"sampling_method", sampling_method}, - - {"learning_rate", "1"}, - {"reg_alpha", "0"}, - {"reg_lambda", "0"}, - }; - learner->SetParams(args); - - constexpr int kNumRound = 10; - for (int i = 0; i < kNumRound; ++i) { - learner->UpdateOneIter(i, dtrain.get()); - } - - HostDeviceVector preds; - learner->Predict(dtrain.get(), true, &preds); - auto h_preds = preds.ConstHostVector(); - float mean = std::accumulate(h_preds.begin(), h_preds.end(), 0.0f) / h_preds.size(); - EXPECT_NEAR(mean, 0.0f, 2e-2) << "subsample=" << subsample; - } - - static dmlc::TemporaryDirectory* temp_dir; - static std::shared_ptr dmat; - static std::shared_ptr dmat_ext; -}; - -dmlc::TemporaryDirectory* GpuHistSamplingTest::temp_dir; -std::shared_ptr GpuHistSamplingTest::dmat; -std::shared_ptr GpuHistSamplingTest::dmat_ext; - -TEST_F(GpuHistSamplingTest, NoSampling) { - VerifyPredictionMean(dmat); -} - -TEST_F(GpuHistSamplingTest, NoSampling_ExternalMemory) { - VerifyPredictionMean(dmat_ext); -} - -TEST_F(GpuHistSamplingTest, UniformSampling) { - for (int i = 1; i < 10; i++) { - float subsample = static_cast(i) / 10.0f; - VerifyPredictionMean(dmat, subsample); - } -} - -TEST_F(GpuHistSamplingTest, UniformSampling_ExternalMemory) { - for (int i = 1; i < 10; i++) { - float subsample = static_cast(i) / 10.0f; - VerifyPredictionMean(dmat_ext, subsample); - } -} - -TEST_F(GpuHistSamplingTest, GradientBasedSampling) { - for (int i = 1; i < 10; i++) { - float subsample = static_cast(i) / 10.0f; - VerifyPredictionMean(dmat, subsample, "gradient_based"); - } -} - -TEST_F(GpuHistSamplingTest, GradientBasedSampling_ExternalMemory) { - for (int i = 1; i < 10; i++) { - float subsample = static_cast(i) / 10.0f; - VerifyPredictionMean(dmat_ext, subsample, "gradient_based"); - } -} - -} // namespace tree -} // namespace xgboost From 3a0fd99de63e77b491d256228b7c72fdab23083e Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Tue, 21 Jan 2020 15:21:21 -0800 Subject: [PATCH 38/48] review feedback --- src/tree/gpu_hist/gradient_based_sampler.cu | 6 ++++-- src/tree/updater_basemaker-inl.h | 3 ++- src/tree/updater_colmaker.cc | 3 ++- src/tree/updater_quantile_hist.cc | 3 ++- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index 0f17ed49a3a6..c5b03b0afb54 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -9,6 +9,7 @@ #include #include +#include #include "../../common/compressed_iterator.h" #include "../../common/random.h" @@ -169,7 +170,7 @@ struct CalculateWeight size_t i) { // If the gradient and hessian are both empty, we should never select this row. if (gpair.GetGrad() == 0 && gpair.GetHess() == 0) { - return thrust::make_tuple(FLT_MAX, gpair); + return thrust::make_tuple(std::numeric_limits::max(), gpair); } float combined_gradient = combine(gpair); float p = sample_rows * combined_gradient / normalization; @@ -212,7 +213,7 @@ struct ClearEmptyRows : public thrust::binary_function::max(); } } }; @@ -238,6 +239,7 @@ GradientBasedSample GradientBasedSampler::SequentialPoissonSampling( dh::tbegin(sample_row_index_)))); // Compact the non-zero gradient pairs. + thrust::fill(dh::tbegin(gpair_), dh::tend(gpair_), GradientPair()); thrust::copy_if(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(gpair_), IsNonZero()); // Index the sample rows. diff --git a/src/tree/updater_basemaker-inl.h b/src/tree/updater_basemaker-inl.h index 2ef24b11facc..8761e39901bd 100644 --- a/src/tree/updater_basemaker-inl.h +++ b/src/tree/updater_basemaker-inl.h @@ -149,7 +149,8 @@ class BaseMaker: public TreeUpdater { // mark subsample if (param_.subsample < 1.0f) { CHECK_EQ(param_.sampling_method, TrainParam::kUniform) - << "Only uniform sampling is supported"; + << "Only uniform sampling is supported, " + << "gradient-based sampling is only support by GPU Hist."; std::bernoulli_distribution coin_flip(param_.subsample); auto& rnd = common::GlobalRandom(); for (size_t i = 0; i < position_.size(); ++i) { diff --git a/src/tree/updater_colmaker.cc b/src/tree/updater_colmaker.cc index 908d42f0e15d..5b0f859c36d3 100644 --- a/src/tree/updater_colmaker.cc +++ b/src/tree/updater_colmaker.cc @@ -203,7 +203,8 @@ class ColMaker: public TreeUpdater { // mark subsample if (param_.subsample < 1.0f) { CHECK_EQ(param_.sampling_method, TrainParam::kUniform) - << "Only uniform sampling is supported"; + << "Only uniform sampling is supported, " + << "gradient-based sampling is only support by GPU Hist."; std::bernoulli_distribution coin_flip(param_.subsample); auto& rnd = common::GlobalRandom(); for (size_t ridx = 0; ridx < position_.size(); ++ridx) { diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 3c9d0b31eb94..6e2953afbd7e 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -442,7 +442,8 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat, if (param_.subsample < 1.0f) { CHECK_EQ(param_.sampling_method, TrainParam::kUniform) - << "Only uniform sampling is supported"; + << "Only uniform sampling is supported, " + << "gradient-based sampling is only support by GPU Hist."; std::bernoulli_distribution coin_flip(param_.subsample); auto& rnd = common::GlobalRandom(); size_t j = 0; From 9680a6964afe5372f2bcf0a888e535ccec395812 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Wed, 22 Jan 2020 16:54:44 -0800 Subject: [PATCH 39/48] calculate threshold --- src/tree/gpu_hist/gradient_based_sampler.cu | 82 +++++++++++++++++--- src/tree/gpu_hist/gradient_based_sampler.cuh | 4 + 2 files changed, 76 insertions(+), 10 deletions(-) diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index c5b03b0afb54..a643ac6d4919 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -43,6 +43,7 @@ GradientBasedSampler::GradientBasedSampler(EllpackPageImpl* page, ba_.Allocate(batch_param_.gpu_id, &gpair_, sample_rows_, &row_weight_, n_rows, + &threshold_, n_rows + 1, &row_index_, n_rows, &sample_row_index_, n_rows); thrust::copy(thrust::counting_iterator(0), @@ -155,16 +156,78 @@ struct CombineGradientPair : public thrust::unary_function } }; +/*! \brief A functor that calculates the difference between the sample rate and the desired sample + * rows, given a cumulative gradient sum. + */ +struct SampleRateDelta : public thrust::binary_function { + common::Span threshold; + size_t n_rows; + size_t sample_rows; + + XGBOOST_DEVICE SampleRateDelta(common::Span _threshold, + size_t _n_rows, + size_t _sample_rows) + : threshold(_threshold), n_rows(_n_rows), sample_rows(_sample_rows) {} + + XGBOOST_DEVICE float operator()(float gradient_sum, size_t row_index) const { + // For the last row, if gradient_sum/sample_rows > gradient, that means no row will be sampled + // with probability equal to 1, so we use the mean as the threshold. + if (row_index == n_rows - 1) { + float last = threshold[n_rows - 1]; + float mean = gradient_sum / sample_rows; + if (mean > last) { + threshold[n_rows] = mean; + return 0.0f; + } else { + return std::numeric_limits::max(); + } + } + + // For a given u = threshold[row_index], the summed sample rate for the rows above the current + // row is `gradient_sum / u`. + // + // Rows below (including the current row) are sampled with probability equal to 1, thus adding + // up to `n_rows - row_index`. + // + // The total sample rate is therefore `gradient_sum / u + n_rows - row_index`. + // + // We want to choose the threshold that makes this value as close to `sample_rows` as possible. + float u = threshold[row_index + 1]; + if (u == 0.0f) { + return std::numeric_limits::max(); + } else { + return fabsf(gradient_sum / u + n_rows - row_index - sample_rows); + } + } +}; + +size_t GradientBasedSampler::CalculateThresholdIndex(common::Span gpair) { + thrust::transform(dh::tbegin(gpair), dh::tend(gpair), + dh::tbegin(threshold_), + CombineGradientPair()); + thrust::sort(dh::tbegin(threshold_), dh::tend(threshold_) - 1); + thrust::inclusive_scan(dh::tbegin(threshold_), dh::tend(threshold_) - 1, dh::tbegin(row_weight_)); + thrust::transform(dh::tbegin(row_weight_), dh::tend(row_weight_), + thrust::counting_iterator(0), + dh::tbegin(row_weight_), + SampleRateDelta(threshold_, gpair.size(), sample_rows_)); + thrust::device_ptr min = thrust::min_element(dh::tbegin(row_weight_), + dh::tend(row_weight_)); + return thrust::distance(dh::tbegin(row_weight_), min) + 1; +} + /*! \brief A functor that calculates the weight of each row, and scales gradient pairs by 1/p_i. */ struct CalculateWeight : public thrust::binary_function> { - size_t sample_rows; - float normalization; + common::Span threshold; + size_t threshold_index; RandomWeight rnd; CombineGradientPair combine; - XGBOOST_DEVICE CalculateWeight(size_t _sample_rows, float _normalization, RandomWeight _rnd) - : sample_rows(_sample_rows), normalization(_normalization), rnd(_rnd) {} + XGBOOST_DEVICE CalculateWeight(common::Span _threshold, + size_t _threshold_index, + RandomWeight _rnd) + : threshold(_threshold), threshold_index(_threshold_index), rnd(_rnd) {} XGBOOST_DEVICE thrust::tuple operator()(const GradientPair& gpair, size_t i) { @@ -173,7 +236,8 @@ struct CalculateWeight return thrust::make_tuple(std::numeric_limits::max(), gpair); } float combined_gradient = combine(gpair); - float p = sample_rows * combined_gradient / normalization; + float u = threshold[threshold_index]; + float p = combined_gradient / u; if (p >= 1) { // Always select this row. return thrust::make_tuple(0.0f, gpair); @@ -187,15 +251,13 @@ struct CalculateWeight GradientBasedSample GradientBasedSampler::GradientBasedSampling( common::Span gpair, DMatrix* dmat) { - float normalization = thrust::transform_reduce(dh::tbegin(gpair), dh::tend(gpair), - CombineGradientPair(), - 0.0f, - thrust::plus()); + size_t threshold_index = CalculateThresholdIndex(gpair); + printf("threshold_index=%lu\n", threshold_index); thrust::transform(dh::tbegin(gpair), dh::tend(gpair), thrust::counting_iterator(0), thrust::make_zip_iterator(thrust::make_tuple( dh::tbegin(row_weight_), dh::tbegin(gpair))), - CalculateWeight(sample_rows_, normalization, + CalculateWeight(threshold_, threshold_index, RandomWeight(common::GlobalRandom()()))); return SequentialPoissonSampling(gpair, dmat); } diff --git a/src/tree/gpu_hist/gradient_based_sampler.cuh b/src/tree/gpu_hist/gradient_based_sampler.cuh index e4696327f080..32ffe0357637 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cuh +++ b/src/tree/gpu_hist/gradient_based_sampler.cuh @@ -50,6 +50,9 @@ class GradientBasedSampler { /*! \brief Concatenate all the rows from a DMatrix into a single ELLPACK page. */ void ConcatenatePages(DMatrix* dmat); + /*! \brief Calculate the threshold used to normalize sampling probabilities. */ + size_t CalculateThresholdIndex(common::Span gpair); + /*! \brief Fixed-size Poisson sampling after the row weights are calculated. */ GradientBasedSample SequentialPoissonSampling(common::Span gpair, DMatrix* dmat); @@ -65,6 +68,7 @@ class GradientBasedSampler { std::unique_ptr page_; common::Span gpair_; common::Span row_weight_; + common::Span threshold_; common::Span row_index_; common::Span sample_row_index_; bool page_concatenated_{false}; From c93f20dcfe21c2219dbb20f0318847f3bdec2c9e Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 23 Jan 2020 10:24:33 -0800 Subject: [PATCH 40/48] tweak test --- tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu index 59319b06b41b..f2e27ced42aa 100644 --- a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu +++ b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu @@ -8,7 +8,7 @@ namespace xgboost { namespace tree { void VerifySampling(size_t page_size, float subsample, int sampling_method) { - constexpr size_t kRows = 2048; + constexpr size_t kRows = 4096; constexpr size_t kCols = 1; size_t sample_rows = kRows * subsample; From 09864ed4effe34dae9a52906cc7bcb198b8cd6f5 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 24 Jan 2020 10:18:26 -0800 Subject: [PATCH 41/48] more accurate threshold --- src/tree/gpu_hist/gradient_based_sampler.cu | 35 +++-------- src/tree/updater_gpu_hist.cu | 11 +--- tests/cpp/tree/test_gpu_hist.cu | 66 +++++++++++++++++++-- 3 files changed, 73 insertions(+), 39 deletions(-) diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index a643ac6d4919..6e7233312537 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -170,38 +170,20 @@ struct SampleRateDelta : public thrust::binary_function { : threshold(_threshold), n_rows(_n_rows), sample_rows(_sample_rows) {} XGBOOST_DEVICE float operator()(float gradient_sum, size_t row_index) const { - // For the last row, if gradient_sum/sample_rows > gradient, that means no row will be sampled - // with probability equal to 1, so we use the mean as the threshold. - if (row_index == n_rows - 1) { - float last = threshold[n_rows - 1]; - float mean = gradient_sum / sample_rows; - if (mean > last) { - threshold[n_rows] = mean; - return 0.0f; - } else { - return std::numeric_limits::max(); - } - } - - // For a given u = threshold[row_index], the summed sample rate for the rows above the current - // row is `gradient_sum / u`. - // - // Rows below (including the current row) are sampled with probability equal to 1, thus adding - // up to `n_rows - row_index`. - // - // The total sample rate is therefore `gradient_sum / u + n_rows - row_index`. - // - // We want to choose the threshold that makes this value as close to `sample_rows` as possible. - float u = threshold[row_index + 1]; - if (u == 0.0f) { - return std::numeric_limits::max(); + float lower = threshold[row_index]; + float upper = threshold[row_index + 1]; + float u = gradient_sum / static_cast(sample_rows - n_rows + row_index + 1); + if (u > lower && u <= upper) { + threshold[row_index + 1] = u; + return 0.0f; } else { - return fabsf(gradient_sum / u + n_rows - row_index - sample_rows); + return std::numeric_limits::max(); } } }; size_t GradientBasedSampler::CalculateThresholdIndex(common::Span gpair) { + thrust::fill(dh::tend(threshold_) - 1, dh::tend(threshold_), std::numeric_limits::max()); thrust::transform(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(threshold_), CombineGradientPair()); @@ -252,7 +234,6 @@ struct CalculateWeight GradientBasedSample GradientBasedSampler::GradientBasedSampling( common::Span gpair, DMatrix* dmat) { size_t threshold_index = CalculateThresholdIndex(gpair); - printf("threshold_index=%lu\n", threshold_index); thrust::transform(dh::tbegin(gpair), dh::tend(gpair), thrust::counting_iterator(0), thrust::make_zip_iterator(thrust::make_tuple( diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 67aa3cc6417e..eafa227b05e3 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -416,11 +416,8 @@ __global__ void SharedMemHistKernel(xgboost::EllpackMatrix matrix, } for (auto idx : dh::GridStrideRange(static_cast(0), n_elements)) { int ridx = d_ridx[idx / matrix.info.row_stride]; - if (!matrix.IsInRange(ridx)) { - continue; - } - int gidx = matrix.gidx_iter[(ridx - matrix.base_rowid) * matrix.info.row_stride - + idx % matrix.info.row_stride]; + int gidx = + matrix.gidx_iter[ridx * matrix.info.row_stride + idx % matrix.info.row_stride]; if (gidx != matrix.info.n_bins) { // If we are not using shared memory, accumulate the values directly into // global memory @@ -481,7 +478,6 @@ struct GPUHistMakerDevice { std::function>; std::unique_ptr qexpand; - bool use_gradient_based_sampling {false}; std::unique_ptr sampler; GPUHistMakerDevice(int _device_id, @@ -498,8 +494,7 @@ struct GPUHistMakerDevice { prediction_cache_initialised(false), column_sampler(column_sampler_seed), interaction_constraints(param, n_features), - batch_param(_batch_param), - use_gradient_based_sampling(_page->matrix.n_rows != _n_rows) { + batch_param(_batch_param) { sampler.reset(new GradientBasedSampler(page, n_rows, batch_param, diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index ebc6f0cde4f1..faf9fa0f92d7 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -347,7 +347,8 @@ void UpdateTree(HostDeviceVector* gpair, size_t gpu_page_size, RegTree* tree, HostDeviceVector* preds, - float subsample = 1.0f) { + float subsample = 1.0f, + const std::string& sampling_method = "uniform") { constexpr size_t kMaxBin = 2; if (gpu_page_size > 0) { @@ -370,7 +371,7 @@ void UpdateTree(HostDeviceVector* gpair, {"reg_alpha", "0"}, {"reg_lambda", "0"}, {"subsample", std::to_string(subsample)}, - {"sampling_method", "gradient_based"}, + {"sampling_method", sampling_method}, }; tree::GPUHistMakerSpecialised hist_maker; @@ -382,6 +383,62 @@ void UpdateTree(HostDeviceVector* gpair, hist_maker.UpdatePredictionCache(dmat, preds); } +TEST(GpuHist, UniformSampling) { + constexpr size_t kRows = 6; + constexpr size_t kCols = 2; + constexpr float kSubsample = 0.99; + + // Create an in-memory DMatrix. + std::unique_ptr dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true)); + + auto gpair = GenerateRandomGradients(kRows); + + // Build a tree using the in-memory DMatrix. + RegTree tree; + HostDeviceVector preds(kRows, 0.0, 0); + UpdateTree(&gpair, dmat.get(), 0, &tree, &preds); + + // Build another tree using sampling. + RegTree tree_sampling; + HostDeviceVector preds_sampling(kRows, 0.0, 0); + UpdateTree(&gpair, dmat.get(), 0, &tree_sampling, &preds_sampling, kSubsample); + + // Make sure the predictions are the same. + auto preds_h = preds.ConstHostVector(); + auto preds_sampling_h = preds_sampling.ConstHostVector(); + for (int i = 0; i < kRows; i++) { + EXPECT_NEAR(preds_h[i], preds_sampling_h[i], 2e-6); + } +} + +TEST(GpuHist, GradientBasedSampling) { + constexpr size_t kRows = 6; + constexpr size_t kCols = 2; + constexpr float kSubsample = 0.99; + + // Create an in-memory DMatrix. + std::unique_ptr dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true)); + + auto gpair = GenerateRandomGradients(kRows); + + // Build a tree using the in-memory DMatrix. + RegTree tree; + HostDeviceVector preds(kRows, 0.0, 0); + UpdateTree(&gpair, dmat.get(), 0, &tree, &preds); + + // Build another tree using sampling. + RegTree tree_sampling; + HostDeviceVector preds_sampling(kRows, 0.0, 0); + UpdateTree(&gpair, dmat.get(), 0, &tree_sampling, &preds_sampling, kSubsample, "gradient_based"); + + // Make sure the predictions are the same. + auto preds_h = preds.ConstHostVector(); + auto preds_sampling_h = preds_sampling.ConstHostVector(); + for (int i = 0; i < kRows; i++) { + EXPECT_NEAR(preds_h[i], preds_sampling_h[i], 1e-5); + } +} + TEST(GpuHist, ExternalMemory) { constexpr size_t kRows = 4096; constexpr size_t kCols = 2; @@ -420,6 +477,7 @@ TEST(GpuHist, ExternalMemoryWithSampling) { constexpr size_t kCols = 2; constexpr size_t kPageSize = 1024; constexpr float kSubsample = 0.5; + const std::string kSamplingMethod = "gradient_based"; // Create an in-memory DMatrix. std::unique_ptr dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true)); @@ -434,12 +492,12 @@ TEST(GpuHist, ExternalMemoryWithSampling) { // Build a tree using the in-memory DMatrix. RegTree tree; HostDeviceVector preds(kRows, 0.0, 0); - UpdateTree(&gpair, dmat.get(), 0, &tree, &preds, kSubsample); + UpdateTree(&gpair, dmat.get(), 0, &tree, &preds, kSubsample, kSamplingMethod); // Build another tree using multiple ELLPACK pages. RegTree tree_ext; HostDeviceVector preds_ext(kRows, 0.0, 0); - UpdateTree(&gpair, dmat_ext.get(), kPageSize, &tree_ext, &preds_ext, kSubsample); + UpdateTree(&gpair, dmat_ext.get(), kPageSize, &tree_ext, &preds_ext, kSubsample, kSamplingMethod); // Make sure the predictions are the same. auto preds_h = preds.ConstHostVector(); From f8b7dbf322e1e3fd8a2ff724cc16e8afb3c2abca Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 24 Jan 2020 11:20:24 -0800 Subject: [PATCH 42/48] tweak test tolerance --- tests/cpp/tree/test_gpu_hist.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index faf9fa0f92d7..8fd5f3598b80 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -384,7 +384,7 @@ void UpdateTree(HostDeviceVector* gpair, } TEST(GpuHist, UniformSampling) { - constexpr size_t kRows = 6; + constexpr size_t kRows = 4096; constexpr size_t kCols = 2; constexpr float kSubsample = 0.99; @@ -407,12 +407,12 @@ TEST(GpuHist, UniformSampling) { auto preds_h = preds.ConstHostVector(); auto preds_sampling_h = preds_sampling.ConstHostVector(); for (int i = 0; i < kRows; i++) { - EXPECT_NEAR(preds_h[i], preds_sampling_h[i], 2e-6); + EXPECT_NEAR(preds_h[i], preds_sampling_h[i], 1e-3); } } TEST(GpuHist, GradientBasedSampling) { - constexpr size_t kRows = 6; + constexpr size_t kRows = 4096; constexpr size_t kCols = 2; constexpr float kSubsample = 0.99; @@ -435,7 +435,7 @@ TEST(GpuHist, GradientBasedSampling) { auto preds_h = preds.ConstHostVector(); auto preds_sampling_h = preds_sampling.ConstHostVector(); for (int i = 0; i < kRows; i++) { - EXPECT_NEAR(preds_h[i], preds_sampling_h[i], 1e-5); + EXPECT_NEAR(preds_h[i], preds_sampling_h[i], 1e-3); } } From 3aaae8938454c3fb705dcc6c1092fdeb11ed6cdd Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 24 Jan 2020 15:38:40 -0800 Subject: [PATCH 43/48] wip: refactor the code to disintangle sampling methods --- src/tree/gpu_hist/gradient_based_sampler.cu | 208 +++++++++++++------ src/tree/gpu_hist/gradient_based_sampler.cuh | 89 +++++++- 2 files changed, 223 insertions(+), 74 deletions(-) diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index 6e7233312537..352641fd0b37 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -18,9 +18,115 @@ namespace xgboost { namespace tree { +NoSampling::NoSampling(EllpackPageImpl* page) : page_(page) {} + +GradientBasedSample NoSampling::Sample(common::Span gpair, DMatrix* dmat) { + return {dmat->Info().num_row_, page_, gpair}; +} + +ExternalMemoryNoSampling::ExternalMemoryNoSampling(EllpackPageImpl* page, + size_t n_rows, + const BatchParam& batch_param) + : batch_param_(batch_param), + page_(new EllpackPageImpl(batch_param.gpu_id, page->matrix.info, n_rows)) {} + +GradientBasedSample ExternalMemoryNoSampling::Sample(common::Span gpair, + DMatrix* dmat) { + ConcatenatePages(dmat); + return {dmat->Info().num_row_, page_.get(), gpair}; +} + +// Concatenate all the external memory ELLPACK pages into a single in-memory page. +void ExternalMemoryNoSampling::ConcatenatePages(DMatrix* dmat) { + if (page_concatenated_) { + return; + } + + size_t offset = 0; + for (auto& batch : dmat->GetBatches(batch_param_)) { + auto page = batch.Impl(); + size_t num_elements = page_->Copy(batch_param_.gpu_id, page, offset); + offset += num_elements; + } + page_concatenated_ = true; +} + +/*! \brief A functor that returns random weights. */ +class RandomWeight : public thrust::unary_function { + public: + explicit RandomWeight(size_t _seed) : seed(_seed) {} + + XGBOOST_DEVICE float operator()(size_t i) const { + thrust::default_random_engine rng(seed); + thrust::uniform_real_distribution dist; + rng.discard(i); + return dist(rng); + } + + private: + uint32_t seed; +}; + +/*! \brief A functor that performs a Bernoulli trial to discard a gradient pair. */ +class BernoulliTrial : public thrust::unary_function { + public: + BernoulliTrial(size_t _seed, float _p) : rnd(_seed), p(_p) {} + + XGBOOST_DEVICE bool operator()(size_t i) const { + return rnd(i) > p; + } + + private: + RandomWeight rnd; + float p; +}; + +UniformSampling::UniformSampling(EllpackPageImpl* page, float subsample) + : page_(page), subsample_(subsample) {} + +GradientBasedSample UniformSampling::Sample(common::Span gpair, DMatrix* dmat) { + // Set gradient pair to 0 with p = 1 - subsample + thrust::replace_if(dh::tbegin(gpair), dh::tend(gpair), + thrust::counting_iterator(0), + BernoulliTrial(common::GlobalRandom()(), subsample_), + GradientPair()); + return {dmat->Info().num_row_, page_, gpair}; +} + +ExternalMemoryUniformSampling::ExternalMemoryUniformSampling(float subsample) + : subsample_(subsample) {} + +/*! \brief A functor that returns true if the gradient pair is non-zero. */ +struct IsNonZero : public thrust::unary_function { + XGBOOST_DEVICE bool operator()(const GradientPair& gpair) const { + return gpair.GetGrad() != 0 || gpair.GetHess() != 0; + } +}; + +GradientBasedSample ExternalMemoryUniformSampling::Sample(common::Span gpair, + DMatrix* dmat) { + // Set gradient pair to 0 with p = 1 - subsample + thrust::replace_if(dh::tbegin(gpair), dh::tend(gpair), + thrust::counting_iterator(0), + BernoulliTrial(common::GlobalRandom()(), subsample_), + GradientPair()); + size_t sample_rows = thrust::count_if(dh::tbegin(gpair), dh::tend(gpair), IsNonZero()); + return GradientBasedSample(); +} + +GradientBasedSample GradientBasedSampling::Sample(common::Span gpair, + DMatrix* dmat) { + return GradientBasedSample(); +} + +GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(common::Span gpair, + DMatrix* dmat) { + return GradientBasedSample(); +} + GradientBasedSampler::GradientBasedSampler(EllpackPageImpl* page, size_t n_rows, - BatchParam batch_param, + const BatchParam& batch_param, float subsample, int sampling_method) : original_page_(page), @@ -30,8 +136,37 @@ GradientBasedSampler::GradientBasedSampler(EllpackPageImpl* page, is_sampling_(subsample < 1.0), sampling_method_(sampling_method), sample_rows_(n_rows * subsample) { + monitor_.Init("gradient_based_sampler"); + if (is_sampling_) { + switch (sampling_method_) { + case TrainParam::kUniform: + if (is_external_memory_) { + strategy_.reset(new ExternalMemoryUniformSampling(subsample)); + } else { + strategy_.reset(new UniformSampling(page, subsample)); + } + break; + case TrainParam::kGradientBased: + if (is_external_memory_) { + strategy_.reset(new ExternalMemoryGradientBasedSampling()); + } else { + strategy_.reset(new GradientBasedSampling()); + } + break; + default: + LOG(FATAL) << "unknown sampling method"; + } + } else { + if (is_external_memory_) { + strategy_.reset(new ExternalMemoryNoSampling(page, n_rows, batch_param)); + } else { + strategy_.reset(new NoSampling(page)); + } + } + +/* if (is_sampling_ || is_external_memory_) { // Create a new ELLPACK page with empty rows. page_.reset(new EllpackPageImpl(batch_param.gpu_id, @@ -50,72 +185,18 @@ GradientBasedSampler::GradientBasedSampler(EllpackPageImpl* page, thrust::counting_iterator(n_rows), dh::tbegin(row_index_)); } +*/ } // Sample a DMatrix based on the given gradient pairs. GradientBasedSample GradientBasedSampler::Sample(common::Span gpair, DMatrix* dmat) { monitor_.StartCuda("Sample"); - GradientBasedSample sample; - if (is_sampling_) { - switch (sampling_method_) { - case TrainParam::kUniform: - sample = UniformSampling(gpair, dmat); - break; - case TrainParam::kGradientBased: - sample = GradientBasedSampling(gpair, dmat); - break; - default: - LOG(FATAL) << "unknown sampling method"; - sample = {0, nullptr, gpair}; - } - } else { - sample = NoSampling(gpair, dmat); - } + GradientBasedSample sample = strategy_->Sample(gpair, dmat); monitor_.StopCuda("Sample"); return sample; } -GradientBasedSample GradientBasedSampler::NoSampling(common::Span gpair, - DMatrix* dmat) { - if (is_external_memory_) { - ConcatenatePages(dmat); - return {dmat->Info().num_row_, page_.get(), gpair}; - } else { - return {dmat->Info().num_row_, original_page_, gpair}; - } -} - -// When not sampling, concatenate all the external memory ELLPACK pages into a single in-memory -// page. -void GradientBasedSampler::ConcatenatePages(DMatrix* dmat) { - if (page_concatenated_) { - return; - } - - size_t offset = 0; - for (auto& batch : dmat->GetBatches(batch_param_)) { - auto page = batch.Impl(); - size_t num_elements = page_->Copy(batch_param_.gpu_id, page, offset); - offset += num_elements; - } - page_concatenated_ = true; -} - -/*! \brief A functor that returns random weights. */ -struct RandomWeight : public thrust::unary_function { - uint32_t seed; - - XGBOOST_DEVICE explicit RandomWeight(size_t _seed) : seed(_seed) {} - - XGBOOST_DEVICE float operator()(size_t i) const { - thrust::default_random_engine rng(seed); - thrust::uniform_real_distribution dist; - rng.discard(i); - return dist(rng); - } -}; - /*! \brief A functor that scales gradient pairs by 1/p. */ struct FixedScaling : public thrust::unary_function { float p; @@ -127,6 +208,7 @@ struct FixedScaling : public thrust::unary_function } }; +/* GradientBasedSample GradientBasedSampler::UniformSampling(common::Span gpair, DMatrix* dmat) { // Generate random weights. @@ -140,6 +222,7 @@ GradientBasedSample GradientBasedSampler::UniformSampling(common::Span gpair, DMatrix* dmat) { + common::Span& gpair, DMatrix* dmat) { size_t threshold_index = CalculateThresholdIndex(gpair); thrust::transform(dh::tbegin(gpair), dh::tend(gpair), thrust::counting_iterator(0), @@ -242,13 +326,7 @@ GradientBasedSample GradientBasedSampler::GradientBasedSampling( RandomWeight(common::GlobalRandom()()))); return SequentialPoissonSampling(gpair, dmat); } - -/*! \brief A functor that returns true if the gradient pair is non-zero. */ -struct IsNonZero : public thrust::unary_function { - XGBOOST_DEVICE bool operator()(const GradientPair& gpair) const { - return gpair.GetGrad() != 0 || gpair.GetHess() != 0; - } -}; +*/ /*! \brief A functor that clears the row indexes with empty gradient. */ struct ClearEmptyRows : public thrust::binary_function { diff --git a/src/tree/gpu_hist/gradient_based_sampler.cuh b/src/tree/gpu_hist/gradient_based_sampler.cuh index 32ffe0357637..fceb28ba1ae5 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cuh +++ b/src/tree/gpu_hist/gradient_based_sampler.cuh @@ -21,6 +21,82 @@ struct GradientBasedSample { common::Span gpair; }; +class SamplingStrategy { + public: + /*! \brief Sample from a DMatrix based on the given gradient pairs. */ + virtual GradientBasedSample Sample(common::Span gpair, DMatrix* dmat) = 0; +}; + +/*! \brief No sampling in in-memory mode. */ +class NoSampling : public SamplingStrategy { + public: + explicit NoSampling(EllpackPageImpl* page); + + GradientBasedSample Sample(common::Span gpair, DMatrix* dmat) override; + + private: + EllpackPageImpl* page_; +}; + +/*! \brief No sampling in external memory mode. */ +class ExternalMemoryNoSampling : public SamplingStrategy { + public: + ExternalMemoryNoSampling(EllpackPageImpl* page, + size_t n_rows, + const BatchParam& batch_param); + + GradientBasedSample Sample(common::Span gpair, DMatrix* dmat) override; + + private: + /*! \brief Concatenate all the rows from a DMatrix into a single ELLPACK page. */ + void ConcatenatePages(DMatrix* dmat); + + BatchParam batch_param_; + std::unique_ptr page_; + bool page_concatenated_{false}; +}; + +/*! \brief Uniform sampling in in-memory mode. */ +class UniformSampling : public SamplingStrategy { + public: + UniformSampling(EllpackPageImpl* page, float subsample); + + GradientBasedSample Sample(common::Span gpair, DMatrix* dmat) override; + + private: + EllpackPageImpl* page_; + float subsample_; +}; + +/*! \brief No sampling in external memory mode. */ +class ExternalMemoryUniformSampling : public SamplingStrategy { + public: + ExternalMemoryUniformSampling(float subsample); + + GradientBasedSample Sample(common::Span gpair, DMatrix* dmat) override; + + private: + dh::BulkAllocator ba_; + EllpackPageImpl* original_page_; + float subsample_; + BatchParam batch_param_; + std::unique_ptr page_; + common::Span gpair_; + common::Span sample_row_index_; +}; + +class GradientBasedSampling : public SamplingStrategy { + public: + /*! \brief Gradient-based sampling in in-memory mode.. */ + GradientBasedSample Sample(common::Span gpair, DMatrix* dmat) override; +}; + +class ExternalMemoryGradientBasedSampling : public SamplingStrategy { + public: + /*! \brief Gradient-based sampling in external memory mode.. */ + GradientBasedSample Sample(common::Span gpair, DMatrix* dmat) override; +}; + /*! \brief Draw a sample of rows from a DMatrix. * * \see Ke, G., Meng, Q., Finley, T., Wang, T., Chen, W., Ma, W., ... & Liu, T. Y. (2017). @@ -32,10 +108,9 @@ struct GradientBasedSample { */ class GradientBasedSampler { public: - GradientBasedSampler(EllpackPageImpl* page, size_t n_rows, - BatchParam batch_param, + const BatchParam& batch_param, float subsample, int sampling_method); @@ -43,13 +118,6 @@ class GradientBasedSampler { GradientBasedSample Sample(common::Span gpair, DMatrix* dmat); private: - GradientBasedSample NoSampling(common::Span gpair, DMatrix* dmat); - GradientBasedSample UniformSampling(common::Span gpair, DMatrix* dmat); - GradientBasedSample GradientBasedSampling(common::Span gpair, DMatrix* dmat); - - /*! \brief Concatenate all the rows from a DMatrix into a single ELLPACK page. */ - void ConcatenatePages(DMatrix* dmat); - /*! \brief Calculate the threshold used to normalize sampling probabilities. */ size_t CalculateThresholdIndex(common::Span gpair); @@ -57,6 +125,9 @@ class GradientBasedSampler { GradientBasedSample SequentialPoissonSampling(common::Span gpair, DMatrix* dmat); common::Monitor monitor_; + std::unique_ptr strategy_; + + dh::BulkAllocator ba_; EllpackPageImpl* original_page_; float subsample_; From 5cbca76157963c8e001eca0a946f793a5c4d61fb Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Mon, 27 Jan 2020 15:53:59 -0800 Subject: [PATCH 44/48] done with refactoring --- src/tree/gpu_hist/gradient_based_sampler.cu | 458 ++++++++++-------- src/tree/gpu_hist/gradient_based_sampler.cuh | 68 +-- .../gpu_hist/test_gradient_based_sampler.cu | 31 +- 3 files changed, 306 insertions(+), 251 deletions(-) diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index 352641fd0b37..029b206b2b62 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -4,7 +4,6 @@ #include #include #include -#include #include #include @@ -30,12 +29,6 @@ ExternalMemoryNoSampling::ExternalMemoryNoSampling(EllpackPageImpl* page, : batch_param_(batch_param), page_(new EllpackPageImpl(batch_param.gpu_id, page->matrix.info, n_rows)) {} -GradientBasedSample ExternalMemoryNoSampling::Sample(common::Span gpair, - DMatrix* dmat) { - ConcatenatePages(dmat); - return {dmat->Info().num_row_, page_.get(), gpair}; -} - // Concatenate all the external memory ELLPACK pages into a single in-memory page. void ExternalMemoryNoSampling::ConcatenatePages(DMatrix* dmat) { if (page_concatenated_) { @@ -51,51 +44,45 @@ void ExternalMemoryNoSampling::ConcatenatePages(DMatrix* dmat) { page_concatenated_ = true; } +GradientBasedSample ExternalMemoryNoSampling::Sample(common::Span gpair, + DMatrix* dmat) { + ConcatenatePages(dmat); + return {dmat->Info().num_row_, page_.get(), gpair}; +} + /*! \brief A functor that returns random weights. */ class RandomWeight : public thrust::unary_function { public: - explicit RandomWeight(size_t _seed) : seed(_seed) {} + explicit RandomWeight(size_t seed) : seed_(seed) {} XGBOOST_DEVICE float operator()(size_t i) const { - thrust::default_random_engine rng(seed); + thrust::default_random_engine rng(seed_); thrust::uniform_real_distribution dist; rng.discard(i); return dist(rng); } private: - uint32_t seed; + uint32_t seed_; }; +UniformSampling::UniformSampling(EllpackPageImpl* page, float subsample) + : page_(page), subsample_(subsample) {} + /*! \brief A functor that performs a Bernoulli trial to discard a gradient pair. */ class BernoulliTrial : public thrust::unary_function { public: - BernoulliTrial(size_t _seed, float _p) : rnd(_seed), p(_p) {} + BernoulliTrial(size_t seed, float p) : rnd_(seed), p_(p) {} XGBOOST_DEVICE bool operator()(size_t i) const { - return rnd(i) > p; + return rnd_(i) > p_; } private: - RandomWeight rnd; - float p; + RandomWeight rnd_; + float p_; }; -UniformSampling::UniformSampling(EllpackPageImpl* page, float subsample) - : page_(page), subsample_(subsample) {} - -GradientBasedSample UniformSampling::Sample(common::Span gpair, DMatrix* dmat) { - // Set gradient pair to 0 with p = 1 - subsample - thrust::replace_if(dh::tbegin(gpair), dh::tend(gpair), - thrust::counting_iterator(0), - BernoulliTrial(common::GlobalRandom()(), subsample_), - GradientPair()); - return {dmat->Info().num_row_, page_, gpair}; -} - -ExternalMemoryUniformSampling::ExternalMemoryUniformSampling(float subsample) - : subsample_(subsample) {} - /*! \brief A functor that returns true if the gradient pair is non-zero. */ struct IsNonZero : public thrust::unary_function { XGBOOST_DEVICE bool operator()(const GradientPair& gpair) const { @@ -103,126 +90,100 @@ struct IsNonZero : public thrust::unary_function { } }; -GradientBasedSample ExternalMemoryUniformSampling::Sample(common::Span gpair, - DMatrix* dmat) { +/*! \brief A functor that scales gradient pairs by 1/p. */ +class FixedScaling : public thrust::unary_function { + public: + explicit FixedScaling(float p) : p_(p) {} + + XGBOOST_DEVICE GradientPair operator()(const GradientPair& gpair) const { + return gpair / p_; + } + + private: + float p_; +}; + +GradientBasedSample UniformSampling::Sample(common::Span gpair, DMatrix* dmat) { // Set gradient pair to 0 with p = 1 - subsample thrust::replace_if(dh::tbegin(gpair), dh::tend(gpair), thrust::counting_iterator(0), BernoulliTrial(common::GlobalRandom()(), subsample_), GradientPair()); + + // Count the sampled rows. size_t sample_rows = thrust::count_if(dh::tbegin(gpair), dh::tend(gpair), IsNonZero()); - return GradientBasedSample(); -} + size_t n_rows = dmat->Info().num_row_; -GradientBasedSample GradientBasedSampling::Sample(common::Span gpair, - DMatrix* dmat) { - return GradientBasedSample(); -} + // Rescale the gradient pairs by 1/p. + thrust::transform(dh::tbegin(gpair), dh::tend(gpair), + dh::tbegin(gpair), + FixedScaling(static_cast(sample_rows) / static_cast(n_rows))); -GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(common::Span gpair, - DMatrix* dmat) { - return GradientBasedSample(); + return {n_rows, page_, gpair}; } -GradientBasedSampler::GradientBasedSampler(EllpackPageImpl* page, - size_t n_rows, - const BatchParam& batch_param, - float subsample, - int sampling_method) - : original_page_(page), - batch_param_(batch_param), - is_external_memory_(page->matrix.n_rows != n_rows), - subsample_(subsample), - is_sampling_(subsample < 1.0), - sampling_method_(sampling_method), - sample_rows_(n_rows * subsample) { - - monitor_.Init("gradient_based_sampler"); +ExternalMemoryUniformSampling::ExternalMemoryUniformSampling(EllpackPageImpl* page, + size_t n_rows, + const BatchParam& batch_param, + float subsample) + : original_page_(page), batch_param_(batch_param), subsample_(subsample) { + ba_.Allocate(batch_param_.gpu_id, &sample_row_index_, n_rows); +} - if (is_sampling_) { - switch (sampling_method_) { - case TrainParam::kUniform: - if (is_external_memory_) { - strategy_.reset(new ExternalMemoryUniformSampling(subsample)); - } else { - strategy_.reset(new UniformSampling(page, subsample)); - } - break; - case TrainParam::kGradientBased: - if (is_external_memory_) { - strategy_.reset(new ExternalMemoryGradientBasedSampling()); - } else { - strategy_.reset(new GradientBasedSampling()); - } - break; - default: - LOG(FATAL) << "unknown sampling method"; - } - } else { - if (is_external_memory_) { - strategy_.reset(new ExternalMemoryNoSampling(page, n_rows, batch_param)); +/*! \brief A functor that clears the row indexes with empty gradient. */ +struct ClearEmptyRows : public thrust::binary_function { + XGBOOST_DEVICE size_t operator()(const GradientPair& gpair, size_t row_index) const { + if (gpair.GetGrad() != 0 || gpair.GetHess() != 0) { + return row_index; } else { - strategy_.reset(new NoSampling(page)); + return std::numeric_limits::max(); } } +}; -/* - if (is_sampling_ || is_external_memory_) { - // Create a new ELLPACK page with empty rows. - page_.reset(new EllpackPageImpl(batch_param.gpu_id, - original_page_->matrix.info, - sample_rows_)); - } - // Allocate GPU memory for sampling. - if (is_sampling_) { - ba_.Allocate(batch_param_.gpu_id, - &gpair_, sample_rows_, - &row_weight_, n_rows, - &threshold_, n_rows + 1, - &row_index_, n_rows, - &sample_row_index_, n_rows); - thrust::copy(thrust::counting_iterator(0), - thrust::counting_iterator(n_rows), - dh::tbegin(row_index_)); - } -*/ -} +GradientBasedSample ExternalMemoryUniformSampling::Sample(common::Span gpair, + DMatrix* dmat) { + // Set gradient pair to 0 with p = 1 - subsample + thrust::replace_if(dh::tbegin(gpair), dh::tend(gpair), + thrust::counting_iterator(0), + BernoulliTrial(common::GlobalRandom()(), subsample_), + GradientPair()); -// Sample a DMatrix based on the given gradient pairs. -GradientBasedSample GradientBasedSampler::Sample(common::Span gpair, - DMatrix* dmat) { - monitor_.StartCuda("Sample"); - GradientBasedSample sample = strategy_->Sample(gpair, dmat); - monitor_.StopCuda("Sample"); - return sample; -} + // Count the sampled rows. + size_t sample_rows = thrust::count_if(dh::tbegin(gpair), dh::tend(gpair), IsNonZero()); + size_t n_rows = dmat->Info().num_row_; -/*! \brief A functor that scales gradient pairs by 1/p. */ -struct FixedScaling : public thrust::unary_function { - float p; + // Compact gradient pairs. + gpair_.resize(sample_rows); + thrust::copy_if(dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero()); - XGBOOST_DEVICE explicit FixedScaling(float _p) : p(_p) {} + // Rescale the gradient pairs by 1/p. + thrust::transform(gpair_.begin(), gpair_.end(), + gpair_.begin(), + FixedScaling(static_cast(sample_rows) / static_cast(n_rows))); - XGBOOST_DEVICE GradientPair operator()(const GradientPair& gpair) const { - return gpair / p; + // Index the sample rows. + thrust::transform(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(sample_row_index_), IsNonZero()); + thrust::exclusive_scan(dh::tbegin(sample_row_index_), dh::tend(sample_row_index_), + dh::tbegin(sample_row_index_)); + thrust::transform(dh::tbegin(gpair), dh::tend(gpair), + dh::tbegin(sample_row_index_), + dh::tbegin(sample_row_index_), + ClearEmptyRows()); + + // Create a new ELLPACK page with empty rows. + page_.reset(new EllpackPageImpl(batch_param_.gpu_id, + original_page_->matrix.info, + sample_rows)); + + // Compact the ELLPACK pages into the single sample page. + thrust::fill(dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0); + for (auto& batch : dmat->GetBatches(batch_param_)) { + page_->Compact(batch_param_.gpu_id, batch.Impl(), sample_row_index_); } -}; -/* -GradientBasedSample GradientBasedSampler::UniformSampling(common::Span gpair, - DMatrix* dmat) { - // Generate random weights. - thrust::transform(thrust::counting_iterator(0), - thrust::counting_iterator(gpair.size()), - dh::tbegin(row_weight_), - RandomWeight(common::GlobalRandom()())); - // Scale gradient pairs by 1/subsample. - thrust::transform(dh::tbegin(gpair), dh::tend(gpair), - dh::tbegin(gpair), - FixedScaling(subsample_)); - return SequentialPoissonSampling(gpair, dmat); + return {sample_rows, page_.get(), dh::ToSpan(gpair_)}; } -*/ /*! \brief A functor that combines the gradient pair into a single float. * @@ -231,139 +192,154 @@ GradientBasedSample GradientBasedSampler::UniformSampling(common::Span { - static constexpr float kLambda = 0.1f; - +class CombineGradientPair : public thrust::unary_function { + public: XGBOOST_DEVICE float operator()(const GradientPair& gpair) const { return sqrtf(powf(gpair.GetGrad(), 2) + kLambda * powf(gpair.GetHess(), 2)); } + + private: + static constexpr float kLambda = 0.1f; }; /*! \brief A functor that calculates the difference between the sample rate and the desired sample * rows, given a cumulative gradient sum. */ -struct SampleRateDelta : public thrust::binary_function { - common::Span threshold; - size_t n_rows; - size_t sample_rows; - - XGBOOST_DEVICE SampleRateDelta(common::Span _threshold, - size_t _n_rows, - size_t _sample_rows) - : threshold(_threshold), n_rows(_n_rows), sample_rows(_sample_rows) {} +class SampleRateDelta : public thrust::binary_function { + public: + SampleRateDelta(common::Span threshold, size_t n_rows, size_t sample_rows) + : threshold_(threshold), n_rows_(n_rows), sample_rows_(sample_rows) {} XGBOOST_DEVICE float operator()(float gradient_sum, size_t row_index) const { - float lower = threshold[row_index]; - float upper = threshold[row_index + 1]; - float u = gradient_sum / static_cast(sample_rows - n_rows + row_index + 1); + float lower = threshold_[row_index]; + float upper = threshold_[row_index + 1]; + float u = gradient_sum / static_cast(sample_rows_ - n_rows_ + row_index + 1); if (u > lower && u <= upper) { - threshold[row_index + 1] = u; + threshold_[row_index + 1] = u; return 0.0f; } else { return std::numeric_limits::max(); } } + + private: + common::Span threshold_; + size_t n_rows_; + size_t sample_rows_; }; -size_t GradientBasedSampler::CalculateThresholdIndex(common::Span gpair) { - thrust::fill(dh::tend(threshold_) - 1, dh::tend(threshold_), std::numeric_limits::max()); +size_t GradientBasedSampler::CalculateThresholdIndex(common::Span gpair, + common::Span threshold, + common::Span grad_sum, + size_t sample_rows) { + thrust::fill(dh::tend(threshold) - 1, dh::tend(threshold), std::numeric_limits::max()); thrust::transform(dh::tbegin(gpair), dh::tend(gpair), - dh::tbegin(threshold_), + dh::tbegin(threshold), CombineGradientPair()); - thrust::sort(dh::tbegin(threshold_), dh::tend(threshold_) - 1); - thrust::inclusive_scan(dh::tbegin(threshold_), dh::tend(threshold_) - 1, dh::tbegin(row_weight_)); - thrust::transform(dh::tbegin(row_weight_), dh::tend(row_weight_), + thrust::sort(dh::tbegin(threshold), dh::tend(threshold) - 1); + thrust::inclusive_scan(dh::tbegin(threshold), dh::tend(threshold) - 1, dh::tbegin(grad_sum)); + thrust::transform(dh::tbegin(grad_sum), dh::tend(grad_sum), thrust::counting_iterator(0), - dh::tbegin(row_weight_), - SampleRateDelta(threshold_, gpair.size(), sample_rows_)); - thrust::device_ptr min = thrust::min_element(dh::tbegin(row_weight_), - dh::tend(row_weight_)); - return thrust::distance(dh::tbegin(row_weight_), min) + 1; + dh::tbegin(grad_sum), + SampleRateDelta(threshold, gpair.size(), sample_rows)); + thrust::device_ptr min = thrust::min_element(dh::tbegin(grad_sum), dh::tend(grad_sum)); + return thrust::distance(dh::tbegin(grad_sum), min) + 1; } -/*! \brief A functor that calculates the weight of each row, and scales gradient pairs by 1/p_i. */ -struct CalculateWeight - : public thrust::binary_function> { - common::Span threshold; - size_t threshold_index; - RandomWeight rnd; - CombineGradientPair combine; - - XGBOOST_DEVICE CalculateWeight(common::Span _threshold, - size_t _threshold_index, - RandomWeight _rnd) - : threshold(_threshold), threshold_index(_threshold_index), rnd(_rnd) {} - - XGBOOST_DEVICE thrust::tuple operator()(const GradientPair& gpair, - size_t i) { +GradientBasedSampling::GradientBasedSampling(EllpackPageImpl* page, + size_t n_rows, + const BatchParam& batch_param, + float subsample) : page_(page), subsample_(subsample) { + ba_.Allocate(batch_param.gpu_id, + &threshold_, n_rows + 1, + &grad_sum_, n_rows); +} + +/*! \brief A functor that performs Poisson sampling, and scales gradient pairs by 1/p_i. */ +class PoissonSampling : public thrust::binary_function { + public: + PoissonSampling(common::Span threshold, size_t threshold_index, RandomWeight rnd) + : threshold_(threshold), threshold_index_(threshold_index), rnd_(rnd) {} + + XGBOOST_DEVICE GradientPair operator()(const GradientPair& gpair, size_t i) { // If the gradient and hessian are both empty, we should never select this row. if (gpair.GetGrad() == 0 && gpair.GetHess() == 0) { - return thrust::make_tuple(std::numeric_limits::max(), gpair); + return gpair; } - float combined_gradient = combine(gpair); - float u = threshold[threshold_index]; + float combined_gradient = combine_(gpair); + float u = threshold_[threshold_index_]; float p = combined_gradient / u; if (p >= 1) { // Always select this row. - return thrust::make_tuple(0.0f, gpair); + return gpair; } else { // Select this row randomly with probability proportional to the combined gradient. // Scale gpair by 1/p. - return thrust::make_tuple(rnd(i) / combined_gradient, gpair / p); + if (rnd_(i) <= p) { + return gpair / p; + } else { + return GradientPair(); + } } } + + private: + common::Span threshold_; + size_t threshold_index_; + RandomWeight rnd_; + CombineGradientPair combine_; }; -/* -GradientBasedSample GradientBasedSampler::GradientBasedSampling( - common::Span& gpair, DMatrix* dmat) { - size_t threshold_index = CalculateThresholdIndex(gpair); +GradientBasedSample GradientBasedSampling::Sample(common::Span gpair, + DMatrix* dmat) { + size_t n_rows = dmat->Info().num_row_; + size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex( + gpair, threshold_, grad_sum_, n_rows * subsample_); + + // Perform Poisson sampling in place. thrust::transform(dh::tbegin(gpair), dh::tend(gpair), thrust::counting_iterator(0), - thrust::make_zip_iterator(thrust::make_tuple( - dh::tbegin(row_weight_), dh::tbegin(gpair))), - CalculateWeight(threshold_, threshold_index, - RandomWeight(common::GlobalRandom()()))); - return SequentialPoissonSampling(gpair, dmat); + dh::tbegin(gpair), + PoissonSampling(threshold_, + threshold_index, + RandomWeight(common::GlobalRandom()()))); + return {n_rows, page_, gpair}; } -*/ -/*! \brief A functor that clears the row indexes with empty gradient. */ -struct ClearEmptyRows : public thrust::binary_function { - XGBOOST_DEVICE size_t operator()(const GradientPair& gpair, size_t row_index) const { - if (gpair.GetGrad() != 0 || gpair.GetHess() != 0) { - return row_index; - } else { - return std::numeric_limits::max(); - } - } -}; - -// Perform sampling after the weights are calculated. -GradientBasedSample GradientBasedSampler::SequentialPoissonSampling( - common::Span gpair, DMatrix* dmat) { - // Sort the gradient pairs and row indexes by weight. - thrust::sort_by_key(dh::tbegin(row_weight_), dh::tend(row_weight_), - thrust::make_zip_iterator(thrust::make_tuple(dh::tbegin(gpair), - dh::tbegin(row_index_)))); +ExternalMemoryGradientBasedSampling::ExternalMemoryGradientBasedSampling( + EllpackPageImpl* page, + size_t n_rows, + const BatchParam& batch_param, + float subsample) : original_page_(page), batch_param_(batch_param), subsample_(subsample) { + ba_.Allocate(batch_param.gpu_id, + &threshold_, n_rows + 1, + &grad_sum_, n_rows, + &sample_row_index_, n_rows); +} - // Clear the gradient pairs not included in the sample. - thrust::fill(dh::tbegin(gpair) + sample_rows_, dh::tend(gpair), GradientPair()); +GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(common::Span gpair, + DMatrix* dmat) { + size_t n_rows = dmat->Info().num_row_; + size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex( + gpair, threshold_, grad_sum_, n_rows * subsample_); - // Mask the sample rows. - thrust::fill(dh::tbegin(sample_row_index_), dh::tbegin(sample_row_index_) + sample_rows_, 1); - thrust::fill(dh::tbegin(sample_row_index_) + sample_rows_, dh::tend(sample_row_index_), 0); + // Perform Poisson sampling in place. + thrust::transform(dh::tbegin(gpair), dh::tend(gpair), + thrust::counting_iterator(0), + dh::tbegin(gpair), + PoissonSampling(threshold_, + threshold_index, + RandomWeight(common::GlobalRandom()()))); - // Sort the gradient pairs and sample row indexes by the original row index. - thrust::sort_by_key(dh::tbegin(row_index_), dh::tend(row_index_), - thrust::make_zip_iterator(thrust::make_tuple(dh::tbegin(gpair), - dh::tbegin(sample_row_index_)))); + // Count the sampled rows. + size_t sample_rows = thrust::count_if(dh::tbegin(gpair), dh::tend(gpair), IsNonZero()); - // Compact the non-zero gradient pairs. - thrust::fill(dh::tbegin(gpair_), dh::tend(gpair_), GradientPair()); - thrust::copy_if(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(gpair_), IsNonZero()); + // Compact gradient pairs. + gpair_.resize(sample_rows); + thrust::copy_if(dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero()); // Index the sample rows. + thrust::transform(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(sample_row_index_), IsNonZero()); thrust::exclusive_scan(dh::tbegin(sample_row_index_), dh::tend(sample_row_index_), dh::tbegin(sample_row_index_)); thrust::transform(dh::tbegin(gpair), dh::tend(gpair), @@ -371,13 +347,65 @@ GradientBasedSample GradientBasedSampler::SequentialPoissonSampling( dh::tbegin(sample_row_index_), ClearEmptyRows()); + // Create a new ELLPACK page with empty rows. + page_.reset(new EllpackPageImpl(batch_param_.gpu_id, + original_page_->matrix.info, + sample_rows)); + // Compact the ELLPACK pages into the single sample page. thrust::fill(dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0); for (auto& batch : dmat->GetBatches(batch_param_)) { page_->Compact(batch_param_.gpu_id, batch.Impl(), sample_row_index_); } - return {sample_rows_, page_.get(), gpair_}; + return {sample_rows, page_.get(), dh::ToSpan(gpair_)}; +} + +GradientBasedSampler::GradientBasedSampler(EllpackPageImpl* page, + size_t n_rows, + const BatchParam& batch_param, + float subsample, + int sampling_method) { + monitor_.Init("gradient_based_sampler"); + + bool is_sampling = subsample < 1.0; + bool is_external_memory = page->matrix.n_rows != n_rows; + + if (is_sampling) { + switch (sampling_method) { + case TrainParam::kUniform: + if (is_external_memory) { + strategy_.reset(new ExternalMemoryUniformSampling(page, n_rows, batch_param, subsample)); + } else { + strategy_.reset(new UniformSampling(page, subsample)); + } + break; + case TrainParam::kGradientBased: + if (is_external_memory) { + strategy_.reset( + new ExternalMemoryGradientBasedSampling(page, n_rows, batch_param, subsample)); + } else { + strategy_.reset(new GradientBasedSampling(page, n_rows, batch_param, subsample)); + } + break; + default:LOG(FATAL) << "unknown sampling method"; + } + } else { + if (is_external_memory) { + strategy_.reset(new ExternalMemoryNoSampling(page, n_rows, batch_param)); + } else { + strategy_.reset(new NoSampling(page)); + } + } +} + +// Sample a DMatrix based on the given gradient pairs. +GradientBasedSample GradientBasedSampler::Sample(common::Span gpair, + DMatrix* dmat) { + monitor_.StartCuda("Sample"); + GradientBasedSample sample = strategy_->Sample(gpair, dmat); + monitor_.StopCuda("Sample"); + return sample; } }; // namespace tree diff --git a/src/tree/gpu_hist/gradient_based_sampler.cuh b/src/tree/gpu_hist/gradient_based_sampler.cuh index fceb28ba1ae5..1aa0893ad4aa 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cuh +++ b/src/tree/gpu_hist/gradient_based_sampler.cuh @@ -71,30 +71,61 @@ class UniformSampling : public SamplingStrategy { /*! \brief No sampling in external memory mode. */ class ExternalMemoryUniformSampling : public SamplingStrategy { public: - ExternalMemoryUniformSampling(float subsample); + ExternalMemoryUniformSampling(EllpackPageImpl* page, + size_t n_rows, + const BatchParam& batch_param, + float subsample); GradientBasedSample Sample(common::Span gpair, DMatrix* dmat) override; private: dh::BulkAllocator ba_; EllpackPageImpl* original_page_; - float subsample_; BatchParam batch_param_; + float subsample_; std::unique_ptr page_; - common::Span gpair_; + dh::device_vector gpair_; common::Span sample_row_index_; }; +/*! \brief Gradient-based sampling in in-memory mode.. */ class GradientBasedSampling : public SamplingStrategy { public: - /*! \brief Gradient-based sampling in in-memory mode.. */ + GradientBasedSampling(EllpackPageImpl* page, + size_t n_rows, + const BatchParam& batch_param, + float subsample); + GradientBasedSample Sample(common::Span gpair, DMatrix* dmat) override; + + private: + EllpackPageImpl* page_; + float subsample_; + dh::BulkAllocator ba_; + common::Span threshold_; + common::Span grad_sum_; }; +/*! \brief Gradient-based sampling in external memory mode.. */ class ExternalMemoryGradientBasedSampling : public SamplingStrategy { public: - /*! \brief Gradient-based sampling in external memory mode.. */ + ExternalMemoryGradientBasedSampling(EllpackPageImpl* page, + size_t n_rows, + const BatchParam& batch_param, + float subsample); + GradientBasedSample Sample(common::Span gpair, DMatrix* dmat) override; + + private: + dh::BulkAllocator ba_; + EllpackPageImpl* original_page_; + BatchParam batch_param_; + float subsample_; + common::Span threshold_; + common::Span grad_sum_; + std::unique_ptr page_; + dh::device_vector gpair_; + common::Span sample_row_index_; }; /*! \brief Draw a sample of rows from a DMatrix. @@ -117,32 +148,15 @@ class GradientBasedSampler { /*! \brief Sample from a DMatrix based on the given gradient pairs. */ GradientBasedSample Sample(common::Span gpair, DMatrix* dmat); - private: /*! \brief Calculate the threshold used to normalize sampling probabilities. */ - size_t CalculateThresholdIndex(common::Span gpair); - - /*! \brief Fixed-size Poisson sampling after the row weights are calculated. */ - GradientBasedSample SequentialPoissonSampling(common::Span gpair, DMatrix* dmat); + static size_t CalculateThresholdIndex(common::Span gpair, + common::Span threshold, + common::Span grad_sum, + size_t sample_rows); + private: common::Monitor monitor_; std::unique_ptr strategy_; - - - dh::BulkAllocator ba_; - EllpackPageImpl* original_page_; - float subsample_; - bool is_external_memory_; - bool is_sampling_; - BatchParam batch_param_; - int sampling_method_; - size_t sample_rows_; - std::unique_ptr page_; - common::Span gpair_; - common::Span row_weight_; - common::Span threshold_; - common::Span row_index_; - common::Span sample_row_index_; - bool page_concatenated_{false}; }; }; // namespace tree }; // namespace xgboost diff --git a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu index f2e27ced42aa..281b01b09148 100644 --- a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu +++ b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu @@ -7,7 +7,10 @@ namespace xgboost { namespace tree { -void VerifySampling(size_t page_size, float subsample, int sampling_method) { +void VerifySampling(size_t page_size, + float subsample, + int sampling_method, + bool fixed_size_sampling = false) { constexpr size_t kRows = 4096; constexpr size_t kCols = 1; size_t sample_rows = kRows * subsample; @@ -30,9 +33,16 @@ void VerifySampling(size_t page_size, float subsample, int sampling_method) { GradientBasedSampler sampler(page, kRows, param, subsample, sampling_method); auto sample = sampler.Sample(gpair.DeviceSpan(), dmat.get()); - EXPECT_EQ(sample.sample_rows, sample_rows); - EXPECT_EQ(sample.page->matrix.n_rows, sample_rows); - EXPECT_EQ(sample.gpair.size(), sample_rows); + + if (fixed_size_sampling) { + EXPECT_EQ(sample.sample_rows, kRows); + EXPECT_EQ(sample.page->matrix.n_rows, kRows); + EXPECT_EQ(sample.gpair.size(), kRows); + } else { + EXPECT_NEAR(sample.sample_rows, sample_rows, kRows * 0.012); + EXPECT_NEAR(sample.page->matrix.n_rows, sample_rows, kRows * 0.012); + EXPECT_NEAR(sample.gpair.size(), sample_rows, kRows * 0.012); + } GradientPair sum_sampled_gpair{}; std::vector sampled_gpair_h(sample.gpair.size()); @@ -48,7 +58,8 @@ TEST(GradientBasedSampler, NoSampling) { constexpr size_t kPageSize = 0; constexpr float kSubsample = 1.0f; constexpr int kSamplingMethod = TrainParam::kUniform; - VerifySampling(kPageSize, kSubsample, kSamplingMethod); + constexpr bool kFixedSizeSampling = true; + VerifySampling(kPageSize, kSubsample, kSamplingMethod, kFixedSizeSampling); } // In external mode, when not sampling, we concatenate the pages together. @@ -101,7 +112,8 @@ TEST(GradientBasedSampler, UniformSampling) { constexpr size_t kPageSize = 0; constexpr float kSubsample = 0.5; constexpr int kSamplingMethod = TrainParam::kUniform; - VerifySampling(kPageSize, kSubsample, kSamplingMethod); + constexpr bool kFixedSizeSampling = true; + VerifySampling(kPageSize, kSubsample, kSamplingMethod, kFixedSizeSampling); } TEST(GradientBasedSampler, UniformSampling_ExternalMemory) { @@ -113,14 +125,15 @@ TEST(GradientBasedSampler, UniformSampling_ExternalMemory) { TEST(GradientBasedSampler, GradientBasedSampling) { constexpr size_t kPageSize = 0; - constexpr float kSubsample = 0.5; + constexpr float kSubsample = 0.8; constexpr int kSamplingMethod = TrainParam::kGradientBased; - VerifySampling(kPageSize, kSubsample, kSamplingMethod); + constexpr bool kFixedSizeSampling = true; + VerifySampling(kPageSize, kSubsample, kSamplingMethod, kFixedSizeSampling); } TEST(GradientBasedSampler, GradientBasedSampling_ExternalMemory) { constexpr size_t kPageSize = 1024; - constexpr float kSubsample = 0.5; + constexpr float kSubsample = 0.8; constexpr int kSamplingMethod = TrainParam::kGradientBased; VerifySampling(kPageSize, kSubsample, kSamplingMethod); } From 7cf91107f101f5ade3e216b99311a4c1b6ebfd0d Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Mon, 27 Jan 2020 16:29:20 -0800 Subject: [PATCH 45/48] fix tests --- tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu | 4 ++-- tests/cpp/tree/test_gpu_hist.cu | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu index 281b01b09148..9b136e73fa45 100644 --- a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu +++ b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu @@ -50,8 +50,8 @@ void VerifySampling(size_t page_size, for (const auto& gp : sampled_gpair_h) { sum_sampled_gpair += gp; } - EXPECT_NEAR(sum_gpair.GetGrad(), sum_sampled_gpair.GetGrad(), 0.01f * kRows); - EXPECT_NEAR(sum_gpair.GetHess(), sum_sampled_gpair.GetHess(), 0.01f * kRows); + EXPECT_NEAR(sum_gpair.GetGrad(), sum_sampled_gpair.GetGrad(), 0.02f * kRows); + EXPECT_NEAR(sum_gpair.GetHess(), sum_sampled_gpair.GetHess(), 0.02f * kRows); } TEST(GradientBasedSampler, NoSampling) { diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 8fd5f3598b80..6cb0aad26719 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -407,7 +407,7 @@ TEST(GpuHist, UniformSampling) { auto preds_h = preds.ConstHostVector(); auto preds_sampling_h = preds_sampling.ConstHostVector(); for (int i = 0; i < kRows; i++) { - EXPECT_NEAR(preds_h[i], preds_sampling_h[i], 1e-3); + EXPECT_NEAR(preds_h[i], preds_sampling_h[i], 2e-3); } } @@ -503,7 +503,7 @@ TEST(GpuHist, ExternalMemoryWithSampling) { auto preds_h = preds.ConstHostVector(); auto preds_ext_h = preds_ext.ConstHostVector(); for (int i = 0; i < kRows; i++) { - EXPECT_NEAR(preds_h[i], preds_ext_h[i], 2e-6); + EXPECT_NEAR(preds_h[i], preds_ext_h[i], 3e-3); } } From 3a734a9796a98ab328187b2391536bc43db5d55e Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Tue, 28 Jan 2020 11:13:54 -0800 Subject: [PATCH 46/48] release device memory --- rabit | 2 +- src/tree/gpu_hist/gradient_based_sampler.cu | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/rabit b/rabit index 2f7fcff4d770..2f253471680f 160000 --- a/rabit +++ b/rabit @@ -1 +1 @@ -Subproject commit 2f7fcff4d770a3eb4fba6b25ded74b45e196ccd6 +Subproject commit 2f253471680f1bdafc1dfa17395ca0f309fe96de diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index 029b206b2b62..248853533e55 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -172,6 +172,7 @@ GradientBasedSample ExternalMemoryUniformSampling::Sample(common::Spanmatrix.info, sample_rows)); @@ -348,6 +349,7 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(common::Spanmatrix.info, sample_rows)); From 55b36f229f3c8cb62929cff72395cda2dd16a436 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Tue, 28 Jan 2020 13:29:55 -0800 Subject: [PATCH 47/48] remove scaling in uniform sampling --- src/tree/gpu_hist/gradient_based_sampler.cu | 264 ++++++++---------- src/tree/gpu_hist/gradient_based_sampler.cuh | 11 +- .../gpu_hist/test_gradient_based_sampler.cu | 34 ++- 3 files changed, 137 insertions(+), 172 deletions(-) diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index 248853533e55..f294855b5dbd 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -17,39 +17,6 @@ namespace xgboost { namespace tree { -NoSampling::NoSampling(EllpackPageImpl* page) : page_(page) {} - -GradientBasedSample NoSampling::Sample(common::Span gpair, DMatrix* dmat) { - return {dmat->Info().num_row_, page_, gpair}; -} - -ExternalMemoryNoSampling::ExternalMemoryNoSampling(EllpackPageImpl* page, - size_t n_rows, - const BatchParam& batch_param) - : batch_param_(batch_param), - page_(new EllpackPageImpl(batch_param.gpu_id, page->matrix.info, n_rows)) {} - -// Concatenate all the external memory ELLPACK pages into a single in-memory page. -void ExternalMemoryNoSampling::ConcatenatePages(DMatrix* dmat) { - if (page_concatenated_) { - return; - } - - size_t offset = 0; - for (auto& batch : dmat->GetBatches(batch_param_)) { - auto page = batch.Impl(); - size_t num_elements = page_->Copy(batch_param_.gpu_id, page, offset); - offset += num_elements; - } - page_concatenated_ = true; -} - -GradientBasedSample ExternalMemoryNoSampling::Sample(common::Span gpair, - DMatrix* dmat) { - ConcatenatePages(dmat); - return {dmat->Info().num_row_, page_.get(), gpair}; -} - /*! \brief A functor that returns random weights. */ class RandomWeight : public thrust::unary_function { public: @@ -66,9 +33,6 @@ class RandomWeight : public thrust::unary_function { uint32_t seed_; }; -UniformSampling::UniformSampling(EllpackPageImpl* page, float subsample) - : page_(page), subsample_(subsample) {} - /*! \brief A functor that performs a Bernoulli trial to discard a gradient pair. */ class BernoulliTrial : public thrust::unary_function { public: @@ -90,46 +54,6 @@ struct IsNonZero : public thrust::unary_function { } }; -/*! \brief A functor that scales gradient pairs by 1/p. */ -class FixedScaling : public thrust::unary_function { - public: - explicit FixedScaling(float p) : p_(p) {} - - XGBOOST_DEVICE GradientPair operator()(const GradientPair& gpair) const { - return gpair / p_; - } - - private: - float p_; -}; - -GradientBasedSample UniformSampling::Sample(common::Span gpair, DMatrix* dmat) { - // Set gradient pair to 0 with p = 1 - subsample - thrust::replace_if(dh::tbegin(gpair), dh::tend(gpair), - thrust::counting_iterator(0), - BernoulliTrial(common::GlobalRandom()(), subsample_), - GradientPair()); - - // Count the sampled rows. - size_t sample_rows = thrust::count_if(dh::tbegin(gpair), dh::tend(gpair), IsNonZero()); - size_t n_rows = dmat->Info().num_row_; - - // Rescale the gradient pairs by 1/p. - thrust::transform(dh::tbegin(gpair), dh::tend(gpair), - dh::tbegin(gpair), - FixedScaling(static_cast(sample_rows) / static_cast(n_rows))); - - return {n_rows, page_, gpair}; -} - -ExternalMemoryUniformSampling::ExternalMemoryUniformSampling(EllpackPageImpl* page, - size_t n_rows, - const BatchParam& batch_param, - float subsample) - : original_page_(page), batch_param_(batch_param), subsample_(subsample) { - ba_.Allocate(batch_param_.gpu_id, &sample_row_index_, n_rows); -} - /*! \brief A functor that clears the row indexes with empty gradient. */ struct ClearEmptyRows : public thrust::binary_function { XGBOOST_DEVICE size_t operator()(const GradientPair& gpair, size_t row_index) const { @@ -141,51 +65,6 @@ struct ClearEmptyRows : public thrust::binary_function gpair, - DMatrix* dmat) { - // Set gradient pair to 0 with p = 1 - subsample - thrust::replace_if(dh::tbegin(gpair), dh::tend(gpair), - thrust::counting_iterator(0), - BernoulliTrial(common::GlobalRandom()(), subsample_), - GradientPair()); - - // Count the sampled rows. - size_t sample_rows = thrust::count_if(dh::tbegin(gpair), dh::tend(gpair), IsNonZero()); - size_t n_rows = dmat->Info().num_row_; - - // Compact gradient pairs. - gpair_.resize(sample_rows); - thrust::copy_if(dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero()); - - // Rescale the gradient pairs by 1/p. - thrust::transform(gpair_.begin(), gpair_.end(), - gpair_.begin(), - FixedScaling(static_cast(sample_rows) / static_cast(n_rows))); - - // Index the sample rows. - thrust::transform(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(sample_row_index_), IsNonZero()); - thrust::exclusive_scan(dh::tbegin(sample_row_index_), dh::tend(sample_row_index_), - dh::tbegin(sample_row_index_)); - thrust::transform(dh::tbegin(gpair), dh::tend(gpair), - dh::tbegin(sample_row_index_), - dh::tbegin(sample_row_index_), - ClearEmptyRows()); - - // Create a new ELLPACK page with empty rows. - page_.reset(); // Release the device memory first before reallocating - page_.reset(new EllpackPageImpl(batch_param_.gpu_id, - original_page_->matrix.info, - sample_rows)); - - // Compact the ELLPACK pages into the single sample page. - thrust::fill(dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0); - for (auto& batch : dmat->GetBatches(batch_param_)) { - page_->Compact(batch_param_.gpu_id, batch.Impl(), sample_row_index_); - } - - return {sample_rows, page_.get(), dh::ToSpan(gpair_)}; -} - /*! \brief A functor that combines the gradient pair into a single float. * * The approach here is based on Minimal Variance Sampling (MVS), with lambda set to 0.1. @@ -229,38 +108,11 @@ class SampleRateDelta : public thrust::binary_function { size_t sample_rows_; }; -size_t GradientBasedSampler::CalculateThresholdIndex(common::Span gpair, - common::Span threshold, - common::Span grad_sum, - size_t sample_rows) { - thrust::fill(dh::tend(threshold) - 1, dh::tend(threshold), std::numeric_limits::max()); - thrust::transform(dh::tbegin(gpair), dh::tend(gpair), - dh::tbegin(threshold), - CombineGradientPair()); - thrust::sort(dh::tbegin(threshold), dh::tend(threshold) - 1); - thrust::inclusive_scan(dh::tbegin(threshold), dh::tend(threshold) - 1, dh::tbegin(grad_sum)); - thrust::transform(dh::tbegin(grad_sum), dh::tend(grad_sum), - thrust::counting_iterator(0), - dh::tbegin(grad_sum), - SampleRateDelta(threshold, gpair.size(), sample_rows)); - thrust::device_ptr min = thrust::min_element(dh::tbegin(grad_sum), dh::tend(grad_sum)); - return thrust::distance(dh::tbegin(grad_sum), min) + 1; -} - -GradientBasedSampling::GradientBasedSampling(EllpackPageImpl* page, - size_t n_rows, - const BatchParam& batch_param, - float subsample) : page_(page), subsample_(subsample) { - ba_.Allocate(batch_param.gpu_id, - &threshold_, n_rows + 1, - &grad_sum_, n_rows); -} - /*! \brief A functor that performs Poisson sampling, and scales gradient pairs by 1/p_i. */ class PoissonSampling : public thrust::binary_function { public: PoissonSampling(common::Span threshold, size_t threshold_index, RandomWeight rnd) - : threshold_(threshold), threshold_index_(threshold_index), rnd_(rnd) {} + : threshold_(threshold), threshold_index_(threshold_index), rnd_(rnd) {} XGBOOST_DEVICE GradientPair operator()(const GradientPair& gpair, size_t i) { // If the gradient and hessian are both empty, we should never select this row. @@ -291,6 +143,102 @@ class PoissonSampling : public thrust::binary_function gpair, DMatrix* dmat) { + return {dmat->Info().num_row_, page_, gpair}; +} + +ExternalMemoryNoSampling::ExternalMemoryNoSampling(EllpackPageImpl* page, + size_t n_rows, + const BatchParam& batch_param) + : batch_param_(batch_param), + page_(new EllpackPageImpl(batch_param.gpu_id, page->matrix.info, n_rows)) {} + +GradientBasedSample ExternalMemoryNoSampling::Sample(common::Span gpair, + DMatrix* dmat) { + if (!page_concatenated_) { + // Concatenate all the external memory ELLPACK pages into a single in-memory page. + size_t offset = 0; + for (auto& batch : dmat->GetBatches(batch_param_)) { + auto page = batch.Impl(); + size_t num_elements = page_->Copy(batch_param_.gpu_id, page, offset); + offset += num_elements; + } + page_concatenated_ = true; + } + return {dmat->Info().num_row_, page_.get(), gpair}; +} + +UniformSampling::UniformSampling(EllpackPageImpl* page, float subsample) + : page_(page), subsample_(subsample) {} + +GradientBasedSample UniformSampling::Sample(common::Span gpair, DMatrix* dmat) { + // Set gradient pair to 0 with p = 1 - subsample + thrust::replace_if(dh::tbegin(gpair), dh::tend(gpair), + thrust::counting_iterator(0), + BernoulliTrial(common::GlobalRandom()(), subsample_), + GradientPair()); + return {dmat->Info().num_row_, page_, gpair}; +} + +ExternalMemoryUniformSampling::ExternalMemoryUniformSampling(EllpackPageImpl* page, + size_t n_rows, + const BatchParam& batch_param, + float subsample) + : original_page_(page), batch_param_(batch_param), subsample_(subsample) { + ba_.Allocate(batch_param_.gpu_id, &sample_row_index_, n_rows); +} + +GradientBasedSample ExternalMemoryUniformSampling::Sample(common::Span gpair, + DMatrix* dmat) { + // Set gradient pair to 0 with p = 1 - subsample + thrust::replace_if(dh::tbegin(gpair), dh::tend(gpair), + thrust::counting_iterator(0), + BernoulliTrial(common::GlobalRandom()(), subsample_), + GradientPair()); + + // Count the sampled rows. + size_t sample_rows = thrust::count_if(dh::tbegin(gpair), dh::tend(gpair), IsNonZero()); + size_t n_rows = dmat->Info().num_row_; + + // Compact gradient pairs. + gpair_.resize(sample_rows); + thrust::copy_if(dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero()); + + // Index the sample rows. + thrust::transform(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(sample_row_index_), IsNonZero()); + thrust::exclusive_scan(dh::tbegin(sample_row_index_), dh::tend(sample_row_index_), + dh::tbegin(sample_row_index_)); + thrust::transform(dh::tbegin(gpair), dh::tend(gpair), + dh::tbegin(sample_row_index_), + dh::tbegin(sample_row_index_), + ClearEmptyRows()); + + // Create a new ELLPACK page with empty rows. + page_.reset(); // Release the device memory first before reallocating + page_.reset(new EllpackPageImpl(batch_param_.gpu_id, + original_page_->matrix.info, + sample_rows)); + + // Compact the ELLPACK pages into the single sample page. + thrust::fill(dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0); + for (auto& batch : dmat->GetBatches(batch_param_)) { + page_->Compact(batch_param_.gpu_id, batch.Impl(), sample_row_index_); + } + + return {sample_rows, page_.get(), dh::ToSpan(gpair_)}; +} + +GradientBasedSampling::GradientBasedSampling(EllpackPageImpl* page, + size_t n_rows, + const BatchParam& batch_param, + float subsample) : page_(page), subsample_(subsample) { + ba_.Allocate(batch_param.gpu_id, + &threshold_, n_rows + 1, + &grad_sum_, n_rows); +} + GradientBasedSample GradientBasedSampling::Sample(common::Span gpair, DMatrix* dmat) { size_t n_rows = dmat->Info().num_row_; @@ -410,5 +358,23 @@ GradientBasedSample GradientBasedSampler::Sample(common::Span gpai return sample; } +size_t GradientBasedSampler::CalculateThresholdIndex(common::Span gpair, + common::Span threshold, + common::Span grad_sum, + size_t sample_rows) { + thrust::fill(dh::tend(threshold) - 1, dh::tend(threshold), std::numeric_limits::max()); + thrust::transform(dh::tbegin(gpair), dh::tend(gpair), + dh::tbegin(threshold), + CombineGradientPair()); + thrust::sort(dh::tbegin(threshold), dh::tend(threshold) - 1); + thrust::inclusive_scan(dh::tbegin(threshold), dh::tend(threshold) - 1, dh::tbegin(grad_sum)); + thrust::transform(dh::tbegin(grad_sum), dh::tend(grad_sum), + thrust::counting_iterator(0), + dh::tbegin(grad_sum), + SampleRateDelta(threshold, gpair.size(), sample_rows)); + thrust::device_ptr min = thrust::min_element(dh::tbegin(grad_sum), dh::tend(grad_sum)); + return thrust::distance(dh::tbegin(grad_sum), min) + 1; +} + }; // namespace tree }; // namespace xgboost diff --git a/src/tree/gpu_hist/gradient_based_sampler.cuh b/src/tree/gpu_hist/gradient_based_sampler.cuh index 1aa0893ad4aa..41099e3bc134 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cuh +++ b/src/tree/gpu_hist/gradient_based_sampler.cuh @@ -31,7 +31,6 @@ class SamplingStrategy { class NoSampling : public SamplingStrategy { public: explicit NoSampling(EllpackPageImpl* page); - GradientBasedSample Sample(common::Span gpair, DMatrix* dmat) override; private: @@ -44,13 +43,9 @@ class ExternalMemoryNoSampling : public SamplingStrategy { ExternalMemoryNoSampling(EllpackPageImpl* page, size_t n_rows, const BatchParam& batch_param); - GradientBasedSample Sample(common::Span gpair, DMatrix* dmat) override; private: - /*! \brief Concatenate all the rows from a DMatrix into a single ELLPACK page. */ - void ConcatenatePages(DMatrix* dmat); - BatchParam batch_param_; std::unique_ptr page_; bool page_concatenated_{false}; @@ -60,7 +55,6 @@ class ExternalMemoryNoSampling : public SamplingStrategy { class UniformSampling : public SamplingStrategy { public: UniformSampling(EllpackPageImpl* page, float subsample); - GradientBasedSample Sample(common::Span gpair, DMatrix* dmat) override; private: @@ -75,7 +69,6 @@ class ExternalMemoryUniformSampling : public SamplingStrategy { size_t n_rows, const BatchParam& batch_param, float subsample); - GradientBasedSample Sample(common::Span gpair, DMatrix* dmat) override; private: @@ -84,7 +77,7 @@ class ExternalMemoryUniformSampling : public SamplingStrategy { BatchParam batch_param_; float subsample_; std::unique_ptr page_; - dh::device_vector gpair_; + dh::device_vector gpair_{}; common::Span sample_row_index_; }; @@ -95,7 +88,6 @@ class GradientBasedSampling : public SamplingStrategy { size_t n_rows, const BatchParam& batch_param, float subsample); - GradientBasedSample Sample(common::Span gpair, DMatrix* dmat) override; private: @@ -113,7 +105,6 @@ class ExternalMemoryGradientBasedSampling : public SamplingStrategy { size_t n_rows, const BatchParam& batch_param, float subsample); - GradientBasedSample Sample(common::Span gpair, DMatrix* dmat) override; private: diff --git a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu index 9b136e73fa45..579436245c7f 100644 --- a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu +++ b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu @@ -10,7 +10,8 @@ namespace tree { void VerifySampling(size_t page_size, float subsample, int sampling_method, - bool fixed_size_sampling = false) { + bool fixed_size_sampling = true, + bool check_sum = true) { constexpr size_t kRows = 4096; constexpr size_t kCols = 1; size_t sample_rows = kRows * subsample; @@ -39,9 +40,9 @@ void VerifySampling(size_t page_size, EXPECT_EQ(sample.page->matrix.n_rows, kRows); EXPECT_EQ(sample.gpair.size(), kRows); } else { - EXPECT_NEAR(sample.sample_rows, sample_rows, kRows * 0.012); - EXPECT_NEAR(sample.page->matrix.n_rows, sample_rows, kRows * 0.012); - EXPECT_NEAR(sample.gpair.size(), sample_rows, kRows * 0.012); + EXPECT_NEAR(sample.sample_rows, sample_rows, kRows * 0.012f); + EXPECT_NEAR(sample.page->matrix.n_rows, sample_rows, kRows * 0.012f); + EXPECT_NEAR(sample.gpair.size(), sample_rows, kRows * 0.012f); } GradientPair sum_sampled_gpair{}; @@ -50,16 +51,20 @@ void VerifySampling(size_t page_size, for (const auto& gp : sampled_gpair_h) { sum_sampled_gpair += gp; } - EXPECT_NEAR(sum_gpair.GetGrad(), sum_sampled_gpair.GetGrad(), 0.02f * kRows); - EXPECT_NEAR(sum_gpair.GetHess(), sum_sampled_gpair.GetHess(), 0.02f * kRows); + if (check_sum) { + EXPECT_NEAR(sum_gpair.GetGrad(), sum_sampled_gpair.GetGrad(), 0.02f * kRows); + EXPECT_NEAR(sum_gpair.GetHess(), sum_sampled_gpair.GetHess(), 0.02f * kRows); + } else { + EXPECT_NEAR(sum_gpair.GetGrad() / kRows, sum_sampled_gpair.GetGrad() / sample_rows, 0.02f); + EXPECT_NEAR(sum_gpair.GetHess() / kRows, sum_sampled_gpair.GetHess() / sample_rows, 0.02f); + } } TEST(GradientBasedSampler, NoSampling) { constexpr size_t kPageSize = 0; constexpr float kSubsample = 1.0f; constexpr int kSamplingMethod = TrainParam::kUniform; - constexpr bool kFixedSizeSampling = true; - VerifySampling(kPageSize, kSubsample, kSamplingMethod, kFixedSizeSampling); + VerifySampling(kPageSize, kSubsample, kSamplingMethod); } // In external mode, when not sampling, we concatenate the pages together. @@ -113,29 +118,32 @@ TEST(GradientBasedSampler, UniformSampling) { constexpr float kSubsample = 0.5; constexpr int kSamplingMethod = TrainParam::kUniform; constexpr bool kFixedSizeSampling = true; - VerifySampling(kPageSize, kSubsample, kSamplingMethod, kFixedSizeSampling); + constexpr bool kCheckSum = false; + VerifySampling(kPageSize, kSubsample, kSamplingMethod, kFixedSizeSampling, kCheckSum); } TEST(GradientBasedSampler, UniformSampling_ExternalMemory) { constexpr size_t kPageSize = 1024; constexpr float kSubsample = 0.5; constexpr int kSamplingMethod = TrainParam::kUniform; - VerifySampling(kPageSize, kSubsample, kSamplingMethod); + constexpr bool kFixedSizeSampling = false; + constexpr bool kCheckSum = false; + VerifySampling(kPageSize, kSubsample, kSamplingMethod, kFixedSizeSampling, kCheckSum); } TEST(GradientBasedSampler, GradientBasedSampling) { constexpr size_t kPageSize = 0; constexpr float kSubsample = 0.8; constexpr int kSamplingMethod = TrainParam::kGradientBased; - constexpr bool kFixedSizeSampling = true; - VerifySampling(kPageSize, kSubsample, kSamplingMethod, kFixedSizeSampling); + VerifySampling(kPageSize, kSubsample, kSamplingMethod); } TEST(GradientBasedSampler, GradientBasedSampling_ExternalMemory) { constexpr size_t kPageSize = 1024; constexpr float kSubsample = 0.8; constexpr int kSamplingMethod = TrainParam::kGradientBased; - VerifySampling(kPageSize, kSubsample, kSamplingMethod); + constexpr bool kFixedSizeSampling = false; + VerifySampling(kPageSize, kSubsample, kSamplingMethod, kFixedSizeSampling); } }; // namespace tree From 7fd7c31e5418cc760b20923bd0f8ba9deae1d390 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 31 Jan 2020 12:03:43 -0800 Subject: [PATCH 48/48] revert rabit --- rabit | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rabit b/rabit index 2f253471680f..2f7fcff4d770 160000 --- a/rabit +++ b/rabit @@ -1 +1 @@ -Subproject commit 2f253471680f1bdafc1dfa17395ca0f309fe96de +Subproject commit 2f7fcff4d770a3eb4fba6b25ded74b45e196ccd6