diff --git a/cpp/src/arrow/acero/asof_join_node.cc b/cpp/src/arrow/acero/asof_join_node.cc index 55fa45543e4..3970050e502 100644 --- a/cpp/src/arrow/acero/asof_join_node.cc +++ b/cpp/src/arrow/acero/asof_join_node.cc @@ -514,9 +514,9 @@ class InputState : public util::SerialSequencingQueue::Processor { std::unique_ptr backpressure_control = std::make_unique( /*node=*/asof_input, /*output=*/asof_node, backpressure_counter); - ARROW_ASSIGN_OR_RAISE( - auto handler, BackpressureHandler::Make(asof_input, low_threshold, high_threshold, - std::move(backpressure_control))); + ARROW_ASSIGN_OR_RAISE(auto handler, + BackpressureHandler::Make(low_threshold, high_threshold, + std::move(backpressure_control))); return std::make_unique(index, tolerance, must_hash, may_rehash, key_hasher, asof_node, std::move(handler), schema, time_col_index, key_col_index); @@ -763,10 +763,10 @@ class InputState : public util::SerialSequencingQueue::Processor { total_batches_ = n; } - Status ForceShutdown() { + void ForceShutdown() { // Force the upstream input node to unpause. Necessary to avoid deadlock when we // terminate the process thread - return queue_.ForceShutdown(); + queue_.ForceShutdown(); } private: @@ -1046,8 +1046,10 @@ class AsofJoinNode : public ExecNode { if (st.ok()) { st = output_->InputFinished(this, batches_produced_); } - for (const auto& s : state_) { - st &= s->ForceShutdown(); + for (size_t i = 0; i < state_.size(); ++i) { + const auto& s = state_[i]; + s->ForceShutdown(); + st &= inputs_[i]->StopProducing(); } })); } @@ -1499,8 +1501,11 @@ class AsofJoinNode : public ExecNode { if (st.ok()) { st = output_->InputFinished(this, batches_produced_); } - for (const auto& s : state_) { - st &= s->ForceShutdown(); + + for (size_t i = 0; i < state_.size(); ++i) { + const auto& s = state_[i]; + s->ForceShutdown(); + st &= inputs_[i]->StopProducing(); } } diff --git a/cpp/src/arrow/acero/backpressure_handler.h b/cpp/src/arrow/acero/backpressure_handler.h index db6c3799354..c6a47e60197 100644 --- a/cpp/src/arrow/acero/backpressure_handler.h +++ b/cpp/src/arrow/acero/backpressure_handler.h @@ -25,16 +25,15 @@ namespace arrow::acero { class BackpressureHandler { private: - BackpressureHandler(ExecNode* input, size_t low_threshold, size_t high_threshold, + BackpressureHandler(size_t low_threshold, size_t high_threshold, std::unique_ptr backpressure_control) - : input_(input), - low_threshold_(low_threshold), + : low_threshold_(low_threshold), high_threshold_(high_threshold), backpressure_control_(std::move(backpressure_control)) {} public: static Result Make( - ExecNode* input, size_t low_threshold, size_t high_threshold, + size_t low_threshold, size_t high_threshold, std::unique_ptr backpressure_control) { if (low_threshold >= high_threshold) { return Status::Invalid("low threshold (", low_threshold, @@ -43,7 +42,7 @@ class BackpressureHandler { if (backpressure_control == NULLPTR) { return Status::Invalid("null backpressure control parameter"); } - BackpressureHandler backpressure_handler(input, low_threshold, high_threshold, + BackpressureHandler backpressure_handler(low_threshold, high_threshold, std::move(backpressure_control)); return backpressure_handler; } @@ -56,16 +55,7 @@ class BackpressureHandler { } } - Status ForceShutdown() { - // It may be unintuitive to call Resume() here, but this is to avoid a deadlock. - // Since acero's executor won't terminate if any one node is paused, we need to - // force resume the node before stopping production. - backpressure_control_->Resume(); - return input_->StopProducing(); - } - private: - ExecNode* input_; size_t low_threshold_; size_t high_threshold_; std::unique_ptr backpressure_control_; diff --git a/cpp/src/arrow/acero/concurrent_queue_internal.h b/cpp/src/arrow/acero/concurrent_queue_internal.h index a751db70262..b91daae8b04 100644 --- a/cpp/src/arrow/acero/concurrent_queue_internal.h +++ b/cpp/src/arrow/acero/concurrent_queue_internal.h @@ -113,7 +113,7 @@ class ConcurrentQueue { }; template -class BackpressureConcurrentQueue : public ConcurrentQueue { +class BackpressureConcurrentQueue : private ConcurrentQueue { private: struct DoHandle { explicit DoHandle(BackpressureConcurrentQueue& queue) @@ -134,6 +134,9 @@ class BackpressureConcurrentQueue : public ConcurrentQueue { explicit BackpressureConcurrentQueue(BackpressureHandler handler) : handler_(std::move(handler)) {} + using ConcurrentQueue::Empty; + using ConcurrentQueue::Front; + // Pops the last item from the queue but waits if the queue is empty until new items are // pushed. T WaitAndPop() { @@ -152,6 +155,7 @@ class BackpressureConcurrentQueue : public ConcurrentQueue { // Pushes an item to the queue void Push(const T& item) { + if (shutdown_) return; std::unique_lock lock(ConcurrentQueue::GetMutex()); DoHandle do_handle(*this); ConcurrentQueue::PushUnlocked(item); @@ -164,10 +168,14 @@ class BackpressureConcurrentQueue : public ConcurrentQueue { ConcurrentQueue::ClearUnlocked(); } - Status ForceShutdown() { return handler_.ForceShutdown(); } + void ForceShutdown() { + shutdown_ = true; + Clear(); + } private: BackpressureHandler handler_; + std::atomic shutdown_{false}; }; } // namespace arrow::acero diff --git a/cpp/src/arrow/acero/sorted_merge_node.cc b/cpp/src/arrow/acero/sorted_merge_node.cc index 37997232cd0..43f5b7b930a 100644 --- a/cpp/src/arrow/acero/sorted_merge_node.cc +++ b/cpp/src/arrow/acero/sorted_merge_node.cc @@ -119,7 +119,7 @@ class InputState { std::unique_ptr backpressure_control = std::make_unique(input, output, backpressure_counter); ARROW_ASSIGN_OR_RAISE(auto handler, - BackpressureHandler::Make(input, low_threshold, high_threshold, + BackpressureHandler::Make(low_threshold, high_threshold, std::move(backpressure_control))); return PtrType(new InputState(index, std::move(handler), schema, time_col_index)); } diff --git a/cpp/src/arrow/acero/util_test.cc b/cpp/src/arrow/acero/util_test.cc index d86577d3584..557574a40c9 100644 --- a/cpp/src/arrow/acero/util_test.cc +++ b/cpp/src/arrow/acero/util_test.cc @@ -263,8 +263,7 @@ class TestBackpressureControl : public BackpressureControl { TEST(BackpressureConcurrentQueue, BasicTest) { BackpressureTestExecNode dummy_node; auto ctrl = std::make_unique(&dummy_node); - ASSERT_OK_AND_ASSIGN(auto handler, - BackpressureHandler::Make(&dummy_node, 2, 4, std::move(ctrl))); + ASSERT_OK_AND_ASSIGN(auto handler, BackpressureHandler::Make(2, 4, std::move(ctrl))); BackpressureConcurrentQueue queue(std::move(handler)); ConcurrentQueueBasicTest(queue); @@ -275,8 +274,7 @@ TEST(BackpressureConcurrentQueue, BasicTest) { TEST(BackpressureConcurrentQueue, BackpressureTest) { BackpressureTestExecNode dummy_node; auto ctrl = std::make_unique(&dummy_node); - ASSERT_OK_AND_ASSIGN(auto handler, - BackpressureHandler::Make(&dummy_node, 2, 4, std::move(ctrl))); + ASSERT_OK_AND_ASSIGN(auto handler, BackpressureHandler::Make(2, 4, std::move(ctrl))); BackpressureConcurrentQueue queue(std::move(handler)); queue.Push(6); @@ -299,9 +297,28 @@ TEST(BackpressureConcurrentQueue, BackpressureTest) { queue.Push(11); ASSERT_TRUE(dummy_node.paused); ASSERT_FALSE(dummy_node.stopped); - ASSERT_OK(queue.ForceShutdown()); + queue.ForceShutdown(); + ASSERT_FALSE(dummy_node.paused); +} + +TEST(BackpressureConcurrentQueue, BackpressureTestStayUnpaused) { + BackpressureTestExecNode dummy_node; + auto ctrl = std::make_unique(&dummy_node); + ASSERT_OK_AND_ASSIGN( + auto handler, BackpressureHandler::Make(/*low_threshold=*/2, /*high_threshold=*/4, + std::move(ctrl))); + BackpressureConcurrentQueue queue(std::move(handler)); + + queue.Push(6); + queue.Push(7); + queue.Push(8); + ASSERT_FALSE(dummy_node.paused); + ASSERT_FALSE(dummy_node.stopped); + queue.ForceShutdown(); + for (int i = 0; i < 10; ++i) { + queue.Push(i); + } ASSERT_FALSE(dummy_node.paused); - ASSERT_TRUE(dummy_node.stopped); } } // namespace acero