diff --git a/src/common/hist_util.h b/src/common/hist_util.h index d86b73135f34..0334b901224a 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -116,26 +116,14 @@ inline HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins) { for (auto& column : column_sizes) { column.resize(info.num_col_, 0); } - for (auto const& page : m->GetBatches()) { - page.data.HostVector(); - page.offset.HostVector(); - ParallelFor(page.Size(), threads, [&](size_t i) { - auto &local_column_sizes = column_sizes.at(omp_get_thread_num()); - auto row = page[i]; - auto const *p_row = row.data(); - for (size_t j = 0; j < row.size(); ++j) { - local_column_sizes.at(p_row[j].index)++; - } - }); - } std::vector reduced(info.num_col_, 0); - - ParallelFor(info.num_col_, threads, [&](size_t i) { - for (auto const &thread : column_sizes) { - reduced[i] += thread[i]; + for (auto const& page : m->GetBatches()) { + auto const &entries_per_column = + HostSketchContainer::CalcColumnSize(page, info.num_col_, threads); + for (size_t i = 0; i < entries_per_column.size(); ++i) { + reduced[i] += entries_per_column[i]; } - }); - + } HostSketchContainer container(reduced, max_bins, HostSketchContainer::UseGroup(info)); for (auto const &page : m->GetBatches()) { diff --git a/src/common/quantile.cc b/src/common/quantile.cc index 374864c8f4b0..9ab48a304b77 100644 --- a/src/common/quantile.cc +++ b/src/common/quantile.cc @@ -25,34 +25,67 @@ HostSketchContainer::HostSketchContainer(std::vector columns_size, } } -std::vector LoadBalance(SparsePage const &page, - std::vector columns_size, - size_t const nthreads) { - /* Some sparse datasets have their mass concentrating on small - * number of features. To avoid wating for a few threads running - * forever, we here distirbute different number of columns to - * different threads according to number of entries. */ - size_t const total_entries = page.data.Size(); +std::vector +HostSketchContainer::CalcColumnSize(SparsePage const &batch, + bst_feature_t const n_columns, + size_t const nthreads) { + auto page = batch.GetView(); + std::vector> column_sizes(nthreads); + for (auto &column : column_sizes) { + column.resize(n_columns, 0); + } + + ParallelFor(page.Size(), nthreads, [&](size_t i) { + auto &local_column_sizes = column_sizes.at(omp_get_thread_num()); + auto row = page[i]; + auto const *p_row = row.data(); + for (size_t j = 0; j < row.size(); ++j) { + local_column_sizes.at(p_row[j].index)++; + } + }); + std::vector entries_per_columns(n_columns, 0); + ParallelFor(n_columns, nthreads, [&](size_t i) { + for (auto const &thread : column_sizes) { + entries_per_columns[i] += thread[i]; + } + }); + return entries_per_columns; +} + +std::vector HostSketchContainer::LoadBalance( + SparsePage const &batch, bst_feature_t n_columns, size_t const nthreads) { + /* Some sparse datasets have their mass concentrating on small number of features. To + * avoid wating for a few threads running forever, we here distirbute different number + * of columns to different threads according to number of entries. + */ + auto page = batch.GetView(); + size_t const total_entries = page.data.size(); size_t const entries_per_thread = common::DivRoundUp(total_entries, nthreads); - std::vector cols_ptr(nthreads+1, 0); + std::vector> column_sizes(nthreads); + for (auto& column : column_sizes) { + column.resize(n_columns, 0); + } + std::vector entries_per_columns = + CalcColumnSize(batch, n_columns, nthreads); + std::vector cols_ptr(nthreads + 1, 0); size_t count {0}; size_t current_thread {1}; - for (auto col : columns_size) { - cols_ptr[current_thread]++; // add one column to thread + for (auto col : entries_per_columns) { + cols_ptr.at(current_thread)++; // add one column to thread count += col; - if (count > entries_per_thread + 1) { + CHECK_LE(count, total_entries); + if (count > entries_per_thread) { current_thread++; count = 0; - cols_ptr[current_thread] = cols_ptr[current_thread-1]; + cols_ptr.at(current_thread) = cols_ptr[current_thread-1]; } } // Idle threads. for (; current_thread < cols_ptr.size() - 1; ++current_thread) { cols_ptr[current_thread+1] = cols_ptr[current_thread]; } - return cols_ptr; } @@ -67,11 +100,10 @@ void HostSketchContainer::PushRowPage(SparsePage const &page, // Use group index for weights? auto batch = page.GetView(); dmlc::OMPException exec; - // Parallel over columns. Asumming the data is dense, each thread owns a set of - // consecutive columns. + // Parallel over columns. Each thread owns a set of consecutive columns. auto const ncol = static_cast(info.num_col_); auto const is_dense = info.num_nonzero_ == info.num_col_ * info.num_row_; - auto thread_columns_ptr = LoadBalance(page, columns_size_, nthread); + auto thread_columns_ptr = LoadBalance(page, info.num_col_, nthread); #pragma omp parallel num_threads(nthread) { @@ -112,58 +144,158 @@ void HostSketchContainer::PushRowPage(SparsePage const &page, monitor_.Stop(__func__); } -void AddCutPoint(WQuantileSketch::SummaryContainer const &summary, - int max_bin, HistogramCuts *cuts) { - size_t required_cuts = std::min(summary.size, static_cast(max_bin)); - auto& cut_values = cuts->cut_values_.HostVector(); - for (size_t i = 1; i < required_cuts; ++i) { - bst_float cpt = summary.data[i].value; - if (i == 1 || cpt > cuts->cut_values_.ConstHostVector().back()) { - cut_values.push_back(cpt); - } +void HostSketchContainer::GatherSketchInfo( + std::vector const &reduced, + std::vector *p_worker_segments, + std::vector *p_sketches_scan, + std::vector *p_global_sketches) { + auto& worker_segments = *p_worker_segments; + worker_segments.resize(1, 0); + auto world = rabit::GetWorldSize(); + auto rank = rabit::GetRank(); + auto n_columns = sketches_.size(); + + std::vector sketch_size; + for (auto const& sketch : reduced) { + sketch_size.push_back(sketch.size); + } + std::vector& sketches_scan = *p_sketches_scan; + sketches_scan.resize((n_columns + 1) * world, 0); + size_t beg_scan = rank * (n_columns + 1); + std::partial_sum(sketch_size.cbegin(), sketch_size.cend(), + sketches_scan.begin() + beg_scan + 1); + // Gather all column pointers + rabit::Allreduce(sketches_scan.data(), sketches_scan.size()); + + for (int32_t i = 0; i < world; ++i) { + size_t back = (i + 1) * (n_columns + 1) - 1; + auto n_entries = sketches_scan.at(back); + worker_segments.push_back(n_entries); + } + // Offset of sketch from each worker. + std::partial_sum(worker_segments.begin(), worker_segments.end(), + worker_segments.begin()); + CHECK_GE(worker_segments.size(), 1); + auto total = worker_segments.back(); + + auto& global_sketches = *p_global_sketches; + global_sketches.resize(total, WQSketch::Entry{0, 0, 0, 0}); + auto worker_sketch = Span{global_sketches}.subspan( + worker_segments[rank], worker_segments[rank + 1] - worker_segments[rank]); + size_t cursor = 0; + for (auto const &sketch : reduced) { + std::copy(sketch.data, sketch.data + sketch.size, + worker_sketch.begin() + cursor); + cursor += sketch.size; } + + static_assert(sizeof(WQSketch::Entry) / 4 == sizeof(float), ""); + rabit::Allreduce( + reinterpret_cast(global_sketches.data()), + global_sketches.size() * sizeof(WQSketch::Entry) / sizeof(float)); } -void HostSketchContainer::MakeCuts(HistogramCuts* cuts) { +void HostSketchContainer::AllReduce( + std::vector *p_reduced, + std::vector* p_num_cuts) { monitor_.Start(__func__); - rabit::Allreduce(columns_size_.data(), columns_size_.size()); - std::vector reduced(sketches_.size()); - std::vector num_cuts; - size_t nbytes = 0; + auto& num_cuts = *p_num_cuts; + CHECK_EQ(num_cuts.size(), 0); + auto &reduced = *p_reduced; + reduced.resize(sketches_.size()); + + size_t n_columns = sketches_.size(); + rabit::Allreduce(&n_columns, 1); + CHECK_EQ(n_columns, sketches_.size()) << "Number of columns differs across workers"; + + // Prune the intermediate num cuts for synchronization. + std::vector global_column_size(columns_size_); + rabit::Allreduce(global_column_size.data(), global_column_size.size()); + +size_t nbytes = 0; for (size_t i = 0; i < sketches_.size(); ++i) { int32_t intermediate_num_cuts = static_cast(std::min( - columns_size_[i], static_cast(max_bins_ * WQSketch::kFactor))); - if (columns_size_[i] != 0) { + global_column_size[i], static_cast(max_bins_ * WQSketch::kFactor))); + if (global_column_size[i] != 0) { WQSketch::SummaryContainer out; sketches_[i].GetSummary(&out); reduced[i].Reserve(intermediate_num_cuts); CHECK(reduced[i].data); reduced[i].SetPrune(out, intermediate_num_cuts); + nbytes = std::max( + WQSketch::SummaryContainer::CalcMemCost(intermediate_num_cuts), + nbytes); } + num_cuts.push_back(intermediate_num_cuts); - nbytes = std::max( - WQSketch::SummaryContainer::CalcMemCost(intermediate_num_cuts), nbytes); } + auto world = rabit::GetWorldSize(); + if (world == 1) { + return; + } + + std::vector worker_segments(1, 0); // CSC pointer to sketches. + std::vector sketches_scan((n_columns + 1) * world, 0); + + std::vector global_sketches; + this->GatherSketchInfo(reduced, &worker_segments, &sketches_scan, + &global_sketches); + + std::vector final_sketches(n_columns); + ParallelFor(n_columns, omp_get_max_threads(), [&](size_t fidx) { + int32_t intermediate_num_cuts = num_cuts[fidx]; + auto nbytes = + WQSketch::SummaryContainer::CalcMemCost(intermediate_num_cuts); + + for (int32_t i = 1; i < world + 1; ++i) { + auto size = worker_segments.at(i) - worker_segments[i - 1]; + auto worker_sketches = Span{global_sketches}.subspan( + worker_segments[i - 1], size); + auto worker_scan = + Span(sketches_scan) + .subspan((i - 1) * (n_columns + 1), (n_columns + 1)); + + auto worker_feature = worker_sketches.subspan( + worker_scan[fidx], worker_scan[fidx + 1] - worker_scan[fidx]); + CHECK(worker_feature.data()); + WQSummary summary(worker_feature.data(), + worker_feature.size()); + auto &out = final_sketches.at(fidx); + out.Reduce(summary, nbytes); + } + + reduced.at(fidx).Reserve(intermediate_num_cuts); + reduced.at(fidx).SetPrune(final_sketches.at(fidx), intermediate_num_cuts); + }); + monitor_.Stop(__func__); +} - if (rabit::IsDistributed()) { - // FIXME(trivialfis): This call will allocate nbytes * num_columns on rabit, which - // may generate oom error when data is sparse. To fix it, we need to: - // - gather the column offsets over all workers. - // - run rabit::allgather on sketch data to collect all data. - // - merge all gathered sketches based on worker offsets and column offsets of data - // from each worker. - // See GPU implementation for details. - rabit::SerializeReducer sreducer; - sreducer.Allreduce(dmlc::BeginPtr(reduced), nbytes, reduced.size()); +void AddCutPoint(WQuantileSketch::SummaryContainer const &summary, + int max_bin, HistogramCuts *cuts) { + size_t required_cuts = std::min(summary.size, static_cast(max_bin)); + auto& cut_values = cuts->cut_values_.HostVector(); + for (size_t i = 1; i < required_cuts; ++i) { + bst_float cpt = summary.data[i].value; + if (i == 1 || cpt > cuts->cut_values_.ConstHostVector().back()) { + cut_values.push_back(cpt); + } } +} + +void HostSketchContainer::MakeCuts(HistogramCuts* cuts) { + monitor_.Start(__func__); + std::vector reduced; + std::vector num_cuts; + this->AllReduce(&reduced, &num_cuts); cuts->min_vals_.HostVector().resize(sketches_.size(), 0.0f); + for (size_t fid = 0; fid < reduced.size(); ++fid) { WQSketch::SummaryContainer a; size_t max_num_bins = std::min(num_cuts[fid], max_bins_); a.Reserve(max_num_bins + 1); CHECK(a.data); - if (columns_size_[fid] != 0) { + if (num_cuts[fid] != 0) { a.SetPrune(reduced[fid], max_num_bins + 1); CHECK(a.data && reduced[fid].data); const bst_float mval = a.data[0].value; @@ -173,6 +305,7 @@ void HostSketchContainer::MakeCuts(HistogramCuts* cuts) { const float mval = 1e-5f; cuts->min_vals_.HostVector()[fid] = mval; } + AddCutPoint(a, max_num_bins, cuts); // push a value that is greater than anything const bst_float cpt diff --git a/src/common/quantile.h b/src/common/quantile.h index 11e2530f748e..a70bf809ea28 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -166,6 +166,16 @@ struct WQSummary { * \param src source sketch */ inline void CopyFrom(const WQSummary &src) { + if (!src.data) { + CHECK_EQ(src.size, 0); + size = 0; + return; + } + if (!data) { + CHECK_EQ(this->size, 0); + CHECK_EQ(src.size, 0); + return; + } size = src.size; std::memcpy(data, src.data, sizeof(Entry) * size); } @@ -721,6 +731,14 @@ class HostSketchContainer { return use_group_ind; } + static std::vector CalcColumnSize(SparsePage const &page, + bst_feature_t const n_columns, + size_t const nthreads); + + static std::vector LoadBalance(SparsePage const &page, + bst_feature_t n_columns, + size_t const nthreads); + static uint32_t SearchGroupIndFromRow(std::vector const &group_ptr, size_t const base_rowid) { CHECK_LT(base_rowid, group_ptr.back()) @@ -730,6 +748,14 @@ class HostSketchContainer { group_ptr.cbegin() - 1; return group_ind; } + // Gather sketches from all workers. + void GatherSketchInfo(std::vector const &reduced, + std::vector *p_worker_segments, + std::vector *p_sketches_scan, + std::vector *p_global_sketches); + // Merge sketches from all workers. + void AllReduce(std::vector *p_reduced, + std::vector* p_num_cuts); /* \brief Push a CSR matrix. */ void PushRowPage(SparsePage const& page, MetaInfo const& info); diff --git a/tests/cpp/c_api/test_c_api.cc b/tests/cpp/c_api/test_c_api.cc index f4c2722fe92c..664118780cc3 100644 --- a/tests/cpp/c_api/test_c_api.cc +++ b/tests/cpp/c_api/test_c_api.cc @@ -23,9 +23,9 @@ TEST(CAPI, XGDMatrixCreateFromMatDT) { std::shared_ptr *dmat = static_cast *>(handle); xgboost::MetaInfo &info = (*dmat)->Info(); - ASSERT_EQ(info.num_col_, 2); - ASSERT_EQ(info.num_row_, 3); - ASSERT_EQ(info.num_nonzero_, 6); + ASSERT_EQ(info.num_col_, 2ul); + ASSERT_EQ(info.num_row_, 3ul); + ASSERT_EQ(info.num_nonzero_, 6ul); for (const auto &batch : (*dmat)->GetBatches()) { ASSERT_EQ(batch[0][0].fvalue, 0.0f); @@ -38,9 +38,9 @@ TEST(CAPI, XGDMatrixCreateFromMatDT) { } TEST(CAPI, XGDMatrixCreateFromMatOmp) { - std::vector num_rows = {100, 11374, 15000}; + std::vector num_rows = {100, 11374, 15000}; for (auto row : num_rows) { - int num_cols = 50; + bst_ulong num_cols = 50; int num_missing = 5; DMatrixHandle handle; std::vector data(num_cols * row, 1.5); diff --git a/tests/cpp/common/test_hist_util.cc b/tests/cpp/common/test_hist_util.cc index 0fad360f4298..24c23b3e2608 100644 --- a/tests/cpp/common/test_hist_util.cc +++ b/tests/cpp/common/test_hist_util.cc @@ -159,10 +159,10 @@ TEST(CutsBuilder, SearchGroupInd) { HistogramCuts hmat; size_t group_ind = HostSketchContainer::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 0); - ASSERT_EQ(group_ind, 0); + ASSERT_EQ(group_ind, 0ul); group_ind = HostSketchContainer::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 5); - ASSERT_EQ(group_ind, 2); + ASSERT_EQ(group_ind, 2ul); EXPECT_ANY_THROW(HostSketchContainer::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 17)); @@ -189,7 +189,7 @@ TEST(HistUtil, DenseCutsCategorical) { EXPECT_LT(cuts.MinValues()[0], x_sorted.front()); EXPECT_GT(cuts_from_sketch.front(), x_sorted.front()); EXPECT_GE(cuts_from_sketch.back(), x_sorted.back()); - EXPECT_EQ(cuts_from_sketch.size(), num_categories); + EXPECT_EQ(cuts_from_sketch.size(), static_cast(num_categories)); } } } diff --git a/tests/cpp/common/test_hist_util.h b/tests/cpp/common/test_hist_util.h index bd88d14ef1f2..d025e5ea60bf 100644 --- a/tests/cpp/common/test_hist_util.h +++ b/tests/cpp/common/test_hist_util.h @@ -162,7 +162,7 @@ inline void ValidateColumn(const HistogramCuts& cuts, int column_idx, // Check all cut points are unique EXPECT_EQ(std::set(cuts_begin, cuts_end).size(), - cuts_end - cuts_begin); + static_cast(cuts_end - cuts_begin)); auto unique = std::set(sorted_column.begin(), sorted_column.end()); if (unique.size() <= num_bins) { @@ -189,7 +189,7 @@ inline void ValidateCuts(const HistogramCuts& cuts, DMatrix* dmat, // Collect data into columns std::vector> columns(dmat->Info().num_col_); for (auto& batch : dmat->GetBatches()) { - ASSERT_GT(batch.Size(), 0); + ASSERT_GT(batch.Size(), 0ul); for (auto i = 0ull; i < batch.Size(); i++) { for (auto e : batch[i]) { columns[e.index].push_back(e.fvalue); diff --git a/tests/cpp/common/test_json.cc b/tests/cpp/common/test_json.cc index 8665420d684a..029beee8d48b 100644 --- a/tests/cpp/common/test_json.cc +++ b/tests/cpp/common/test_json.cc @@ -222,7 +222,7 @@ TEST(Json, ParseArray) { auto json = Json::Load(StringView{str.c_str(), str.size()}); json = json["nodes"]; std::vector arr = get(json); - ASSERT_EQ(arr.size(), 3); + ASSERT_EQ(arr.size(), 3ul); Json v0 = arr[0]; ASSERT_EQ(get(v0["depth"]), 3); ASSERT_NEAR(get(v0["gain"]), 10.4866, kRtEps); @@ -284,7 +284,7 @@ TEST(Json, EmptyArray) { std::istringstream iss(str); auto json = Json::Load(StringView{str.c_str(), str.size()}); auto arr = get(json["leaf_vector"]); - ASSERT_EQ(arr.size(), 0); + ASSERT_EQ(arr.size(), 0ul); } TEST(Json, Boolean) { @@ -315,7 +315,7 @@ TEST(Json, AssigningObjects) { Json json; json = JsonObject(); json["Okay"] = JsonArray(); - ASSERT_EQ(get(json["Okay"]).size(), 0); + ASSERT_EQ(get(json["Okay"]).size(), 0ul); } { diff --git a/tests/cpp/common/test_quantile.cc b/tests/cpp/common/test_quantile.cc index c273658e54cb..fa748de1cc6c 100644 --- a/tests/cpp/common/test_quantile.cc +++ b/tests/cpp/common/test_quantile.cc @@ -5,14 +5,122 @@ namespace xgboost { namespace common { + +TEST(Quantile, LoadBalance) { + size_t constexpr kRows = 1000, kCols = 100; + auto m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(); + std::vector cols_ptr; + for (auto const &page : m->GetBatches()) { + cols_ptr = HostSketchContainer::LoadBalance(page, kCols, 13); + } + size_t n_cols = 0; + for (size_t i = 1; i < cols_ptr.size(); ++i) { + n_cols += cols_ptr[i] - cols_ptr[i - 1]; + } + CHECK_EQ(n_cols, kCols); +} + +void TestDistributedQuantile(size_t rows, size_t cols) { + std::string msg {"Skipping AllReduce test"}; + int32_t constexpr kWorkers = 4; + InitRabitContext(msg, kWorkers); + auto world = rabit::GetWorldSize(); + if (world != 1) { + ASSERT_EQ(world, kWorkers); + } else { + return; + } + + std::vector infos(2); + auto& h_weights = infos.front().weights_.HostVector(); + h_weights.resize(rows); + SimpleLCG lcg; + SimpleRealUniformDistribution dist(3, 1000); + std::generate(h_weights.begin(), h_weights.end(), [&]() { return dist(&lcg); }); + std::vector column_size(cols, rows); + size_t n_bins = 64; + + // Generate cuts for distributed environment. + auto sparsity = 0.5f; + auto rank = rabit::GetRank(); + HostSketchContainer sketch_distributed(column_size, n_bins, false); + auto m = RandomDataGenerator{rows, cols, sparsity} + .Seed(rank) + .Lower(.0f) + .Upper(1.0f) + .GenerateDMatrix(); + for (auto const &page : m->GetBatches()) { + sketch_distributed.PushRowPage(page, m->Info()); + } + HistogramCuts distributed_cuts; + sketch_distributed.MakeCuts(&distributed_cuts); + + // Generate cuts for single node environment + rabit::Finalize(); + CHECK_EQ(rabit::GetWorldSize(), 1); + std::for_each(column_size.begin(), column_size.end(), [=](auto& size) { size *= world; }); + HostSketchContainer sketch_on_single_node(column_size, n_bins, false); + for (auto rank = 0; rank < world; ++rank) { + auto m = RandomDataGenerator{rows, cols, sparsity} + .Seed(rank) + .Lower(.0f) + .Upper(1.0f) + .GenerateDMatrix(); + for (auto const &page : m->GetBatches()) { + sketch_on_single_node.PushRowPage(page, m->Info()); + } + } + + HistogramCuts single_node_cuts; + sketch_on_single_node.MakeCuts(&single_node_cuts); + + auto const& sptrs = single_node_cuts.Ptrs(); + auto const& dptrs = distributed_cuts.Ptrs(); + auto const& svals = single_node_cuts.Values(); + auto const& dvals = distributed_cuts.Values(); + auto const& smins = single_node_cuts.MinValues(); + auto const& dmins = distributed_cuts.MinValues(); + + ASSERT_EQ(sptrs.size(), dptrs.size()); + for (size_t i = 0; i < sptrs.size(); ++i) { + ASSERT_EQ(sptrs[i], dptrs[i]); + } + + ASSERT_EQ(svals.size(), dvals.size()); + for (size_t i = 0; i < svals.size(); ++i) { + ASSERT_NEAR(svals[i], dvals[i], 2e-2f); + } + + ASSERT_EQ(smins.size(), dmins.size()); + for (size_t i = 0; i < smins.size(); ++i) { + ASSERT_FLOAT_EQ(smins[i], dmins[i]); + } +} + +TEST(Quantile, DistributedBasic) { +#if defined(__unix__) + constexpr size_t kRows = 10, kCols = 10; + TestDistributedQuantile(kRows, kCols); +#endif +} + +TEST(Quantile, Distributed) { +#if defined(__unix__) + constexpr size_t kRows = 1000, kCols = 200; + TestDistributedQuantile(kRows, kCols); +#endif +} + TEST(Quantile, SameOnAllWorkers) { +#if defined(__unix__) std::string msg{"Skipping Quantile AllreduceBasic test"}; - size_t constexpr kWorkers = 4; + int32_t constexpr kWorkers = 4; InitRabitContext(msg, kWorkers); auto world = rabit::GetWorldSize(); if (world != 1) { CHECK_EQ(world, kWorkers); } else { + LOG(WARNING) << msg; return; } @@ -72,6 +180,8 @@ TEST(Quantile, SameOnAllWorkers) { } } }); + rabit::Finalize(); +#endif // defined(__unix__) } } // namespace common } // namespace xgboost diff --git a/tests/cpp/common/test_quantile.h b/tests/cpp/common/test_quantile.h index 7dea0b17deb3..e91f19ef84a8 100644 --- a/tests/cpp/common/test_quantile.h +++ b/tests/cpp/common/test_quantile.h @@ -7,7 +7,7 @@ namespace xgboost { namespace common { -inline void InitRabitContext(std::string msg, size_t n_workers) { +inline void InitRabitContext(std::string msg, int32_t n_workers) { auto port = std::getenv("DMLC_TRACKER_PORT"); std::string port_str; if (port) { @@ -35,7 +35,7 @@ template void RunWithSeedsAndBins(size_t rows, Fn fn) { for (size_t i = 0; i < bins.size() - 1; ++i) { bins[i] = i * 35 + 2; } - bins.back() = rows + 80; // provide a bin number greater than rows. + bins.back() = rows + 160; // provide a bin number greater than rows. std::vector infos(2); auto& h_weights = infos.front().weights_.HostVector(); diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index dc5c155e6027..145fa0b524cd 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -501,17 +501,20 @@ def run_updater_test(self, client, params, num_rounds, dataset, num_boost_round=num_rounds, evals=[(m, 'train')])['history'] note(history) - assert tm.non_increasing(history['train'][dataset.metric]) + history = history['train'][dataset.metric] + assert tm.non_increasing(history) + # Make sure that it's decreasing + assert history[-1] < history[0] @given(params=hist_parameter_strategy, - num_rounds=strategies.integers(10, 20), + num_rounds=strategies.integers(20, 30), dataset=tm.dataset_strategy) @settings(deadline=None) def test_hist(self, params, num_rounds, dataset, client): self.run_updater_test(client, params, num_rounds, dataset, 'hist') @given(params=exact_parameter_strategy, - num_rounds=strategies.integers(10, 20), + num_rounds=strategies.integers(20, 30), dataset=tm.dataset_strategy) @settings(deadline=None) def test_approx(self, client, params, num_rounds, dataset): @@ -524,8 +527,7 @@ def run_quantile(self, name): exe = None for possible_path in {'./testxgboost', './build/testxgboost', '../build/testxgboost', - '../cpu-build/testxgboost', - '../gpu-build/testxgboost'}: + '../cpu-build/testxgboost'}: if os.path.exists(possible_path): exe = possible_path if exe is None: @@ -542,7 +544,7 @@ def runit(worker_addr, rabit_args): port = port.split('=') env = os.environ.copy() env[port[0]] = port[1] - return subprocess.run([exe, test], env=env, stdout=subprocess.PIPE) + return subprocess.run([exe, test], env=env, capture_output=True) with LocalCluster(n_workers=4) as cluster: with Client(cluster) as client: @@ -555,6 +557,7 @@ def runit(worker_addr, rabit_args): workers=workers, rabit_args=rabit_args) results = client.gather(futures) + for ret in results: msg = ret.stdout.decode('utf-8') assert msg.find('1 test from Quantile') != -1, msg @@ -563,4 +566,14 @@ def runit(worker_addr, rabit_args): @pytest.mark.skipif(**tm.no_dask()) @pytest.mark.gtest def test_quantile_basic(self): + self.run_quantile('DistributedBasic') + + @pytest.mark.skipif(**tm.no_dask()) + @pytest.mark.gtest + def test_quantile(self): + self.run_quantile('Distributed') + + @pytest.mark.skipif(**tm.no_dask()) + @pytest.mark.gtest + def test_quantile_same_on_all_workers(self): self.run_quantile('SameOnAllWorkers')