Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions cpp/src/arrow/acero/asof_join_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -514,9 +514,9 @@ class InputState : public util::SerialSequencingQueue::Processor {
std::unique_ptr<BackpressureControl> backpressure_control =
std::make_unique<BackpressureController>(
/*node=*/asof_input, /*output=*/asof_node, backpressure_counter);
ARROW_ASSIGN_OR_RAISE(
auto handler, BackpressureHandler::Make(asof_input, low_threshold, high_threshold,
std::move(backpressure_control)));
ARROW_ASSIGN_OR_RAISE(auto handler,
BackpressureHandler::Make(low_threshold, high_threshold,
std::move(backpressure_control)));
return std::make_unique<InputState>(index, tolerance, must_hash, may_rehash,
key_hasher, asof_node, std::move(handler), schema,
time_col_index, key_col_index);
Expand Down Expand Up @@ -763,10 +763,10 @@ class InputState : public util::SerialSequencingQueue::Processor {
total_batches_ = n;
}

Status ForceShutdown() {
void ForceShutdown() {
// Force the upstream input node to unpause. Necessary to avoid deadlock when we
// terminate the process thread
return queue_.ForceShutdown();
queue_.ForceShutdown();
}

private:
Expand Down Expand Up @@ -1046,8 +1046,10 @@ class AsofJoinNode : public ExecNode {
if (st.ok()) {
st = output_->InputFinished(this, batches_produced_);
}
for (const auto& s : state_) {
st &= s->ForceShutdown();
for (size_t i = 0; i < state_.size(); ++i) {
const auto& s = state_[i];
s->ForceShutdown();
st &= inputs_[i]->StopProducing();
}
}));
}
Expand Down Expand Up @@ -1499,8 +1501,11 @@ class AsofJoinNode : public ExecNode {
if (st.ok()) {
st = output_->InputFinished(this, batches_produced_);
}
for (const auto& s : state_) {
st &= s->ForceShutdown();

for (size_t i = 0; i < state_.size(); ++i) {
const auto& s = state_[i];
s->ForceShutdown();
st &= inputs_[i]->StopProducing();
}
}

Expand Down
18 changes: 4 additions & 14 deletions cpp/src/arrow/acero/backpressure_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,15 @@ namespace arrow::acero {

class BackpressureHandler {
private:
BackpressureHandler(ExecNode* input, size_t low_threshold, size_t high_threshold,
BackpressureHandler(size_t low_threshold, size_t high_threshold,
std::unique_ptr<BackpressureControl> backpressure_control)
: input_(input),
low_threshold_(low_threshold),
: low_threshold_(low_threshold),
high_threshold_(high_threshold),
backpressure_control_(std::move(backpressure_control)) {}

public:
static Result<BackpressureHandler> Make(
ExecNode* input, size_t low_threshold, size_t high_threshold,
size_t low_threshold, size_t high_threshold,
std::unique_ptr<BackpressureControl> backpressure_control) {
if (low_threshold >= high_threshold) {
return Status::Invalid("low threshold (", low_threshold,
Expand All @@ -43,7 +42,7 @@ class BackpressureHandler {
if (backpressure_control == NULLPTR) {
return Status::Invalid("null backpressure control parameter");
}
BackpressureHandler backpressure_handler(input, low_threshold, high_threshold,
BackpressureHandler backpressure_handler(low_threshold, high_threshold,
std::move(backpressure_control));
return backpressure_handler;
}
Expand All @@ -56,16 +55,7 @@ class BackpressureHandler {
}
}

Status ForceShutdown() {
// It may be unintuitive to call Resume() here, but this is to avoid a deadlock.
// Since acero's executor won't terminate if any one node is paused, we need to
// force resume the node before stopping production.
backpressure_control_->Resume();
return input_->StopProducing();
}

private:
ExecNode* input_;
size_t low_threshold_;
size_t high_threshold_;
std::unique_ptr<BackpressureControl> backpressure_control_;
Expand Down
12 changes: 10 additions & 2 deletions cpp/src/arrow/acero/concurrent_queue_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class ConcurrentQueue {
};

template <typename T>
class BackpressureConcurrentQueue : public ConcurrentQueue<T> {
class BackpressureConcurrentQueue : private ConcurrentQueue<T> {
private:
struct DoHandle {
explicit DoHandle(BackpressureConcurrentQueue& queue)
Expand All @@ -134,6 +134,9 @@ class BackpressureConcurrentQueue : public ConcurrentQueue<T> {
explicit BackpressureConcurrentQueue(BackpressureHandler handler)
: handler_(std::move(handler)) {}

using ConcurrentQueue<T>::Empty;
using ConcurrentQueue<T>::Front;

// Pops the last item from the queue but waits if the queue is empty until new items are
// pushed.
T WaitAndPop() {
Expand All @@ -152,6 +155,7 @@ class BackpressureConcurrentQueue : public ConcurrentQueue<T> {

// Pushes an item to the queue
void Push(const T& item) {
if (shutdown_) return;
std::unique_lock<std::mutex> lock(ConcurrentQueue<T>::GetMutex());
DoHandle do_handle(*this);
ConcurrentQueue<T>::PushUnlocked(item);
Expand All @@ -164,10 +168,14 @@ class BackpressureConcurrentQueue : public ConcurrentQueue<T> {
ConcurrentQueue<T>::ClearUnlocked();
}

Status ForceShutdown() { return handler_.ForceShutdown(); }
void ForceShutdown() {
shutdown_ = true;
Clear();
}

private:
BackpressureHandler handler_;
std::atomic<bool> shutdown_{false};
};

} // namespace arrow::acero
2 changes: 1 addition & 1 deletion cpp/src/arrow/acero/sorted_merge_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class InputState {
std::unique_ptr<arrow::acero::BackpressureControl> backpressure_control =
std::make_unique<BackpressureController>(input, output, backpressure_counter);
ARROW_ASSIGN_OR_RAISE(auto handler,
BackpressureHandler::Make(input, low_threshold, high_threshold,
BackpressureHandler::Make(low_threshold, high_threshold,
std::move(backpressure_control)));
return PtrType(new InputState(index, std::move(handler), schema, time_col_index));
}
Expand Down
29 changes: 23 additions & 6 deletions cpp/src/arrow/acero/util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,7 @@ class TestBackpressureControl : public BackpressureControl {
TEST(BackpressureConcurrentQueue, BasicTest) {
BackpressureTestExecNode dummy_node;
auto ctrl = std::make_unique<TestBackpressureControl>(&dummy_node);
ASSERT_OK_AND_ASSIGN(auto handler,
BackpressureHandler::Make(&dummy_node, 2, 4, std::move(ctrl)));
ASSERT_OK_AND_ASSIGN(auto handler, BackpressureHandler::Make(2, 4, std::move(ctrl)));
BackpressureConcurrentQueue<int> queue(std::move(handler));

ConcurrentQueueBasicTest(queue);
Expand All @@ -275,8 +274,7 @@ TEST(BackpressureConcurrentQueue, BasicTest) {
TEST(BackpressureConcurrentQueue, BackpressureTest) {
BackpressureTestExecNode dummy_node;
auto ctrl = std::make_unique<TestBackpressureControl>(&dummy_node);
ASSERT_OK_AND_ASSIGN(auto handler,
BackpressureHandler::Make(&dummy_node, 2, 4, std::move(ctrl)));
ASSERT_OK_AND_ASSIGN(auto handler, BackpressureHandler::Make(2, 4, std::move(ctrl)));
BackpressureConcurrentQueue<int> queue(std::move(handler));

queue.Push(6);
Expand All @@ -299,9 +297,28 @@ TEST(BackpressureConcurrentQueue, BackpressureTest) {
queue.Push(11);
ASSERT_TRUE(dummy_node.paused);
ASSERT_FALSE(dummy_node.stopped);
ASSERT_OK(queue.ForceShutdown());
queue.ForceShutdown();
ASSERT_FALSE(dummy_node.paused);
}

TEST(BackpressureConcurrentQueue, BackpressureTestStayUnpaused) {
BackpressureTestExecNode dummy_node;
auto ctrl = std::make_unique<TestBackpressureControl>(&dummy_node);
ASSERT_OK_AND_ASSIGN(
auto handler, BackpressureHandler::Make(/*low_threshold=*/2, /*high_threshold=*/4,
std::move(ctrl)));
BackpressureConcurrentQueue<int> queue(std::move(handler));

queue.Push(6);
queue.Push(7);
queue.Push(8);
ASSERT_FALSE(dummy_node.paused);
ASSERT_FALSE(dummy_node.stopped);
queue.ForceShutdown();
for (int i = 0; i < 10; ++i) {
queue.Push(i);
}
ASSERT_FALSE(dummy_node.paused);
ASSERT_TRUE(dummy_node.stopped);
}

} // namespace acero
Expand Down
Loading