diff --git a/include/xgboost/base.h b/include/xgboost/base.h index 9b8af5f17902..0b3a1e0af5f1 100644 --- a/include/xgboost/base.h +++ b/include/xgboost/base.h @@ -193,6 +193,36 @@ class GradientPairInternal { return g; } + XGBOOST_DEVICE GradientPairInternal &operator*=(float multiplier) { + grad_ *= multiplier; + hess_ *= multiplier; + return *this; + } + + XGBOOST_DEVICE GradientPairInternal operator*(float multiplier) const { + GradientPairInternal g; + g.grad_ = grad_ * multiplier; + g.hess_ = hess_ * multiplier; + return g; + } + + XGBOOST_DEVICE GradientPairInternal &operator/=(float divisor) { + grad_ /= divisor; + hess_ /= divisor; + return *this; + } + + XGBOOST_DEVICE GradientPairInternal operator/(float divisor) const { + GradientPairInternal g; + g.grad_ = grad_ / divisor; + g.hess_ = hess_ / divisor; + return g; + } + + XGBOOST_DEVICE bool operator==(const GradientPairInternal &rhs) const { + return grad_ == rhs.grad_ && hess_ == rhs.hess_; + } + XGBOOST_DEVICE explicit GradientPairInternal(int value) { *this = GradientPairInternal(static_cast(value), static_cast(value)); diff --git a/src/common/compressed_iterator.h b/src/common/compressed_iterator.h index 6057aab17216..1c5543e43f5c 100644 --- a/src/common/compressed_iterator.h +++ b/src/common/compressed_iterator.h @@ -63,7 +63,7 @@ class CompressedBufferWriter { * \fn static size_t CompressedBufferWriter::CalculateBufferSize(int * num_elements, int num_symbols) * - * \brief Calculates number of bytes requiredm for a given number of elements + * \brief Calculates number of bytes required for a given number of elements * and a symbol range. * * \author Rory @@ -74,7 +74,6 @@ class CompressedBufferWriter { * * \return The calculated buffer size. */ - static size_t CalculateBufferSize(size_t num_elements, size_t num_symbols) { const int bits_per_byte = 8; size_t compressed_size = static_cast(std::ceil( @@ -188,7 +187,7 @@ class CompressedIterator { public: CompressedIterator() : buffer_(nullptr), symbol_bits_(0), offset_(0) {} - CompressedIterator(CompressedByteT *buffer, int num_symbols) + CompressedIterator(CompressedByteT *buffer, size_t num_symbols) : buffer_(buffer), offset_(0) { symbol_bits_ = detail::SymbolBits(num_symbols); } diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 6ec021ef3465..ad5d961c091f 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -1266,6 +1266,26 @@ thrust::device_ptr tcend(xgboost::HostDeviceVector const& vector) { return tcbegin(vector) + vector.Size(); } +template +thrust::device_ptr tbegin(xgboost::common::Span& span) { // NOLINT + return thrust::device_ptr(span.data()); +} + +template +thrust::device_ptr tend(xgboost::common::Span& span) { // // NOLINT + return tbegin(span) + span.size(); +} + +template +thrust::device_ptr tcbegin(xgboost::common::Span const& span) { + return thrust::device_ptr(span.data()); +} + +template +thrust::device_ptr tcend(xgboost::common::Span const& span) { + return tcbegin(span) + span.size(); +} + template class LauncherItr { public: diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index a58c74e6f60d..760d47b06f8d 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -64,6 +64,20 @@ __global__ void CompressBinEllpackKernel( wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + ifeature); } +// Construct an ELLPACK matrix with the given number of empty rows. +EllpackPageImpl::EllpackPageImpl(int device, EllpackInfo info, size_t n_rows) { + monitor_.Init("ellpack_page"); + dh::safe_cuda(cudaSetDevice(device)); + + matrix.info = info; + matrix.base_rowid = 0; + matrix.n_rows = n_rows; + + monitor_.StartCuda("InitCompressedData"); + InitCompressedData(device, n_rows); + monitor_.StopCuda("InitCompressedData"); +} + // Construct an ELLPACK matrix in memory. EllpackPageImpl::EllpackPageImpl(DMatrix* dmat, const BatchParam& param) { monitor_.Init("ellpack_page"); @@ -96,6 +110,85 @@ EllpackPageImpl::EllpackPageImpl(DMatrix* dmat, const BatchParam& param) { monitor_.StopCuda("BinningCompression"); } +// A functor that copies the data from one EllpackPage to another. +struct CopyPage { + common::CompressedBufferWriter cbw; + common::CompressedByteT* dst_data_d; + common::CompressedIterator src_iterator_d; + // The number of elements to skip. + size_t offset; + + CopyPage(EllpackPageImpl* dst, EllpackPageImpl* src, size_t offset) + : cbw{dst->matrix.info.NumSymbols()}, + dst_data_d{dst->gidx_buffer.data()}, + src_iterator_d{src->gidx_buffer.data(), src->matrix.info.NumSymbols()}, + offset(offset) {} + + __device__ void operator()(size_t element_id) { + cbw.AtomicWriteSymbol(dst_data_d, src_iterator_d[element_id], element_id + offset); + } +}; + +// Copy the data from the given EllpackPage to the current page. +size_t EllpackPageImpl::Copy(int device, EllpackPageImpl* page, size_t offset) { + monitor_.StartCuda("Copy"); + size_t num_elements = page->matrix.n_rows * page->matrix.info.row_stride; + CHECK_EQ(matrix.info.row_stride, page->matrix.info.row_stride); + CHECK_EQ(matrix.info.NumSymbols(), page->matrix.info.NumSymbols()); + CHECK_GE(matrix.n_rows * matrix.info.row_stride, offset + num_elements); + dh::LaunchN(device, num_elements, CopyPage(this, page, offset)); + monitor_.StopCuda("Copy"); + return num_elements; +} + +// A functor that compacts the rows from one EllpackPage into another. +struct CompactPage { + common::CompressedBufferWriter cbw; + common::CompressedByteT* dst_data_d; + common::CompressedIterator src_iterator_d; + /*! \brief An array that maps the rows from the full DMatrix to the compacted page. + * + * The total size is the number of rows in the original, uncompacted DMatrix. Elements are the + * row ids in the compacted page. Rows not needed are set to SIZE_MAX. + * + * An example compacting 16 rows to 8 rows: + * [SIZE_MAX, 0, 1, SIZE_MAX, SIZE_MAX, 2, SIZE_MAX, 3, 4, 5, SIZE_MAX, 6, SIZE_MAX, 7, SIZE_MAX, + * SIZE_MAX] + */ + common::Span row_indexes; + size_t base_rowid; + size_t row_stride; + + CompactPage(EllpackPageImpl* dst, EllpackPageImpl* src, common::Span row_indexes) + : cbw{dst->matrix.info.NumSymbols()}, + dst_data_d{dst->gidx_buffer.data()}, + src_iterator_d{src->gidx_buffer.data(), src->matrix.info.NumSymbols()}, + row_indexes(row_indexes), + base_rowid{src->matrix.base_rowid}, + row_stride{src->matrix.info.row_stride} {} + + __device__ void operator()(size_t row_id) { + size_t src_row = base_rowid + row_id; + size_t dst_row = row_indexes[src_row]; + if (dst_row == SIZE_MAX) return; + size_t dst_offset = dst_row * row_stride; + size_t src_offset = row_id * row_stride; + for (size_t j = 0; j < row_stride; j++) { + cbw.AtomicWriteSymbol(dst_data_d, src_iterator_d[src_offset + j], dst_offset + j); + } + } +}; + +// Compacts the data from the given EllpackPage into the current page. +void EllpackPageImpl::Compact(int device, EllpackPageImpl* page, common::Span row_indexes) { + monitor_.StartCuda("Compact"); + CHECK_EQ(matrix.info.row_stride, page->matrix.info.row_stride); + CHECK_EQ(matrix.info.NumSymbols(), page->matrix.info.NumSymbols()); + CHECK_LE(page->matrix.base_rowid + page->matrix.n_rows, row_indexes.size()); + dh::LaunchN(device, page->matrix.n_rows, CompactPage(this, page, row_indexes)); + monitor_.StopCuda("Compact"); +} + // Construct an EllpackInfo based on histogram cuts of features. EllpackInfo::EllpackInfo(int device, bool is_dense, @@ -123,16 +216,14 @@ void EllpackPageImpl::InitInfo(int device, // Initialize the buffer to stored compressed features. void EllpackPageImpl::InitCompressedData(int device, size_t num_rows) { - size_t num_symbols = matrix.info.n_bins + 1; + size_t num_symbols = matrix.info.NumSymbols(); // Required buffer size for storing data matrix in ELLPack format. size_t compressed_size_bytes = common::CompressedBufferWriter::CalculateBufferSize( matrix.info.row_stride * num_rows, num_symbols); ba_.Allocate(device, &gidx_buffer, compressed_size_bytes); - thrust::fill( - thrust::device_pointer_cast(gidx_buffer.data()), - thrust::device_pointer_cast(gidx_buffer.data() + gidx_buffer.size()), 0); + thrust::fill(dh::tbegin(gidx_buffer), dh::tend(gidx_buffer), 0); matrix.gidx_iter = common::CompressedIterator(gidx_buffer.data(), num_symbols); } @@ -149,7 +240,6 @@ void EllpackPageImpl::CreateHistIndices(int device, const auto& offset_vec = row_batch.offset.ConstHostVector(); - int num_symbols = matrix.info.n_bins + 1; // bin and compress entries in batches of rows size_t gpu_batch_nrows = std::min( dh::TotalMemory(device) / (16 * row_stride * sizeof(Entry)), @@ -193,7 +283,7 @@ void EllpackPageImpl::CreateHistIndices(int device, 1); dh::LaunchKernel {grid3, block3} ( CompressBinEllpackKernel, - common::CompressedBufferWriter(num_symbols), + common::CompressedBufferWriter(matrix.info.NumSymbols()), gidx_buffer.data(), row_ptrs.data().get(), entries_d.data().get(), @@ -254,11 +344,9 @@ void EllpackPageImpl::CompressSparsePage(int device) { // Return the memory cost for storing the compressed features. size_t EllpackPageImpl::MemCostBytes() const { - size_t num_symbols = matrix.info.n_bins + 1; - // Required buffer size for storing data matrix in ELLPack format. size_t compressed_size_bytes = common::CompressedBufferWriter::CalculateBufferSize( - matrix.info.row_stride * matrix.n_rows, num_symbols); + matrix.info.row_stride * matrix.n_rows, matrix.info.NumSymbols()); return compressed_size_bytes; } @@ -280,5 +368,4 @@ void EllpackPageImpl::InitDevice(int device, EllpackInfo info) { device_initialized_ = true; } - } // namespace xgboost diff --git a/src/data/ellpack_page.cuh b/src/data/ellpack_page.cuh index 47fd98910a95..fcf89ab8fe98 100644 --- a/src/data/ellpack_page.cuh +++ b/src/data/ellpack_page.cuh @@ -71,6 +71,11 @@ struct EllpackInfo { size_t row_stride, const common::HistogramCuts& hmat, dh::BulkAllocator* ba); + + /*! \brief Return the total number of symbols (total number of bins plus 1 for not found). */ + size_t NumSymbols() const { + return n_bins + 1; + } }; /** \brief Struct for accessing and manipulating an ellpack matrix on the @@ -200,6 +205,14 @@ class EllpackPageImpl { */ EllpackPageImpl() = default; + /*! + * \brief Constructor from an existing EllpackInfo. + * + * This is used in the sampling case. The ELLPACK page is constructed from an existing EllpackInfo + * and the given number of rows. + */ + explicit EllpackPageImpl(int device, EllpackInfo info, size_t n_rows); + /*! * \brief Constructor from an existing DMatrix. * @@ -208,6 +221,23 @@ class EllpackPageImpl { */ explicit EllpackPageImpl(DMatrix* dmat, const BatchParam& parm); + /*! \brief Copy the elements of the given ELLPACK page into this page. + * + * @param device The GPU device to use. + * @param page The ELLPACK page to copy from. + * @param offset The number of elements to skip before copying. + * @returns The number of elements copied. + */ + size_t Copy(int device, EllpackPageImpl* page, size_t offset); + + /*! \brief Compact the given ELLPACK page into the current page. + * + * @param device The GPU device to use. + * @param page The ELLPACK page to compact from. + * @param row_indexes Row indexes for the compacted page. + */ + void Compact(int device, EllpackPageImpl* page, common::Span row_indexes); + /*! * \brief Initialize the EllpackInfo contained in the EllpackMatrix. * diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu new file mode 100644 index 000000000000..f294855b5dbd --- /dev/null +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -0,0 +1,380 @@ +/*! + * Copyright 2019 by XGBoost Contributors + */ +#include +#include +#include +#include +#include + +#include +#include + +#include "../../common/compressed_iterator.h" +#include "../../common/random.h" +#include "gradient_based_sampler.cuh" + +namespace xgboost { +namespace tree { + +/*! \brief A functor that returns random weights. */ +class RandomWeight : public thrust::unary_function { + public: + explicit RandomWeight(size_t seed) : seed_(seed) {} + + XGBOOST_DEVICE float operator()(size_t i) const { + thrust::default_random_engine rng(seed_); + thrust::uniform_real_distribution dist; + rng.discard(i); + return dist(rng); + } + + private: + uint32_t seed_; +}; + +/*! \brief A functor that performs a Bernoulli trial to discard a gradient pair. */ +class BernoulliTrial : public thrust::unary_function { + public: + BernoulliTrial(size_t seed, float p) : rnd_(seed), p_(p) {} + + XGBOOST_DEVICE bool operator()(size_t i) const { + return rnd_(i) > p_; + } + + private: + RandomWeight rnd_; + float p_; +}; + +/*! \brief A functor that returns true if the gradient pair is non-zero. */ +struct IsNonZero : public thrust::unary_function { + XGBOOST_DEVICE bool operator()(const GradientPair& gpair) const { + return gpair.GetGrad() != 0 || gpair.GetHess() != 0; + } +}; + +/*! \brief A functor that clears the row indexes with empty gradient. */ +struct ClearEmptyRows : public thrust::binary_function { + XGBOOST_DEVICE size_t operator()(const GradientPair& gpair, size_t row_index) const { + if (gpair.GetGrad() != 0 || gpair.GetHess() != 0) { + return row_index; + } else { + return std::numeric_limits::max(); + } + } +}; + +/*! \brief A functor that combines the gradient pair into a single float. + * + * The approach here is based on Minimal Variance Sampling (MVS), with lambda set to 0.1. + * + * \see Ibragimov, B., & Gusev, G. (2019). Minimal Variance Sampling in Stochastic Gradient + * Boosting. In Advances in Neural Information Processing Systems (pp. 15061-15071). + */ +class CombineGradientPair : public thrust::unary_function { + public: + XGBOOST_DEVICE float operator()(const GradientPair& gpair) const { + return sqrtf(powf(gpair.GetGrad(), 2) + kLambda * powf(gpair.GetHess(), 2)); + } + + private: + static constexpr float kLambda = 0.1f; +}; + +/*! \brief A functor that calculates the difference between the sample rate and the desired sample + * rows, given a cumulative gradient sum. + */ +class SampleRateDelta : public thrust::binary_function { + public: + SampleRateDelta(common::Span threshold, size_t n_rows, size_t sample_rows) + : threshold_(threshold), n_rows_(n_rows), sample_rows_(sample_rows) {} + + XGBOOST_DEVICE float operator()(float gradient_sum, size_t row_index) const { + float lower = threshold_[row_index]; + float upper = threshold_[row_index + 1]; + float u = gradient_sum / static_cast(sample_rows_ - n_rows_ + row_index + 1); + if (u > lower && u <= upper) { + threshold_[row_index + 1] = u; + return 0.0f; + } else { + return std::numeric_limits::max(); + } + } + + private: + common::Span threshold_; + size_t n_rows_; + size_t sample_rows_; +}; + +/*! \brief A functor that performs Poisson sampling, and scales gradient pairs by 1/p_i. */ +class PoissonSampling : public thrust::binary_function { + public: + PoissonSampling(common::Span threshold, size_t threshold_index, RandomWeight rnd) + : threshold_(threshold), threshold_index_(threshold_index), rnd_(rnd) {} + + XGBOOST_DEVICE GradientPair operator()(const GradientPair& gpair, size_t i) { + // If the gradient and hessian are both empty, we should never select this row. + if (gpair.GetGrad() == 0 && gpair.GetHess() == 0) { + return gpair; + } + float combined_gradient = combine_(gpair); + float u = threshold_[threshold_index_]; + float p = combined_gradient / u; + if (p >= 1) { + // Always select this row. + return gpair; + } else { + // Select this row randomly with probability proportional to the combined gradient. + // Scale gpair by 1/p. + if (rnd_(i) <= p) { + return gpair / p; + } else { + return GradientPair(); + } + } + } + + private: + common::Span threshold_; + size_t threshold_index_; + RandomWeight rnd_; + CombineGradientPair combine_; +}; + +NoSampling::NoSampling(EllpackPageImpl* page) : page_(page) {} + +GradientBasedSample NoSampling::Sample(common::Span gpair, DMatrix* dmat) { + return {dmat->Info().num_row_, page_, gpair}; +} + +ExternalMemoryNoSampling::ExternalMemoryNoSampling(EllpackPageImpl* page, + size_t n_rows, + const BatchParam& batch_param) + : batch_param_(batch_param), + page_(new EllpackPageImpl(batch_param.gpu_id, page->matrix.info, n_rows)) {} + +GradientBasedSample ExternalMemoryNoSampling::Sample(common::Span gpair, + DMatrix* dmat) { + if (!page_concatenated_) { + // Concatenate all the external memory ELLPACK pages into a single in-memory page. + size_t offset = 0; + for (auto& batch : dmat->GetBatches(batch_param_)) { + auto page = batch.Impl(); + size_t num_elements = page_->Copy(batch_param_.gpu_id, page, offset); + offset += num_elements; + } + page_concatenated_ = true; + } + return {dmat->Info().num_row_, page_.get(), gpair}; +} + +UniformSampling::UniformSampling(EllpackPageImpl* page, float subsample) + : page_(page), subsample_(subsample) {} + +GradientBasedSample UniformSampling::Sample(common::Span gpair, DMatrix* dmat) { + // Set gradient pair to 0 with p = 1 - subsample + thrust::replace_if(dh::tbegin(gpair), dh::tend(gpair), + thrust::counting_iterator(0), + BernoulliTrial(common::GlobalRandom()(), subsample_), + GradientPair()); + return {dmat->Info().num_row_, page_, gpair}; +} + +ExternalMemoryUniformSampling::ExternalMemoryUniformSampling(EllpackPageImpl* page, + size_t n_rows, + const BatchParam& batch_param, + float subsample) + : original_page_(page), batch_param_(batch_param), subsample_(subsample) { + ba_.Allocate(batch_param_.gpu_id, &sample_row_index_, n_rows); +} + +GradientBasedSample ExternalMemoryUniformSampling::Sample(common::Span gpair, + DMatrix* dmat) { + // Set gradient pair to 0 with p = 1 - subsample + thrust::replace_if(dh::tbegin(gpair), dh::tend(gpair), + thrust::counting_iterator(0), + BernoulliTrial(common::GlobalRandom()(), subsample_), + GradientPair()); + + // Count the sampled rows. + size_t sample_rows = thrust::count_if(dh::tbegin(gpair), dh::tend(gpair), IsNonZero()); + size_t n_rows = dmat->Info().num_row_; + + // Compact gradient pairs. + gpair_.resize(sample_rows); + thrust::copy_if(dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero()); + + // Index the sample rows. + thrust::transform(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(sample_row_index_), IsNonZero()); + thrust::exclusive_scan(dh::tbegin(sample_row_index_), dh::tend(sample_row_index_), + dh::tbegin(sample_row_index_)); + thrust::transform(dh::tbegin(gpair), dh::tend(gpair), + dh::tbegin(sample_row_index_), + dh::tbegin(sample_row_index_), + ClearEmptyRows()); + + // Create a new ELLPACK page with empty rows. + page_.reset(); // Release the device memory first before reallocating + page_.reset(new EllpackPageImpl(batch_param_.gpu_id, + original_page_->matrix.info, + sample_rows)); + + // Compact the ELLPACK pages into the single sample page. + thrust::fill(dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0); + for (auto& batch : dmat->GetBatches(batch_param_)) { + page_->Compact(batch_param_.gpu_id, batch.Impl(), sample_row_index_); + } + + return {sample_rows, page_.get(), dh::ToSpan(gpair_)}; +} + +GradientBasedSampling::GradientBasedSampling(EllpackPageImpl* page, + size_t n_rows, + const BatchParam& batch_param, + float subsample) : page_(page), subsample_(subsample) { + ba_.Allocate(batch_param.gpu_id, + &threshold_, n_rows + 1, + &grad_sum_, n_rows); +} + +GradientBasedSample GradientBasedSampling::Sample(common::Span gpair, + DMatrix* dmat) { + size_t n_rows = dmat->Info().num_row_; + size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex( + gpair, threshold_, grad_sum_, n_rows * subsample_); + + // Perform Poisson sampling in place. + thrust::transform(dh::tbegin(gpair), dh::tend(gpair), + thrust::counting_iterator(0), + dh::tbegin(gpair), + PoissonSampling(threshold_, + threshold_index, + RandomWeight(common::GlobalRandom()()))); + return {n_rows, page_, gpair}; +} + +ExternalMemoryGradientBasedSampling::ExternalMemoryGradientBasedSampling( + EllpackPageImpl* page, + size_t n_rows, + const BatchParam& batch_param, + float subsample) : original_page_(page), batch_param_(batch_param), subsample_(subsample) { + ba_.Allocate(batch_param.gpu_id, + &threshold_, n_rows + 1, + &grad_sum_, n_rows, + &sample_row_index_, n_rows); +} + +GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(common::Span gpair, + DMatrix* dmat) { + size_t n_rows = dmat->Info().num_row_; + size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex( + gpair, threshold_, grad_sum_, n_rows * subsample_); + + // Perform Poisson sampling in place. + thrust::transform(dh::tbegin(gpair), dh::tend(gpair), + thrust::counting_iterator(0), + dh::tbegin(gpair), + PoissonSampling(threshold_, + threshold_index, + RandomWeight(common::GlobalRandom()()))); + + // Count the sampled rows. + size_t sample_rows = thrust::count_if(dh::tbegin(gpair), dh::tend(gpair), IsNonZero()); + + // Compact gradient pairs. + gpair_.resize(sample_rows); + thrust::copy_if(dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero()); + + // Index the sample rows. + thrust::transform(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(sample_row_index_), IsNonZero()); + thrust::exclusive_scan(dh::tbegin(sample_row_index_), dh::tend(sample_row_index_), + dh::tbegin(sample_row_index_)); + thrust::transform(dh::tbegin(gpair), dh::tend(gpair), + dh::tbegin(sample_row_index_), + dh::tbegin(sample_row_index_), + ClearEmptyRows()); + + // Create a new ELLPACK page with empty rows. + page_.reset(); // Release the device memory first before reallocating + page_.reset(new EllpackPageImpl(batch_param_.gpu_id, + original_page_->matrix.info, + sample_rows)); + + // Compact the ELLPACK pages into the single sample page. + thrust::fill(dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0); + for (auto& batch : dmat->GetBatches(batch_param_)) { + page_->Compact(batch_param_.gpu_id, batch.Impl(), sample_row_index_); + } + + return {sample_rows, page_.get(), dh::ToSpan(gpair_)}; +} + +GradientBasedSampler::GradientBasedSampler(EllpackPageImpl* page, + size_t n_rows, + const BatchParam& batch_param, + float subsample, + int sampling_method) { + monitor_.Init("gradient_based_sampler"); + + bool is_sampling = subsample < 1.0; + bool is_external_memory = page->matrix.n_rows != n_rows; + + if (is_sampling) { + switch (sampling_method) { + case TrainParam::kUniform: + if (is_external_memory) { + strategy_.reset(new ExternalMemoryUniformSampling(page, n_rows, batch_param, subsample)); + } else { + strategy_.reset(new UniformSampling(page, subsample)); + } + break; + case TrainParam::kGradientBased: + if (is_external_memory) { + strategy_.reset( + new ExternalMemoryGradientBasedSampling(page, n_rows, batch_param, subsample)); + } else { + strategy_.reset(new GradientBasedSampling(page, n_rows, batch_param, subsample)); + } + break; + default:LOG(FATAL) << "unknown sampling method"; + } + } else { + if (is_external_memory) { + strategy_.reset(new ExternalMemoryNoSampling(page, n_rows, batch_param)); + } else { + strategy_.reset(new NoSampling(page)); + } + } +} + +// Sample a DMatrix based on the given gradient pairs. +GradientBasedSample GradientBasedSampler::Sample(common::Span gpair, + DMatrix* dmat) { + monitor_.StartCuda("Sample"); + GradientBasedSample sample = strategy_->Sample(gpair, dmat); + monitor_.StopCuda("Sample"); + return sample; +} + +size_t GradientBasedSampler::CalculateThresholdIndex(common::Span gpair, + common::Span threshold, + common::Span grad_sum, + size_t sample_rows) { + thrust::fill(dh::tend(threshold) - 1, dh::tend(threshold), std::numeric_limits::max()); + thrust::transform(dh::tbegin(gpair), dh::tend(gpair), + dh::tbegin(threshold), + CombineGradientPair()); + thrust::sort(dh::tbegin(threshold), dh::tend(threshold) - 1); + thrust::inclusive_scan(dh::tbegin(threshold), dh::tend(threshold) - 1, dh::tbegin(grad_sum)); + thrust::transform(dh::tbegin(grad_sum), dh::tend(grad_sum), + thrust::counting_iterator(0), + dh::tbegin(grad_sum), + SampleRateDelta(threshold, gpair.size(), sample_rows)); + thrust::device_ptr min = thrust::min_element(dh::tbegin(grad_sum), dh::tend(grad_sum)); + return thrust::distance(dh::tbegin(grad_sum), min) + 1; +} + +}; // namespace tree +}; // namespace xgboost diff --git a/src/tree/gpu_hist/gradient_based_sampler.cuh b/src/tree/gpu_hist/gradient_based_sampler.cuh new file mode 100644 index 000000000000..41099e3bc134 --- /dev/null +++ b/src/tree/gpu_hist/gradient_based_sampler.cuh @@ -0,0 +1,153 @@ +/*! + * Copyright 2019 by XGBoost Contributors + */ +#pragma once +#include +#include +#include + +#include "../../common/device_helpers.cuh" +#include "../../data/ellpack_page.cuh" + +namespace xgboost { +namespace tree { + +struct GradientBasedSample { + /*!\brief Number of sampled rows. */ + size_t sample_rows; + /*!\brief Sampled rows in ELLPACK format. */ + EllpackPageImpl* page; + /*!\brief Gradient pairs for the sampled rows. */ + common::Span gpair; +}; + +class SamplingStrategy { + public: + /*! \brief Sample from a DMatrix based on the given gradient pairs. */ + virtual GradientBasedSample Sample(common::Span gpair, DMatrix* dmat) = 0; +}; + +/*! \brief No sampling in in-memory mode. */ +class NoSampling : public SamplingStrategy { + public: + explicit NoSampling(EllpackPageImpl* page); + GradientBasedSample Sample(common::Span gpair, DMatrix* dmat) override; + + private: + EllpackPageImpl* page_; +}; + +/*! \brief No sampling in external memory mode. */ +class ExternalMemoryNoSampling : public SamplingStrategy { + public: + ExternalMemoryNoSampling(EllpackPageImpl* page, + size_t n_rows, + const BatchParam& batch_param); + GradientBasedSample Sample(common::Span gpair, DMatrix* dmat) override; + + private: + BatchParam batch_param_; + std::unique_ptr page_; + bool page_concatenated_{false}; +}; + +/*! \brief Uniform sampling in in-memory mode. */ +class UniformSampling : public SamplingStrategy { + public: + UniformSampling(EllpackPageImpl* page, float subsample); + GradientBasedSample Sample(common::Span gpair, DMatrix* dmat) override; + + private: + EllpackPageImpl* page_; + float subsample_; +}; + +/*! \brief No sampling in external memory mode. */ +class ExternalMemoryUniformSampling : public SamplingStrategy { + public: + ExternalMemoryUniformSampling(EllpackPageImpl* page, + size_t n_rows, + const BatchParam& batch_param, + float subsample); + GradientBasedSample Sample(common::Span gpair, DMatrix* dmat) override; + + private: + dh::BulkAllocator ba_; + EllpackPageImpl* original_page_; + BatchParam batch_param_; + float subsample_; + std::unique_ptr page_; + dh::device_vector gpair_{}; + common::Span sample_row_index_; +}; + +/*! \brief Gradient-based sampling in in-memory mode.. */ +class GradientBasedSampling : public SamplingStrategy { + public: + GradientBasedSampling(EllpackPageImpl* page, + size_t n_rows, + const BatchParam& batch_param, + float subsample); + GradientBasedSample Sample(common::Span gpair, DMatrix* dmat) override; + + private: + EllpackPageImpl* page_; + float subsample_; + dh::BulkAllocator ba_; + common::Span threshold_; + common::Span grad_sum_; +}; + +/*! \brief Gradient-based sampling in external memory mode.. */ +class ExternalMemoryGradientBasedSampling : public SamplingStrategy { + public: + ExternalMemoryGradientBasedSampling(EllpackPageImpl* page, + size_t n_rows, + const BatchParam& batch_param, + float subsample); + GradientBasedSample Sample(common::Span gpair, DMatrix* dmat) override; + + private: + dh::BulkAllocator ba_; + EllpackPageImpl* original_page_; + BatchParam batch_param_; + float subsample_; + common::Span threshold_; + common::Span grad_sum_; + std::unique_ptr page_; + dh::device_vector gpair_; + common::Span sample_row_index_; +}; + +/*! \brief Draw a sample of rows from a DMatrix. + * + * \see Ke, G., Meng, Q., Finley, T., Wang, T., Chen, W., Ma, W., ... & Liu, T. Y. (2017). + * Lightgbm: A highly efficient gradient boosting decision tree. In Advances in Neural Information + * Processing Systems (pp. 3146-3154). + * \see Zhu, R. (2016). Gradient-based sampling: An adaptive importance sampling for least-squares. + * In Advances in Neural Information Processing Systems (pp. 406-414). + * \see Ohlsson, E. (1998). Sequential poisson sampling. Journal of official Statistics, 14(2), 149. + */ +class GradientBasedSampler { + public: + GradientBasedSampler(EllpackPageImpl* page, + size_t n_rows, + const BatchParam& batch_param, + float subsample, + int sampling_method); + + /*! \brief Sample from a DMatrix based on the given gradient pairs. */ + GradientBasedSample Sample(common::Span gpair, DMatrix* dmat); + + /*! \brief Calculate the threshold used to normalize sampling probabilities. */ + static size_t CalculateThresholdIndex(common::Span gpair, + common::Span threshold, + common::Span grad_sum, + size_t sample_rows); + + private: + common::Monitor monitor_; + std::unique_ptr strategy_; +}; +}; // namespace tree +}; // namespace xgboost diff --git a/src/tree/gpu_hist/row_partitioner.cuh b/src/tree/gpu_hist/row_partitioner.cuh index 2d9500faf4a7..4818d71abc9f 100644 --- a/src/tree/gpu_hist/row_partitioner.cuh +++ b/src/tree/gpu_hist/row_partitioner.cuh @@ -125,7 +125,6 @@ class RowPartitioner { idx += segment.begin; RowIndexT ridx = d_ridx[idx]; bst_node_t new_position = op(ridx); // new node id - if (new_position == kIgnoredTreePosition) return; KERNEL_CHECK(new_position == left_nidx || new_position == right_nidx); AtomicIncrement(d_left_count, new_position == left_nidx); d_position[idx] = new_position; diff --git a/src/tree/param.h b/src/tree/param.h index 3d991a6f1738..da63895a5ef8 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -50,6 +50,9 @@ struct TrainParam : public XGBoostParameter { float max_delta_step; // whether we want to do subsample float subsample; + // sampling method + enum SamplingMethod { kUniform = 0, kGradientBased = 1 }; + int sampling_method; // whether to subsample columns in each split (node) float colsample_bynode; // whether to subsample columns in each level @@ -144,6 +147,14 @@ struct TrainParam : public XGBoostParameter { .set_range(0.0f, 1.0f) .set_default(1.0f) .describe("Row subsample ratio of training instance."); + DMLC_DECLARE_FIELD(sampling_method) + .set_default(kUniform) + .add_enum("uniform", kUniform) + .add_enum("gradient_based", kGradientBased) + .describe( + "Sampling method. 0: select random training instances uniformly. " + "1: select random training instances with higher probability when the " + "gradient and hessian are larger. (cf. CatBoost)"); DMLC_DECLARE_FIELD(colsample_bynode) .set_range(0.0f, 1.0f) .set_default(1.0f) diff --git a/src/tree/updater_basemaker-inl.h b/src/tree/updater_basemaker-inl.h index f806628dd023..8761e39901bd 100644 --- a/src/tree/updater_basemaker-inl.h +++ b/src/tree/updater_basemaker-inl.h @@ -148,6 +148,9 @@ class BaseMaker: public TreeUpdater { } // mark subsample if (param_.subsample < 1.0f) { + CHECK_EQ(param_.sampling_method, TrainParam::kUniform) + << "Only uniform sampling is supported, " + << "gradient-based sampling is only support by GPU Hist."; std::bernoulli_distribution coin_flip(param_.subsample); auto& rnd = common::GlobalRandom(); for (size_t i = 0; i < position_.size(); ++i) { diff --git a/src/tree/updater_colmaker.cc b/src/tree/updater_colmaker.cc index 0fb9c067c581..5b0f859c36d3 100644 --- a/src/tree/updater_colmaker.cc +++ b/src/tree/updater_colmaker.cc @@ -202,6 +202,9 @@ class ColMaker: public TreeUpdater { } // mark subsample if (param_.subsample < 1.0f) { + CHECK_EQ(param_.sampling_method, TrainParam::kUniform) + << "Only uniform sampling is supported, " + << "gradient-based sampling is only support by GPU Hist."; std::bernoulli_distribution coin_flip(param_.subsample); auto& rnd = common::GlobalRandom(); for (size_t ridx = 0; ridx < position_.size(); ++ridx) { diff --git a/src/tree/updater_gpu_common.cuh b/src/tree/updater_gpu_common.cuh index dacb32f0aac4..b64eeaae6e70 100644 --- a/src/tree/updater_gpu_common.cuh +++ b/src/tree/updater_gpu_common.cuh @@ -187,41 +187,5 @@ XGBOOST_DEVICE inline int MaxNodesDepth(int depth) { return (1 << (depth + 1)) - 1; } -/* - * Random - */ -struct BernoulliRng { - float p; - uint32_t seed; - - XGBOOST_DEVICE BernoulliRng(float p, size_t seed_) : p(p) { - seed = static_cast(seed_); - } - - XGBOOST_DEVICE bool operator()(const int i) const { - thrust::default_random_engine rng(seed); - thrust::uniform_real_distribution dist; - rng.discard(i); - return dist(rng) <= p; - } -}; - -// Set gradient pair to 0 with p = 1 - subsample -inline void SubsampleGradientPair(int device_idx, - common::Span d_gpair, - float subsample, int offset = 0) { - if (subsample == 1.0) { - return; - } - - BernoulliRng rng(subsample, common::GlobalRandom()()); - - dh::LaunchN(device_idx, d_gpair.size(), [=] XGBOOST_DEVICE(int i) { - if (!rng(i + offset)) { - d_gpair[i] = GradientPair(); - } - }); -} - } // namespace tree } // namespace xgboost diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index f8b7731c9b33..eafa227b05e3 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -29,6 +29,7 @@ #include "param.h" #include "updater_gpu_common.cuh" #include "constraints.cuh" +#include "gpu_hist/gradient_based_sampler.cuh" #include "gpu_hist/row_partitioner.cuh" namespace xgboost { @@ -415,11 +416,8 @@ __global__ void SharedMemHistKernel(xgboost::EllpackMatrix matrix, } for (auto idx : dh::GridStrideRange(static_cast(0), n_elements)) { int ridx = d_ridx[idx / matrix.info.row_stride]; - if (!matrix.IsInRange(ridx)) { - continue; - } - int gidx = matrix.gidx_iter[(ridx - matrix.base_rowid) * matrix.info.row_stride - + idx % matrix.info.row_stride]; + int gidx = + matrix.gidx_iter[ridx * matrix.info.row_stride + idx % matrix.info.row_stride]; if (gidx != matrix.info.n_bins) { // If we are not using shared memory, accumulate the values directly into // global memory @@ -480,6 +478,8 @@ struct GPUHistMakerDevice { std::function>; std::unique_ptr qexpand; + std::unique_ptr sampler; + GPUHistMakerDevice(int _device_id, EllpackPageImpl* _page, bst_uint _n_rows, @@ -495,6 +495,11 @@ struct GPUHistMakerDevice { column_sampler(column_sampler_seed), interaction_constraints(param, n_features), batch_param(_batch_param) { + sampler.reset(new GradientBasedSampler(page, + n_rows, + batch_param, + param.subsample, + param.sampling_method)); monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(device_id)); } @@ -528,7 +533,7 @@ struct GPUHistMakerDevice { // Reset values for each update iteration // Note that the column sampler must be passed by value because it is not // thread safe - void Reset(HostDeviceVector* dh_gpair, int64_t num_columns) { + void Reset(HostDeviceVector* dh_gpair, DMatrix* dmat, int64_t num_columns) { if (param.grow_policy == TrainParam::kLossGuide) { qexpand.reset(new ExpandQueue(LossGuide)); } else { @@ -540,13 +545,14 @@ struct GPUHistMakerDevice { this->interaction_constraints.Reset(); std::fill(node_sum_gradients.begin(), node_sum_gradients.end(), GradientPair()); + + auto sample = sampler->Sample(dh_gpair->DeviceSpan(), dmat); + n_rows = sample.sample_rows; + page = sample.page; + gpair = sample.gpair; + row_partitioner.reset(); // Release the device memory first before reallocating row_partitioner.reset(new RowPartitioner(device_id, n_rows)); - - dh::safe_cuda(cudaMemcpyAsync( - gpair.data(), dh_gpair->ConstDevicePointer(), - gpair.size() * sizeof(GradientPair), cudaMemcpyHostToHost)); - SubsampleGradientPair(device_id, gpair, param.subsample); hist.Reset(); } @@ -632,14 +638,6 @@ struct GPUHistMakerDevice { return std::vector(result_all.begin(), result_all.end()); } - // Build gradient histograms for a given node across all the batches in the DMatrix. - void BuildHistBatches(int nidx, DMatrix* p_fmat) { - for (auto& batch : p_fmat->GetBatches(batch_param)) { - page = batch.Impl(); - BuildHist(nidx); - } - } - void BuildHist(int nidx) { hist.AllocateHistogram(nidx); auto d_node_hist = hist.GetNodeHistogram(nidx); @@ -687,10 +685,7 @@ struct GPUHistMakerDevice { row_partitioner->UpdatePosition( nidx, split_node.LeftChild(), split_node.RightChild(), - [=] __device__(size_t ridx) { - if (!d_matrix.IsInRange(ridx)) { - return RowPartitioner::kIgnoredTreePosition; - } + [=] __device__(bst_uint ridx) { // given a row index, returns the node id it belongs to bst_float cut_value = d_matrix.GetElement(ridx, split_node.SplitIndex()); @@ -719,35 +714,46 @@ struct GPUHistMakerDevice { d_nodes.size() * sizeof(RegTree::Node), cudaMemcpyHostToDevice)); - for (auto& batch : p_fmat->GetBatches(batch_param)) { - page = batch.Impl(); - auto d_matrix = page->matrix; - row_partitioner->FinalisePosition( - [=] __device__(size_t row_id, int position) { - if (!d_matrix.IsInRange(row_id)) { - return RowPartitioner::kIgnoredTreePosition; - } - auto node = d_nodes[position]; - - while (!node.IsLeaf()) { - bst_float element = d_matrix.GetElement(row_id, node.SplitIndex()); - // Missing value - if (isnan(element)) { - position = node.DefaultChild(); - } else { - if (element <= node.SplitCond()) { - position = node.LeftChild(); - } else { - position = node.RightChild(); - } - } - node = d_nodes[position]; - } - return position; - }); + if (row_partitioner->GetRows().size() != p_fmat->Info().num_row_) { + row_partitioner.reset(); // Release the device memory first before reallocating + row_partitioner.reset(new RowPartitioner(device_id, p_fmat->Info().num_row_)); + } + if (page->matrix.n_rows == p_fmat->Info().num_row_) { + FinalisePositionInPage(page, d_nodes); + } else { + for (auto& batch : p_fmat->GetBatches(batch_param)) { + FinalisePositionInPage(batch.Impl(), d_nodes); + } } } + void FinalisePositionInPage(EllpackPageImpl* page, const common::Span d_nodes) { + auto d_matrix = page->matrix; + row_partitioner->FinalisePosition( + [=] __device__(size_t row_id, int position) { + if (!d_matrix.IsInRange(row_id)) { + return RowPartitioner::kIgnoredTreePosition; + } + auto node = d_nodes[position]; + + while (!node.IsLeaf()) { + bst_float element = d_matrix.GetElement(row_id, node.SplitIndex()); + // Missing value + if (isnan(element)) { + position = node.DefaultChild(); + } else { + if (element <= node.SplitCond()) { + position = node.LeftChild(); + } else { + position = node.RightChild(); + } + } + node = d_nodes[position]; + } + return position; + }); + } + void UpdatePredictionCache(bst_float* out_preds_d) { dh::safe_cuda(cudaSetDevice(device_id)); if (!prediction_cache_initialised) { @@ -797,7 +803,8 @@ struct GPUHistMakerDevice { /** * \brief Build GPU local histograms for the left and right child of some parent node */ - void BuildHistLeftRight(const ExpandEntry &candidate, int nidx_left, int nidx_right) { + void BuildHistLeftRight(const ExpandEntry &candidate, int nidx_left, + int nidx_right, dh::AllReducer* reducer) { auto build_hist_nidx = nidx_left; auto subtraction_trick_nidx = nidx_right; @@ -809,34 +816,6 @@ struct GPUHistMakerDevice { } this->BuildHist(build_hist_nidx); - - // Check whether we can use the subtraction trick to calculate the other - bool do_subtraction_trick = this->CanDoSubtractionTrick( - candidate.nid, build_hist_nidx, subtraction_trick_nidx); - - if (!do_subtraction_trick) { - // Calculate other histogram manually - this->BuildHist(subtraction_trick_nidx); - } - } - - /** - * \brief AllReduce GPU histograms for the left and right child of some parent node. - */ - void ReduceHistLeftRight(const ExpandEntry& candidate, - int nidx_left, - int nidx_right, - dh::AllReducer* reducer) { - auto build_hist_nidx = nidx_left; - auto subtraction_trick_nidx = nidx_right; - - // Decide whether to build the left histogram or right histogram - // Use sum of Hessian as a heuristic to select node with fewest training instances - bool fewer_right = candidate.split.right_sum.GetHess() < candidate.split.left_sum.GetHess(); - if (fewer_right) { - std::swap(build_hist_nidx, subtraction_trick_nidx); - } - this->AllReduceHist(build_hist_nidx, reducer); // Check whether we can use the subtraction trick to calculate the other @@ -849,6 +828,7 @@ struct GPUHistMakerDevice { subtraction_trick_nidx); } else { // Calculate other histogram manually + this->BuildHist(subtraction_trick_nidx); this->AllReduceHist(subtraction_trick_nidx, reducer); } } @@ -889,14 +869,10 @@ struct GPUHistMakerDevice { tree[candidate.nid].RightChild()); } - void InitRoot(RegTree* p_tree, HostDeviceVector* gpair_all, DMatrix* p_fmat, - dh::AllReducer* reducer, int64_t num_columns) { + void InitRoot(RegTree* p_tree, dh::AllReducer* reducer, int64_t num_columns) { constexpr int kRootNIdx = 0; - const auto &gpair = gpair_all->DeviceSpan(); - - dh::SumReduction(temp_memory, gpair, node_sum_gradients_d, - gpair.size()); + dh::SumReduction(temp_memory, gpair, node_sum_gradients_d, gpair.size()); reducer->AllReduceSum( reinterpret_cast(node_sum_gradients_d.data()), reinterpret_cast(node_sum_gradients_d.data()), 2); @@ -905,7 +881,7 @@ struct GPUHistMakerDevice { node_sum_gradients_d.data(), sizeof(GradientPair), cudaMemcpyDeviceToHost)); - this->BuildHistBatches(kRootNIdx, p_fmat); + this->BuildHist(kRootNIdx); this->AllReduceHist(kRootNIdx, reducer); // Remember root stats @@ -928,11 +904,11 @@ struct GPUHistMakerDevice { auto& tree = *p_tree; monitor.StartCuda("Reset"); - this->Reset(gpair_all, p_fmat->Info().num_col_); + this->Reset(gpair_all, p_fmat, p_fmat->Info().num_col_); monitor.StopCuda("Reset"); monitor.StartCuda("InitRoot"); - this->InitRoot(p_tree, gpair_all, p_fmat, reducer, p_fmat->Info().num_col_); + this->InitRoot(p_tree, reducer, p_fmat->Info().num_col_); monitor.StopCuda("InitRoot"); auto timestamp = qexpand->size(); @@ -951,21 +927,15 @@ struct GPUHistMakerDevice { int left_child_nidx = tree[candidate.nid].LeftChild(); int right_child_nidx = tree[candidate.nid].RightChild(); // Only create child entries if needed - if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx), num_leaves)) { - for (auto& batch : p_fmat->GetBatches(batch_param)) { - page = batch.Impl(); - - monitor.StartCuda("UpdatePosition"); - this->UpdatePosition(candidate.nid, (*p_tree)[candidate.nid]); - monitor.StopCuda("UpdatePosition"); + if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx), + num_leaves)) { + monitor.StartCuda("UpdatePosition"); + this->UpdatePosition(candidate.nid, (*p_tree)[candidate.nid]); + monitor.StopCuda("UpdatePosition"); - monitor.StartCuda("BuildHist"); - this->BuildHistLeftRight(candidate, left_child_nidx, right_child_nidx); - monitor.StopCuda("BuildHist"); - } - monitor.StartCuda("ReduceHist"); - this->ReduceHistLeftRight(candidate, left_child_nidx, right_child_nidx, reducer); - monitor.StopCuda("ReduceHist"); + monitor.StartCuda("BuildHist"); + this->BuildHistLeftRight(candidate, left_child_nidx, right_child_nidx, reducer); + monitor.StopCuda("BuildHist"); monitor.StartCuda("EvaluateSplits"); auto splits = this->EvaluateSplits({left_child_nidx, right_child_nidx}, @@ -997,7 +967,6 @@ inline void GPUHistMakerDevice::InitHistogram() { param.max_leaves > 0 ? param.max_leaves * 2 : MaxNodesDepth(param.max_depth); ba.Allocate(device_id, - &gpair, n_rows, &prediction_cache, n_rows, &node_sum_gradients_d, max_nodes, &monotone_constraints, param.monotone_constraints.size()); diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index d1d206587c0b..11a7bdaf2123 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -534,6 +534,9 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat, // mark subsample and build list of member rows if (param_.subsample < 1.0f) { + CHECK_EQ(param_.sampling_method, TrainParam::kUniform) + << "Only uniform sampling is supported, " + << "gradient-based sampling is only support by GPU Hist."; std::bernoulli_distribution coin_flip(param_.subsample); auto& rnd = common::GlobalRandom(); size_t j = 0; diff --git a/tests/cpp/data/test_ellpack_page.cu b/tests/cpp/data/test_ellpack_page.cu index 6dd97ef7ca05..6479e9feea93 100644 --- a/tests/cpp/data/test_ellpack_page.cu +++ b/tests/cpp/data/test_ellpack_page.cu @@ -81,4 +81,119 @@ TEST(EllpackPage, BuildGidxSparse) { } } +struct ReadRowFunction { + EllpackMatrix matrix; + int row; + bst_float* row_data_d; + ReadRowFunction(EllpackMatrix matrix, int row, bst_float* row_data_d) + : matrix(std::move(matrix)), row(row), row_data_d(row_data_d) {} + + __device__ void operator()(size_t col) { + auto value = matrix.GetElement(row, col); + if (isnan(value)) { + value = -1; + } + row_data_d[col] = value; + } +}; + +TEST(EllpackPage, Copy) { + constexpr size_t kRows = 1024; + constexpr size_t kCols = 16; + constexpr size_t kPageSize = 1024; + + // Create a DMatrix with multiple batches. + dmlc::TemporaryDirectory tmpdir; + std::unique_ptr + dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir)); + BatchParam param{0, 256, 0, kPageSize}; + auto page = (*dmat->GetBatches(param).begin()).Impl(); + + // Create an empty result page. + EllpackPageImpl result(0, page->matrix.info, kRows); + + // Copy batch pages into the result page. + size_t offset = 0; + for (auto& batch : dmat->GetBatches(param)) { + size_t num_elements = result.Copy(0, batch.Impl(), offset); + offset += num_elements; + } + + size_t current_row = 0; + thrust::device_vector row_d(kCols); + thrust::device_vector row_result_d(kCols); + std::vector row(kCols); + std::vector row_result(kCols); + for (auto& page : dmat->GetBatches(param)) { + auto impl = page.Impl(); + EXPECT_EQ(impl->matrix.base_rowid, current_row); + + for (size_t i = 0; i < impl->Size(); i++) { + dh::LaunchN(0, kCols, ReadRowFunction(impl->matrix, current_row, row_d.data().get())); + thrust::copy(row_d.begin(), row_d.end(), row.begin()); + + dh::LaunchN(0, kCols, ReadRowFunction(result.matrix, current_row, row_result_d.data().get())); + thrust::copy(row_result_d.begin(), row_result_d.end(), row_result.begin()); + + EXPECT_EQ(row, row_result); + current_row++; + } + } +} + +TEST(EllpackPage, Compact) { + constexpr size_t kRows = 16; + constexpr size_t kCols = 2; + constexpr size_t kPageSize = 1; + constexpr size_t kCompactedRows = 8; + + // Create a DMatrix with multiple batches. + dmlc::TemporaryDirectory tmpdir; + std::unique_ptr + dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir)); + BatchParam param{0, 256, 0, kPageSize}; + auto page = (*dmat->GetBatches(param).begin()).Impl(); + + // Create an empty result page. + EllpackPageImpl result(0, page->matrix.info, kCompactedRows); + + // Compact batch pages into the result page. + std::vector row_indexes_h { + SIZE_MAX, 0, 1, 2, SIZE_MAX, 3, SIZE_MAX, 4, 5, SIZE_MAX, 6, SIZE_MAX, 7, SIZE_MAX, SIZE_MAX, + SIZE_MAX}; + thrust::device_vector row_indexes_d = row_indexes_h; + common::Span row_indexes_span(row_indexes_d.data().get(), kRows); + for (auto& batch : dmat->GetBatches(param)) { + result.Compact(0, batch.Impl(), row_indexes_span); + } + + size_t current_row = 0; + thrust::device_vector row_d(kCols); + thrust::device_vector row_result_d(kCols); + std::vector row(kCols); + std::vector row_result(kCols); + for (auto& page : dmat->GetBatches(param)) { + auto impl = page.Impl(); + EXPECT_EQ(impl->matrix.base_rowid, current_row); + + for (size_t i = 0; i < impl->Size(); i++) { + size_t compacted_row = row_indexes_h[current_row]; + if (compacted_row == SIZE_MAX) { + current_row++; + continue; + } + + dh::LaunchN(0, kCols, ReadRowFunction(impl->matrix, current_row, row_d.data().get())); + thrust::copy(row_d.begin(), row_d.end(), row.begin()); + + dh::LaunchN(0, kCols, + ReadRowFunction(result.matrix, compacted_row, row_result_d.data().get())); + thrust::copy(row_result_d.begin(), row_result_d.end(), row_result.begin()); + + EXPECT_EQ(row, row_result); + current_row++; + } + } +} + } // namespace xgboost diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index 10101d92983c..ca124234337e 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -221,6 +221,19 @@ inline GenericParameter CreateEmptyGenericParam(int gpu_id) { return tparam; } +inline HostDeviceVector GenerateRandomGradients(const size_t n_rows) { + xgboost::SimpleLCG gen; + xgboost::SimpleRealUniformDistribution dist(0.0f, 1.0f); + std::vector h_gpair(n_rows); + for (auto &gpair : h_gpair) { + bst_float grad = dist(&gen); + bst_float hess = dist(&gen); + gpair = GradientPair(grad, hess); + } + HostDeviceVector gpair(h_gpair); + return gpair; +} + #if defined(__CUDACC__) namespace { class HistogramCutsWrapper : public common::HistogramCuts { diff --git a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu new file mode 100644 index 000000000000..579436245c7f --- /dev/null +++ b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu @@ -0,0 +1,150 @@ +#include + +#include "../../../../src/data/ellpack_page.cuh" +#include "../../../../src/tree/gpu_hist/gradient_based_sampler.cuh" +#include "../../helpers.h" + +namespace xgboost { +namespace tree { + +void VerifySampling(size_t page_size, + float subsample, + int sampling_method, + bool fixed_size_sampling = true, + bool check_sum = true) { + constexpr size_t kRows = 4096; + constexpr size_t kCols = 1; + size_t sample_rows = kRows * subsample; + + dmlc::TemporaryDirectory tmpdir; + std::unique_ptr dmat( + CreateSparsePageDMatrixWithRC(kRows, kCols, page_size, true, tmpdir)); + auto gpair = GenerateRandomGradients(kRows); + GradientPair sum_gpair{}; + for (const auto& gp : gpair.ConstHostVector()) { + sum_gpair += gp; + } + gpair.SetDevice(0); + + BatchParam param{0, 256, 0, page_size}; + auto page = (*dmat->GetBatches(param).begin()).Impl(); + if (page_size != 0) { + EXPECT_NE(page->matrix.n_rows, kRows); + } + + GradientBasedSampler sampler(page, kRows, param, subsample, sampling_method); + auto sample = sampler.Sample(gpair.DeviceSpan(), dmat.get()); + + if (fixed_size_sampling) { + EXPECT_EQ(sample.sample_rows, kRows); + EXPECT_EQ(sample.page->matrix.n_rows, kRows); + EXPECT_EQ(sample.gpair.size(), kRows); + } else { + EXPECT_NEAR(sample.sample_rows, sample_rows, kRows * 0.012f); + EXPECT_NEAR(sample.page->matrix.n_rows, sample_rows, kRows * 0.012f); + EXPECT_NEAR(sample.gpair.size(), sample_rows, kRows * 0.012f); + } + + GradientPair sum_sampled_gpair{}; + std::vector sampled_gpair_h(sample.gpair.size()); + dh::CopyDeviceSpanToVector(&sampled_gpair_h, sample.gpair); + for (const auto& gp : sampled_gpair_h) { + sum_sampled_gpair += gp; + } + if (check_sum) { + EXPECT_NEAR(sum_gpair.GetGrad(), sum_sampled_gpair.GetGrad(), 0.02f * kRows); + EXPECT_NEAR(sum_gpair.GetHess(), sum_sampled_gpair.GetHess(), 0.02f * kRows); + } else { + EXPECT_NEAR(sum_gpair.GetGrad() / kRows, sum_sampled_gpair.GetGrad() / sample_rows, 0.02f); + EXPECT_NEAR(sum_gpair.GetHess() / kRows, sum_sampled_gpair.GetHess() / sample_rows, 0.02f); + } +} + +TEST(GradientBasedSampler, NoSampling) { + constexpr size_t kPageSize = 0; + constexpr float kSubsample = 1.0f; + constexpr int kSamplingMethod = TrainParam::kUniform; + VerifySampling(kPageSize, kSubsample, kSamplingMethod); +} + +// In external mode, when not sampling, we concatenate the pages together. +TEST(GradientBasedSampler, NoSampling_ExternalMemory) { + constexpr size_t kRows = 2048; + constexpr size_t kCols = 1; + constexpr float kSubsample = 1.0f; + constexpr size_t kPageSize = 1024; + + // Create a DMatrix with multiple batches. + dmlc::TemporaryDirectory tmpdir; + std::unique_ptr + dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir)); + auto gpair = GenerateRandomGradients(kRows); + gpair.SetDevice(0); + + BatchParam param{0, 256, 0, kPageSize}; + auto page = (*dmat->GetBatches(param).begin()).Impl(); + EXPECT_NE(page->matrix.n_rows, kRows); + + GradientBasedSampler sampler(page, kRows, param, kSubsample, TrainParam::kUniform); + auto sample = sampler.Sample(gpair.DeviceSpan(), dmat.get()); + auto sampled_page = sample.page; + EXPECT_EQ(sample.sample_rows, kRows); + EXPECT_EQ(sample.gpair.size(), gpair.Size()); + EXPECT_EQ(sample.gpair.data(), gpair.DevicePointer()); + EXPECT_EQ(sampled_page->matrix.n_rows, kRows); + + std::vector buffer(sampled_page->gidx_buffer.size()); + dh::CopyDeviceSpanToVector(&buffer, sampled_page->gidx_buffer); + common::CompressedIterator + ci(buffer.data(), sampled_page->matrix.info.NumSymbols()); + + size_t offset = 0; + for (auto& batch : dmat->GetBatches(param)) { + auto page = batch.Impl(); + std::vector page_buffer(page->gidx_buffer.size()); + dh::CopyDeviceSpanToVector(&page_buffer, page->gidx_buffer); + common::CompressedIterator + page_ci(page_buffer.data(), page->matrix.info.NumSymbols()); + size_t num_elements = page->matrix.n_rows * page->matrix.info.row_stride; + for (size_t i = 0; i < num_elements; i++) { + EXPECT_EQ(ci[i + offset], page_ci[i]); + } + offset += num_elements; + } +} + +TEST(GradientBasedSampler, UniformSampling) { + constexpr size_t kPageSize = 0; + constexpr float kSubsample = 0.5; + constexpr int kSamplingMethod = TrainParam::kUniform; + constexpr bool kFixedSizeSampling = true; + constexpr bool kCheckSum = false; + VerifySampling(kPageSize, kSubsample, kSamplingMethod, kFixedSizeSampling, kCheckSum); +} + +TEST(GradientBasedSampler, UniformSampling_ExternalMemory) { + constexpr size_t kPageSize = 1024; + constexpr float kSubsample = 0.5; + constexpr int kSamplingMethod = TrainParam::kUniform; + constexpr bool kFixedSizeSampling = false; + constexpr bool kCheckSum = false; + VerifySampling(kPageSize, kSubsample, kSamplingMethod, kFixedSizeSampling, kCheckSum); +} + +TEST(GradientBasedSampler, GradientBasedSampling) { + constexpr size_t kPageSize = 0; + constexpr float kSubsample = 0.8; + constexpr int kSamplingMethod = TrainParam::kGradientBased; + VerifySampling(kPageSize, kSubsample, kSamplingMethod); +} + +TEST(GradientBasedSampler, GradientBasedSampling_ExternalMemory) { + constexpr size_t kPageSize = 1024; + constexpr float kSubsample = 0.8; + constexpr int kSamplingMethod = TrainParam::kGradientBased; + constexpr bool kFixedSizeSampling = false; + VerifySampling(kPageSize, kSubsample, kSamplingMethod, kFixedSizeSampling); +} + +}; // namespace tree +}; // namespace xgboost diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 2ba1c9b196c7..6cb0aad26719 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -88,12 +88,13 @@ void TestBuildHist(bool use_shared_memory_histograms) { xgboost::SimpleLCG gen; xgboost::SimpleRealUniformDistribution dist(0.0f, 1.0f); - std::vector h_gpair(kNRows); - for (auto &gpair : h_gpair) { + HostDeviceVector gpair(kNRows); + for (auto &gp : gpair.HostVector()) { bst_float grad = dist(&gen); bst_float hess = dist(&gen); - gpair = GradientPair(grad, hess); + gp = GradientPair(grad, hess); } + gpair.SetDevice(0); thrust::host_vector h_gidx_buffer (page->gidx_buffer.size()); @@ -104,7 +105,7 @@ void TestBuildHist(bool use_shared_memory_histograms) { maker.row_partitioner.reset(new RowPartitioner(0, kNRows)); maker.hist.AllocateHistogram(0); - dh::CopyVectorToDeviceSpan(maker.gpair, h_gpair); + maker.gpair = gpair.DeviceSpan(); maker.use_shared_memory_histograms = use_shared_memory_histograms; maker.BuildHist(0); @@ -319,19 +320,6 @@ int32_t TestMinSplitLoss(DMatrix* dmat, float gamma, HostDeviceVector GenerateRandomGradients(const size_t n_rows) { - xgboost::SimpleLCG gen; - xgboost::SimpleRealUniformDistribution dist(0.0f, 1.0f); - std::vector h_gpair(n_rows); - for (auto &gpair : h_gpair) { - bst_float grad = dist(&gen); - bst_float hess = dist(&gen); - gpair = GradientPair(grad, hess); - } - HostDeviceVector gpair(h_gpair); - return gpair; -} - TEST(GpuHist, MinSplitLoss) { constexpr size_t kRows = 32; constexpr size_t kCols = 16; @@ -358,7 +346,9 @@ void UpdateTree(HostDeviceVector* gpair, DMatrix* dmat, size_t gpu_page_size, RegTree* tree, - HostDeviceVector* preds) { + HostDeviceVector* preds, + float subsample = 1.0f, + const std::string& sampling_method = "uniform") { constexpr size_t kMaxBin = 2; if (gpu_page_size > 0) { @@ -379,7 +369,9 @@ void UpdateTree(HostDeviceVector* gpair, {"max_bin", std::to_string(kMaxBin)}, {"min_child_weight", "0.0"}, {"reg_alpha", "0"}, - {"reg_lambda", "0"} + {"reg_lambda", "0"}, + {"subsample", std::to_string(subsample)}, + {"sampling_method", sampling_method}, }; tree::GPUHistMakerSpecialised hist_maker; @@ -391,10 +383,66 @@ void UpdateTree(HostDeviceVector* gpair, hist_maker.UpdatePredictionCache(dmat, preds); } +TEST(GpuHist, UniformSampling) { + constexpr size_t kRows = 4096; + constexpr size_t kCols = 2; + constexpr float kSubsample = 0.99; + + // Create an in-memory DMatrix. + std::unique_ptr dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true)); + + auto gpair = GenerateRandomGradients(kRows); + + // Build a tree using the in-memory DMatrix. + RegTree tree; + HostDeviceVector preds(kRows, 0.0, 0); + UpdateTree(&gpair, dmat.get(), 0, &tree, &preds); + + // Build another tree using sampling. + RegTree tree_sampling; + HostDeviceVector preds_sampling(kRows, 0.0, 0); + UpdateTree(&gpair, dmat.get(), 0, &tree_sampling, &preds_sampling, kSubsample); + + // Make sure the predictions are the same. + auto preds_h = preds.ConstHostVector(); + auto preds_sampling_h = preds_sampling.ConstHostVector(); + for (int i = 0; i < kRows; i++) { + EXPECT_NEAR(preds_h[i], preds_sampling_h[i], 2e-3); + } +} + +TEST(GpuHist, GradientBasedSampling) { + constexpr size_t kRows = 4096; + constexpr size_t kCols = 2; + constexpr float kSubsample = 0.99; + + // Create an in-memory DMatrix. + std::unique_ptr dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true)); + + auto gpair = GenerateRandomGradients(kRows); + + // Build a tree using the in-memory DMatrix. + RegTree tree; + HostDeviceVector preds(kRows, 0.0, 0); + UpdateTree(&gpair, dmat.get(), 0, &tree, &preds); + + // Build another tree using sampling. + RegTree tree_sampling; + HostDeviceVector preds_sampling(kRows, 0.0, 0); + UpdateTree(&gpair, dmat.get(), 0, &tree_sampling, &preds_sampling, kSubsample, "gradient_based"); + + // Make sure the predictions are the same. + auto preds_h = preds.ConstHostVector(); + auto preds_sampling_h = preds_sampling.ConstHostVector(); + for (int i = 0; i < kRows; i++) { + EXPECT_NEAR(preds_h[i], preds_sampling_h[i], 1e-3); + } +} + TEST(GpuHist, ExternalMemory) { - constexpr size_t kRows = 6; + constexpr size_t kRows = 4096; constexpr size_t kCols = 2; - constexpr size_t kPageSize = 1; + constexpr size_t kPageSize = 1024; // Create an in-memory DMatrix. std::unique_ptr dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true)); @@ -420,7 +468,42 @@ TEST(GpuHist, ExternalMemory) { auto preds_h = preds.ConstHostVector(); auto preds_ext_h = preds_ext.ConstHostVector(); for (int i = 0; i < kRows; i++) { - ASSERT_FLOAT_EQ(preds_h[i], preds_ext_h[i]); + EXPECT_NEAR(preds_h[i], preds_ext_h[i], 2e-6); + } +} + +TEST(GpuHist, ExternalMemoryWithSampling) { + constexpr size_t kRows = 4096; + constexpr size_t kCols = 2; + constexpr size_t kPageSize = 1024; + constexpr float kSubsample = 0.5; + const std::string kSamplingMethod = "gradient_based"; + + // Create an in-memory DMatrix. + std::unique_ptr dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true)); + + // Create a DMatrix with multiple batches. + dmlc::TemporaryDirectory tmpdir; + std::unique_ptr + dmat_ext(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir)); + + auto gpair = GenerateRandomGradients(kRows); + + // Build a tree using the in-memory DMatrix. + RegTree tree; + HostDeviceVector preds(kRows, 0.0, 0); + UpdateTree(&gpair, dmat.get(), 0, &tree, &preds, kSubsample, kSamplingMethod); + + // Build another tree using multiple ELLPACK pages. + RegTree tree_ext; + HostDeviceVector preds_ext(kRows, 0.0, 0); + UpdateTree(&gpair, dmat_ext.get(), kPageSize, &tree_ext, &preds_ext, kSubsample, kSamplingMethod); + + // Make sure the predictions are the same. + auto preds_h = preds.ConstHostVector(); + auto preds_ext_h = preds_ext.ConstHostVector(); + for (int i = 0; i < kRows; i++) { + EXPECT_NEAR(preds_h[i], preds_ext_h[i], 3e-3); } }