Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-41190: [C++] support for single threaded joins #41125

Merged
merged 11 commits into from
May 29, 2024
21 changes: 5 additions & 16 deletions cpp/src/arrow/acero/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,8 @@ add_arrow_acero_test(hash_join_node_test SOURCES hash_join_node_test.cc
bloom_filter_test.cc)
add_arrow_acero_test(pivot_longer_node_test SOURCES pivot_longer_node_test.cc)

# asof_join_node and sorted_merge_node use std::thread internally
# and doesn't use ThreadPool so it will
# be broken if threading is turned off
if(ARROW_ENABLE_THREADING)
add_arrow_acero_test(asof_join_node_test SOURCES asof_join_node_test.cc)
add_arrow_acero_test(sorted_merge_node_test SOURCES sorted_merge_node_test.cc)
endif()
add_arrow_acero_test(asof_join_node_test SOURCES asof_join_node_test.cc)
add_arrow_acero_test(sorted_merge_node_test SOURCES sorted_merge_node_test.cc)

add_arrow_acero_test(tpch_node_test SOURCES tpch_node_test.cc)
add_arrow_acero_test(union_node_test SOURCES union_node_test.cc)
Expand Down Expand Up @@ -228,9 +223,7 @@ if(ARROW_BUILD_BENCHMARKS)
add_arrow_acero_benchmark(project_benchmark SOURCES benchmark_util.cc
project_benchmark.cc)

if(ARROW_ENABLE_THREADING)
add_arrow_acero_benchmark(asof_join_benchmark SOURCES asof_join_benchmark.cc)
endif()
add_arrow_acero_benchmark(asof_join_benchmark SOURCES asof_join_benchmark.cc)

add_arrow_acero_benchmark(tpch_benchmark SOURCES tpch_benchmark.cc)

Expand All @@ -253,9 +246,7 @@ if(ARROW_BUILD_BENCHMARKS)
target_link_libraries(arrow-acero-expression-benchmark PUBLIC arrow_acero_static)
target_link_libraries(arrow-acero-filter-benchmark PUBLIC arrow_acero_static)
target_link_libraries(arrow-acero-project-benchmark PUBLIC arrow_acero_static)
if(ARROW_ENABLE_THREADING)
target_link_libraries(arrow-acero-asof-join-benchmark PUBLIC arrow_acero_static)
endif()
target_link_libraries(arrow-acero-asof-join-benchmark PUBLIC arrow_acero_static)
target_link_libraries(arrow-acero-tpch-benchmark PUBLIC arrow_acero_static)
if(ARROW_BUILD_OPENMP_BENCHMARKS)
target_link_libraries(arrow-acero-hash-join-benchmark PUBLIC arrow_acero_static)
Expand All @@ -264,9 +255,7 @@ if(ARROW_BUILD_BENCHMARKS)
target_link_libraries(arrow-acero-expression-benchmark PUBLIC arrow_acero_shared)
target_link_libraries(arrow-acero-filter-benchmark PUBLIC arrow_acero_shared)
target_link_libraries(arrow-acero-project-benchmark PUBLIC arrow_acero_shared)
if(ARROW_ENABLE_THREADING)
target_link_libraries(arrow-acero-asof-join-benchmark PUBLIC arrow_acero_shared)
endif()
target_link_libraries(arrow-acero-asof-join-benchmark PUBLIC arrow_acero_shared)
target_link_libraries(arrow-acero-tpch-benchmark PUBLIC arrow_acero_shared)
if(ARROW_BUILD_OPENMP_BENCHMARKS)
target_link_libraries(arrow-acero-hash-join-benchmark PUBLIC arrow_acero_shared)
Expand Down
85 changes: 77 additions & 8 deletions cpp/src/arrow/acero/asof_join_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1014,6 +1014,8 @@ class AsofJoinNode : public ExecNode {
}
}

#ifdef ARROW_ENABLE_THREADING

