From c2d008febdd1b95d9549c569e90b47ab82bc4ed7 Mon Sep 17 00:00:00 2001 From: fis Date: Tue, 16 Jun 2020 01:15:35 +0800 Subject: [PATCH 01/15] Implement GK sketching on GPU. --- Jenkinsfile | 1 + include/xgboost/span.h | 9 +- src/common/device_helpers.cu | 27 + src/common/device_helpers.cuh | 177 ++++++- src/common/hist_util.cc | 10 +- src/common/hist_util.cu | 231 ++++++--- src/common/hist_util.cuh | 365 ++++++------- src/common/hist_util.h | 2 +- src/common/host_device_vector.cu | 8 +- src/common/quantile.cu | 569 +++++++++++++++++++++ src/common/quantile.cuh | 142 +++++ src/common/quantile.h | 15 +- src/common/threading_utils.h | 0 tests/cpp/common/test_device_helpers.cu | 126 +++++ tests/cpp/common/test_hist_util.cu | 87 ++-- tests/cpp/common/test_hist_util.h | 4 +- tests/cpp/common/test_partition_builder.cc | 0 tests/cpp/common/test_quantile.cu | 388 ++++++++++++++ tests/cpp/common/test_span.cc | 4 +- tests/cpp/common/test_threading_utils.cc | 0 tests/cpp/helpers.h | 4 +- tests/pytest.ini | 3 +- tests/python-gpu/test_gpu_with_dask.py | 53 +- tests/python/testing.py | 15 +- 24 files changed, 1883 insertions(+), 357 deletions(-) create mode 100644 src/common/quantile.cu create mode 100644 src/common/quantile.cuh mode change 100755 => 100644 src/common/threading_utils.h mode change 100755 => 100644 tests/cpp/common/test_partition_builder.cc create mode 100644 tests/cpp/common/test_quantile.cu mode change 100755 => 100644 tests/cpp/common/test_threading_utils.cc diff --git a/Jenkinsfile b/Jenkinsfile index 21d67f231bc9..277bb5133590 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -313,6 +313,7 @@ def TestPythonGPU(args) { nodeReq = (args.multi_gpu) ? 'linux && mgpu' : 'linux && gpu' node(nodeReq) { unstash name: 'xgboost_whl_cuda10' + unstash name: 'xgboost_cpp_tests' unstash name: 'srcs' echo "Test Python GPU: CUDA ${args.cuda_version}" def container_type = "gpu" diff --git a/include/xgboost/span.h b/include/xgboost/span.h index 29b0b9a7930d..ed8c97bd47b1 100644 --- a/include/xgboost/span.h +++ b/include/xgboost/span.h @@ -573,8 +573,8 @@ class Span { XGBOOST_DEVICE auto subspan() const -> // NOLINT Span::value> { - SPAN_CHECK(Offset < size() || size() == 0); - SPAN_CHECK(Count == dynamic_extent || (Offset + Count <= size())); + SPAN_CHECK((Count == dynamic_extent) ? + (Offset <= size()) : (Offset + Count <= size())); return {data() + Offset, Count == dynamic_extent ? size() - Offset : Count}; } @@ -582,9 +582,8 @@ class Span { XGBOOST_DEVICE Span subspan( // NOLINT index_type _offset, index_type _count = dynamic_extent) const { - SPAN_CHECK(_offset < size() || size() == 0); - SPAN_CHECK((_count == dynamic_extent) || (_offset + _count <= size())); - + SPAN_CHECK((_count == dynamic_extent) ? + (_offset <= size()) : (_offset + _count <= size())); return {data() + _offset, _count == dynamic_extent ? size() - _offset : _count}; } diff --git a/src/common/device_helpers.cu b/src/common/device_helpers.cu index c005c5a91d86..a80f2464c246 100644 --- a/src/common/device_helpers.cu +++ b/src/common/device_helpers.cu @@ -78,6 +78,33 @@ void AllReducer::Init(int _device_ordinal) { #endif // XGBOOST_USE_NCCL } +void AllReducer::AllGather(void const *data, size_t length_bytes, + std::vector *segments, + dh::caching_device_vector *recvbuf) { +#ifdef XGBOOST_USE_NCCL + CHECK(initialised_); + dh::safe_cuda(cudaSetDevice(device_ordinal_)); + size_t world = rabit::GetWorldSize(); + segments->clear(); + segments->resize(world, 0); + segments->at(rabit::GetRank()) = length_bytes; + rabit::Allreduce(segments->data(), segments->size()); + auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0); + recvbuf->resize(total_bytes); + + size_t offset = 0; + safe_nccl(ncclGroupStart()); + for (int32_t i = 0; i < world; ++i) { + size_t as_bytes = segments->at(i); + safe_nccl( + ncclBroadcast(data, recvbuf->data().get() + offset, + as_bytes, ncclChar, i, comm_, stream_)); + offset += as_bytes; + } + safe_nccl(ncclGroupEnd()); +#endif // XGBOOST_USE_NCCL +} + AllReducer::~AllReducer() { #ifdef XGBOOST_USE_NCCL if (initialised_) { diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 324674b33dbe..d339dc72d712 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -5,10 +5,16 @@ #include #include #include +#include +#include #include #include +#include + +#include #include #include +#include #include #include @@ -53,6 +59,36 @@ __device__ __forceinline__ double atomicAdd(double* address, double val) { // N } #endif +namespace dh { +namespace detail { +template +struct AtomicDispatcher; + +template <> +struct AtomicDispatcher { + using Type = unsigned int; // NOLINT + static_assert(sizeof(Type) == sizeof(uint32_t), "Unsigned should be of size 32 bits."); +}; + +template <> +struct AtomicDispatcher { + using Type = unsigned long long; // NOLINT + static_assert(sizeof(Type) == sizeof(uint64_t), "Unsigned long long should be of size 64 bits."); +}; +} // namespace detail +} // namespace dh + +// atomicAdd is not defined for size_t. +template ::value && + !std::is_same::value> * = // NOLINT + nullptr> +T __device__ __forceinline__ atomicAdd(T *addr, T v) { // NOLINT + using Type = typename dh::detail::AtomicDispatcher::Type; + Type ret = ::atomicAdd(reinterpret_cast(addr), static_cast(v)); + return static_cast(ret); +} + namespace dh { #define HOST_DEV_INLINE XGBOOST_DEVICE __forceinline__ @@ -291,10 +327,12 @@ public: safe_cuda(cudaGetDevice(¤t_device)); stats_.RegisterDeallocation(ptr, n, current_device); } - size_t PeakMemory() - { + size_t PeakMemory() const { return stats_.peak_allocated_bytes; } + size_t CurrentlyAllocatedBytes() const { + return stats_.currently_allocated_bytes; + } void Clear() { stats_ = DeviceStats(); @@ -529,7 +567,6 @@ class AllReducer { bool initialised_ {false}; size_t allreduce_bytes_ {0}; // Keep statistics of the number of bytes communicated size_t allreduce_calls_ {0}; // Keep statistics of the number of reduce calls - std::vector host_data_; // Used for all reduce on host #ifdef XGBOOST_USE_NCCL ncclComm_t comm_; cudaStream_t stream_; @@ -569,6 +606,27 @@ class AllReducer { #endif } + /** + * \brief Allgather implemented as grouped calls to Broadcast. This way we can accept + * different size of data on different workers. + * \param length_bytes Size of input data in bytes. + * \param segments Size of data on each worker. + * \param recvbuf Buffer storing the result of data from all workers. + */ + void AllGather(void const* data, size_t length_bytes, + std::vector* segments, dh::caching_device_vector* recvbuf); + + void AllGather(uint32_t const* data, size_t length, + dh::caching_device_vector* recvbuf) { +#ifdef XGBOOST_USE_NCCL + CHECK(initialised_); + size_t world = rabit::GetWorldSize(); + recvbuf->resize(length * world); + safe_nccl(ncclAllGather(data, recvbuf->data().get(), length, ncclUint32, + comm_, stream_)); +#endif // XGBOOST_USE_NCCL + } + /** * \brief Allreduce. Use in exactly the same way as NCCL but without needing * streams or comms. @@ -607,6 +665,40 @@ class AllReducer { #endif } + void AllReduceSum(const uint32_t *sendbuff, uint32_t *recvbuff, int count) { +#ifdef XGBOOST_USE_NCCL + CHECK(initialised_); + + dh::safe_cuda(cudaSetDevice(device_ordinal_)); + dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclUint32, ncclSum, comm_, stream_)); +#endif + } + + void AllReduceSum(const uint64_t *sendbuff, uint64_t *recvbuff, int count) { +#ifdef XGBOOST_USE_NCCL + CHECK(initialised_); + + dh::safe_cuda(cudaSetDevice(device_ordinal_)); + dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclUint64, ncclSum, comm_, stream_)); +#endif + } + + // Specialization for size_t, which is implementation defined so it might or might not + // be one of uint64_t/uint32_t/unsigned long long/unsigned long. + template ::value && + !std::is_same::value> // NOLINT + * = nullptr> + void AllReduceSum(const T *sendbuff, T *recvbuff, int count) { // NOLINT +#ifdef XGBOOST_USE_NCCL + CHECK(initialised_); + + dh::safe_cuda(cudaSetDevice(device_ordinal_)); + static_assert(sizeof(unsigned long long) == sizeof(uint64_t), ""); // NOLINT + dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclUint64, ncclSum, comm_, stream_)); +#endif + } + /** * \fn void Synchronize() * @@ -886,9 +978,86 @@ DEV_INLINE void AtomicAddGpair(OutputGradientT* dest, // Thrust version of this function causes error on Windows template -thrust::transform_iterator MakeTransformIterator( +XGBOOST_DEVICE thrust::transform_iterator MakeTransformIterator( IterT iter, FuncT func) { return thrust::transform_iterator(iter, func); } +template +size_t XGBOOST_DEVICE SegmentId(It first, It last, size_t idx) { + size_t segment_id = thrust::upper_bound(thrust::seq, first, last, idx) - + 1 - first; + return segment_id; +} + +template +size_t XGBOOST_DEVICE SegmentId(xgboost::common::Span segments_ptr, size_t idx) { + return SegmentId(segments_ptr.cbegin(), segments_ptr.cend(), idx); +} + +namespace detail { +template +struct SegmentedUniqueReduceOp { + KeyOutIt key_out; + __device__ Key const& operator()(Key const& key) const { + auto constexpr kOne = static_cast>(1); + atomicAdd(&(*(key_out + key.first)), kOne); + return key; + } +}; +} // namespace detail + +/* \brief Segmented unique function. Keys are pointers to segments with key_segments_last - + * key_segments_first = n_segments + 1. + * + * \pre Input segment and output segment must not overlap. + * + * \param key_segments_first Beginning iterator of segments. + * \param key_segments_last End iterator of segments. + * \param val_first Beginning iterator of values. + * \param val_last End iterator of values. + * \param key_segments_out Output iterator of segments. + * \param val_out Output iterator of values. + * + * \return Number of unique values in total. + */ +template +size_t +SegmentedUnique(KeyInIt key_segments_first, KeyInIt key_segments_last, ValInIt val_first, + ValInIt val_last, KeyOutIt key_segments_out, ValOutIt val_out, + Comp comp) { + using Key = thrust::pair::value_type>; + dh::XGBCachingDeviceAllocator alloc; + auto unique_key_it = dh::MakeTransformIterator( + thrust::make_counting_iterator(static_cast(0)), + [=] __device__(size_t i) { + size_t seg = dh::SegmentId(key_segments_first, key_segments_last, i); + return thrust::make_pair(seg, *(val_first + i)); + }); + size_t segments_len = key_segments_last - key_segments_first; + thrust::fill(thrust::device, key_segments_out, key_segments_out + segments_len, 0); + size_t n_inputs = std::distance(val_first, val_last); + // Reduce the number of uniques elements per segment, avoid creating an intermediate + // array for `reduce_by_key`. It's limited by the types that atomicAdd supports. For + // example, size_t is not supported as of CUDA 10.2. + auto reduce_it = thrust::make_transform_output_iterator( + thrust::make_discard_iterator(), + detail::SegmentedUniqueReduceOp{key_segments_out}); + auto uniques_ret = thrust::unique_by_key_copy( + thrust::cuda::par(alloc), unique_key_it, unique_key_it + n_inputs, + val_first, reduce_it, val_out, + [=] __device__(Key const &l, Key const &r) { + if (l.first == r.first) { + // In the same segment. + return comp(l.second, r.second); + } + return false; + }); + auto n_uniques = uniques_ret.second - val_out; + CHECK_LE(n_uniques, n_inputs); + thrust::exclusive_scan(thrust::cuda::par(alloc), key_segments_out, + key_segments_out + segments_len, key_segments_out, 0); + return n_uniques; +} } // namespace dh diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index d44a705586f1..f8e42f2f454a 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -158,7 +158,6 @@ void SparseCuts::SingleThreadBuild(SparsePage const& page, MetaInfo const& info, uint32_t beg_col, uint32_t end_col, uint32_t thread_id) { CHECK_GE(end_col, beg_col); - constexpr float kFactor = 8; // Data groups, used in ranking. std::vector const& group_ptr = info.group_ptr_; @@ -175,11 +174,12 @@ void SparseCuts::SingleThreadBuild(SparsePage const& page, MetaInfo const& info, max_num_bins); if (n_bins == 0) { // cut_ptrs_ is initialized with a zero, so there's always an element at the back + CHECK_GE(local_ptrs.size(), 1); local_ptrs.emplace_back(local_ptrs.back()); continue; } - sketch.Init(info.num_row_, 1.0 / (n_bins * kFactor)); + sketch.Init(info.num_row_, 1.0 / (n_bins * WQSketch::kFactor)); for (auto const& entry : column) { uint32_t weight_ind = 0; if (use_group_ind) { @@ -329,7 +329,6 @@ void DenseCuts::Build(DMatrix* p_fmat, uint32_t max_num_bins) { const MetaInfo& info = p_fmat->Info(); // safe factor for better accuracy - constexpr int kFactor = 8; std::vector sketchs; const int nthread = omp_get_max_threads(); @@ -339,7 +338,7 @@ void DenseCuts::Build(DMatrix* p_fmat, uint32_t max_num_bins) { unsigned const ncol = static_cast(info.num_col_); sketchs.resize(info.num_col_); for (auto& s : sketchs) { - s.Init(info.num_row_, 1.0 / (max_num_bins * kFactor)); + s.Init(info.num_row_, 1.0 / (max_num_bins * WQSketch::kFactor)); } // Data groups, used in ranking. @@ -410,9 +409,8 @@ void DenseCuts::Init // This allows efficient training on wide data size_t global_max_rows = max_rows; rabit::Allreduce(&global_max_rows, 1); - constexpr int kFactor = 8; size_t intermediate_num_cuts = - std::min(global_max_rows, static_cast(max_num_bins * kFactor)); + std::min(global_max_rows, static_cast(max_num_bins * WQSketch::kFactor)); // gather the histogram data rabit::SerializeReducer sreducer; std::vector summary_array; diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index f7744c7884cb..803a1df02658 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -31,21 +32,20 @@ namespace common { constexpr float SketchContainer::kFactor; +namespace detail { + // Count the entries in each column and exclusive scan -void ExtractCuts(int device, - size_t num_cuts_per_feature, - Span sorted_data, - Span column_sizes_scan, - Span out_cuts) { +void ExtractCutsSparse(int device, common::Span cuts_ptr, + Span sorted_data, + Span column_sizes_scan, + Span out_cuts) { dh::LaunchN(device, out_cuts.size(), [=] __device__(size_t idx) { // Each thread is responsible for obtaining one cut from the sorted input - size_t column_idx = idx / num_cuts_per_feature; + size_t column_idx = dh::SegmentId(cuts_ptr, idx); size_t column_size = column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx]; - size_t num_available_cuts = - min(static_cast(num_cuts_per_feature), column_size); - size_t cut_idx = idx % num_cuts_per_feature; - if (cut_idx >= num_available_cuts) return; + size_t num_available_cuts = cuts_ptr[column_idx + 1] - cuts_ptr[column_idx]; + size_t cut_idx = idx - cuts_ptr[column_idx]; Span column_entries = sorted_data.subspan(column_sizes_scan[column_idx], column_size); size_t rank = (column_entries.size() * cut_idx) / @@ -55,31 +55,20 @@ void ExtractCuts(int device, }); } -/** - * \brief Extracts the cuts from sorted data, considering weights. - * - * \param device The device. - * \param cuts Output cuts. - * \param num_cuts_per_feature Number of cuts per feature. - * \param sorted_data Sorted entries in segments of columns. - * \param weights_scan Inclusive scan of weights for each entry in sorted_data. - * \param column_sizes_scan Describes the boundaries of column segments in sorted data. - */ -void ExtractWeightedCuts(int device, - size_t num_cuts_per_feature, - Span sorted_data, - Span weights_scan, - Span column_sizes_scan, - Span cuts) { +void ExtractWeightedCutsSparse(int device, + common::Span cuts_ptr, + Span sorted_data, + Span weights_scan, + Span column_sizes_scan, + Span cuts) { dh::LaunchN(device, cuts.size(), [=] __device__(size_t idx) { // Each thread is responsible for obtaining one cut from the sorted input - size_t column_idx = idx / num_cuts_per_feature; + size_t column_idx = dh::SegmentId(cuts_ptr, idx); size_t column_size = column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx]; - size_t num_available_cuts = - min(static_cast(num_cuts_per_feature), column_size); - size_t cut_idx = idx % num_cuts_per_feature; - if (cut_idx >= num_available_cuts) return; + size_t num_available_cuts = cuts_ptr[column_idx + 1] - cuts_ptr[column_idx]; + size_t cut_idx = idx - cuts_ptr[column_idx]; + Span column_entries = sorted_data.subspan(column_sizes_scan[column_idx], column_size); @@ -109,7 +98,7 @@ void ExtractWeightedCuts(int device, max(static_cast(0), min(sample_idx, column_entries.size() - 1)); } - // repeated values will be filtered out on the CPU + // repeated values will be filtered out later. bst_float rmin = sample_idx > 0 ? column_weights_scan[sample_idx - 1] : 0.0f; bst_float rmax = column_weights_scan[sample_idx]; cuts[idx] = WQSketch::Entry(rmin, rmax, rmax - rmin, @@ -117,31 +106,73 @@ void ExtractWeightedCuts(int device, }); } -void ProcessBatch(int device, const SparsePage& page, size_t begin, size_t end, - SketchContainer* sketch_container, int num_cuts, - size_t num_columns) { - dh::XGBCachingDeviceAllocator alloc; - const auto& host_data = page.data.ConstHostVector(); - dh::caching_device_vector sorted_entries(host_data.begin() + begin, - host_data.begin() + end); - thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(), - sorted_entries.end(), EntryCompareOp()); +size_t RequiredSampleCutsPerColumn(int max_bins, size_t num_rows) { + double eps = 1.0 / (WQSketch::kFactor * max_bins); + size_t dummy_nlevel; + size_t num_cuts; + WQuantileSketch::LimitSizeLevel( + num_rows, eps, &dummy_nlevel, &num_cuts); + return std::min(num_cuts, num_rows); +} - dh::caching_device_vector column_sizes_scan; - GetColumnSizesScan(device, &column_sizes_scan, - {sorted_entries.data().get(), sorted_entries.size()}, - num_columns); - thrust::host_vector host_column_sizes_scan(column_sizes_scan); +size_t RequiredSampleCuts(bst_row_t num_rows, bst_feature_t num_columns, + size_t max_bins, size_t nnz) { + auto per_column = RequiredSampleCutsPerColumn(max_bins, num_rows); + auto if_dense = num_columns * per_column; + auto result = std::min(nnz, if_dense); + return result; +} - dh::caching_device_vector cuts(num_columns * num_cuts); - ExtractCuts(device, num_cuts, - dh::ToSpan(sorted_entries), - dh::ToSpan(column_sizes_scan), - dh::ToSpan(cuts)); +size_t RequiredMemory(bst_row_t num_rows, bst_feature_t num_columns, size_t nnz, + size_t num_bins, bool with_weights) { + size_t peak = 0; + // 0. Allocate cut pointer in quantile container by increasing: n_columns + 1 + size_t total = (num_columns + 1) * sizeof(SketchContainer::OffsetT); + // 1. Copy and sort: 2 * bytes_per_element * shape + total += BytesPerElement(with_weights) * num_rows * num_columns; + peak = std::max(peak, total); + // 2. Deallocate bytes_per_element * shape due to reusing memory in sort. + total -= BytesPerElement(with_weights) * num_rows * num_columns / 2; + // 3. Allocate colomn size scan by increasing: n_columns + 1 + total += (num_columns + 1) * sizeof(SketchContainer::OffsetT); + // 4. Allocate cut pointer by increasing: n_columns + 1 + total += (num_columns + 1) * sizeof(SketchContainer::OffsetT); + // 5. Allocate cuts: assuming rows is greater than bins: n_columns * limit_size + total += RequiredSampleCuts(num_rows, num_bins, num_bins, nnz) * sizeof(SketchEntry); + // 6. Deallocate copied entries by reducing: bytes_per_element * shape. + peak = std::max(peak, total); + total -= (BytesPerElement(with_weights) * num_rows * num_columns) / 2; + // 7. Deallocate column size scan. + peak = std::max(peak, total); + total -= (num_columns + 1) * sizeof(SketchContainer::OffsetT); + // 8. Deallocate cut size scan. + total -= (num_columns + 1) * sizeof(SketchContainer::OffsetT); + // 9. Allocate std::min(rows, bins * factor) * shape due to pruning to global num rows. + total += RequiredSampleCuts(num_rows, num_bins, num_bins, nnz) * sizeof(SketchEntry); + // 10. Allocate final cut values, min values, cut ptrs: std::min(rows, bins + 1) * + // n_columns + n_columns + n_columns + 1 + total += std::min(num_rows, num_bins) * num_columns * sizeof(float); + total += num_columns * + sizeof(std::remove_reference_t().MinValues())>::value_type); + total += (num_columns + 1) * + sizeof(std::remove_reference_t().Ptrs())>::value_type); + peak = std::max(peak, total); - // add cuts into sketches - thrust::host_vector host_cuts(cuts); - sketch_container->Push(num_cuts, host_cuts, host_column_sizes_scan); + return peak; +} + +size_t SketchBatchNumElements(size_t sketch_batch_num_elements, + bst_row_t num_rows, size_t columns, size_t nnz, int device, + size_t num_cuts, bool has_weight) { + if (sketch_batch_num_elements == 0) { + auto required_memory = RequiredMemory(num_rows, columns, nnz, num_cuts, has_weight); + // use up to 80% of available space + sketch_batch_num_elements = (dh::AvailableMemory(device) - + required_memory * 0.8); + } + return sketch_batch_num_elements; } void SortByWeight(dh::XGBCachingDeviceAllocator* alloc, @@ -150,7 +181,7 @@ void SortByWeight(dh::XGBCachingDeviceAllocator* alloc, // Sort both entries and wegihts. thrust::sort_by_key(thrust::cuda::par(*alloc), sorted_entries->begin(), sorted_entries->end(), weights->begin(), - EntryCompareOp()); + detail::EntryCompareOp()); // Scan weights thrust::inclusive_scan_by_key(thrust::cuda::par(*alloc), @@ -160,6 +191,46 @@ void SortByWeight(dh::XGBCachingDeviceAllocator* alloc, return a.index == b.index; }); } +} // namespace detail + +void ProcessBatch(int device, const SparsePage &page, size_t begin, size_t end, + SketchContainer *sketch_container, int num_cuts_per_feature, + size_t num_columns) { + dh::XGBCachingDeviceAllocator alloc; + const auto& host_data = page.data.ConstHostVector(); + dh::caching_device_vector sorted_entries(host_data.begin() + begin, + host_data.begin() + end); + thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(), + sorted_entries.end(), detail::EntryCompareOp()); + + HostDeviceVector cuts_ptr; + dh::caching_device_vector column_sizes_scan; + data::IsValidFunctor dummy_is_valid(std::numeric_limits::quiet_NaN()); + auto batch_it = dh::MakeTransformIterator( + sorted_entries.data().get(), + [] __device__(Entry const &e) -> data::COOTuple { + return {0, e.index, e.fvalue}; // row_idx is not needed for scanning column size. + }); + detail::GetColumnSizesScan(device, num_columns, num_cuts_per_feature, + batch_it, dummy_is_valid, + 0, sorted_entries.size(), + &cuts_ptr, &column_sizes_scan); + + auto const& h_cuts_ptr = cuts_ptr.ConstHostVector(); + dh::caching_device_vector cuts(h_cuts_ptr.back()); + auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan(); + + CHECK_EQ(d_cuts_ptr.size(), column_sizes_scan.size()); + detail::ExtractCutsSparse(device, d_cuts_ptr, dh::ToSpan(sorted_entries), + dh::ToSpan(column_sizes_scan), dh::ToSpan(cuts)); + + // add cuts into sketches + sorted_entries.clear(); + sorted_entries.shrink_to_fit(); + CHECK_EQ(sorted_entries.capacity(), 0); + CHECK_NE(cuts_ptr.Size(), 0); + sketch_container->Push(cuts_ptr.ConstDeviceSpan(), &cuts); +} void ProcessWeightedBatch(int device, const SparsePage& page, Span weights, size_t begin, size_t end, @@ -204,40 +275,53 @@ void ProcessWeightedBatch(int device, const SparsePage& page, d_temp_weights[idx] = weights[ridx + base_rowid]; }); } - SortByWeight(&alloc, &temp_weights, &sorted_entries); + detail::SortByWeight(&alloc, &temp_weights, &sorted_entries); + HostDeviceVector cuts_ptr; dh::caching_device_vector column_sizes_scan; - GetColumnSizesScan(device, &column_sizes_scan, - {sorted_entries.data().get(), sorted_entries.size()}, - num_columns); - thrust::host_vector host_column_sizes_scan(column_sizes_scan); + data::IsValidFunctor dummy_is_valid(std::numeric_limits::quiet_NaN()); + auto batch_it = dh::MakeTransformIterator( + sorted_entries.data().get(), + [] __device__(Entry const &e) -> data::COOTuple { + return {0, e.index, e.fvalue}; // row_idx is not needed for scaning column size. + }); + detail::GetColumnSizesScan(device, num_columns, num_cuts_per_feature, + batch_it, dummy_is_valid, + 0, sorted_entries.size(), + &cuts_ptr, &column_sizes_scan); + + auto const& h_cuts_ptr = cuts_ptr.ConstHostVector(); + dh::caching_device_vector cuts(h_cuts_ptr.back()); + auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan(); // Extract cuts - dh::caching_device_vector cuts(num_columns * num_cuts_per_feature); - ExtractWeightedCuts(device, num_cuts_per_feature, - dh::ToSpan(sorted_entries), - dh::ToSpan(temp_weights), - dh::ToSpan(column_sizes_scan), - dh::ToSpan(cuts)); + detail::ExtractWeightedCutsSparse(device, d_cuts_ptr, + dh::ToSpan(sorted_entries), + dh::ToSpan(temp_weights), + dh::ToSpan(column_sizes_scan), + dh::ToSpan(cuts)); // add cuts into sketches - thrust::host_vector host_cuts(cuts); - sketch_container->Push(num_cuts_per_feature, host_cuts, host_column_sizes_scan); + sketch_container->Push(cuts_ptr.ConstDeviceSpan(), &cuts); } HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, size_t sketch_batch_num_elements) { // Configure batch size based on available memory bool has_weights = dmat->Info().weights_.Size() > 0; - size_t num_cuts_per_feature = RequiredSampleCuts(max_bins, dmat->Info().num_row_); - sketch_batch_num_elements = SketchBatchNumElements( + size_t num_cuts_per_feature = + detail::RequiredSampleCutsPerColumn(max_bins, dmat->Info().num_row_); + sketch_batch_num_elements = detail::SketchBatchNumElements( sketch_batch_num_elements, - dmat->Info().num_col_, device, num_cuts_per_feature, has_weights); + dmat->Info().num_row_, + dmat->Info().num_col_, + dmat->Info().num_nonzero_, + device, num_cuts_per_feature, has_weights); HistogramCuts cuts; DenseCuts dense_cuts(&cuts); SketchContainer sketch_container(max_bins, dmat->Info().num_col_, - dmat->Info().num_row_); + dmat->Info().num_row_, device); dmat->Info().weights_.SetDevice(device); for (const auto& batch : dmat->GetBatches()) { @@ -261,8 +345,7 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, } } } - - dense_cuts.Init(&sketch_container.sketches_, max_bins, dmat->Info().num_row_); + sketch_container.MakeCuts(&cuts); return cuts; } } // namespace common diff --git a/src/common/hist_util.cuh b/src/common/hist_util.cuh index fe720c530b06..94744513aa81 100644 --- a/src/common/hist_util.cuh +++ b/src/common/hist_util.cuh @@ -1,5 +1,8 @@ /*! * Copyright 2020 XGBoost contributors + * + * \brief Front end and utilities for GPU based sketching. Works on sliding window + * instead of stream. */ #ifndef COMMON_HIST_UTIL_CUH_ #define COMMON_HIST_UTIL_CUH_ @@ -7,74 +10,15 @@ #include #include "hist_util.h" -#include "threading_utils.h" +#include "quantile.cuh" #include "device_helpers.cuh" +#include "timer.h" #include "../data/device_adapter.cuh" namespace xgboost { namespace common { -using WQSketch = DenseCuts::WQSketch; -using SketchEntry = WQSketch::Entry; - -/*! - * \brief A container that holds the device sketches across all - * sparse page batches which are distributed to different devices. - * As sketches are aggregated by column, the mutex guards - * multiple devices pushing sketch summary for the same column - * across distinct rows. - */ -struct SketchContainer { - std::vector sketches_; // NOLINT - static constexpr int kOmpNumColsParallelizeLimit = 1000; - static constexpr float kFactor = 8; - - SketchContainer(int max_bin, size_t num_columns, size_t num_rows) { - // Initialize Sketches for this dmatrix - sketches_.resize(num_columns); -#pragma omp parallel for schedule(static) if (num_columns > kOmpNumColsParallelizeLimit) // NOLINT - for (int icol = 0; icol < num_columns; ++icol) { // NOLINT - sketches_[icol].Init(num_rows, 1.0 / (8 * max_bin)); - } - } - - /** - * \brief Pushes cuts to the sketches. - * - * \param entries_per_column The entries per column. - * \param entries Vector of cuts from all columns, length - * entries_per_column * num_columns. \param column_scan Exclusive scan - * of column sizes. Used to detect cases where there are fewer entries than we - * have storage for. - */ - void Push(size_t entries_per_column, - const thrust::host_vector& entries, - const thrust::host_vector& column_scan) { -#pragma omp parallel for schedule(static) if (sketches_.size() > SketchContainer::kOmpNumColsParallelizeLimit) // NOLINT - for (int icol = 0; icol < sketches_.size(); ++icol) { - size_t column_size = column_scan[icol + 1] - column_scan[icol]; - if (column_size == 0) continue; - WQuantileSketch::SummaryContainer summary; - size_t num_available_cuts = - std::min(size_t(entries_per_column), column_size); - summary.Reserve(num_available_cuts); - summary.MakeFromSorted(&entries[entries_per_column * icol], - num_available_cuts); - - sketches_[icol].PushSummary(summary); - } - } - - // Prevent copying/assigning/moving this as its internals can't be - // assigned/copied/moved - SketchContainer(const SketchContainer&) = delete; - SketchContainer(SketchContainer&& that) { - std::swap(sketches_, that.sketches_); - } - SketchContainer& operator=(const SketchContainer&) = delete; - SketchContainer& operator=(SketchContainer&&) = delete; -}; - +namespace detail { struct EntryCompareOp { __device__ bool operator()(const Entry& a, const Entry& b) { if (a.index == b.index) { @@ -88,100 +32,105 @@ struct EntryCompareOp { * \brief Extracts the cuts from sorted data. * * \param device The device. - * \param cuts Output cuts - * \param num_cuts_per_feature Number of cuts per feature. + * \param cuts_ptr Column pointers to CSC structured cuts * \param sorted_data Sorted entries in segments of columns - * \param column_sizes_scan Describes the boundaries of column segments in - * sorted data + * \param column_sizes_scan Describes the boundaries of column segments in sorted data + * \param out_cuts Output cut values */ -void ExtractCuts(int device, - size_t num_cuts_per_feature, - Span sorted_data, - Span column_sizes_scan, - Span out_cuts); - -// Count the entries in each column and exclusive scan -inline void GetColumnSizesScan(int device, - dh::caching_device_vector* column_sizes_scan, - Span entries, size_t num_columns) { - column_sizes_scan->resize(num_columns + 1, 0); - auto d_column_sizes_scan = column_sizes_scan->data().get(); - auto d_entries = entries.data(); - dh::LaunchN(device, entries.size(), [=] __device__(size_t idx) { - auto& e = d_entries[idx]; - atomicAdd(reinterpret_cast( // NOLINT - &d_column_sizes_scan[e.index]), - static_cast(1)); // NOLINT - }); - dh::XGBCachingDeviceAllocator alloc; - thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan->begin(), - column_sizes_scan->end(), column_sizes_scan->begin()); -} +void ExtractCutsSparse(int device, common::Span cuts_ptr, + Span sorted_data, + Span column_sizes_scan, + Span out_cuts); -// For adapter. +/** + * \brief Extracts the cuts from sorted data, considering weights. + * + * \param device The device. + * \param cuts_ptr Column pointers to CSC structured cuts + * \param sorted_data Sorted entries in segments of columns. + * \param weights_scan Inclusive scan of weights for each entry in sorted_data. + * \param column_sizes_scan Describes the boundaries of column segments in sorted data. + * \param cuts Output cuts. + */ +void ExtractWeightedCutsSparse(int device, + common::Span cuts_ptr, + Span sorted_data, + Span weights_scan, + Span column_sizes_scan, + Span cuts); + +// Get column size from adapter batch and for output cuts. template -void GetColumnSizesScan(int device, size_t num_columns, +void GetColumnSizesScan(int device, size_t num_columns, size_t num_cuts_per_feature, Iter batch_iter, data::IsValidFunctor is_valid, size_t begin, size_t end, + HostDeviceVector *cuts_ptr, dh::caching_device_vector* column_sizes_scan) { - dh::XGBCachingDeviceAllocator alloc; column_sizes_scan->resize(num_columns + 1, 0); + cuts_ptr->SetDevice(device); + cuts_ptr->Resize(num_columns + 1, 0); + + dh::XGBCachingDeviceAllocator alloc; auto d_column_sizes_scan = column_sizes_scan->data().get(); dh::LaunchN(device, end - begin, [=] __device__(size_t idx) { auto e = batch_iter[begin + idx]; if (is_valid(e)) { - atomicAdd(reinterpret_cast( // NOLINT - &d_column_sizes_scan[e.column_idx]), - static_cast(1)); // NOLINT + atomicAdd(&d_column_sizes_scan[e.column_idx], static_cast(1)); } }); + // Calculate cuts CSC pointer + auto cut_ptr_it = dh::MakeTransformIterator( + column_sizes_scan->begin(), [=] __device__(size_t column_size) { + return thrust::min(num_cuts_per_feature, column_size); + }); + thrust::exclusive_scan(thrust::cuda::par(alloc), cut_ptr_it, + cut_ptr_it + column_sizes_scan->size(), + cuts_ptr->DevicePointer()); thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan->begin(), column_sizes_scan->end(), column_sizes_scan->begin()); } -inline size_t BytesPerElement(bool has_weight) { +inline size_t constexpr BytesPerElement(bool has_weight) { // Double the memory usage for sorting. We need to assign weight for each element, so // sizeof(float) is added to all elements. return (has_weight ? sizeof(Entry) + sizeof(float) : sizeof(Entry)) * 2; } -inline size_t SketchBatchNumElements(size_t sketch_batch_num_elements, - size_t columns, int device, - size_t num_cuts, bool has_weight) { - if (sketch_batch_num_elements == 0) { - size_t bytes_per_element = BytesPerElement(has_weight); - size_t bytes_cuts = num_cuts * columns * sizeof(SketchEntry); - size_t bytes_num_columns = (columns + 1) * sizeof(size_t); - // use up to 80% of available space - sketch_batch_num_elements = (dh::AvailableMemory(device) - - bytes_cuts - bytes_num_columns) * - 0.8 / bytes_per_element; - } - return sketch_batch_num_elements; -} - +/* \brief Calcuate the length of sliding window. Returns `sketch_batch_num_elements` + * directly if it's not 0. + */ +size_t SketchBatchNumElements(size_t sketch_batch_num_elements, + bst_row_t num_rows, size_t columns, size_t nnz, int device, + size_t num_cuts, bool has_weight); // Compute number of sample cuts needed on local node to maintain accuracy // We take more cuts than needed and then reduce them later -inline size_t RequiredSampleCuts(int max_bins, size_t num_rows) { - double eps = 1.0 / (SketchContainer::kFactor * max_bins); - size_t dummy_nlevel; - size_t num_cuts; - WQuantileSketch::LimitSizeLevel( - num_rows, eps, &dummy_nlevel, &num_cuts); - return std::min(num_cuts, num_rows); -} - -// sketch_batch_num_elements 0 means autodetect. Only modify this for testing. -HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, - size_t sketch_batch_num_elements = 0); +size_t RequiredSampleCutsPerColumn(int max_bins, size_t num_rows); +/* \brief Estimate required memory for each sliding window. + * + * It's not precise as to obtain exact memory usage for sparse dataset we need to walk + * through the whole dataset first. Also if data is from host DMatrix, we copy the + * weight, group and offset on first batch, which is not considered in the function. + * + * \param num_rows Number of rows in this worker. + * \param num_columns Number of columns for this dataset. + * \param nnz Number of non-zero element. Put in something greater than rows * + * cols if nnz is unknown. + * \param num_bins Number of histogram bins. + * \param with_weights Whether weight is used, works the same for ranking and other models. + * + * \return The estimated bytes + */ +size_t RequiredMemory(bst_row_t num_rows, bst_feature_t num_columns, size_t nnz, + size_t num_bins, bool with_weights); +// Count the valid entries in each column and copy them out. template void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter, Range1d range, float missing, - size_t columns, int device, - thrust::host_vector* host_column_sizes_scan, + size_t columns, size_t cuts_per_feature, int device, + HostDeviceVector* cut_sizes_scan, dh::caching_device_vector* column_sizes_scan, dh::caching_device_vector* sorted_entries) { auto entry_iter = dh::MakeTransformIterator( @@ -191,16 +140,12 @@ void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter, }); data::IsValidFunctor is_valid(missing); // Work out how many valid entries we have in each column - GetColumnSizesScan(device, columns, + GetColumnSizesScan(device, columns, cuts_per_feature, batch_iter, is_valid, range.begin(), range.end(), + cut_sizes_scan, column_sizes_scan); - host_column_sizes_scan->resize(column_sizes_scan->size()); - thrust::copy(column_sizes_scan->begin(), column_sizes_scan->end(), - host_column_sizes_scan->begin()); - - size_t num_valid = host_column_sizes_scan->back(); - + size_t num_valid = column_sizes_scan->back(); // Copy current subset of valid elements into temporary storage and sort sorted_entries->resize(num_valid); dh::XGBCachingDeviceAllocator alloc; @@ -208,6 +153,16 @@ void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter, entry_iter + range.end(), sorted_entries->begin(), is_valid); } +void SortByWeight(dh::XGBCachingDeviceAllocator* alloc, + dh::caching_device_vector* weights, + dh::caching_device_vector* sorted_entries); +} // namespace detail + +// Compute sketch on DMatrix. +// sketch_batch_num_elements 0 means autodetect. Only modify this for testing. +HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, + size_t sketch_batch_num_elements = 0); + template void ProcessSlidingWindow(AdapterBatch const& batch, int device, size_t columns, size_t begin, size_t end, float missing, @@ -215,41 +170,33 @@ void ProcessSlidingWindow(AdapterBatch const& batch, int device, size_t columns, // Copy current subset of valid elements into temporary storage and sort dh::caching_device_vector sorted_entries; dh::caching_device_vector column_sizes_scan; - thrust::host_vector host_column_sizes_scan; auto batch_iter = dh::MakeTransformIterator( thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) { return batch.GetElement(idx); }); - MakeEntriesFromAdapter(batch, batch_iter, {begin, end}, missing, columns, device, - &host_column_sizes_scan, - &column_sizes_scan, - &sorted_entries); + HostDeviceVector cuts_ptr; + detail::MakeEntriesFromAdapter(batch, batch_iter, {begin, end}, missing, + columns, num_cuts, device, + &cuts_ptr, + &column_sizes_scan, + &sorted_entries); dh::XGBCachingDeviceAllocator alloc; thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(), - sorted_entries.end(), EntryCompareOp()); + sorted_entries.end(), detail::EntryCompareOp()); + auto const& h_cuts_ptr = cuts_ptr.ConstHostVector(); + auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan(); + dh::caching_device_vector cuts(h_cuts_ptr.back()); // Extract the cuts from all columns concurrently - dh::caching_device_vector cuts(columns * num_cuts); - ExtractCuts(device, num_cuts, - dh::ToSpan(sorted_entries), - dh::ToSpan(column_sizes_scan), - dh::ToSpan(cuts)); - - // Push cuts into sketches stored in host memory - thrust::host_vector host_cuts(cuts); - sketch_container->Push(num_cuts, host_cuts, host_column_sizes_scan); + detail::ExtractCutsSparse(device, d_cuts_ptr, + dh::ToSpan(sorted_entries), + dh::ToSpan(column_sizes_scan), + dh::ToSpan(cuts)); + sorted_entries.clear(); + sorted_entries.shrink_to_fit(); + + sketch_container->Push(cuts_ptr.ConstDeviceSpan(), &cuts); } -void ExtractWeightedCuts(int device, - size_t num_cuts_per_feature, - Span sorted_data, - Span weights_scan, - Span column_sizes_scan, - Span cuts); - -void SortByWeight(dh::XGBCachingDeviceAllocator* alloc, - dh::caching_device_vector* weights, - dh::caching_device_vector* sorted_entries); - template void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info, int num_cuts_per_feature, @@ -268,12 +215,13 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info, [=] __device__(size_t idx) { return batch.GetElement(idx); }); dh::caching_device_vector sorted_entries; dh::caching_device_vector column_sizes_scan; - thrust::host_vector host_column_sizes_scan; - MakeEntriesFromAdapter(batch, batch_iter, - {begin, end}, missing, columns, device, - &host_column_sizes_scan, - &column_sizes_scan, - &sorted_entries); + HostDeviceVector cuts_ptr; + detail::MakeEntriesFromAdapter(batch, batch_iter, + {begin, end}, missing, + columns, num_cuts_per_feature, device, + &cuts_ptr, + &column_sizes_scan, + &sorted_entries); data::IsValidFunctor is_valid(missing); dh::caching_device_vector temp_weights(sorted_entries.size()); @@ -297,6 +245,7 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info, is_valid); CHECK_EQ(retit - d_temp_weights.data(), d_temp_weights.size()); } else { + CHECK_EQ(batch.NumRows(), weights.size()); auto const weight_iter = dh::MakeTransformIterator( thrust::make_counting_iterator(0lu), [=]__device__(size_t idx) -> float { @@ -310,90 +259,114 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info, CHECK_EQ(retit - d_temp_weights.data(), d_temp_weights.size()); } - SortByWeight(&alloc, &temp_weights, &sorted_entries); - // Extract cuts - dh::caching_device_vector cuts(columns * num_cuts_per_feature); - ExtractWeightedCuts(device, num_cuts_per_feature, - dh::ToSpan(sorted_entries), - dh::ToSpan(temp_weights), - dh::ToSpan(column_sizes_scan), - dh::ToSpan(cuts)); + detail::SortByWeight(&alloc, &temp_weights, &sorted_entries); + + auto const& h_cuts_ptr = cuts_ptr.ConstHostVector(); + auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan(); + // Extract cuts + dh::caching_device_vector cuts(h_cuts_ptr.back()); + detail::ExtractWeightedCutsSparse(device, d_cuts_ptr, + dh::ToSpan(sorted_entries), + dh::ToSpan(temp_weights), + dh::ToSpan(column_sizes_scan), + dh::ToSpan(cuts)); + sorted_entries.clear(); + sorted_entries.shrink_to_fit(); // add cuts into sketches - thrust::host_vector host_cuts(cuts); - sketch_container->Push(num_cuts_per_feature, host_cuts, host_column_sizes_scan); + sketch_container->Push(cuts_ptr.ConstDeviceSpan(), &cuts); } template HistogramCuts AdapterDeviceSketch(AdapterT* adapter, int num_bins, float missing, size_t sketch_batch_num_elements = 0) { - size_t num_cuts = RequiredSampleCuts(num_bins, adapter->NumRows()); + size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, adapter->NumRows()); CHECK(adapter->NumRows() != data::kAdapterUnknownSize); CHECK(adapter->NumColumns() != data::kAdapterUnknownSize); adapter->BeforeFirst(); adapter->Next(); auto& batch = adapter->Value(); - sketch_batch_num_elements = SketchBatchNumElements( + sketch_batch_num_elements = detail::SketchBatchNumElements( sketch_batch_num_elements, - adapter->NumColumns(), adapter->DeviceIdx(), num_cuts, false); + adapter->NumRows(), adapter->NumColumns(), std::numeric_limits::max(), + adapter->DeviceIdx(), + num_cuts_per_feature, false); // Enforce single batch CHECK(!adapter->Next()); HistogramCuts cuts; - DenseCuts dense_cuts(&cuts); SketchContainer sketch_container(num_bins, adapter->NumColumns(), - adapter->NumRows()); + adapter->NumRows(), adapter->DeviceIdx()); - for (auto begin = 0ull; begin < batch.Size(); - begin += sketch_batch_num_elements) { + for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) { size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements)); auto const& batch = adapter->Value(); ProcessSlidingWindow(batch, adapter->DeviceIdx(), adapter->NumColumns(), - begin, end, missing, &sketch_container, num_cuts); + begin, end, missing, &sketch_container, num_cuts_per_feature); } - dense_cuts.Init(&sketch_container.sketches_, num_bins, adapter->NumRows()); + sketch_container.MakeCuts(&cuts); return cuts; } +/* + * \brief Perform sketching on GPU. + * + * \param batch A batch from adapter. + * \param num_bins Bins per column. + * \param missing Floating point value that represents invalid value. + * \param sketch_container Container for output sketch. + * \param sketch_batch_num_elements Number of element per-sliding window, use it only for + * testing. + */ template void AdapterDeviceSketch(Batch batch, int num_bins, - float missing, int device, - SketchContainer* sketch_container, + float missing, SketchContainer* sketch_container, size_t sketch_batch_num_elements = 0) { size_t num_rows = batch.NumRows(); size_t num_cols = batch.NumCols(); - size_t num_cuts = RequiredSampleCuts(num_bins, num_rows); - sketch_batch_num_elements = SketchBatchNumElements( + size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, num_rows); + int32_t device = sketch_container->DeviceIdx(); + sketch_batch_num_elements = detail::SketchBatchNumElements( sketch_batch_num_elements, - num_cols, device, num_cuts, false); + num_rows, num_cols, std::numeric_limits::max(), + device, num_cuts_per_feature, false); for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) { size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements)); ProcessSlidingWindow(batch, device, num_cols, - begin, end, missing, sketch_container, num_cuts); + begin, end, missing, sketch_container, num_cuts_per_feature); } } +/* + * \brief Perform weighted sketching on GPU. + * + * When weight in info is empty, this function is equivalent to unweighted version. + */ template void AdapterDeviceSketchWeighted(Batch batch, int num_bins, MetaInfo const& info, - float missing, - int device, - SketchContainer* sketch_container, + float missing, SketchContainer* sketch_container, size_t sketch_batch_num_elements = 0) { + if (info.weights_.Size() == 0) { + return AdapterDeviceSketch(batch, num_bins, missing, sketch_container, sketch_batch_num_elements); + } + size_t num_rows = batch.NumRows(); size_t num_cols = batch.NumCols(); - size_t num_cuts = RequiredSampleCuts(num_bins, num_rows); - sketch_batch_num_elements = SketchBatchNumElements( + size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, num_rows); + int32_t device = sketch_container->DeviceIdx(); + sketch_batch_num_elements = detail::SketchBatchNumElements( sketch_batch_num_elements, - num_cols, device, num_cuts, true); + num_rows, num_cols, std::numeric_limits::max(), + device, num_cuts_per_feature, true); for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) { size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements)); ProcessWeightedSlidingWindow(batch, info, - num_cuts, + num_cuts_per_feature, CutsBuilder::UseGroup(info), missing, device, num_cols, begin, end, sketch_container); } diff --git a/src/common/hist_util.h b/src/common/hist_util.h index c48eafad84dc..b736670d22bf 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -167,7 +167,7 @@ class CutsBuilder { /*! \brief Cut configuration for sparse dataset. */ class SparseCuts : public CutsBuilder { - /* \brief Distrbute columns to each thread according to number of entries. */ + /* \brief Distribute columns to each thread according to number of entries. */ static std::vector LoadBalance(SparsePage const& page, size_t const nthreads); Monitor monitor_; diff --git a/src/common/host_device_vector.cu b/src/common/host_device_vector.cu index 7950096ca756..6ba153656bba 100644 --- a/src/common/host_device_vector.cu +++ b/src/common/host_device_vector.cu @@ -205,10 +205,10 @@ class HostDeviceVectorImpl { // data is on the host LazyResizeDevice(data_h_.size()); SetDevice(); - dh::safe_cuda(cudaMemcpy(data_d_->data().get(), - data_h_.data(), - data_d_->size() * sizeof(T), - cudaMemcpyHostToDevice)); + dh::safe_cuda(cudaMemcpyAsync(data_d_->data().get(), + data_h_.data(), + data_d_->size() * sizeof(T), + cudaMemcpyHostToDevice)); gpu_access_ = access; } diff --git a/src/common/quantile.cu b/src/common/quantile.cu new file mode 100644 index 000000000000..05519fea9702 --- /dev/null +++ b/src/common/quantile.cu @@ -0,0 +1,569 @@ +/*! + * Copyright 2020 by XGBoost Contributors + */ +#include +#include +#include +#include +#include + +#include + +#include "xgboost/span.h" +#include "quantile.h" +#include "quantile.cuh" +#include "hist_util.h" +#include "device_helpers.cuh" +#include "common.h" + +namespace xgboost { +namespace common { + +using WQSketch = DenseCuts::WQSketch; +using SketchEntry = WQSketch::Entry; + +// Algorithm 4 in XGBoost's paper, using binary search to find i. +__device__ SketchEntry BinarySearchQuery(Span const& entries, float rank) { + assert(entries.size() >= 2); + rank *= 2; + if (rank < entries.front().rmin + entries.front().rmax) { + return entries.front(); + } + if (rank >= entries.back().rmin + entries.back().rmax) { + return entries.back(); + } + + auto begin = dh::MakeTransformIterator( + entries.begin(), [=] __device__(SketchEntry const &entry) { + return entry.rmin + entry.rmax; + }); + auto end = begin + entries.size(); + auto i = thrust::upper_bound(thrust::seq, begin + 1, end - 1, rank) - begin - 1; + if (rank < entries[i].RMinNext() + entries[i+1].RMaxPrev()) { + return entries[i]; + } else { + return entries[i+1]; + } +} + +template +void CopyTo(Span out, Span src) { + CHECK_EQ(out.size(), src.size()); + dh::safe_cuda(cudaMemcpyAsync(out.data(), src.data(), + out.size_bytes(), + cudaMemcpyDefault)); +} + +// Compute the merge path. +common::Span> MergePath( + Span const &d_x, Span const &x_ptr, + Span const &d_y, Span const &y_ptr, + Span out, Span out_ptr) { + auto x_merge_key_it = thrust::make_zip_iterator(thrust::make_tuple( + dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), + [=] __device__(size_t idx) { return dh::SegmentId(x_ptr, idx); }), + d_x.data())); + auto y_merge_key_it = thrust::make_zip_iterator(thrust::make_tuple( + dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), + [=] __device__(size_t idx) { return dh::SegmentId(y_ptr, idx); }), + d_y.data())); + + using Tuple = thrust::tuple; + + thrust::constant_iterator a_ind_iter(0ul); + thrust::constant_iterator b_ind_iter(1ul); + + auto place_holder = thrust::make_constant_iterator(0u); + auto x_merge_val_it = + thrust::make_zip_iterator(thrust::make_tuple(a_ind_iter, place_holder)); + auto y_merge_val_it = + thrust::make_zip_iterator(thrust::make_tuple(b_ind_iter, place_holder)); + + dh::XGBCachingDeviceAllocator alloc; + static_assert(sizeof(Tuple) == sizeof(SketchEntry), ""); + // We reuse the memory for storing merge path. + common::Span merge_path{reinterpret_cast(out.data()), out.size()}; + // Determine the merge path, 0 if element is from x, 1 if it's from y. + thrust::merge_by_key( + thrust::cuda::par(alloc), x_merge_key_it, x_merge_key_it + d_x.size(), + y_merge_key_it, y_merge_key_it + d_y.size(), x_merge_val_it, + y_merge_val_it, thrust::make_discard_iterator(), merge_path.data(), + [=] __device__(auto const &l, auto const &r) -> bool { + auto l_column_id = thrust::get<0>(l); + auto r_column_id = thrust::get<0>(r); + if (l_column_id == r_column_id) { + return thrust::get<1>(l).value < thrust::get<1>(r).value; + } + return l_column_id < r_column_id; + }); + + // Compute output ptr + auto transform_it = + thrust::make_zip_iterator(thrust::make_tuple(x_ptr.data(), y_ptr.data())); + thrust::transform( + thrust::cuda::par(alloc), transform_it, transform_it + x_ptr.size(), + out_ptr.data(), + [] __device__(auto const& t) { return thrust::get<0>(t) + thrust::get<1>(t); }); + + // 0^th is the indicator, 1^th is placeholder + auto get_ind = []XGBOOST_DEVICE(Tuple const& t) { return thrust::get<0>(t); }; + // 0^th is the counter for x, 1^th for y. + auto get_x = []XGBOOST_DEVICE(Tuple const &t) { return thrust::get<0>(t); }; + auto get_y = []XGBOOST_DEVICE(Tuple const &t) { return thrust::get<1>(t); }; + + auto scan_key_it = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), + [=] __device__(size_t idx) { return dh::SegmentId(out_ptr, idx); }); + + auto scan_val_it = dh::MakeTransformIterator( + merge_path.data(), [=] __device__(Tuple const &t) -> Tuple { + auto ind = get_ind(t); // == 0 if element is from x + // x_counter, y_counter + return thrust::make_tuple(!ind, ind); + }); + + // Compute the index for both x and y (which of the element in a and b are used in each + // comparison) by scaning the binary merge path. Take output [(x_0, y_0), (x_0, y_1), + // ...] as an example, the comparison between (x_0, y_0) adds 1 step in the merge path. + // Asumming y_0 is less than x_0 so this step is torward the end of y. After the + // comparison, index of y is incremented by 1 from y_0 to y_1, and at the same time, y_0 + // is landed into output as the first element in merge result. The scan result is the + // subscript of x and y. + thrust::exclusive_scan_by_key( + thrust::cuda::par(alloc), scan_key_it, scan_key_it + merge_path.size(), + scan_val_it, merge_path.data(), + thrust::make_tuple(0ul, 0ul), + thrust::equal_to{}, + [=] __device__(Tuple const &l, Tuple const &r) -> Tuple { + return thrust::make_tuple(get_x(l) + get_x(r), get_y(l) + get_y(r)); + }); + + return merge_path; +} + +// Merge d_x and d_y into out. Because the final output depends on predicate (which +// summary does the output element come from) result by definition of merged rank. So we +// run it in 2 passes to obtain the merge path and then customize the standard merge +// algorithm. +void MergeImpl(int32_t device, Span const &d_x, + Span const &x_ptr, + Span const &d_y, + Span const &y_ptr, + Span out, + Span out_ptr) { + dh::safe_cuda(cudaSetDevice(device)); + CHECK_EQ(d_x.size() + d_y.size(), out.size()); + CHECK_EQ(x_ptr.size(), out_ptr.size()); + CHECK_EQ(y_ptr.size(), out_ptr.size()); + + auto d_merge_path = MergePath(d_x, x_ptr, d_y, y_ptr, out, out_ptr); + auto d_out = out; + + dh::LaunchN(device, d_out.size(), [=] __device__(size_t idx) { + auto column_id = dh::SegmentId(out_ptr, idx); + idx -= out_ptr[column_id]; + + auto d_x_column = + d_x.subspan(x_ptr[column_id], x_ptr[column_id + 1] - x_ptr[column_id]); + auto d_y_column = + d_y.subspan(y_ptr[column_id], y_ptr[column_id + 1] - y_ptr[column_id]); + auto d_out_column = d_out.subspan( + out_ptr[column_id], out_ptr[column_id + 1] - out_ptr[column_id]); + auto d_path_column = d_merge_path.subspan( + out_ptr[column_id], out_ptr[column_id + 1] - out_ptr[column_id]); + + uint64_t a_ind, b_ind; + thrust::tie(a_ind, b_ind) = d_path_column[idx]; + + // Handle empty column. If both columns are empty, we should not get this column_id + // as result of binary search. + assert((d_x_column.size() != 0) || (d_y_column.size() != 0)); + if (d_x_column.size() == 0) { + d_out_column[idx] = d_y_column[b_ind]; + return; + } + if (d_y_column.size() == 0) { + d_out_column[idx] = d_x_column[a_ind]; + return; + } + + // Handle trailing elements. + assert(a_ind <= d_x_column.size()); + if (a_ind == d_x_column.size()) { + // Trailing elements are from y because there's no more x to land. + auto y_elem = d_y_column[b_ind]; + d_out_column[idx] = SketchEntry(y_elem.rmin + d_x_column.back().RMinNext(), + y_elem.rmax + d_x_column.back().rmax, + y_elem.wmin, y_elem.value); + return; + } + auto x_elem = d_x_column[a_ind]; + assert(b_ind <= d_y_column.size()); + if (b_ind == d_y_column.size()) { + d_out_column[idx] = SketchEntry(x_elem.rmin + d_y_column.back().RMinNext(), + x_elem.rmax + d_y_column.back().rmax, + x_elem.wmin, x_elem.value); + return; + } + auto y_elem = d_y_column[b_ind]; + + /* Merge procedure. See A.3 merge operation eq (26) ~ (28). The trick to interpret + it is rewriting the symbols on both side of equality. Take eq (26) as an example: + Expand it according to definition of extended rank then rewrite it into: + + If $k_i$ is the $i$ element in output and \textbf{comes from $D_1$}: + + r_\bar{D}(k_i) = r_{\bar{D_1}}(k_i) + w_{\bar{{D_1}}}(k_i) + + [r_{\bar{D_2}}(x_i) + w_{\bar{D_2}}(x_i)] + + Where $x_i$ is the largest element in $D_2$ that's less than $k_i$. $k_i$ can be + used in $D_1$ as it's since $k_i \in D_1$. Other 2 equations can be applied + similarly with $k_i$ comes from different $D$. just use different symbol on + different source of summary. + */ + assert(idx < d_out_column.size()); + if (x_elem.value == y_elem.value) { + d_out_column[idx] = + SketchEntry{x_elem.rmin + y_elem.rmin, x_elem.rmax + y_elem.rmax, + x_elem.wmin + y_elem.wmin, x_elem.value}; + } else if (x_elem.value < y_elem.value) { + // elem from x is landed. yprev_min is the element in D_2 that's 1 rank less than + // x_elem if we put x_elem in D_2. + float yprev_min = b_ind == 0 ? 0.0f : d_y_column[b_ind - 1].RMinNext(); + // rmin should be equal to x_elem.rmin + x_elem.wmin + yprev_min. But for + // implementation, the weight is stored in a separated field and we compute the + // extended definition on the fly when needed. + d_out_column[idx] = + SketchEntry{x_elem.rmin + yprev_min, x_elem.rmax + y_elem.RMaxPrev(), + x_elem.wmin, x_elem.value}; + } else { + // elem from y is landed. + float xprev_min = a_ind == 0 ? 0.0f : d_x_column[a_ind - 1].RMinNext(); + d_out_column[idx] = + SketchEntry{xprev_min + y_elem.rmin, x_elem.RMaxPrev() + y_elem.rmax, + y_elem.wmin, y_elem.value}; + } + }); +} + +void SketchContainer::Push(common::Span cuts_ptr, + dh::caching_device_vector* entries) { + timer_.Start(__func__); + dh::safe_cuda(cudaSetDevice(device_)); + // Copy or merge the new cuts, pruning is performed during `MakeCuts`. + if (this->Current().size() == 0) { + CHECK_EQ(this->columns_ptr_.Size(), cuts_ptr.size()); + // See thrust issue 1030, THRUST_CPP_DIALECT is not correctly defined so + // move constructor is not used. + this->Current().swap(*entries); + CHECK_EQ(entries->size(), 0); + auto d_cuts_ptr = this->columns_ptr_.DevicePointer(); + thrust::copy(thrust::device, cuts_ptr.data(), + cuts_ptr.data() + cuts_ptr.size(), d_cuts_ptr); + } else { + auto d_entries = dh::ToSpan(*entries); + this->Merge(cuts_ptr, d_entries); + this->FixError(); + } + CHECK_NE(this->columns_ptr_.Size(), 0); + timer_.Stop(__func__); +} + +size_t SketchContainer::Unique() { + timer_.Start(__func__); + dh::safe_cuda(cudaSetDevice(device_)); + this->columns_ptr_.SetDevice(device_); + Span d_column_scan = this->columns_ptr_.DeviceSpan(); + CHECK_EQ(d_column_scan.size(), num_columns_ + 1); + Span entries = dh::ToSpan(this->Current()); + HostDeviceVector scan_out(d_column_scan.size()); + scan_out.SetDevice(device_); + auto d_scan_out = scan_out.DeviceSpan(); + + d_column_scan = this->columns_ptr_.DeviceSpan(); + size_t n_uniques = dh::SegmentedUnique( + d_column_scan.data(), d_column_scan.data() + d_column_scan.size(), + entries.data(), entries.data() + entries.size(), scan_out.DevicePointer(), + entries.data(), + detail::SketchUnique{}); + this->columns_ptr_.Copy(scan_out); + CHECK(!this->columns_ptr_.HostCanRead()); + + this->Current().resize(n_uniques); + timer_.Stop(__func__); + return n_uniques; +} + +void SketchContainer::Prune(size_t to) { + timer_.Start(__func__); + dh::safe_cuda(cudaSetDevice(device_)); + + this->Unique(); + OffsetT to_total = 0; + HostDeviceVector new_columns_ptr{to_total}; + for (bst_feature_t i = 0; i < num_columns_; ++i) { + size_t length = this->Column(i).size(); + length = std::min(length, to); + to_total += length; + new_columns_ptr.HostVector().emplace_back(to_total); + } + new_columns_ptr.SetDevice(device_); + this->Other().resize(to_total); + + auto d_columns_ptr_in = this->columns_ptr_.ConstDeviceSpan(); + auto d_columns_ptr_out = new_columns_ptr.ConstDeviceSpan(); + auto out = dh::ToSpan(this->Other()); + auto in = dh::ToSpan(this->Current()); + dh::LaunchN(0, to_total, [=] __device__(size_t idx) { + size_t column_id = dh::SegmentId(d_columns_ptr_out, idx); + auto out_column = out.subspan(d_columns_ptr_out[column_id], + d_columns_ptr_out[column_id + 1] - + d_columns_ptr_out[column_id]); + auto in_column = in.subspan(d_columns_ptr_in[column_id], + d_columns_ptr_in[column_id + 1] - + d_columns_ptr_in[column_id]); + idx -= d_columns_ptr_out[column_id]; + // Input has lesser columns than `to`, just copy them to the output. This is correct + // as the new output size is calculated based on both the size of `to` and current + // column. + if (in_column.size() <= to) { + out_column[idx] = in_column[idx]; + return; + } + // 1 thread for each output. See A.4 for detail. + auto entries = in_column; + auto d_out = out_column; + if (idx == 0) { + d_out.front() = entries.front(); + return; + } + if (idx == to - 1) { + d_out.back() = entries.back(); + return; + } + + float w = entries.back().rmin - entries.front().rmax; + assert(w != 0); + auto budget = static_cast(d_out.size()); + assert(budget != 0); + auto q = ((idx * w) / (to - 1) + entries.front().rmax); + d_out[idx] = BinarySearchQuery(entries, q); + }); + this->columns_ptr_.HostVector() = new_columns_ptr.HostVector(); + this->Alternate(); + timer_.Stop(__func__); +} + +void SketchContainer::Merge(Span d_that_columns_ptr, + Span that) { + dh::safe_cuda(cudaSetDevice(device_)); + timer_.Start(__func__); + if (this->Current().size() == 0) { + CHECK_EQ(this->columns_ptr_.HostVector().back(), 0); + CHECK_EQ(this->columns_ptr_.HostVector().size(), d_that_columns_ptr.size()); + CHECK_EQ(columns_ptr_.Size(), num_columns_ + 1); + thrust::copy(thrust::device, d_that_columns_ptr.data(), + d_that_columns_ptr.data() + d_that_columns_ptr.size(), + this->columns_ptr_.DevicePointer()); + auto total = this->columns_ptr_.HostVector().back(); + this->Current().resize(total); + CopyTo(dh::ToSpan(this->Current()), that); + timer_.Stop(__func__); + return; + } + + this->Other().resize(this->Current().size() + that.size()); + CHECK_EQ(d_that_columns_ptr.size(), this->columns_ptr_.Size()); + + HostDeviceVector new_columns_ptr; + new_columns_ptr.SetDevice(device_); + new_columns_ptr.Resize(this->ColumnsPtr().size()); + MergeImpl(device_, this->Data(), this->ColumnsPtr(), + that, d_that_columns_ptr, + dh::ToSpan(this->Other()), new_columns_ptr.DeviceSpan()); + this->columns_ptr_ = std::move(new_columns_ptr); + CHECK_EQ(this->columns_ptr_.Size(), num_columns_ + 1); + CHECK_EQ(new_columns_ptr.Size(), 0); + this->Alternate(); + timer_.Stop(__func__); +} + +void SketchContainer::FixError() { + dh::safe_cuda(cudaSetDevice(device_)); + auto d_columns_ptr = this->columns_ptr_.ConstDeviceSpan(); + auto in = dh::ToSpan(this->Current()); + dh::LaunchN(device_, in.size(), [=] __device__(size_t idx) { + auto column_id = dh::SegmentId(d_columns_ptr, idx); + auto in_column = in.subspan(d_columns_ptr[column_id], + d_columns_ptr[column_id + 1] - + d_columns_ptr[column_id]); + idx -= d_columns_ptr[column_id]; + float prev_rmin = idx == 0 ? 0.0f : in_column[idx-1].rmin; + if (in_column[idx].rmin < prev_rmin) { + in_column[idx].rmin = prev_rmin; + } + float prev_rmax = idx == 0 ? 0.0f : in_column[idx-1].rmax; + if (in_column[idx].rmax < prev_rmax) { + in_column[idx].rmax = prev_rmax; + } + float rmin_next = in_column[idx].RMinNext(); + if (in_column[idx].rmax < rmin_next) { + in_column[idx].rmax = rmin_next; + } + }); +} + +void SketchContainer::AllReduce() { + dh::safe_cuda(cudaSetDevice(device_)); + auto world = rabit::GetWorldSize(); + if (world == 1) { + return; + } + + timer_.Start(__func__); + if (!reducer_) { + reducer_ = std::make_unique(); + reducer_->Init(device_); + } + auto d_columns_ptr = this->columns_ptr_.ConstDeviceSpan(); + dh::device_vector gathered_ptrs; + + CHECK_EQ(d_columns_ptr.size(), num_columns_ + 1); + size_t n = d_columns_ptr.size(); + rabit::Allreduce(&n, 1); + CHECK_EQ(n, d_columns_ptr.size()) << "Number of columns differs across workers"; + + // Get the columns ptr from all workers + gathered_ptrs.resize(d_columns_ptr.size() * world, 0); + size_t rank = rabit::GetRank(); + auto offset = rank * d_columns_ptr.size(); + thrust::copy(thrust::device, d_columns_ptr.data(), d_columns_ptr.data() + d_columns_ptr.size(), + gathered_ptrs.begin() + offset); + reducer_->AllReduceSum(gathered_ptrs.data().get(), gathered_ptrs.data().get(), + gathered_ptrs.size()); + + // Get the data from all workers. + std::vector recv_lengths; + dh::caching_device_vector recvbuf; + reducer_->AllGather(this->Current().data().get(), + dh::ToSpan(this->Current()).size_bytes(), &recv_lengths, + &recvbuf); + reducer_->Synchronize(); + + // Segment the received data. + auto s_recvbuf = dh::ToSpan(recvbuf); + std::vector> allworkers; + offset = 0; + for (int32_t i = 0; i < world; ++i) { + size_t length_as_bytes = recv_lengths.at(i); + auto raw = s_recvbuf.subspan(offset, length_as_bytes); + auto sketch = Span(reinterpret_cast(raw.data()), + length_as_bytes / sizeof(SketchEntry)); + allworkers.emplace_back(sketch); + offset += length_as_bytes; + } + + // Merge them into current sketch. + for (size_t i = 0; i < allworkers.size(); ++i) { + if (i == rank) { + continue; + } + auto worker = allworkers[i]; + auto worker_ptr = + dh::ToSpan(gathered_ptrs) + .subspan(i * d_columns_ptr.size(), d_columns_ptr.size()); + this->Merge(worker_ptr, worker); + this->FixError(); + } + timer_.Stop(__func__); +} + +void SketchContainer::MakeCuts(HistogramCuts* p_cuts) { + timer_.Start(__func__); + dh::safe_cuda(cudaSetDevice(device_)); + p_cuts->min_vals_.Resize(num_columns_); + size_t global_max_rows = num_rows_; + rabit::Allreduce(&global_max_rows, 1); + + // Sync between workers. + size_t intermediate_num_cuts = + std::min(global_max_rows, static_cast(num_bins_ * kFactor)); + this->Prune(intermediate_num_cuts); + this->AllReduce(); + + // Prune to final number of bins. + this->Prune(num_bins_ + 1); + this->Unique(); + this->FixError(); + + // Set up inputs + auto d_in_columns_ptr = this->columns_ptr_.ConstDeviceSpan(); + + p_cuts->min_vals_.SetDevice(device_); + auto d_min_values = p_cuts->min_vals_.DeviceSpan(); + auto in_cut_values = dh::ToSpan(this->Current()); + + // Set up output ptr + p_cuts->cut_ptrs_.SetDevice(device_); + auto& h_out_columns_ptr = p_cuts->cut_ptrs_.HostVector(); + h_out_columns_ptr.clear(); + h_out_columns_ptr.push_back(0); + for (bst_feature_t i = 0; i < num_columns_; ++i) { + h_out_columns_ptr.push_back( + std::min(static_cast(std::max(static_cast(1ul), + this->Column(i).size())), + static_cast(num_bins_))); + } + std::partial_sum(h_out_columns_ptr.begin(), h_out_columns_ptr.end(), + h_out_columns_ptr.begin()); + auto d_out_columns_ptr = p_cuts->cut_ptrs_.ConstDeviceSpan(); + + // Set up output cuts + size_t total_bins = h_out_columns_ptr.back(); + p_cuts->cut_values_.SetDevice(device_); + p_cuts->cut_values_.Resize(total_bins); + auto out_cut_values = p_cuts->cut_values_.DeviceSpan(); + + dh::LaunchN(0, total_bins, [=] __device__(size_t idx) { + auto column_id = dh::SegmentId(d_out_columns_ptr, idx); + auto in_column = in_cut_values.subspan(d_in_columns_ptr[column_id], + d_in_columns_ptr[column_id + 1] - + d_in_columns_ptr[column_id]); + auto out_column = out_cut_values.subspan(d_out_columns_ptr[column_id], + d_out_columns_ptr[column_id + 1] - + d_out_columns_ptr[column_id]); + idx -= d_out_columns_ptr[column_id]; + if (in_column.size() == 0) { + // If the column is empty, we push a dummy value. If won't effect training as the + // column is empty, trees cannot split on it. This is just to be consistent with + // rest of the library. + if (idx == 0) { + d_min_values[column_id] = kRtEps; + out_column[0] = kRtEps; + assert(out_column.size() == 1); + } + return; + } + + // First thread is responsible for setting min values. + if (idx == 0) { + auto mval = in_column[idx].value; + d_min_values[column_id] = mval - (fabs(mval) + 1e-5); + } + // Last thread is responsible for setting a value that's greater than other cuts. + if (idx == out_column.size() - 1) { + const bst_float cpt = in_column.back().value; + // this must be bigger than last value in a scale + const bst_float last = cpt + (fabs(cpt) + 1e-5); + out_column[idx] = last; + return; + } + assert(idx+1 < in_column.size()); + out_column[idx] = in_column[idx+1].value; + }); + timer_.Stop(__func__); +} +} // namespace common +} // namespace xgboost diff --git a/src/common/quantile.cuh b/src/common/quantile.cuh new file mode 100644 index 000000000000..5c833c20a120 --- /dev/null +++ b/src/common/quantile.cuh @@ -0,0 +1,142 @@ +#ifndef XGBOOST_COMMON_QUANTILE_CUH_ +#define XGBOOST_COMMON_QUANTILE_CUH_ + +#include + +#include "xgboost/span.h" +#include "device_helpers.cuh" +#include "quantile.h" +#include "timer.h" + +namespace xgboost { +namespace common { + +class HistogramCuts; +using WQSketch = WQuantileSketch; +using SketchEntry = WQSketch::Entry; + +/*! + * \brief A container that holds the device sketches. Sketching is performed per-column, + * but fused into single operation for performance. + */ +class SketchContainer { + public: + static constexpr float kFactor = WQSketch::kFactor; + using OffsetT = bst_row_t; + static_assert(sizeof(OffsetT) == sizeof(size_t), "Wrong type for sketch element offset."); + + private: + Monitor timer_; + std::unique_ptr reducer_; + bst_row_t num_rows_; + bst_feature_t num_columns_; + int32_t num_bins_; + int32_t device_; + + // Double buffer as neither prune nor merge can be performed inplace. + dh::caching_device_vector entries_a_; + dh::caching_device_vector entries_b_; + bool current_buffer_ {true}; + // The container is just a CSC matrix. + HostDeviceVector columns_ptr_; + + dh::caching_device_vector& Current() { + if (current_buffer_) { + return entries_a_; + } else { + return entries_b_; + } + } + dh::caching_device_vector& Other() { + if (!current_buffer_) { + return entries_a_; + } else { + return entries_b_; + } + } + dh::caching_device_vector const& Current() const { + return const_cast(this)->Current(); + } + dh::caching_device_vector const& Other() const { + return const_cast(this)->Other(); + } + void Alternate() { + current_buffer_ = !current_buffer_; + } + + // Get the span of one column. + Span Column(bst_feature_t i) { + auto data = dh::ToSpan(this->Current()); + auto h_ptr = columns_ptr_.ConstHostSpan(); + auto c = data.subspan(h_ptr[i], h_ptr[i+1] - h_ptr[i]); + return c; + } + + public: + /* \breif GPU quantile structure, with sketch data for each columns. + * + * \param max_bin Maximum number of bins per columns + * \param num_columns Total number of columns in dataset. + * \param num_rows Total number of rows in known dataset (typically the rows in current worker). + * \param device GPU ID. + */ + SketchContainer(int32_t max_bin, bst_feature_t num_columns, bst_row_t num_rows, int32_t device) : + num_rows_{num_rows}, num_columns_{num_columns}, num_bins_{max_bin}, device_{device} { + // Initialize Sketches for this dmatrix + this->columns_ptr_.SetDevice(device_); + this->columns_ptr_.Resize(num_columns + 1); + timer_.Init(__func__); + } + /* \brief Return GPU ID for this container. */ + int32_t DeviceIdx() const { return device_; } + /* \brief Removes all the duplicated elements in quantile structure. */ + size_t Unique(); + /* Fix rounding error and re-establish invariance. The error is mostly generated by the + * addition inside `RMinNext` and subtraction in `RMaxPrev`. */ + void FixError(); + + /* \brief Push a CSC structured cut matrix. */ + void Push(common::Span cuts_ptr, + dh::caching_device_vector* entries); + /* \brief Prune the quantile structure. + * + * \param to The maximum size of pruned quantile. If the size of quantile structure is + * already less than `to`, then no operation is performed. + */ + void Prune(size_t to); + /* \brief Merge another set of sketch. + * \param that columns of other. + */ + void Merge(Span that_columns_ptr, + Span that); + + /* \brief Merge quantiles from other GPU workers. */ + void AllReduce(); + /* \brief Create the final histogram cut values. */ + void MakeCuts(HistogramCuts* cuts); + + Span Data() const { + return {this->Current().data().get(), this->Current().size()}; + } + + Span ColumnsPtr() const { return this->columns_ptr_.ConstDeviceSpan(); } + + // Prevent copying/assigning/moving this as its internals can't be + // assigned/copied/moved + SketchContainer(const SketchContainer&) = delete; + SketchContainer(const SketchContainer&&) = delete; + SketchContainer& operator=(const SketchContainer&) = delete; + SketchContainer& operator=(const SketchContainer&&) = delete; +}; + +namespace detail { +struct SketchUnique { + XGBOOST_DEVICE bool operator()(SketchEntry const& a, SketchEntry const& b) const { + return a.value - b.value == 0; + } +}; +} // anonymous detail +} // namespace common +} // namespace xgboost + +#endif // XGBOOST_COMMON_QUANTILE_CUH_ \ No newline at end of file diff --git a/src/common/quantile.h b/src/common/quantile.h index c0079ff8ebc8..ee2f44cd21ec 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -55,6 +55,14 @@ struct WQSummary { XGBOOST_DEVICE inline RType RMaxPrev() const { return rmax - wmin; } + + friend std::ostream& operator<<(std::ostream& os, Entry const& e) { + os << "rmin: " << e.rmin << ", " + << "rmax: " << e.rmax << ", " + << "wmin: " << e.wmin << ", " + << "value: " << e.value; + return os; + } }; /*! \brief input data queue before entering the summary */ struct Queue { @@ -184,14 +192,14 @@ struct WQSummary { } } } + /*! * \brief set current summary to be pruned summary of src * assume data field is already allocated to be at least maxsize * \param src source summary * \param maxsize size we can afford in the pruned sketch */ - - inline void SetPrune(const WQSummary &src, size_t maxsize) { + void SetPrune(const WQSummary &src, size_t maxsize) { if (src.size <= maxsize) { this->CopyFrom(src); return; } @@ -454,6 +462,9 @@ struct WXQSummary : public WQSummary { */ template class QuantileSketchTemplate { + public: + static float constexpr kFactor = 8.0; + public: /*! \brief type of summary type */ using Summary = TSummary; diff --git a/src/common/threading_utils.h b/src/common/threading_utils.h old mode 100755 new mode 100644 diff --git a/tests/cpp/common/test_device_helpers.cu b/tests/cpp/common/test_device_helpers.cu index 52436c025148..d857a01188b7 100644 --- a/tests/cpp/common/test_device_helpers.cu +++ b/tests/cpp/common/test_device_helpers.cu @@ -5,6 +5,7 @@ #include #include #include "../../../src/common/device_helpers.cuh" +#include "../../../src/common/quantile.h" #include "../helpers.h" #include "gtest/gtest.h" @@ -14,3 +15,128 @@ TEST(SumReduce, Test) { ASSERT_NEAR(sum, 100.0f, 1e-5); } +void TestAtomicSizeT() { + size_t constexpr kThreads = 235; + dh::device_vector out(1, 0); + auto d_out = dh::ToSpan(out); + dh::LaunchN(0, kThreads, [=]__device__(size_t idx){ + atomicAdd(&d_out[0], static_cast(1)); + }); + ASSERT_EQ(out[0], kThreads); +} + +TEST(AtomicAdd, SizeT) { + TestAtomicSizeT(); +} + +TEST(SegmentedUnique, Basic) { + std::vector values{0.1f, 0.2f, 0.3f, 0.62448811531066895f, 0.62448811531066895f, 0.4f}; + std::vector segments{0, 3, 6}; + + thrust::device_vector d_values(values); + thrust::device_vector d_segments{segments}; + + thrust::device_vector d_segs_out(d_segments.size()); + thrust::device_vector d_vals_out(d_values.size()); + + size_t n_uniques = dh::SegmentedUnique( + d_segments.data().get(), d_segments.data().get() + d_segments.size(), + d_values.data().get(), d_values.data().get() + d_values.size(), + d_segs_out.data().get(), d_vals_out.data().get(), + thrust::equal_to{}); + CHECK_EQ(n_uniques, 5); + + std::vector values_sol{0.1f, 0.2f, 0.3f, 0.62448811531066895f, 0.4f}; + for (auto i = 0 ; i < values_sol.size(); i ++) { + ASSERT_EQ(d_vals_out[i], values_sol[i]); + } + + std::vector segments_sol{0, 3, 5}; + for (size_t i = 0; i < d_segments.size(); ++i) { + ASSERT_EQ(segments_sol[i], d_segs_out[i]); + } + + d_segments[1] = 4; + d_segments[2] = 6; + n_uniques = dh::SegmentedUnique( + d_segments.data().get(), d_segments.data().get() + d_segments.size(), + d_values.data().get(), d_values.data().get() + d_values.size(), + d_segs_out.data().get(), d_vals_out.data().get(), + thrust::equal_to{}); + ASSERT_EQ(n_uniques, values.size()); + for (auto i = 0 ; i < values.size(); i ++) { + ASSERT_EQ(d_vals_out[i], values[i]); + } +} + +namespace { +using SketchEntry = xgboost::common::WQSummary::Entry; +struct SketchUnique { + bool __device__ operator()(SketchEntry const& a, SketchEntry const& b) const { + return a.value - b.value == 0; + } +}; +struct IsSorted { + bool __device__ operator()(SketchEntry const& a, SketchEntry const& b) const { + return a.value < b.value; + } +}; +} // namespace + +namespace xgboost { +namespace common { + +void TestSegmentedUniqueRegression(std::vector values, size_t n_duplicated) { + std::vector segments{0, static_cast(values.size())}; + + thrust::device_vector d_values(values); + thrust::device_vector d_segments(segments); + thrust::device_vector d_segments_out(segments.size()); + + size_t n_uniques = dh::SegmentedUnique( + d_segments.data().get(), d_segments.data().get() + d_segments.size(), d_values.data().get(), + d_values.data().get() + d_values.size(), d_segments_out.data().get(), d_values.data().get(), + SketchUnique{}); + ASSERT_EQ(n_uniques, values.size() - n_duplicated); + ASSERT_TRUE(thrust::is_sorted(thrust::device, d_values.begin(), + d_values.begin() + n_uniques, IsSorted{})); + ASSERT_EQ(segments.at(0), d_segments_out[0]); + ASSERT_EQ(segments.at(1), d_segments_out[1] + n_duplicated); +} + + +TEST(SegmentedUnique, Regression) { + { + std::vector values{{3149, 3150, 1, 0.62392902374267578}, + {3151, 3152, 1, 0.62418866157531738}, + {3152, 3153, 1, 0.62419462203979492}, + {3153, 3154, 1, 0.62431186437606812}, + {3154, 3155, 1, 0.6244881153106689453125}, + {3155, 3156, 1, 0.6244881153106689453125}, + {3155, 3156, 1, 0.6244881153106689453125}, + {3155, 3156, 1, 0.6244881153106689453125}, + {3157, 3158, 1, 0.62552797794342041}, + {3158, 3159, 1, 0.6256556510925293}, + {3159, 3160, 1, 0.62571090459823608}, + {3160, 3161, 1, 0.62577134370803833}}; + TestSegmentedUniqueRegression(values, 3); + } + { + std::vector values{{3149, 3150, 1, 0.62392902374267578}, + {3151, 3152, 1, 0.62418866157531738}, + {3152, 3153, 1, 0.62419462203979492}, + {3153, 3154, 1, 0.62431186437606812}, + {3154, 3155, 1, 0.6244881153106689453125}, + {3157, 3158, 1, 0.62552797794342041}, + {3158, 3159, 1, 0.6256556510925293}, + {3159, 3160, 1, 0.62571090459823608}, + {3160, 3161, 1, 0.62577134370803833}}; + TestSegmentedUniqueRegression(values, 0); + } + { + std::vector values; + TestSegmentedUniqueRegression(values, 0); + } +} +} // namespace common +} // namespace xgboost \ No newline at end of file diff --git a/tests/cpp/common/test_hist_util.cu b/tests/cpp/common/test_hist_util.cu index 365306fb8d6b..d8a75ba49ef2 100644 --- a/tests/cpp/common/test_hist_util.cu +++ b/tests/cpp/common/test_hist_util.cu @@ -30,11 +30,12 @@ HistogramCuts GetHostCuts(AdapterT *adapter, int num_bins, float missing) { builder.Build(&dmat, num_bins); return cuts; } + TEST(HistUtil, DeviceSketch) { - int num_rows = 5; int num_columns = 1; int num_bins = 4; - std::vector x = {1.0, 2.0, 3.0, 4.0, 5.0}; + std::vector x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 7.0f, -1.0f}; + int num_rows = x.size(); auto dmat = GetDMatrixFromData(x, num_rows, num_columns); auto device_cuts = DeviceSketch(0, dmat.get(), num_bins); @@ -47,26 +48,6 @@ TEST(HistUtil, DeviceSketch) { EXPECT_EQ(device_cuts.MinValues(), host_cuts.MinValues()); } -// Duplicate this function from hist_util.cu so we don't have to expose it in -// header -size_t RequiredSampleCutsTest(int max_bins, size_t num_rows) { - double eps = 1.0 / (SketchContainer::kFactor * max_bins); - size_t dummy_nlevel; - size_t num_cuts; - WQuantileSketch::LimitSizeLevel( - num_rows, eps, &dummy_nlevel, &num_cuts); - return std::min(num_cuts, num_rows); -} - -size_t BytesRequiredForTest(size_t num_rows, size_t num_columns, size_t num_bins, - bool with_weights) { - size_t bytes_num_elements = BytesPerElement(with_weights) * num_rows * num_columns; - size_t bytes_cuts = RequiredSampleCutsTest(num_bins, num_rows) * num_columns * - sizeof(DenseCuts::WQSketch::Entry); - // divide by 2 is because the memory quota used in sorting is reused for storing cuts. - return bytes_num_elements / 2 + bytes_cuts; -} - TEST(HistUtil, DeviceSketchMemory) { int num_columns = 100; int num_rows = 1000; @@ -77,15 +58,15 @@ TEST(HistUtil, DeviceSketchMemory) { dh::GlobalMemoryLogger().Clear(); ConsoleLogger::Configure({{"verbosity", "3"}}); auto device_cuts = DeviceSketch(0, dmat.get(), num_bins); - ConsoleLogger::Configure({{"verbosity", "0"}}); - size_t bytes_required = BytesRequiredForTest(num_rows, num_columns, num_bins, false); - size_t bytes_constant = 1000; - EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required + bytes_constant); - EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required); + size_t bytes_required = detail::RequiredMemory( + num_rows, num_columns, num_rows * num_columns, num_bins, false); + EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 1.05); + EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 0.95); + ConsoleLogger::Configure({{"verbosity", "0"}}); } -TEST(HistUtil, DeviceSketchMemoryWeights) { +TEST(HistUtil, DeviceSketchWeightsMemory) { int num_columns = 100; int num_rows = 1000; int num_bins = 256; @@ -98,7 +79,8 @@ TEST(HistUtil, DeviceSketchMemoryWeights) { auto device_cuts = DeviceSketch(0, dmat.get(), num_bins); ConsoleLogger::Configure({{"verbosity", "0"}}); - size_t bytes_required = BytesRequiredForTest(num_rows, num_columns, num_bins, true); + size_t bytes_required = detail::RequiredMemory( + num_rows, num_columns, num_rows * num_columns, num_bins, true); EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 1.05); EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required); } @@ -118,7 +100,7 @@ TEST(HistUtil, DeviceSketchDeterminism) { } } - TEST(HistUtil, DeviceSketchCategorical) { +TEST(HistUtil, DeviceSketchCategorical) { int categorical_sizes[] = {2, 6, 8, 12}; int num_bins = 256; int sizes[] = {25, 100, 1000}; @@ -231,11 +213,10 @@ template void ValidateBatchedCuts(Adapter adapter, int num_bins, int num_columns, int num_rows, DMatrix* dmat) { common::HistogramCuts batched_cuts; - SketchContainer sketch_container(num_bins, num_columns, num_rows); + SketchContainer sketch_container(num_bins, num_columns, num_rows, 0); AdapterDeviceSketch(adapter.Value(), num_bins, std::numeric_limits::quiet_NaN(), - 0, &sketch_container); - common::DenseCuts dense_cuts(&batched_cuts); - dense_cuts.Init(&sketch_container.sketches_, num_bins, num_rows); + &sketch_container); + sketch_container.MakeCuts(&batched_cuts); ValidateCuts(batched_cuts, dmat, num_bins); } @@ -275,12 +256,13 @@ TEST(HistUtil, AdapterDeviceSketchMemory) { std::numeric_limits::quiet_NaN()); ConsoleLogger::Configure({{"verbosity", "0"}}); size_t bytes_constant = 1000; - size_t bytes_required = BytesRequiredForTest(num_rows, num_columns, num_bins, false); + size_t bytes_required = detail::RequiredMemory( + num_rows, num_columns, num_rows * num_columns, num_bins, false); EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required + bytes_constant); - EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required); + EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 0.95); } -TEST(HistUtil, AdapterSketchBatchMemory) { +TEST(HistUtil, AdapterSketchSlidingWindowMemory) { int num_columns = 100; int num_rows = 1000; int num_bins = 256; @@ -291,17 +273,19 @@ TEST(HistUtil, AdapterSketchBatchMemory) { dh::GlobalMemoryLogger().Clear(); ConsoleLogger::Configure({{"verbosity", "3"}}); common::HistogramCuts batched_cuts; - SketchContainer sketch_container(num_bins, num_columns, num_rows); + SketchContainer sketch_container(num_bins, num_columns, num_rows, 0); AdapterDeviceSketch(adapter.Value(), num_bins, std::numeric_limits::quiet_NaN(), - 0, &sketch_container); + &sketch_container); + HistogramCuts cuts; + sketch_container.MakeCuts(&cuts); + size_t bytes_required = detail::RequiredMemory( + num_rows, num_columns, num_rows * num_columns, num_bins, false); + EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 1.05); + EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 0.95); ConsoleLogger::Configure({{"verbosity", "0"}}); - size_t bytes_constant = 1000; - size_t bytes_required = BytesRequiredForTest(num_rows, num_columns, num_bins, false); - EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required + bytes_constant); - EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required); } -TEST(HistUtil, AdapterSketchBatchWeightedMemory) { +TEST(HistUtil, AdapterSketchSlidingWindowWeightedMemory) { int num_columns = 100; int num_rows = 1000; int num_bins = 256; @@ -316,12 +300,15 @@ TEST(HistUtil, AdapterSketchBatchWeightedMemory) { dh::GlobalMemoryLogger().Clear(); ConsoleLogger::Configure({{"verbosity", "3"}}); common::HistogramCuts batched_cuts; - SketchContainer sketch_container(num_bins, num_columns, num_rows); + SketchContainer sketch_container(num_bins, num_columns, num_rows, 0); AdapterDeviceSketchWeighted(adapter.Value(), num_bins, info, - std::numeric_limits::quiet_NaN(), 0, + std::numeric_limits::quiet_NaN(), &sketch_container); + HistogramCuts cuts; + sketch_container.MakeCuts(&cuts); ConsoleLogger::Configure({{"verbosity", "0"}}); - size_t bytes_required = BytesRequiredForTest(num_rows, num_columns, num_bins, true); + size_t bytes_required = detail::RequiredMemory( + num_rows, num_columns, num_rows * num_columns, num_bins, true); EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 1.05); EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required); } @@ -462,13 +449,11 @@ void TestAdapterSketchFromWeights(bool with_group) { data::CupyAdapter adapter(m); auto const& batch = adapter.Value(); - SketchContainer sketch_container(kBins, kCols, kRows); + SketchContainer sketch_container(kBins, kCols, kRows, 0); AdapterDeviceSketchWeighted(adapter.Value(), kBins, info, std::numeric_limits::quiet_NaN(), - 0, &sketch_container); common::HistogramCuts cuts; - common::DenseCuts dense_cuts(&cuts); - dense_cuts.Init(&sketch_container.sketches_, kBins, kRows); + sketch_container.MakeCuts(&cuts); auto dmat = GetDMatrixFromData(storage.HostVector(), kRows, kCols); if (with_group) { diff --git a/tests/cpp/common/test_hist_util.h b/tests/cpp/common/test_hist_util.h index 55edb324fee1..bd88d14ef1f2 100644 --- a/tests/cpp/common/test_hist_util.h +++ b/tests/cpp/common/test_hist_util.h @@ -117,7 +117,7 @@ inline void TestBinDistribution(const HistogramCuts &cuts, int column_idx, // First and last bin can have smaller for (auto& kv : bin_weights) { - EXPECT_LE(std::abs(bin_weights[kv.first] - expected_bin_weight), + ASSERT_LE(std::abs(bin_weights[kv.first] - expected_bin_weight), allowable_error); } } @@ -189,7 +189,7 @@ inline void ValidateCuts(const HistogramCuts& cuts, DMatrix* dmat, // Collect data into columns std::vector> columns(dmat->Info().num_col_); for (auto& batch : dmat->GetBatches()) { - CHECK_GT(batch.Size(), 0); + ASSERT_GT(batch.Size(), 0); for (auto i = 0ull; i < batch.Size(); i++) { for (auto e : batch[i]) { columns[e.index].push_back(e.fvalue); diff --git a/tests/cpp/common/test_partition_builder.cc b/tests/cpp/common/test_partition_builder.cc old mode 100755 new mode 100644 diff --git a/tests/cpp/common/test_quantile.cu b/tests/cpp/common/test_quantile.cu new file mode 100644 index 000000000000..b67e9d060d5e --- /dev/null +++ b/tests/cpp/common/test_quantile.cu @@ -0,0 +1,388 @@ +#include +#include "../helpers.h" +#include "../../../src/common/hist_util.cuh" +#include "../../../src/common/quantile.cuh" + +namespace xgboost { +namespace common { +TEST(GPUQuantile, Basic) { + constexpr size_t kRows = 1000, kCols = 100, kBins = 256; + SketchContainer sketch(kBins, kCols, kRows, 0); + dh::caching_device_vector entries; + dh::device_vector cuts_ptr(kCols+1); + thrust::fill(cuts_ptr.begin(), cuts_ptr.end(), 0); + // Push empty + sketch.Push(dh::ToSpan(cuts_ptr), &entries); + ASSERT_EQ(sketch.Data().size(), 0); +} + +template void RunWithSeedsAndBins(size_t rows, Fn fn) { + std::vector seeds(4); + SimpleLCG lcg; + SimpleRealUniformDistribution dist(3, 1000); + std::generate(seeds.begin(), seeds.end(), [&](){ return dist(&lcg); }); + + std::vector bins(8); + for (size_t i = 0; i < bins.size() - 1; ++i) { + bins[i] = i * 35 + 2; + } + bins.back() = rows + 80; // provide a bin number greater than rows. + + std::vector infos(2); + auto& h_weights = infos.front().weights_.HostVector(); + h_weights.resize(rows); + std::generate(h_weights.begin(), h_weights.end(), [&]() { return dist(&lcg); }); + + for (auto seed : seeds) { + for (auto n_bin : bins) { + for (auto const& info : infos) { + fn(seed, n_bin, info); + } + } + } +} + +void TestSketchUnique(float sparsity) { + constexpr size_t kRows = 1000, kCols = 100; + RunWithSeedsAndBins(kRows, [kRows, kCols, sparsity](int32_t seed, size_t n_bins, MetaInfo const& info) { + SketchContainer sketch(n_bins, kCols, kRows, 0); + + HostDeviceVector storage; + std::string interface_str = RandomDataGenerator{kRows, kCols, sparsity} + .Seed(seed) + .Device(0) + .GenerateArrayInterface(&storage); + data::CupyAdapter adapter(interface_str); + AdapterDeviceSketchWeighted(adapter.Value(), n_bins, info, + std::numeric_limits::quiet_NaN(), &sketch); + auto n_cuts = detail::RequiredSampleCutsPerColumn(n_bins, kRows); + + dh::caching_device_vector column_sizes_scan; + HostDeviceVector cut_sizes_scan; + auto batch = adapter.Value(); + data::IsValidFunctor is_valid(std::numeric_limits::quiet_NaN()); + auto batch_iter = dh::MakeTransformIterator( + thrust::make_counting_iterator(0llu), + [=] __device__(size_t idx) { return batch.GetElement(idx); }); + auto end = kCols * kRows; + detail::GetColumnSizesScan(0, kCols, n_cuts, batch_iter, is_valid, 0, end, + &cut_sizes_scan, &column_sizes_scan); + auto const& cut_sizes = cut_sizes_scan.HostVector(); + + if (sparsity == 0) { + ASSERT_EQ(sketch.Data().size(), n_cuts * kCols); + } else { + ASSERT_EQ(sketch.Data().size(), cut_sizes.back()); + } + + sketch.Unique(); + ASSERT_TRUE(thrust::is_sorted(thrust::device, sketch.Data().data(), + sketch.Data().data() + sketch.Data().size(), + detail::SketchUnique{})); + }); +} + +TEST(GPUQuantile, Unique) { + TestSketchUnique(0); + TestSketchUnique(0.5); +} + +// if with_error is true, the test tolerates floating point error +void TestQuantileElemRank(int32_t device, Span in, + Span d_columns_ptr, bool with_error = false) { + dh::LaunchN(device, in.size(), [=]XGBOOST_DEVICE(size_t idx) { + auto column_id = dh::SegmentId(d_columns_ptr, idx); + auto in_column = in.subspan(d_columns_ptr[column_id], + d_columns_ptr[column_id + 1] - + d_columns_ptr[column_id]); + auto constexpr kEps = 1e-6f; + idx -= d_columns_ptr[column_id]; + float prev_rmin = idx == 0 ? 0.0f : in_column[idx-1].rmin; + float prev_rmax = idx == 0 ? 0.0f : in_column[idx-1].rmax; + float rmin_next = in_column[idx].RMinNext(); + + if (with_error) { + SPAN_CHECK(in_column[idx].rmin + in_column[idx].rmin * kEps >= prev_rmin); + SPAN_CHECK(in_column[idx].rmax + in_column[idx].rmin * kEps >= prev_rmax); + SPAN_CHECK(in_column[idx].rmax + in_column[idx].rmin * kEps >= rmin_next); + } else { + SPAN_CHECK(in_column[idx].rmin >= prev_rmin); + SPAN_CHECK(in_column[idx].rmax >= prev_rmax); + SPAN_CHECK(in_column[idx].rmax >= rmin_next); + } + }); + // Force sync to terminate current test instead of a later one. + dh::DebugSyncDevice(__FILE__, __LINE__); +} + + +TEST(GPUQuantile, Prune) { + constexpr size_t kRows = 1000, kCols = 100; + RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) { + SketchContainer sketch(n_bins, kCols, kRows, 0); + + HostDeviceVector storage; + std::string interface_str = RandomDataGenerator{kRows, kCols, 0} + .Device(0) + .Seed(seed) + .GenerateArrayInterface(&storage); + data::CupyAdapter adapter(interface_str); + AdapterDeviceSketchWeighted(adapter.Value(), n_bins, info, + std::numeric_limits::quiet_NaN(), &sketch); + auto n_cuts = detail::RequiredSampleCutsPerColumn(n_bins, kRows); + ASSERT_EQ(sketch.Data().size(), n_cuts * kCols); + + sketch.Prune(n_bins); + if (n_bins <= kRows) { + ASSERT_EQ(sketch.Data().size(), n_bins * kCols); + } else { + // LE because kRows * kCols is pushed into sketch, after removing duplicated entries + // we might not have that much inputs for prune. + ASSERT_LE(sketch.Data().size(), kRows * kCols); + } + // This is not necessarily true for all inputs without calling unique after + // prune. + ASSERT_TRUE(thrust::is_sorted(thrust::device, sketch.Data().data(), + sketch.Data().data() + sketch.Data().size(), + detail::SketchUnique{})); + TestQuantileElemRank(0, sketch.Data(), sketch.ColumnsPtr()); + }); +} + +TEST(GPUQuantile, MergeEmpty) { + constexpr size_t kRows = 1000, kCols = 100; + size_t n_bins = 10; + SketchContainer sketch_0(n_bins, kCols, kRows, 0); + HostDeviceVector storage_0; + std::string interface_str_0 = + RandomDataGenerator{kRows, kCols, 0}.Device(0).GenerateArrayInterface( + &storage_0); + data::CupyAdapter adapter_0(interface_str_0); + AdapterDeviceSketch(adapter_0.Value(), n_bins, + std::numeric_limits::quiet_NaN(), &sketch_0); + + std::vector entries_before(sketch_0.Data().size()); + dh::CopyDeviceSpanToVector(&entries_before, sketch_0.Data()); + std::vector ptrs_before(sketch_0.ColumnsPtr().size()); + dh::CopyDeviceSpanToVector(&ptrs_before, sketch_0.ColumnsPtr()); + thrust::device_vector columns_ptr(kCols + 1); + // Merge an empty sketch + sketch_0.Merge(dh::ToSpan(columns_ptr), Span{}); + + std::vector entries_after(sketch_0.Data().size()); + dh::CopyDeviceSpanToVector(&entries_after, sketch_0.Data()); + std::vector ptrs_after(sketch_0.ColumnsPtr().size()); + dh::CopyDeviceSpanToVector(&ptrs_after, sketch_0.ColumnsPtr()); + + CHECK_EQ(entries_before.size(), entries_after.size()); + CHECK_EQ(ptrs_before.size(), ptrs_after.size()); + for (size_t i = 0; i < entries_before.size(); ++i) { + CHECK_EQ(entries_before[i].value, entries_after[i].value); + CHECK_EQ(entries_before[i].rmin, entries_after[i].rmin); + CHECK_EQ(entries_before[i].rmax, entries_after[i].rmax); + CHECK_EQ(entries_before[i].wmin, entries_after[i].wmin); + } + for (size_t i = 0; i < ptrs_before.size(); ++i) { + CHECK_EQ(ptrs_before[i], ptrs_after[i]); + } +} + +TEST(GPUQuantile, MergeBasic) { + constexpr size_t kRows = 1000, kCols = 100; + RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) { + SketchContainer sketch_0(n_bins, kCols, kRows, 0); + HostDeviceVector storage_0; + std::string interface_str_0 = RandomDataGenerator{kRows, kCols, 0} + .Device(0) + .Seed(seed) + .GenerateArrayInterface(&storage_0); + data::CupyAdapter adapter_0(interface_str_0); + AdapterDeviceSketchWeighted(adapter_0.Value(), n_bins, info, + std::numeric_limits::quiet_NaN(), &sketch_0); + + SketchContainer sketch_1(n_bins, kCols, kRows, 0); + HostDeviceVector storage_1; + std::string interface_str_1 = RandomDataGenerator{kRows, kCols, 0} + .Device(0) + .Seed(seed) + .GenerateArrayInterface(&storage_1); + data::CupyAdapter adapter_1(interface_str_1); + AdapterDeviceSketchWeighted(adapter_1.Value(), n_bins, info, + std::numeric_limits::quiet_NaN(), &sketch_1); + + size_t size_before_merge = sketch_0.Data().size(); + sketch_0.Merge(sketch_1.ColumnsPtr(), sketch_1.Data()); + if (info.weights_.Size() != 0) { + TestQuantileElemRank(0, sketch_0.Data(), sketch_0.ColumnsPtr(), true); + sketch_0.FixError(); + TestQuantileElemRank(0, sketch_0.Data(), sketch_0.ColumnsPtr(), false); + } else { + TestQuantileElemRank(0, sketch_0.Data(), sketch_0.ColumnsPtr()); + } + + auto columns_ptr = sketch_0.ColumnsPtr(); + std::vector h_columns_ptr(columns_ptr.size()); + dh::CopyDeviceSpanToVector(&h_columns_ptr, columns_ptr); + ASSERT_EQ(h_columns_ptr.back(), sketch_1.Data().size() + size_before_merge); + + sketch_0.Unique(); + ASSERT_TRUE( + thrust::is_sorted(thrust::device, sketch_0.Data().data(), + sketch_0.Data().data() + sketch_0.Data().size(), + detail::SketchUnique{})); + }); +} + +void TestMergeDuplicated(int32_t n_bins, size_t cols, size_t rows, float frac) { + MetaInfo info; + int32_t seed = 0; + SketchContainer sketch_0(n_bins, cols, rows, 0); + HostDeviceVector storage_0; + std::string interface_str_0 = RandomDataGenerator{rows, cols, 0} + .Device(0) + .Seed(seed) + .GenerateArrayInterface(&storage_0); + data::CupyAdapter adapter_0(interface_str_0); + AdapterDeviceSketchWeighted(adapter_0.Value(), n_bins, info, + std::numeric_limits::quiet_NaN(), + &sketch_0); + + size_t f_rows = rows * frac; + SketchContainer sketch_1(n_bins, cols, f_rows, 0); + HostDeviceVector storage_1; + std::string interface_str_1 = RandomDataGenerator{f_rows, cols, 0} + .Device(0) + .Seed(seed) + .GenerateArrayInterface(&storage_1); + auto data_1 = storage_1.DeviceSpan(); + auto tuple_it = thrust::make_tuple( + thrust::make_counting_iterator(0ul), data_1.data()); + using Tuple = thrust::tuple; + auto it = thrust::make_zip_iterator(tuple_it); + thrust::transform(thrust::device, it, it + data_1.size(), data_1.data(), + [=] __device__(Tuple const &tuple) { + auto i = thrust::get<0>(tuple); + if (thrust::get<0>(tuple) % 2 == 0) { + return 0.0f; + } else { + return thrust::get<1>(tuple); + } + }); + data::CupyAdapter adapter_1(interface_str_1); + AdapterDeviceSketchWeighted(adapter_1.Value(), n_bins, info, + std::numeric_limits::quiet_NaN(), + &sketch_1); + + size_t size_before_merge = sketch_0.Data().size(); + sketch_0.Merge(sketch_1.ColumnsPtr(), sketch_1.Data()); + TestQuantileElemRank(0, sketch_0.Data(), sketch_0.ColumnsPtr()); + + auto columns_ptr = sketch_0.ColumnsPtr(); + std::vector h_columns_ptr(columns_ptr.size()); + dh::CopyDeviceSpanToVector(&h_columns_ptr, columns_ptr); + ASSERT_EQ(h_columns_ptr.back(), sketch_1.Data().size() + size_before_merge); + + sketch_0.Unique(); + ASSERT_TRUE(thrust::is_sorted(thrust::device, sketch_0.Data().data(), + sketch_0.Data().data() + sketch_0.Data().size(), + detail::SketchUnique{})); +} + +TEST(GPUQuantile, MergeDuplicated) { + size_t n_bins = 256; + constexpr size_t kRows = 1000, kCols = 100; + for (float frac = 0.5; frac < 2.5; frac += 0.5) { + TestMergeDuplicated(n_bins, kRows, kCols, frac); + } +} + +TEST(GPUQuantile, AllReduce) { + // This test is supposed to run by a python test that setups the environment. + std::string msg {"Skipping AllReduce test"}; +#if defined(__linux__) && defined(XGBOOST_USE_NCCL) + auto n_gpus = AllVisibleGPUs(); + auto port = std::getenv("DMLC_TRACKER_PORT"); + std::string port_str; + if (port) { + port_str = port; + } else { + LOG(WARNING) << msg << " as `DMLC_TRACKER_PORT` is not set up."; + return; + } + + std::vector envs{ + "DMLC_TRACKER_PORT=" + port_str, + "DMLC_TRACKER_URI=127.0.0.1", + "DMLC_NUM_WORKER=" + std::to_string(n_gpus)}; + char* c_envs[] {&(envs[0][0]), &(envs[1][0]), &(envs[2][0])}; + rabit::Init(3, c_envs); + constexpr size_t kRows = 1000, kCols = 100; + RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) { + // Set up single node version; + SketchContainer sketch_on_single_node(n_bins, kCols, kRows, 0); + auto world = rabit::GetWorldSize(); + if (world != 1) { + ASSERT_EQ(world, n_gpus); + } + + for (auto rank = 0; rank < world; ++rank) { + HostDeviceVector storage; + std::string interface_str = RandomDataGenerator{kRows, kCols, 0} + .Device(0) + .Seed(rank + seed) + .GenerateArrayInterface(&storage); + data::CupyAdapter adapter(interface_str); + AdapterDeviceSketchWeighted(adapter.Value(), n_bins, info, + std::numeric_limits::quiet_NaN(), + &sketch_on_single_node); + } + sketch_on_single_node.Unique(); + + // Set up distributed version. We rely on using rank as seed to generate + // the exact same copy of data. + auto rank = rabit::GetRank(); + SketchContainer sketch_distributed(n_bins, kCols, kRows, 0); + HostDeviceVector storage; + std::string interface_str = RandomDataGenerator{kRows, kCols, 0} + .Device(0) + .Seed(rank + seed) + .GenerateArrayInterface(&storage); + data::CupyAdapter adapter(interface_str); + AdapterDeviceSketchWeighted(adapter.Value(), n_bins, info, + std::numeric_limits::quiet_NaN(), + &sketch_distributed); + sketch_distributed.AllReduce(); + sketch_distributed.Unique(); + + ASSERT_EQ(sketch_distributed.ColumnsPtr().size(), + sketch_on_single_node.ColumnsPtr().size()); + ASSERT_EQ(sketch_distributed.Data().size(), + sketch_on_single_node.Data().size()); + + TestQuantileElemRank(0, sketch_distributed.Data(), + sketch_distributed.ColumnsPtr()); + + std::vector single_node_data( + sketch_on_single_node.Data().size()); + dh::CopyDeviceSpanToVector(&single_node_data, sketch_on_single_node.Data()); + + std::vector distributed_data(sketch_distributed.Data().size()); + dh::CopyDeviceSpanToVector(&distributed_data, sketch_distributed.Data()); + float Eps = 2e-4 * world; + + for (size_t i = 0; i < single_node_data.size(); ++i) { + ASSERT_NEAR(single_node_data[i].value, distributed_data[i].value, Eps); + ASSERT_NEAR(single_node_data[i].rmax, distributed_data[i].rmax, Eps); + ASSERT_NEAR(single_node_data[i].rmin, distributed_data[i].rmin, Eps); + ASSERT_NEAR(single_node_data[i].wmin, distributed_data[i].wmin, Eps); + } + }); + rabit::Finalize(); +#else + LOG(WARNING) << msg; + return; +#endif // !defined(__linux__) +} + +} // namespace common +} // namespace xgboost diff --git a/tests/cpp/common/test_span.cc b/tests/cpp/common/test_span.cc index 550d826b8257..c49763f3f61b 100644 --- a/tests/cpp/common/test_span.cc +++ b/tests/cpp/common/test_span.cc @@ -422,11 +422,11 @@ TEST(Span, Subspan) { ASSERT_EQ(s4.size(), s1.size() - 2); EXPECT_DEATH(s1.subspan(-1, 0), "\\[xgboost\\] Condition .* failed.\n"); - EXPECT_DEATH(s1.subspan(16, 0), "\\[xgboost\\] Condition .* failed.\n"); + EXPECT_DEATH(s1.subspan(17, 0), "\\[xgboost\\] Condition .* failed.\n"); auto constexpr kOne = static_cast::index_type>(-1); EXPECT_DEATH(s1.subspan(), "\\[xgboost\\] Condition .* failed.\n"); - EXPECT_DEATH(s1.subspan<16>(), "\\[xgboost\\] Condition .* failed.\n"); + EXPECT_DEATH(s1.subspan<17>(), "\\[xgboost\\] Condition .* failed.\n"); } TEST(Span, Compare) { diff --git a/tests/cpp/common/test_threading_utils.cc b/tests/cpp/common/test_threading_utils.cc old mode 100755 new mode 100644 diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index e0c322a46aa5..0783b9a89dba 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -168,7 +168,9 @@ class SimpleRealUniformDistribution { ResultT operator()(GeneratorT* rng) const { ResultT tmp = GenerateCanonical::digits, GeneratorT>(rng); - return (tmp * (upper_ - lower_)) + lower_; + auto ret = (tmp * (upper_ - lower_)) + lower_; + // Correct floating point error. + return std::max(ret, lower_); } }; diff --git a/tests/pytest.ini b/tests/pytest.ini index aa0c89344ca5..136782056f95 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -1,4 +1,5 @@ [pytest] markers = mgpu: Mark a test that requires multiple GPUs to run. - ci: Mark a test that runs only on CI. \ No newline at end of file + ci: Mark a test that runs only on CI. + gtest: Mark a test that requires C++ Google Test executable. \ No newline at end of file diff --git a/tests/python-gpu/test_gpu_with_dask.py b/tests/python-gpu/test_gpu_with_dask.py index 0a34dfe236ea..75dfcfe49a25 100644 --- a/tests/python-gpu/test_gpu_with_dask.py +++ b/tests/python-gpu/test_gpu_with_dask.py @@ -1,8 +1,10 @@ import sys +import os import pytest import numpy as np import unittest import xgboost +import subprocess if sys.platform.startswith("win"): pytest.skip("Skipping dask tests on Windows", allow_module_level=True) @@ -42,7 +44,8 @@ def test_dask_dataframe(self): y = y.map_partitions(cudf.from_pandas) dtrain = dxgb.DaskDMatrix(client, X, y) - out = dxgb.train(client, {'tree_method': 'gpu_hist'}, + out = dxgb.train(client, {'tree_method': 'gpu_hist', + 'debug_synchronize': True}, dtrain=dtrain, evals=[(dtrain, 'X')], num_boost_round=4) @@ -89,7 +92,8 @@ def test_dask_array(self): X = X.map_blocks(cp.asarray) y = y.map_blocks(cp.asarray) dtrain = dxgb.DaskDMatrix(client, X, y) - out = dxgb.train(client, {'tree_method': 'gpu_hist'}, + out = dxgb.train(client, {'tree_method': 'gpu_hist', + 'debug_synchronize': True}, dtrain=dtrain, evals=[(dtrain, 'X')], num_boost_round=2) @@ -107,12 +111,53 @@ def test_dask_array(self): single_node, inplace_predictions) - @pytest.mark.skipif(**tm.no_dask()) @pytest.mark.skipif(**tm.no_dask_cuda()) @pytest.mark.mgpu def test_empty_dmatrix(self): with LocalCUDACluster() as cluster: with Client(cluster) as client: - parameters = {'tree_method': 'gpu_hist'} + parameters = {'tree_method': 'gpu_hist', + 'debug_synchronize': True} run_empty_dmatrix(client, parameters) + + @pytest.mark.skipif(**tm.no_dask()) + @pytest.mark.mgpu + @pytest.mark.gtest + def test_quantile(self): + if sys.platform.startswith("win"): + pytest.skip("Skipping dask tests on Windows") + + exe = None + for possible_path in {'./testxgboost', './build/testxgboost', + '../build/testxgboost'}: + if os.path.exists(possible_path): + exe = possible_path + assert exe, 'No testxgboost executable found.' + test = "--gtest_filter=GPUQuantile.AllReduce" + + def runit(worker_addr, rabit_args): + port = None + # setup environment for running the c++ part. + for arg in rabit_args: + if arg.decode('utf-8').startswith('DMLC_TRACKER_PORT'): + port = arg.decode('utf-8') + port = port.split('=') + env = os.environ.copy() + env[port[0]] = port[1] + return subprocess.run([exe, test], env=env, stdout=subprocess.PIPE) + + with LocalCUDACluster() as cluster: + with Client(cluster) as client: + workers = list(dxgb._get_client_workers(client).keys()) + rabit_args = dxgb._get_rabit_args(workers, client) + futures = client.map(runit, + workers, + pure=False, + workers=workers, + rabit_args=rabit_args) + results = client.gather(futures) + for ret in results: + msg = ret.stdout.decode('utf-8') + assert msg.find('1 test from GPUQuantile') != -1 + assert ret.returncode == 0, msg diff --git a/tests/python/testing.py b/tests/python/testing.py index 4f9f3394aadc..411d96493695 100644 --- a/tests/python/testing.py +++ b/tests/python/testing.py @@ -1,10 +1,12 @@ # coding: utf-8 +import os from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED from xgboost.compat import DASK_INSTALLED from hypothesis import strategies from hypothesis.extra.numpy import arrays from joblib import Memory from sklearn import datasets +import tempfile import xgboost as xgb import numpy as np @@ -123,10 +125,15 @@ def get_device_dmat(self): return xgb.DeviceQuantileDMatrix(X, y, w) def get_external_dmat(self): - np.savetxt('tmptmp_1234.csv', np.hstack((self.y.reshape(len(self.y), 1), self.X)), - delimiter=',') - return xgb.DMatrix('tmptmp_1234.csv?format=csv&label_column=0#tmptmp_', - weight=self.w) + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, 'tmptmp_1234.csv') + np.savetxt(path, + np.hstack((self.y.reshape(len(self.y), 1), self.X)), + delimiter=',') + uri = path + '?format=csv&label_column=0#tmptmp_' + # The uri looks like: + # 'tmptmp_1234.csv?format=csv&label_column=0#tmptmp_' + return xgb.DMatrix(uri, weight=self.w) def __repr__(self): return self.name From a612555a9a8ea2f2c1a7aa660d6b07fa3d153379 Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 3 Jul 2020 11:43:08 +0800 Subject: [PATCH 02/15] Typo. --- src/common/quantile.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/common/quantile.cu b/src/common/quantile.cu index 05519fea9702..cc672665e6b1 100644 --- a/src/common/quantile.cu +++ b/src/common/quantile.cu @@ -536,7 +536,7 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) { d_out_columns_ptr[column_id]); idx -= d_out_columns_ptr[column_id]; if (in_column.size() == 0) { - // If the column is empty, we push a dummy value. If won't effect training as the + // If the column is empty, we push a dummy value. It won't affect training as the // column is empty, trees cannot split on it. This is just to be consistent with // rest of the library. if (idx == 0) { From 0a1038f98fe29b3b619e62689d62f51b5cba9d13 Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 3 Jul 2020 16:59:08 +0800 Subject: [PATCH 03/15] Rebase for iterative DMatrix. --- src/common/quantile.cuh | 5 +- src/data/iterative_device_dmatrix.cu | 77 +++++++++------------------- 2 files changed, 26 insertions(+), 56 deletions(-) diff --git a/src/common/quantile.cuh b/src/common/quantile.cuh index 5c833c20a120..7525871971fa 100644 --- a/src/common/quantile.cuh +++ b/src/common/quantile.cuh @@ -121,10 +121,9 @@ class SketchContainer { Span ColumnsPtr() const { return this->columns_ptr_.ConstDeviceSpan(); } - // Prevent copying/assigning/moving this as its internals can't be - // assigned/copied/moved + SketchContainer(SketchContainer&&) = default; + SketchContainer(const SketchContainer&) = delete; - SketchContainer(const SketchContainer&&) = delete; SketchContainer& operator=(const SketchContainer&) = delete; SketchContainer& operator=(const SketchContainer&&) = delete; }; diff --git a/src/data/iterative_device_dmatrix.cu b/src/data/iterative_device_dmatrix.cu index 6a4e06e02449..2380d3355a33 100644 --- a/src/data/iterative_device_dmatrix.cu +++ b/src/data/iterative_device_dmatrix.cu @@ -57,80 +57,51 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin size_t nnz = 0; // Sketch for all batches. iter.Reset(); - common::HistogramCuts cuts; - common::DenseCuts dense_cuts(&cuts); std::vector sketch_containers; size_t batches = 0; size_t accumulated_rows = 0; bst_feature_t cols = 0; + int32_t device = -1; while (iter.Next()) { - auto device = proxy->DeviceIdx(); + device = proxy->DeviceIdx(); dh::safe_cuda(cudaSetDevice(device)); if (cols == 0) { cols = num_cols(); } else { CHECK_EQ(cols, num_cols()) << "Inconsistent number of columns."; } - sketch_containers.emplace_back(batch_param_.max_bin, num_cols(), num_rows()); + sketch_containers.emplace_back(batch_param_.max_bin, num_cols(), num_rows(), device); auto* p_sketch = &sketch_containers.back(); - if (proxy->Info().weights_.Size() != 0) { proxy->Info().weights_.SetDevice(device); Dispatch(proxy, [&](auto const &value) { - common::AdapterDeviceSketchWeighted(value, batch_param_.max_bin, - proxy->Info(), - missing, device, p_sketch); - }); - } else { - Dispatch(proxy, [&](auto const &value) { - common::AdapterDeviceSketch(value, batch_param_.max_bin, missing, - device, p_sketch); - }); - } + common::AdapterDeviceSketchWeighted(value, batch_param_.max_bin, + proxy->Info(), missing, p_sketch); + }); - auto batch_rows = num_rows(); - accumulated_rows += batch_rows; - dh::caching_device_vector row_counts(batch_rows + 1, 0); - common::Span row_counts_span(row_counts.data().get(), - row_counts.size()); - row_stride = - std::max(row_stride, Dispatch(proxy, [=](auto const& value) { - return GetRowCounts(value, row_counts_span, device, missing); - })); - nnz += thrust::reduce(thrust::cuda::par(alloc), - row_counts.begin(), row_counts.end()); - batches++; + auto batch_rows = num_rows(); + accumulated_rows += batch_rows; + dh::caching_device_vector row_counts(batch_rows + 1, 0); + common::Span row_counts_span(row_counts.data().get(), + row_counts.size()); + row_stride = std::max(row_stride, Dispatch(proxy, [=](auto const &value) { + return GetRowCounts(value, row_counts_span, + device, missing); + })); + nnz += thrust::reduce(thrust::cuda::par(alloc), row_counts.begin(), + row_counts.end()); + batches++; } - // Merging multiple batches for each column - std::vector summary_array(cols); - size_t intermediate_num_cuts = std::min( - accumulated_rows, static_cast(batch_param_.max_bin * - common::SketchContainer::kFactor)); - size_t nbytes = - common::WQSketch::SummaryContainer::CalcMemCost(intermediate_num_cuts); -#pragma omp parallel for num_threads(nthread) if (nthread > 0) - for (omp_ulong c = 0; c < cols; ++c) { - for (auto& sketch_batch : sketch_containers) { - common::WQSketch::SummaryContainer summary; - sketch_batch.sketches_.at(c).GetSummary(&summary); - sketch_batch.sketches_.at(c).Init(0, 1); - summary_array.at(c).Reduce(summary, nbytes); - } + common::SketchContainer final_sketch(batch_param_.max_bin, cols, accumulated_rows, device); + for (auto const& sketch: sketch_containers) { + final_sketch.Merge(sketch.ColumnsPtr(), sketch.Data()); } sketch_containers.clear(); + sketch_containers.shrink_to_fit(); - // Build the final summary. - std::vector sketches(cols); -#pragma omp parallel for num_threads(nthread) if (nthread > 0) - for (omp_ulong c = 0; c < cols; ++c) { - sketches.at(c).Init( - accumulated_rows, - 1.0 / (common::SketchContainer::kFactor * batch_param_.max_bin)); - sketches.at(c).PushSummary(summary_array.at(c)); - } - dense_cuts.Init(&sketches, batch_param_.max_bin, accumulated_rows); - summary_array.clear(); + common::HistogramCuts cuts; + final_sketch.MakeCuts(&cuts); this->info_.num_col_ = cols; this->info_.num_row_ = accumulated_rows; From 0474abec0c0f963ecdab3e05452839b08d0c7ccc Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 3 Jul 2020 17:13:04 +0800 Subject: [PATCH 04/15] Correct type. --- src/common/quantile.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/common/quantile.cu b/src/common/quantile.cu index cc672665e6b1..6cfa0d090366 100644 --- a/src/common/quantile.cu +++ b/src/common/quantile.cu @@ -113,7 +113,7 @@ common::Span> MergePath( auto get_x = []XGBOOST_DEVICE(Tuple const &t) { return thrust::get<0>(t); }; auto get_y = []XGBOOST_DEVICE(Tuple const &t) { return thrust::get<1>(t); }; - auto scan_key_it = dh::MakeTransformIterator( + auto scan_key_it = dh::MakeTransformIterator( thrust::make_counting_iterator(0ul), [=] __device__(size_t idx) { return dh::SegmentId(out_ptr, idx); }); @@ -135,7 +135,7 @@ common::Span> MergePath( thrust::cuda::par(alloc), scan_key_it, scan_key_it + merge_path.size(), scan_val_it, merge_path.data(), thrust::make_tuple(0ul, 0ul), - thrust::equal_to{}, + thrust::equal_to{}, [=] __device__(Tuple const &l, Tuple const &r) -> Tuple { return thrust::make_tuple(get_x(l) + get_x(r), get_y(l) + get_y(r)); }); From b4590e00a8932ce4a25329c93bf2938274eb516d Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 3 Jul 2020 18:32:02 +0800 Subject: [PATCH 05/15] Lint. --- src/common/quantile.cu | 1 + src/data/iterative_device_dmatrix.cu | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/common/quantile.cu b/src/common/quantile.cu index 6cfa0d090366..1976079b6310 100644 --- a/src/common/quantile.cu +++ b/src/common/quantile.cu @@ -7,6 +7,7 @@ #include #include +#include #include #include "xgboost/span.h" diff --git a/src/data/iterative_device_dmatrix.cu b/src/data/iterative_device_dmatrix.cu index 2380d3355a33..96a45604a217 100644 --- a/src/data/iterative_device_dmatrix.cu +++ b/src/data/iterative_device_dmatrix.cu @@ -94,7 +94,7 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin } common::SketchContainer final_sketch(batch_param_.max_bin, cols, accumulated_rows, device); - for (auto const& sketch: sketch_containers) { + for (auto const& sketch : sketch_containers) { final_sketch.Merge(sketch.ColumnsPtr(), sketch.Data()); } sketch_containers.clear(); From ba9b1bc0a0ce96ab8fa871ed053815e97acc5f0f Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 3 Jul 2020 18:37:56 +0800 Subject: [PATCH 06/15] Format. --- src/data/iterative_device_dmatrix.cu | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/data/iterative_device_dmatrix.cu b/src/data/iterative_device_dmatrix.cu index 96a45604a217..c198f06fa859 100644 --- a/src/data/iterative_device_dmatrix.cu +++ b/src/data/iterative_device_dmatrix.cu @@ -73,24 +73,24 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin } sketch_containers.emplace_back(batch_param_.max_bin, num_cols(), num_rows(), device); auto* p_sketch = &sketch_containers.back(); - proxy->Info().weights_.SetDevice(device); - Dispatch(proxy, [&](auto const &value) { + proxy->Info().weights_.SetDevice(device); + Dispatch(proxy, [&](auto const &value) { common::AdapterDeviceSketchWeighted(value, batch_param_.max_bin, proxy->Info(), missing, p_sketch); }); - auto batch_rows = num_rows(); - accumulated_rows += batch_rows; - dh::caching_device_vector row_counts(batch_rows + 1, 0); - common::Span row_counts_span(row_counts.data().get(), - row_counts.size()); - row_stride = std::max(row_stride, Dispatch(proxy, [=](auto const &value) { - return GetRowCounts(value, row_counts_span, - device, missing); - })); - nnz += thrust::reduce(thrust::cuda::par(alloc), row_counts.begin(), - row_counts.end()); - batches++; + auto batch_rows = num_rows(); + accumulated_rows += batch_rows; + dh::caching_device_vector row_counts(batch_rows + 1, 0); + common::Span row_counts_span(row_counts.data().get(), + row_counts.size()); + row_stride = std::max(row_stride, Dispatch(proxy, [=](auto const &value) { + return GetRowCounts(value, row_counts_span, + device, missing); + })); + nnz += thrust::reduce(thrust::cuda::par(alloc), row_counts.begin(), + row_counts.end()); + batches++; } common::SketchContainer final_sketch(batch_param_.max_bin, cols, accumulated_rows, device); From 019dad59809d26c0e1bb3fa068023c2aa645a612 Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 3 Jul 2020 20:21:59 +0800 Subject: [PATCH 07/15] Preserve sketches order between workers. --- src/common/hist_util.cu | 6 ++---- src/common/quantile.cu | 28 ++++++++++++++------------ src/common/quantile.cuh | 2 +- tests/cpp/common/test_quantile.cu | 14 +++++++++---- tests/python-gpu/test_gpu_with_dask.py | 2 +- 5 files changed, 29 insertions(+), 23 deletions(-) diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index 803a1df02658..5da1006ed972 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -147,10 +147,8 @@ size_t RequiredMemory(bst_row_t num_rows, bst_feature_t num_columns, size_t nnz, total -= (num_columns + 1) * sizeof(SketchContainer::OffsetT); // 8. Deallocate cut size scan. total -= (num_columns + 1) * sizeof(SketchContainer::OffsetT); - // 9. Allocate std::min(rows, bins * factor) * shape due to pruning to global num rows. - total += RequiredSampleCuts(num_rows, num_bins, num_bins, nnz) * sizeof(SketchEntry); - // 10. Allocate final cut values, min values, cut ptrs: std::min(rows, bins + 1) * - // n_columns + n_columns + n_columns + 1 + // 9. Allocate final cut values, min values, cut ptrs: std::min(rows, bins + 1) * + // n_columns + n_columns + n_columns + 1 total += std::min(num_rows, num_bins) * num_columns * sizeof(float); total += num_columns * sizeof(std::remove_reference_t(); reducer_->Init(device_); } - auto d_columns_ptr = this->columns_ptr_.ConstDeviceSpan(); - dh::device_vector gathered_ptrs; + // Reduce the overhead on syncing. + size_t global_sum_rows = num_rows_; + rabit::Allreduce(&global_sum_rows, 1); + size_t intermediate_num_cuts = + std::min(global_sum_rows, static_cast(num_bins_ * kFactor)); + this->Prune(intermediate_num_cuts); + auto d_columns_ptr = this->columns_ptr_.ConstDeviceSpan(); CHECK_EQ(d_columns_ptr.size(), num_columns_ + 1); size_t n = d_columns_ptr.size(); rabit::Allreduce(&n, 1); CHECK_EQ(n, d_columns_ptr.size()) << "Number of columns differs across workers"; // Get the columns ptr from all workers + dh::device_vector gathered_ptrs; gathered_ptrs.resize(d_columns_ptr.size() * world, 0); size_t rank = rabit::GetRank(); auto offset = rank * d_columns_ptr.size(); @@ -466,18 +472,19 @@ void SketchContainer::AllReduce() { offset += length_as_bytes; } - // Merge them into current sketch. + // Merge them into a new sketch. + SketchContainer new_sketch(num_bins_, this->num_columns_, global_sum_rows, + this->device_); for (size_t i = 0; i < allworkers.size(); ++i) { - if (i == rank) { - continue; - } auto worker = allworkers[i]; auto worker_ptr = dh::ToSpan(gathered_ptrs) .subspan(i * d_columns_ptr.size(), d_columns_ptr.size()); - this->Merge(worker_ptr, worker); - this->FixError(); + new_sketch.Merge(worker_ptr, worker); + new_sketch.FixError(); } + + *this = std::move(new_sketch); timer_.Stop(__func__); } @@ -485,13 +492,8 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) { timer_.Start(__func__); dh::safe_cuda(cudaSetDevice(device_)); p_cuts->min_vals_.Resize(num_columns_); - size_t global_max_rows = num_rows_; - rabit::Allreduce(&global_max_rows, 1); // Sync between workers. - size_t intermediate_num_cuts = - std::min(global_max_rows, static_cast(num_bins_ * kFactor)); - this->Prune(intermediate_num_cuts); this->AllReduce(); // Prune to final number of bins. diff --git a/src/common/quantile.cuh b/src/common/quantile.cuh index 7525871971fa..e7a1218bb263 100644 --- a/src/common/quantile.cuh +++ b/src/common/quantile.cuh @@ -122,10 +122,10 @@ class SketchContainer { Span ColumnsPtr() const { return this->columns_ptr_.ConstDeviceSpan(); } SketchContainer(SketchContainer&&) = default; + SketchContainer& operator=(SketchContainer&&) = default; SketchContainer(const SketchContainer&) = delete; SketchContainer& operator=(const SketchContainer&) = delete; - SketchContainer& operator=(const SketchContainer&&) = delete; }; namespace detail { diff --git a/tests/cpp/common/test_quantile.cu b/tests/cpp/common/test_quantile.cu index b67e9d060d5e..61b1a1d858b4 100644 --- a/tests/cpp/common/test_quantile.cu +++ b/tests/cpp/common/test_quantile.cu @@ -324,7 +324,9 @@ TEST(GPUQuantile, AllReduce) { if (world != 1) { ASSERT_EQ(world, n_gpus); } - + size_t intermediate_num_cuts = + std::min(kRows * world, static_cast(n_bins * WQSketch::kFactor)); + std::vector containers; for (auto rank = 0; rank < world; ++rank) { HostDeviceVector storage; std::string interface_str = RandomDataGenerator{kRows, kCols, 0} @@ -332,11 +334,16 @@ TEST(GPUQuantile, AllReduce) { .Seed(rank + seed) .GenerateArrayInterface(&storage); data::CupyAdapter adapter(interface_str); + containers.emplace_back(n_bins, kCols, kRows, 0); AdapterDeviceSketchWeighted(adapter.Value(), n_bins, info, std::numeric_limits::quiet_NaN(), - &sketch_on_single_node); + &containers.back()); + } + for (auto& sketch : containers) { + sketch.Prune(intermediate_num_cuts); + sketch_on_single_node.Merge(sketch.ColumnsPtr(), sketch.Data()); + sketch_on_single_node.FixError(); } - sketch_on_single_node.Unique(); // Set up distributed version. We rely on using rank as seed to generate // the exact same copy of data. @@ -352,7 +359,6 @@ TEST(GPUQuantile, AllReduce) { std::numeric_limits::quiet_NaN(), &sketch_distributed); sketch_distributed.AllReduce(); - sketch_distributed.Unique(); ASSERT_EQ(sketch_distributed.ColumnsPtr().size(), sketch_on_single_node.ColumnsPtr().size()); diff --git a/tests/python-gpu/test_gpu_with_dask.py b/tests/python-gpu/test_gpu_with_dask.py index 75dfcfe49a25..12904befc1c9 100644 --- a/tests/python-gpu/test_gpu_with_dask.py +++ b/tests/python-gpu/test_gpu_with_dask.py @@ -130,7 +130,7 @@ def test_quantile(self): exe = None for possible_path in {'./testxgboost', './build/testxgboost', - '../build/testxgboost'}: + '../build/testxgboost', '../gpu-build/testxgboost'}: if os.path.exists(possible_path): exe = possible_path assert exe, 'No testxgboost executable found.' From 6846dc9b7454e2b381275f581dfee55c1f2c30eb Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 3 Jul 2020 20:34:31 +0800 Subject: [PATCH 08/15] Call fix error after merge. --- src/data/iterative_device_dmatrix.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/src/data/iterative_device_dmatrix.cu b/src/data/iterative_device_dmatrix.cu index c198f06fa859..5b953cf45e99 100644 --- a/src/data/iterative_device_dmatrix.cu +++ b/src/data/iterative_device_dmatrix.cu @@ -96,6 +96,7 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin common::SketchContainer final_sketch(batch_param_.max_bin, cols, accumulated_rows, device); for (auto const& sketch : sketch_containers) { final_sketch.Merge(sketch.ColumnsPtr(), sketch.Data()); + final_sketch.FixError(); } sketch_containers.clear(); sketch_containers.shrink_to_fit(); From 74731254c830b6b1c98eb56f0ba0587c6540c4f8 Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 3 Jul 2020 21:28:50 +0800 Subject: [PATCH 09/15] More tests. --- tests/cpp/common/test_quantile.cu | 78 ++++++++++++++++++++++++-- tests/python-gpu/test_gpu_with_dask.py | 21 +++++-- 2 files changed, 88 insertions(+), 11 deletions(-) diff --git a/tests/cpp/common/test_quantile.cu b/tests/cpp/common/test_quantile.cu index 61b1a1d858b4..1fd34209d066 100644 --- a/tests/cpp/common/test_quantile.cu +++ b/tests/cpp/common/test_quantile.cu @@ -296,10 +296,7 @@ TEST(GPUQuantile, MergeDuplicated) { } } -TEST(GPUQuantile, AllReduce) { - // This test is supposed to run by a python test that setups the environment. - std::string msg {"Skipping AllReduce test"}; -#if defined(__linux__) && defined(XGBOOST_USE_NCCL) +void InitRabitContext(std::string msg) { auto n_gpus = AllVisibleGPUs(); auto port = std::getenv("DMLC_TRACKER_PORT"); std::string port_str; @@ -316,6 +313,14 @@ TEST(GPUQuantile, AllReduce) { "DMLC_NUM_WORKER=" + std::to_string(n_gpus)}; char* c_envs[] {&(envs[0][0]), &(envs[1][0]), &(envs[2][0])}; rabit::Init(3, c_envs); +} + +TEST(GPUQuantile, AllReduceBasic) { + // This test is supposed to run by a python test that setups the environment. + std::string msg {"Skipping AllReduce test"}; +#if defined(__linux__) && defined(XGBOOST_USE_NCCL) + InitRabitContext(msg); + auto n_gpus = AllVisibleGPUs(); constexpr size_t kRows = 1000, kCols = 100; RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) { // Set up single node version; @@ -387,8 +392,71 @@ TEST(GPUQuantile, AllReduce) { #else LOG(WARNING) << msg; return; -#endif // !defined(__linux__) +#endif // !defined(__linux__) && defined(XGBOOST_USE_NCCL) } +TEST(GPUQuantile, SameOnAllWorkers) { + std::string msg {"Skipping SameOnAllWorkers test"}; +#if defined(__linux__) && defined(XGBOOST_USE_NCCL) + InitRabitContext(msg); + + constexpr size_t kRows = 1000, kCols = 100; + RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, + MetaInfo const &info) { + auto world = rabit::GetWorldSize(); + auto rank = rabit::GetRank(); + SketchContainer sketch_distributed(n_bins, kCols, kRows, 0); + HostDeviceVector storage; + std::string interface_str = RandomDataGenerator{kRows, kCols, 0} + .Device(0) + .Seed(rank + seed) + .GenerateArrayInterface(&storage); + data::CupyAdapter adapter(interface_str); + AdapterDeviceSketchWeighted(adapter.Value(), n_bins, info, + std::numeric_limits::quiet_NaN(), + &sketch_distributed); + sketch_distributed.AllReduce(); + + // Test for all workers having the same sketch. + size_t n_data = sketch_distributed.Data().size(); + rabit::Allreduce(&n_data, 1); + ASSERT_EQ(n_data, sketch_distributed.Data().size()); + size_t size_as_float = + sketch_distributed.Data().size_bytes() / sizeof(float); + auto local_data = Span{ + reinterpret_cast(sketch_distributed.Data().data()), + size_as_float}; + + dh::caching_device_vector all_workers(size_as_float * world); + thrust::fill(all_workers.begin(), all_workers.end(), 0); + thrust::copy(thrust::device, local_data.data(), + local_data.data() + local_data.size(), + all_workers.begin() + local_data.size() * rank); + dh::AllReducer reducer; + reducer.Init(0); + + reducer.AllReduceSum(all_workers.data().get(), all_workers.data().get(), + all_workers.size()); + reducer.Synchronize(); + + auto base_line = dh::ToSpan(all_workers).subspan(0, size_as_float); + std::vector h_base_line(base_line.size()); + dh::CopyDeviceSpanToVector(&h_base_line, base_line); + + size_t offset = 0; + for (size_t i = 0; i < world; ++i) { + auto comp = dh::ToSpan(all_workers).subspan(offset, size_as_float); + std::vector h_comp(comp.size()); + dh::CopyDeviceSpanToVector(&h_comp, comp); + ASSERT_EQ(comp.size(), base_line.size()); + ASSERT_EQ(h_base_line, h_comp); + offset += size_as_float; + } + }); +#else + LOG(WARNING) << msg; + return; +#endif // !defined(__linux__) && defined(XGBOOST_USE_NCCL) +} } // namespace common } // namespace xgboost diff --git a/tests/python-gpu/test_gpu_with_dask.py b/tests/python-gpu/test_gpu_with_dask.py index 12904befc1c9..9bed8fa09b3f 100644 --- a/tests/python-gpu/test_gpu_with_dask.py +++ b/tests/python-gpu/test_gpu_with_dask.py @@ -121,10 +121,7 @@ def test_empty_dmatrix(self): 'debug_synchronize': True} run_empty_dmatrix(client, parameters) - @pytest.mark.skipif(**tm.no_dask()) - @pytest.mark.mgpu - @pytest.mark.gtest - def test_quantile(self): + def run_quantile(self, name): if sys.platform.startswith("win"): pytest.skip("Skipping dask tests on Windows") @@ -134,7 +131,7 @@ def test_quantile(self): if os.path.exists(possible_path): exe = possible_path assert exe, 'No testxgboost executable found.' - test = "--gtest_filter=GPUQuantile.AllReduce" + test = "--gtest_filter=GPUQuantile." + name def runit(worker_addr, rabit_args): port = None @@ -159,5 +156,17 @@ def runit(worker_addr, rabit_args): results = client.gather(futures) for ret in results: msg = ret.stdout.decode('utf-8') - assert msg.find('1 test from GPUQuantile') != -1 + assert msg.find('1 test from GPUQuantile') != -1, msg assert ret.returncode == 0, msg + + @pytest.mark.skipif(**tm.no_dask()) + @pytest.mark.mgpu + @pytest.mark.gtest + def test_quantile_basic(self): + self.run_quantile('AllReduceBasic') + + @pytest.mark.skipif(**tm.no_dask()) + @pytest.mark.mgpu + @pytest.mark.gtest + def test_quantile_same_on_all_workers(self): + self.run_quantile('SameOnAllWorkers') From 94931327df8fa2e6b1e94fa8f3dce27ccd067f95 Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 3 Jul 2020 21:46:39 +0800 Subject: [PATCH 10/15] Test rank. --- tests/cpp/common/test_quantile.cu | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/cpp/common/test_quantile.cu b/tests/cpp/common/test_quantile.cu index 1fd34209d066..9ce6c856170c 100644 --- a/tests/cpp/common/test_quantile.cu +++ b/tests/cpp/common/test_quantile.cu @@ -349,6 +349,8 @@ TEST(GPUQuantile, AllReduceBasic) { sketch_on_single_node.Merge(sketch.ColumnsPtr(), sketch.Data()); sketch_on_single_node.FixError(); } + TestQuantileElemRank(0, sketch_on_single_node.Data(), + sketch_on_single_node.ColumnsPtr()); // Set up distributed version. We rely on using rank as seed to generate // the exact same copy of data. @@ -438,6 +440,7 @@ TEST(GPUQuantile, SameOnAllWorkers) { reducer.AllReduceSum(all_workers.data().get(), all_workers.data().get(), all_workers.size()); reducer.Synchronize(); + TestQuantileElemRank(0, sketch_distributed.Data(), sketch_distributed.ColumnsPtr()); auto base_line = dh::ToSpan(all_workers).subspan(0, size_as_float); std::vector h_base_line(base_line.size()); From 345f70bfd7161155d0a0911f5be41264cf364483 Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 3 Jul 2020 22:07:37 +0800 Subject: [PATCH 11/15] Call unique on test. --- tests/cpp/common/test_quantile.cu | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/cpp/common/test_quantile.cu b/tests/cpp/common/test_quantile.cu index 9ce6c856170c..a479833b53f2 100644 --- a/tests/cpp/common/test_quantile.cu +++ b/tests/cpp/common/test_quantile.cu @@ -349,6 +349,7 @@ TEST(GPUQuantile, AllReduceBasic) { sketch_on_single_node.Merge(sketch.ColumnsPtr(), sketch.Data()); sketch_on_single_node.FixError(); } + sketch_on_single_node.Unique(); TestQuantileElemRank(0, sketch_on_single_node.Data(), sketch_on_single_node.ColumnsPtr()); @@ -366,6 +367,7 @@ TEST(GPUQuantile, AllReduceBasic) { std::numeric_limits::quiet_NaN(), &sketch_distributed); sketch_distributed.AllReduce(); + sketch_distributed.Unique(); ASSERT_EQ(sketch_distributed.ColumnsPtr().size(), sketch_on_single_node.ColumnsPtr().size()); @@ -440,6 +442,7 @@ TEST(GPUQuantile, SameOnAllWorkers) { reducer.AllReduceSum(all_workers.data().get(), all_workers.data().get(), all_workers.size()); reducer.Synchronize(); + sketch_distributed.Unique(); TestQuantileElemRank(0, sketch_distributed.Data(), sketch_distributed.ColumnsPtr()); auto base_line = dh::ToSpan(all_workers).subspan(0, size_as_float); From 4e670ad12cd7f5aba6bddf7567ab795c05c13ace Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 3 Jul 2020 23:14:56 +0800 Subject: [PATCH 12/15] Allow possible floating point error. --- tests/cpp/common/test_quantile.cu | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/tests/cpp/common/test_quantile.cu b/tests/cpp/common/test_quantile.cu index a479833b53f2..581b0ec6f37a 100644 --- a/tests/cpp/common/test_quantile.cu +++ b/tests/cpp/common/test_quantile.cu @@ -321,14 +321,18 @@ TEST(GPUQuantile, AllReduceBasic) { #if defined(__linux__) && defined(XGBOOST_USE_NCCL) InitRabitContext(msg); auto n_gpus = AllVisibleGPUs(); + auto world = rabit::GetWorldSize(); + if (world != 1) { + ASSERT_EQ(world, n_gpus); + } else { + return; + } + constexpr size_t kRows = 1000, kCols = 100; RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) { // Set up single node version; SketchContainer sketch_on_single_node(n_bins, kCols, kRows, 0); - auto world = rabit::GetWorldSize(); - if (world != 1) { - ASSERT_EQ(world, n_gpus); - } + size_t intermediate_num_cuts = std::min(kRows * world, static_cast(n_bins * WQSketch::kFactor)); std::vector containers; @@ -403,11 +407,17 @@ TEST(GPUQuantile, SameOnAllWorkers) { std::string msg {"Skipping SameOnAllWorkers test"}; #if defined(__linux__) && defined(XGBOOST_USE_NCCL) InitRabitContext(msg); + auto world = rabit::GetWorldSize(); + auto n_gpus = AllVisibleGPUs(); + if (world != 1) { + ASSERT_EQ(world, n_gpus); + } else { + return; + } constexpr size_t kRows = 1000, kCols = 100; RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const &info) { - auto world = rabit::GetWorldSize(); auto rank = rabit::GetRank(); SketchContainer sketch_distributed(n_bins, kCols, kRows, 0); HostDeviceVector storage; @@ -420,6 +430,7 @@ TEST(GPUQuantile, SameOnAllWorkers) { std::numeric_limits::quiet_NaN(), &sketch_distributed); sketch_distributed.AllReduce(); + TestQuantileElemRank(0, sketch_distributed.Data(), sketch_distributed.ColumnsPtr()); // Test for all workers having the same sketch. size_t n_data = sketch_distributed.Data().size(); @@ -442,8 +453,6 @@ TEST(GPUQuantile, SameOnAllWorkers) { reducer.AllReduceSum(all_workers.data().get(), all_workers.data().get(), all_workers.size()); reducer.Synchronize(); - sketch_distributed.Unique(); - TestQuantileElemRank(0, sketch_distributed.Data(), sketch_distributed.ColumnsPtr()); auto base_line = dh::ToSpan(all_workers).subspan(0, size_as_float); std::vector h_base_line(base_line.size()); @@ -455,7 +464,9 @@ TEST(GPUQuantile, SameOnAllWorkers) { std::vector h_comp(comp.size()); dh::CopyDeviceSpanToVector(&h_comp, comp); ASSERT_EQ(comp.size(), base_line.size()); - ASSERT_EQ(h_base_line, h_comp); + for (size_t j = 0; j < h_comp.size(); ++j) { + ASSERT_NEAR(h_base_line[j], h_comp[j], kRtEps); + } offset += size_as_float; } }); From c1d43c7e1947c63ce050a95ec9d75e5814a410f0 Mon Sep 17 00:00:00 2001 From: fis Date: Sat, 4 Jul 2020 00:02:02 +0800 Subject: [PATCH 13/15] Floating point. --- tests/cpp/common/test_quantile.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/cpp/common/test_quantile.cu b/tests/cpp/common/test_quantile.cu index 581b0ec6f37a..caef6c2c7fad 100644 --- a/tests/cpp/common/test_quantile.cu +++ b/tests/cpp/common/test_quantile.cu @@ -200,7 +200,7 @@ TEST(GPUQuantile, MergeBasic) { AdapterDeviceSketchWeighted(adapter_0.Value(), n_bins, info, std::numeric_limits::quiet_NaN(), &sketch_0); - SketchContainer sketch_1(n_bins, kCols, kRows, 0); + SketchContainer sketch_1(n_bins, kCols, kRows * kRows, 0); HostDeviceVector storage_1; std::string interface_str_1 = RandomDataGenerator{kRows, kCols, 0} .Device(0) @@ -430,6 +430,7 @@ TEST(GPUQuantile, SameOnAllWorkers) { std::numeric_limits::quiet_NaN(), &sketch_distributed); sketch_distributed.AllReduce(); + sketch_distributed.Unique(); TestQuantileElemRank(0, sketch_distributed.Data(), sketch_distributed.ColumnsPtr()); // Test for all workers having the same sketch. From 17a77899cc5dd96386fd51b3617b88f31390bfbc Mon Sep 17 00:00:00 2001 From: fis Date: Sat, 4 Jul 2020 01:32:49 +0800 Subject: [PATCH 14/15] Hypothesis test on dask. --- tests/python-gpu/test_gpu_with_dask.py | 41 +++++++++++++++++++++++++- tests/python/testing.py | 1 + 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/tests/python-gpu/test_gpu_with_dask.py b/tests/python-gpu/test_gpu_with_dask.py index 9bed8fa09b3f..8fca441bbeae 100644 --- a/tests/python-gpu/test_gpu_with_dask.py +++ b/tests/python-gpu/test_gpu_with_dask.py @@ -5,6 +5,8 @@ import unittest import xgboost import subprocess +from hypothesis import given, strategies, settings, note +from test_gpu_updaters import parameter_strategy if sys.platform.startswith("win"): pytest.skip("Skipping dask tests on Windows", allow_module_level=True) @@ -14,11 +16,13 @@ from test_with_dask import generate_array # noqa import testing as tm # noqa + try: import dask.dataframe as dd from xgboost import dask as dxgb from dask_cuda import LocalCUDACluster from dask.distributed import Client + from dask import array as da import cudf except ImportError: pass @@ -64,7 +68,8 @@ def test_dask_dataframe(self): xgboost.DMatrix(X.compute())) cp.testing.assert_allclose(single_node, predictions) - np.testing.assert_allclose(single_node, series_predictions.to_array()) + np.testing.assert_allclose(single_node, + series_predictions.to_array()) predt = dxgb.predict(client, out, X) assert isinstance(predt, dd.Series) @@ -80,6 +85,40 @@ def is_df(part): cp.testing.assert_allclose( predt.values.compute(), single_node) + @given(parameter_strategy, strategies.integers(1, 20), + tm.dataset_strategy) + @settings(deadline=None) + def test_gpu_hist(self, params, num_rounds, dataset): + with LocalCUDACluster(n_workers=2) as cluster: + with Client(cluster) as client: + params['tree_method'] = 'gpu_hist' + params = dataset.set_params(params) + # multi class doesn't handle empty dataset well (empty + # means at least 1 worker has data). + if params['objective'] == "multi:softmax": + return + # It doesn't make sense to distribute a completely + # empty dataset. + if dataset.X.shape[0] == 0: + return + + chunk = 128 + X = da.from_array(dataset.X, + chunks=(chunk, dataset.X.shape[1])) + y = da.from_array(dataset.y, chunks=(chunk, )) + if dataset.w is not None: + w = da.from_array(dataset.w, chunks=(chunk, )) + else: + w = None + + m = dxgb.DaskDMatrix( + client, data=X, label=y, weight=w) + history = dxgb.train(client, params=params, dtrain=m, + num_boost_round=num_rounds, + evals=[(m, 'train')])['history'] + note(history) + assert tm.non_increasing(history['train'][dataset.metric]) + @pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.mgpu def test_dask_array(self): diff --git a/tests/python/testing.py b/tests/python/testing.py index 411d96493695..f8d2431bf19e 100644 --- a/tests/python/testing.py +++ b/tests/python/testing.py @@ -188,6 +188,7 @@ def _dataset_and_weight(draw): data.w = draw(arrays(np.float64, (len(data.y)), elements=strategies.floats(0.1, 2.0))) return data + # A strategy for drawing from a set of example datasets # May add random weights to the dataset dataset_strategy = _dataset_and_weight() From 2f22f3c1e6ae6aa3f34b3769f4ad4d8167886da8 Mon Sep 17 00:00:00 2001 From: fis Date: Sat, 4 Jul 2020 02:35:57 +0800 Subject: [PATCH 15/15] mgpu tag. --- tests/python-gpu/test_gpu_with_dask.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python-gpu/test_gpu_with_dask.py b/tests/python-gpu/test_gpu_with_dask.py index 8fca441bbeae..382835c584aa 100644 --- a/tests/python-gpu/test_gpu_with_dask.py +++ b/tests/python-gpu/test_gpu_with_dask.py @@ -88,6 +88,7 @@ def is_df(part): @given(parameter_strategy, strategies.integers(1, 20), tm.dataset_strategy) @settings(deadline=None) + @pytest.mark.mgpu def test_gpu_hist(self, params, num_rounds, dataset): with LocalCUDACluster(n_workers=2) as cluster: with Client(cluster) as client: