Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for per-thread initialization function #105

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 28 additions & 7 deletions BS_thread_pool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,10 +248,11 @@ class [[nodiscard]] thread_pool
* @brief Construct a new thread pool.
*
* @param thread_count_ The number of threads to use. The default value is the total number of hardware threads available, as reported by the implementation. This is usually determined by the number of cores in the CPU. If a core is hyperthreaded, it will count as two threads.
* @param thread_init_function_ A function to run in every thread once after creation and before any tasks, accepting thread id in range [0, thread count) as argument. Default is to do nothing. Constructor will block until all threads have finished executing it.
*/
thread_pool(const concurrency_t thread_count_ = 0) : thread_count(determine_thread_count(thread_count_)), threads(std::make_unique<std::thread[]>(determine_thread_count(thread_count_)))
thread_pool(const concurrency_t thread_count_ = 0, const std::function<void(concurrency_t)>& thread_init_function_ = [](concurrency_t) {}) : thread_count(determine_thread_count(thread_count_)), threads(std::make_unique<std::thread[]>(determine_thread_count(thread_count_)))
{
create_threads();
create_threads(thread_init_function_);
}

/**
Expand Down Expand Up @@ -437,8 +438,9 @@ class [[nodiscard]] thread_pool
* @brief Reset the number of threads in the pool. Waits for all currently running tasks to be completed, then destroys all threads in the pool and creates a new thread pool with the new number of threads. Any tasks that were waiting in the queue before the pool was reset will then be executed by the new threads. If the pool was paused before resetting it, the new pool will be paused as well.
*
* @param thread_count_ The number of threads to use. The default value is the total number of hardware threads available, as reported by the implementation. This is usually determined by the number of cores in the CPU. If a core is hyperthreaded, it will count as two threads.
* @param thread_init_function_ A function to run in every thread once after creation and before any tasks, accepting thread id in range [0, thread count) as argument. Default is to do nothing. Call to this function will block until all threads have finished executing it.
*/
void reset(const concurrency_t thread_count_ = 0)
void reset(const concurrency_t thread_count_ = 0, const std::function<void(concurrency_t)>& thread_init_function_ = [](concurrency_t) {})
{
const bool was_paused = paused;
paused = true;
Expand All @@ -447,7 +449,7 @@ class [[nodiscard]] thread_pool
thread_count = determine_thread_count(thread_count_);
threads = std::make_unique<std::thread[]>(thread_count);
paused = was_paused;
create_threads();
create_threads(thread_init_function_);
}

/**
Expand Down Expand Up @@ -520,14 +522,31 @@ class [[nodiscard]] thread_pool

/**
* @brief Create the threads in the pool and assign a worker to each thread.
*
* @param thread_init_function A function to run in every thread once after creation and before any tasks, accepting thread id in range [0, thread count) as argument. It will be completed in every thread before returning.
*/
void create_threads()
void create_threads(const std::function<void(concurrency_t)>& thread_init_function)
{
concurrency_t initialized_thread_count = 0;
std::condition_variable thread_initialization_done_cv = {};
std::mutex initialized_thread_count_mutex = {};
running = true;
for (concurrency_t i = 0; i < thread_count; ++i)
{
threads[i] = std::thread(&thread_pool::worker, this);
threads[i] = std::thread(&thread_pool::worker, this,
[&initialized_thread_count, &thread_initialization_done_cv, &initialized_thread_count_mutex, thread_init_function, i]
{
thread_init_function(i);
{
std::unique_lock<std::mutex> initialized_thread_count_lock(initialized_thread_count_mutex);
++initialized_thread_count;
}
thread_initialization_done_cv.notify_one();
}
);
}
std::unique_lock<std::mutex> initialized_thread_count_lock(initialized_thread_count_mutex);
thread_initialization_done_cv.wait(initialized_thread_count_lock, [this, &initialized_thread_count] { return (initialized_thread_count == thread_count); });
}

