diff --git a/dali/pipeline/util/thread_pool.cc b/dali/pipeline/util/thread_pool.cc index df90c82ab15..3589a9f546b 100644 --- a/dali/pipeline/util/thread_pool.cc +++ b/dali/pipeline/util/thread_pool.cc @@ -46,7 +46,7 @@ ThreadPool::ThreadPool(int num_thread, int device_id, bool set_affinity, const c ThreadPool::~ThreadPool() { WaitForWork(false); - std::unique_lock lock(mutex_); + std::unique_lock lock(lock_); running_ = false; condition_.notify_all(); lock.unlock(); @@ -59,7 +59,7 @@ ThreadPool::~ThreadPool() { void ThreadPool::AddWork(Work work, int64_t priority, bool start_immediately) { bool started_before = false; { - std::lock_guard lock(mutex_); + std::lock_guard lock(lock_); work_queue_.push({priority, std::move(work)}); work_complete_ = false; started_before = started_; @@ -75,7 +75,7 @@ void ThreadPool::AddWork(Work work, int64_t priority, bool start_immediately) { // Blocks until all work issued to the thread pool is complete void ThreadPool::WaitForWork(bool checkForErrors) { - std::unique_lock lock(mutex_); + std::unique_lock lock(lock_); completed_.wait(lock, [this] { return this->work_complete_; }); started_ = false; if (checkForErrors) { @@ -93,7 +93,7 @@ void ThreadPool::WaitForWork(bool checkForErrors) { void ThreadPool::RunAll(bool wait) { { - std::lock_guard lock(mutex_); + std::lock_guard lock(lock_); started_ = true; } condition_.notify_all(); // other threads will be waken up if needed @@ -145,7 +145,7 @@ void ThreadPool::ThreadMain(int thread_id, int device_id, bool set_affinity, while (running_) { // Block on the condition to wait for work - std::unique_lock lock(mutex_); + std::unique_lock lock(lock_); condition_.wait(lock, [this] { return !running_ || (!work_queue_.empty() && started_); }); // If we're no longer running, exit the run loop if (!running_) break; diff --git a/dali/pipeline/util/thread_pool.h b/dali/pipeline/util/thread_pool.h index 4ae2082aadc..6482f099e86 100644 --- a/dali/pipeline/util/thread_pool.h +++ b/dali/pipeline/util/thread_pool.h @@ -19,12 +19,12 @@ #include #include #include -#include #include #include #include #include #include "dali/core/common.h" +#include "dali/core/spinlock.h" #if NVML_ENABLED #include "dali/util/nvml.h" #endif @@ -90,9 +90,9 @@ class DLL_PUBLIC ThreadPool { bool work_complete_; bool started_; int active_threads_; - std::mutex mutex_; - std::condition_variable condition_; - std::condition_variable completed_; + spinlock lock_; + std::condition_variable_any condition_; + std::condition_variable_any completed_; // Stored error strings for each thread vector> tl_errors_;