template <typename Callable>
struct Defer {
Callable callable;
Expand Down Expand Up @@ -1100,6 +1102,7 @@ class AsofJoinNode : public ExecNode {
}

static void ProcessThreadWrapper(AsofJoinNode* node) { node->ProcessThread(); }
#endif

public:
AsofJoinNode(ExecPlan* plan, NodeVector inputs, std::vector<std::string> input_labels,
Expand Down Expand Up @@ -1131,8 +1134,10 @@ class AsofJoinNode : public ExecNode {
}

virtual ~AsofJoinNode() {
process_.Push(false); // poison pill
#ifdef ARROW_ENABLE_THREADING
PushProcess(false);
process_thread_.join();
#endif
}

const std::vector<col_index_t>& indices_of_on_key() { return indices_of_on_key_; }
Expand Down Expand Up @@ -1410,7 +1415,8 @@ class AsofJoinNode : public ExecNode {
rb->ToString(), DEBUG_MANIP(std::endl));

ARROW_RETURN_NOT_OK(state_.at(k)->Push(rb));
process_.Push(true);
PushProcess(true);

return Status::OK();
}

Expand All @@ -1425,31 +1431,88 @@ class AsofJoinNode : public ExecNode {
// The reason for this is that there are cases at the end of a table where we don't
// know whether the RHS of the join is up-to-date until we know that the table is
// finished.
process_.Push(true);
PushProcess(true);

return Status::OK();
}
void PushProcess(bool value) {
#ifdef ARROW_ENABLE_THREADING
process_.Push(value);
#else
if (value) {
ProcessNonThreaded();
} else if (!process_task_.is_finished()) {
EndFromSingleThread();
}
#endif
}

Status StartProducing() override {
#ifndef ARROW_ENABLE_THREADING
return Status::NotImplemented("ASOF join requires threading enabled");
bool ProcessNonThreaded() {
while (!process_task_.is_finished()) {
joemarshall marked this conversation as resolved.
Show resolved Hide resolved
Result<std::shared_ptr<RecordBatch>> result = ProcessInner();

if (result.ok()) {
auto out_rb = *result;
if (!out_rb) break;
ExecBatch out_b(*out_rb);
out_b.index = batches_produced_++;
DEBUG_SYNC(this, "produce batch ", out_b.index, ":", DEBUG_MANIP(std::endl),
out_rb->ToString(), DEBUG_MANIP(std::endl));
Status st = output_->InputReceived(this, std::move(out_b));
if (!st.ok()) {
// this isn't really from a thread,
// but we call through to this for consistency
EndFromSingleThread(std::move(st));
return false;
}
} else {
// this isn't really from a thread,
// but we call through to this for consistency
EndFromSingleThread(result.status());
return false;
}
}
auto& lhs = *state_.at(0);
if (lhs.Finished() && !process_task_.is_finished()) {
EndFromSingleThread(Status::OK());
}
return true;
}

void EndFromSingleThread(Status st = Status::OK()) {
process_task_.MarkFinished(st);
if (st.ok()) {
st = output_->InputFinished(this, batches_produced_);
}
for (const auto& s : state_) {
st &= s->ForceShutdown();
}
}

#endif

Status StartProducing() override {
ARROW_ASSIGN_OR_RAISE(process_task_, plan_->query_context()->BeginExternalTask(
"AsofJoinNode::ProcessThread"));
if (!process_task_.is_valid()) {
// Plan has already aborted. Do not start process thread
return Status::OK();
}
#ifdef ARROW_ENABLE_THREADING
process_thread_ = std::thread(&AsofJoinNode::ProcessThreadWrapper, this);
#endif
return Status::OK();
}

void PauseProducing(ExecNode* output, int32_t counter) override {}
void ResumeProducing(ExecNode* output, int32_t counter) override {}

Status StopProducingImpl() override {
#ifdef ARROW_ENABLE_THREADING
process_.Clear();
process_.Push(false);
#endif
PushProcess(false);
return Status::OK();
}

Expand Down Expand Up @@ -1479,11 +1542,13 @@ class AsofJoinNode : public ExecNode {

// Backpressure counter common to all inputs
std::atomic<int32_t> backpressure_counter_;
#ifdef ARROW_ENABLE_THREADING
// Queue for triggering processing of a given input
// (a false value is a poison pill)
ConcurrentQueue<bool> process_;
// Worker thread
std::thread process_thread_;
#endif
Future<> process_task_;

// In-progress batches produced
Expand Down Expand Up @@ -1511,9 +1576,13 @@ AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector inputs,
debug_os_(join_options.debug_opts ? join_options.debug_opts->os : nullptr),
debug_mutex_(join_options.debug_opts ? join_options.debug_opts->mutex : nullptr),
#endif
backpressure_counter_(1),
backpressure_counter_(1)
#ifdef ARROW_ENABLE_THREADING
,
process_(),
process_thread_() {
process_thread_()
#endif
{
for (auto& key_hasher : key_hashers_) {
key_hasher->node_ = this;
}
Expand Down
52 changes: 43 additions & 9 deletions cpp/src/arrow/acero/sorted_merge_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,19 +262,22 @@ class SortedMergeNode : public ExecNode {
: ExecNode(plan, inputs, GetInputLabels(inputs), std::move(output_schema)),
ordering_(std::move(new_ordering)),
input_counter(inputs_.size()),
output_counter(inputs_.size()),
process_thread() {
output_counter(inputs_.size())
#ifdef ARROW_ENABLE_THREADING
,
process_thread()
#endif
{
SetLabel("sorted_merge");
}

~SortedMergeNode() override {
process_queue.Push(
kPoisonPill); // poison pill
// We might create a temporary (such as to inspect the output
// schema), in which case there isn't anything to join
PushTask(kPoisonPill);
#ifdef ARROW_ENABLE_THREADING
if (process_thread.joinable()) {
process_thread.join();
}
#endif
}

static arrow::Result<arrow::acero::ExecNode*> Make(
Expand Down Expand Up @@ -355,10 +358,25 @@ class SortedMergeNode : public ExecNode {
// InputState's ConcurrentQueue manages locking
input_counter[index] += rb->num_rows();
ARROW_RETURN_NOT_OK(state[index]->Push(rb));
process_queue.Push(kNewTask);
PushTask(kNewTask);
return Status::OK();
}

void PushTask(bool ok) {
#ifdef ARROW_ENABLE_THREADING
process_queue.Push(ok);
#else
if (process_task.is_finished()) {
return;
}
if (ok == kNewTask) {
PollOnce();
} else {
EndFromProcessThread();
}
#endif
}

arrow::Status InputFinished(arrow::acero::ExecNode* input, int total_batches) override {
ARROW_DCHECK(std_has(inputs_, input));
{
Expand All @@ -368,7 +386,8 @@ class SortedMergeNode : public ExecNode {
state.at(k)->set_total_batches(total_batches);
}
// Trigger a final process call for stragglers
process_queue.Push(kNewTask);
PushTask(kNewTask);

return Status::OK();
}

Expand All @@ -379,13 +398,17 @@ class SortedMergeNode : public ExecNode {
// Plan has already aborted. Do not start process thread
return Status::OK();
}
#ifdef ARROW_ENABLE_THREADING
process_thread = std::thread(&SortedMergeNode::StartPoller, this);
#endif
return Status::OK();
}

arrow::Status StopProducingImpl() override {
#ifdef ARROW_ENABLE_THREADING
process_queue.Clear();
process_queue.Push(kPoisonPill);
#endif
PushTask(kPoisonPill);
return Status::OK();
}

Expand All @@ -408,13 +431,20 @@ class SortedMergeNode : public ExecNode {
<< input_counter[i] << " != " << output_counter[i];
}

#ifdef ARROW_ENABLE_THREADING
ARROW_UNUSED(
plan_->query_context()->executor()->Spawn([this, st = std::move(st)]() mutable {
Defer cleanup([this, &st]() { process_task.MarkFinished(st); });
if (st.ok()) {
st = output_->InputFinished(this, batches_produced);
}
}));
#else
process_task.MarkFinished(st);
if (st.ok()) {
st = output_->InputFinished(this, batches_produced);
}
#endif
}

bool CheckEnded() {
Expand Down Expand Up @@ -552,6 +582,7 @@ class SortedMergeNode : public ExecNode {
return true;
}

#ifdef ARROW_ENABLE_THREADING
void EmitBatches() {
while (true) {
// Implementation note: If the queue is empty, we will block here
Expand All @@ -567,6 +598,7 @@ class SortedMergeNode : public ExecNode {

/// The entry point for processThread
static void StartPoller(SortedMergeNode* node) { node->EmitBatches(); }
#endif

arrow::Ordering ordering_;

Expand All @@ -583,11 +615,13 @@ class SortedMergeNode : public ExecNode {

std::atomic<int32_t> batches_produced{0};

#ifdef ARROW_ENABLE_THREADING
// Queue to trigger processing of a given input. False acts as a poison pill
ConcurrentQueue<bool> process_queue;
// Once StartProducing is called, we initialize this thread to poll the
// input states and emit batches
std::thread process_thread;
#endif
arrow::Future<> process_task;

// Map arg index --> completion counter
Expand Down
Loading