/**
Expand Down Expand Up @@ -565,8 +584,10 @@ class [[nodiscard]] thread_pool
/**
* @brief A worker function to be assigned to each thread in the pool. Waits until it is notified by push_task() that a task is available, and then retrieves the task from the queue and executes it. Once the task finishes, the worker notifies wait_for_tasks() in case it is waiting.
*/
void worker()
void worker(std::function<void()> thread_init_function)
{
thread_init_function();
thread_init_function = {};
while (running)
{
std::function<void()> task;
Expand Down
32 changes: 27 additions & 5 deletions BS_thread_pool_light.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,11 @@ class [[nodiscard]] thread_pool_light
* @brief Construct a new thread pool.
*
* @param thread_count_ The number of threads to use. The default value is the total number of hardware threads available, as reported by the implementation. This is usually determined by the number of cores in the CPU. If a core is hyperthreaded, it will count as two threads.
* @param thread_init_function_ A function to run in every thread once after creation and before any tasks, accepting thread id in range [0, thread count) as argument. Default is to do nothing. Constructor will block until all threads have finished executing it.
*/
thread_pool_light(const concurrency_t thread_count_ = 0) : thread_count(determine_thread_count(thread_count_)), threads(std::make_unique<std::thread[]>(determine_thread_count(thread_count_)))
thread_pool_light(const concurrency_t thread_count_ = 0, const std::function<void(concurrency_t)>& thread_init_function_ = [](concurrency_t) {}) : thread_count(determine_thread_count(thread_count_)), threads(std::make_unique<std::thread[]>(determine_thread_count(thread_count_)))
{
create_threads();
create_threads(thread_init_function_);
}

/**
Expand Down Expand Up @@ -206,14 +207,31 @@ class [[nodiscard]] thread_pool_light

/**
* @brief Create the threads in the pool and assign a worker to each thread.
*
* @param thread_init_function A function to run in every thread once after creation and before any tasks, accepting thread id in range [0, thread count) as argument. It will be completed in every thread before returning.
*/
void create_threads()
void create_threads(const std::function<void(concurrency_t)>& thread_init_function)
{
concurrency_t initialized_thread_count = 0;
std::condition_variable thread_initialization_done_cv = {};
std::mutex initialized_thread_count_mutex = {};
running = true;
for (concurrency_t i = 0; i < thread_count; ++i)
{
threads[i] = std::thread(&thread_pool_light::worker, this);
threads[i] = std::thread(&thread_pool_light::worker, this,
[&initialized_thread_count, &thread_initialization_done_cv, &initialized_thread_count_mutex, thread_init_function, i]
{
thread_init_function(i);
{
std::unique_lock<std::mutex> initialized_thread_count_lock(initialized_thread_count_mutex);
++initialized_thread_count;
}
thread_initialization_done_cv.notify_one();
}
);
}
std::unique_lock<std::mutex> initialized_thread_count_lock(initialized_thread_count_mutex);
thread_initialization_done_cv.wait(initialized_thread_count_lock, [this, &initialized_thread_count] { return (initialized_thread_count == thread_count); });
}

/**
Expand Down Expand Up @@ -250,9 +268,13 @@ class [[nodiscard]] thread_pool_light

/**
* @brief A worker function to be assigned to each thread in the pool. Waits until it is notified by push_task() that a task is available, and then retrieves the task from the queue and executes it. Once the task finishes, the worker notifies wait_for_tasks() in case it is waiting.
*
* @param thread_init_function Function to execute once before any of the tasks.
*/
void worker()
void worker(std::function<void()> thread_init_function)
{
thread_init_function();
thread_init_function = {};
while (running)
{
std::function<void()> task;
Expand Down
7 changes: 6 additions & 1 deletion BS_thread_pool_light_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,11 @@
// Global variables
// ================

// Global variable that counts how many timed per-thread initializaton function was called
std::atomic<BS::concurrency_t> thread_init_invocation_count = {};

// A global thread pool object to be used throughout the test.
BS::thread_pool_light pool;
BS::thread_pool_light pool(0, [](BS::concurrency_t) { thread_init_invocation_count.fetch_add(1, std::memory_order_release); });

// A global random_device object to be used to seed some random number generators.
std::random_device rd;
Expand Down Expand Up @@ -180,6 +183,8 @@ void check_constructor()
check(std::thread::hardware_concurrency(), pool.get_thread_count());
println("Checking that the manually counted number of unique thread IDs is equal to the reported number of threads...");
check(pool.get_thread_count(), count_unique_threads());
println("Checking that provided thread initializaton function was executed once per thread...");
check(pool.get_thread_count(), thread_init_invocation_count.load(std::memory_order_consume));
}

// =======================================
Expand Down
30 changes: 26 additions & 4 deletions BS_thread_pool_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,11 @@ BS::synced_stream sync_cout(std::cout);
std::ofstream log_file;
BS::synced_stream sync_file(log_file);

// Global variable that counts how many timed per-thread initializaton function was called
std::atomic<BS::concurrency_t> thread_init_invocation_count = {};

// A global thread pool object to be used throughout the test.
BS::thread_pool pool;
BS::thread_pool pool(0, [](BS::concurrency_t) { thread_init_invocation_count.fetch_add(1, std::memory_order_release); });

// A global random_device object to be used to seed some random number generators.
std::random_device rd;
Expand Down Expand Up @@ -214,23 +217,42 @@ void check_constructor()
check(std::thread::hardware_concurrency(), pool.get_thread_count());
dual_println("Checking that the manually counted number of unique thread IDs is equal to the reported number of threads...");
check(pool.get_thread_count(), count_unique_threads());
dual_println("Checking that provided thread initializaton function was executed once per thread...");
check(pool.get_thread_count(), thread_init_invocation_count.load(std::memory_order_consume));
}

/**
* @brief Check that reset() works.
*/
void check_reset()
{
pool.reset(std::thread::hardware_concurrency() / 2);
const BS::concurrency_t half_hardware_concurrency = std::max(std::thread::hardware_concurrency() / 2, static_cast<BS::concurrency_t>(1u));
std::unique_ptr<bool[]> initializaton_results_array = std::make_unique<bool[]>(half_hardware_concurrency);
pool.reset(half_hardware_concurrency, [&initializaton_results_array](BS::concurrency_t thread_id) { initializaton_results_array[thread_id] = true; });
dual_println("Checking that after reset() the thread pool reports a number of threads equal to half the hardware concurrency...");
check(std::thread::hardware_concurrency() / 2, pool.get_thread_count());
check(half_hardware_concurrency, pool.get_thread_count());
dual_println("Checking that after reset() the manually counted number of unique thread IDs is equal to the reported number of threads...");
check(pool.get_thread_count(), count_unique_threads());
pool.reset(std::thread::hardware_concurrency());
dual_println("Checking that reset() call executed initializaton function in every thread...");
bool every_thread_initialized = true;
for (BS::concurrency_t i = 0; i < half_hardware_concurrency; ++i)
{
every_thread_initialized = every_thread_initialized && initializaton_results_array[i];
}
check(every_thread_initialized);
initializaton_results_array = std::make_unique<bool[]>(std::thread::hardware_concurrency());
pool.reset(std::thread::hardware_concurrency(), [&initializaton_results_array](BS::concurrency_t thread_id) { initializaton_results_array[thread_id] = true; });
dual_println("Checking that after a second reset() the thread pool reports a number of threads equal to the hardware concurrency...");
check(std::thread::hardware_concurrency(), pool.get_thread_count());
dual_println("Checking that after a second reset() the manually counted number of unique thread IDs is equal to the reported number of threads...");
check(pool.get_thread_count(), count_unique_threads());
dual_println("Checking that second reset() call executed initializaton function in every thread...");
every_thread_initialized = true;
for (BS::concurrency_t i = 0; i < std::thread::hardware_concurrency(); ++i)
{
every_thread_initialized = every_thread_initialized && initializaton_results_array[i];
}
check(every_thread_initialized);
}

// =======================================
Expand Down