diff --git a/cpp/src/arrow/util/task_group.cc b/cpp/src/arrow/util/task_group.cc index 8046a5291ff54..54dc5ed969191 100644 --- a/cpp/src/arrow/util/task_group.cc +++ b/cpp/src/arrow/util/task_group.cc @@ -23,6 +23,7 @@ #include #include +#include "arrow/util/checked_cast.h" #include "arrow/util/logging.h" #include "arrow/util/thread_pool.h" @@ -88,13 +89,15 @@ class ThreadedTaskGroup : public TaskGroup { // Only if an error occurs is the lock taken if (ok_.load(std::memory_order_acquire)) { nremaining_.fetch_add(1, std::memory_order_acquire); - Status st = thread_pool_->Spawn([this, task]() { - if (ok_.load(std::memory_order_acquire)) { + + auto self = checked_pointer_cast(shared_from_this()); + Status st = thread_pool_->Spawn([self, task]() { + if (self->ok_.load(std::memory_order_acquire)) { // XXX what about exceptions? Status st = task(); - UpdateStatus(std::move(st)); + self->UpdateStatus(std::move(st)); } - OneTaskDone(); + self->OneTaskDone(); }); UpdateStatus(std::move(st)); } diff --git a/cpp/src/arrow/util/task_group.h b/cpp/src/arrow/util/task_group.h index 390d9476e59bd..6ee5163d58ffe 100644 --- a/cpp/src/arrow/util/task_group.h +++ b/cpp/src/arrow/util/task_group.h @@ -40,7 +40,7 @@ class ThreadPool; /// implementation. When Finish() returns, it is guaranteed that all /// tasks have finished, or at least one has errored. /// -class ARROW_EXPORT TaskGroup { +class ARROW_EXPORT TaskGroup : public std::enable_shared_from_this { public: /// Add a Status-returning function to execute. Execution order is /// undefined. The function may be executed immediately or later. diff --git a/cpp/src/arrow/util/task_group_test.cc b/cpp/src/arrow/util/task_group_test.cc index 8328972bd173d..50c1cc57330d8 100644 --- a/cpp/src/arrow/util/task_group_test.cc +++ b/cpp/src/arrow/util/task_group_test.cc @@ -212,6 +212,75 @@ void TestTasksSpawnTasks(std::shared_ptr task_group) { ASSERT_EQ(count.load(), (1 << (N + 1)) - 1); } +// A task that keeps recursing until a barrier is set. +// Using a lambda for this doesn't play well with Thread Sanitizer. +struct BarrierTask { + std::atomic* barrier_; + std::weak_ptr weak_group_ptr_; + Status final_status_; + + Status operator()() { + if (!barrier_->load()) { + sleep_for(1e-5); + // Note the TaskGroup should be kept alive by the fact this task + // is still running... + weak_group_ptr_.lock()->Append(*this); + } + return final_status_; + } +}; + +// Try to replicate subtle lifetime issues when destroying a TaskGroup +// where all tasks may not have finished running. +void StressTaskGroupLifetime(std::function()> factory) { + const int NTASKS = 100; + auto task_group = factory(); + auto weak_group_ptr = std::weak_ptr(task_group); + + std::atomic barrier(false); + + BarrierTask task{&barrier, weak_group_ptr, Status::OK()}; + + for (int i = 0; i < NTASKS; ++i) { + task_group->Append(task); + } + + // Lose strong reference + barrier.store(true); + task_group.reset(); + + // Wait for finish + while (!weak_group_ptr.expired()) { + sleep_for(1e-5); + } +} + +// Same, but with also a failing task +void StressFailingTaskGroupLifetime(std::function()> factory) { + const int NTASKS = 100; + auto task_group = factory(); + auto weak_group_ptr = std::weak_ptr(task_group); + + std::atomic barrier(false); + + BarrierTask task{&barrier, weak_group_ptr, Status::OK()}; + BarrierTask failing_task{&barrier, weak_group_ptr, Status::Invalid("XXX")}; + + for (int i = 0; i < NTASKS; ++i) { + task_group->Append(task); + } + task_group->Append(failing_task); + + // Lose strong reference + barrier.store(true); + task_group.reset(); + + // Wait for finish + while (!weak_group_ptr.expired()) { + sleep_for(1e-5); + } +} + TEST(SerialTaskGroup, Success) { TestTaskGroupSuccess(TaskGroup::MakeSerial()); } TEST(SerialTaskGroup, Errors) { TestTaskGroupErrors(TaskGroup::MakeSerial()); } @@ -259,5 +328,20 @@ TEST(ThreadedTaskGroup, SubGroupsErrors) { TestTaskSubGroupsErrors(TaskGroup::MakeThreaded(thread_pool.get())); } +TEST(ThreadedTaskGroup, StressTaskGroupLifetime) { + std::shared_ptr thread_pool; + ASSERT_OK(ThreadPool::Make(16, &thread_pool)); + + StressTaskGroupLifetime([&] { return TaskGroup::MakeThreaded(thread_pool.get()); }); +} + +TEST(ThreadedTaskGroup, StressFailingTaskGroupLifetime) { + std::shared_ptr thread_pool; + ASSERT_OK(ThreadPool::Make(16, &thread_pool)); + + StressFailingTaskGroupLifetime( + [&] { return TaskGroup::MakeThreaded(thread_pool.get()); }); +} + } // namespace internal } // namespace arrow