Skip to content

Commit

Permalink
Enhance the threadpool implementation. (#10531)
Browse files Browse the repository at this point in the history

- Accept an initialization function.
- Support void return tasks.
  • Loading branch information
trivialfis authored Jul 3, 2024
1 parent 9cb4c93 commit 628411a
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 15 deletions.
30 changes: 20 additions & 10 deletions src/common/threadpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,25 @@ class ThreadPool {
bool stop_{false};

public:
explicit ThreadPool(std::int32_t n_threads) {
/**
* @param n_threads The number of threads this pool should hold.
* @param init_fn Function called once during thread creation.
*/
template <typename InitFn>
explicit ThreadPool(std::int32_t n_threads, InitFn&& init_fn) {
for (std::int32_t i = 0; i < n_threads; ++i) {
pool_.emplace_back([&] {
pool_.emplace_back([&, init_fn = std::forward<InitFn>(init_fn)] {
init_fn();

while (true) {
std::unique_lock lock{mu_};
cv_.wait(lock, [this] { return !this->tasks_.empty() || stop_; });

if (this->stop_) {
if (!tasks_.empty()) {
while (!tasks_.empty()) {
auto fn = tasks_.front();
tasks_.pop();
fn();
}
while (!tasks_.empty()) {
auto fn = tasks_.front();
tasks_.pop();
fn();
}
return;
}
Expand Down Expand Up @@ -81,8 +86,13 @@ class ThreadPool {
// Use shared ptr to make the task copy constructible.
auto p{std::make_shared<std::promise<R>>()};
auto fut = p->get_future();
auto ffn = std::function{[task = std::move(p), fn = std::move(fn)]() mutable {
task->set_value(fn());
auto ffn = std::function{[task = std::move(p), fn = std::forward<Fn>(fn)]() mutable {
if constexpr (std::is_void_v<R>) {
fn();
task->set_value();
} else {
task->set_value(fn());
}
}};

std::unique_lock lock{mu_};
Expand Down
9 changes: 5 additions & 4 deletions src/data/sparse_page_source.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,16 +236,14 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol

exce_.Rethrow();

auto const config = *GlobalConfigThreadLocalStore::Get();
for (std::int32_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) {
fetch_it %= n_batches_; // ring
if (ring_->at(fetch_it).valid()) {
continue;
}
auto const* self = this; // make sure it's const
CHECK_LT(fetch_it, cache_info_->offset.size());
ring_->at(fetch_it) = this->workers_.Submit([fetch_it, self, config, this] {
*GlobalConfigThreadLocalStore::Get() = config;
ring_->at(fetch_it) = this->workers_.Submit([fetch_it, self, this] {
auto page = std::make_shared<S>();
this->exce_.Run([&] {
std::unique_ptr<typename FormatStreamPolicy::FormatT> fmt{this->CreatePageFormat()};
Expand Down Expand Up @@ -297,7 +295,10 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
public:
SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, bst_idx_t n_batches,
std::shared_ptr<Cache> cache)
: workers_{std::max(2, std::min(nthreads, 16))}, // Don't use too many threads.
: workers_{std::max(2, std::min(nthreads, 16)),
[config = *GlobalConfigThreadLocalStore::Get()] {
*GlobalConfigThreadLocalStore::Get() = config;
}},
missing_{missing},
nthreads_{nthreads},
n_features_{n_features},
Expand Down
26 changes: 25 additions & 1 deletion tests/cpp/common/test_threadpool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
* Copyright 2024, XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/global_config.h> // for GlobalConfigThreadLocalStore

#include <cstddef> // for size_t
#include <cstdint> // for int32_t
Expand All @@ -13,7 +14,23 @@
namespace xgboost::common {
TEST(ThreadPool, Basic) {
std::int32_t n_threads = std::thread::hardware_concurrency();
ThreadPool pool{n_threads};

// Set verbosity to 0 for thread-local variable.
auto orig = GlobalConfigThreadLocalStore::Get()->verbosity;
GlobalConfigThreadLocalStore::Get()->verbosity = 4;
// 4 is an invalid value, it's only possible to set it by bypassing the parameter
// validation.
ASSERT_NE(orig, GlobalConfigThreadLocalStore::Get()->verbosity);
ThreadPool pool{n_threads, [config = *GlobalConfigThreadLocalStore::Get()] {
*GlobalConfigThreadLocalStore::Get() = config;
}};
GlobalConfigThreadLocalStore::Get()->verbosity = orig; // restore

{
auto fut = pool.Submit([] { return GlobalConfigThreadLocalStore::Get()->verbosity; });
ASSERT_EQ(fut.get(), 4);
ASSERT_EQ(GlobalConfigThreadLocalStore::Get()->verbosity, orig);
}
{
auto fut = pool.Submit([] { return 3; });
ASSERT_EQ(fut.get(), 3);
Expand Down Expand Up @@ -45,5 +62,12 @@ TEST(ThreadPool, Basic) {
ASSERT_EQ(futures[i].get(), i);
}
}
{
std::int32_t val{0};
auto fut = pool.Submit([&] { val = 3; });
static_assert(std::is_void_v<decltype(fut.get())>);
fut.get();
ASSERT_EQ(val, 3);
}
}
} // namespace xgboost::common

0 comments on commit 628411a

Please sign in to comment.